Charm_15 / model.safetensors
GeminiFan207's picture
Update model.safetensors
ced2d25 verified
raw
history blame
15.5 kB
import torch
from safetensors.torch import save_file, load_file
from typing import Dict, Optional, Tuple, List, Union, Any
import logging
import time
import json
import yaml
import os
from pathlib import Path
import sys
import shutil
from dataclasses import dataclass, asdict
import numpy as np
from tqdm import tqdm
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor, as_completed
import hashlib
from torch.nn.init import xavier_uniform_, kaiming_uniform_
# Configure logging with rotation and detailed output
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - [%(processName)s:%(threadName)s] - %(message)s",
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler("transformer_shard_builder.log", mode="a")
]
)
@dataclass
class ModelConfig:
"""Configuration for transformer model parameters and sharding."""
num_layers: int = 48
hidden_size: int = 8192
heads: int = 64
seq_length: int = 4096
vocab_size: int = 50000
dtype: str = "float16"
ffn_multiplier: int = 4
total_shards: int = 278
base_path: str = "model_shards"
device: str = "cuda" if torch.cuda.is_available() else "cpu"
seed: Optional[int] = 42
init_method: str = "xavier" # Options: "xavier", "kaiming", "normal"
shard_compression: bool = True
validation_threshold: float = 1e-5
class TransformerShardBuilder:
"""Advanced class to build, shard, validate, and save a large transformer model."""
def __init__(self, config: Optional[ModelConfig] = None):
"""Initialize with configuration and setup environment."""
self.config = config or ModelConfig()
self.dtype = getattr(torch, self.config.dtype)
self.device = torch.device(self.config.device)
self.base_path = Path(self.config.base_path)
self.weights: Dict[int, Dict[str, torch.Tensor]] = {} # Shard-indexed weights
self.metadata: Dict[str, Any] = {}
self._validate_config()
self._setup_environment()
self._calculate_sharding()
def _validate_config(self) -> None:
"""Validate configuration parameters."""
checks = [
(self.config.num_layers > 0, "Number of layers must be positive"),
(self.config.hidden_size % self.config.heads == 0, "Hidden size must be divisible by heads"),
(self.config.seq_length > 0, "Sequence length must be positive"),
(self.config.vocab_size > 0, "Vocab size must be positive"),
(self.config.total_shards > 0, "Total shards must be positive"),
(self.config.ffn_multiplier > 1, "FFN multiplier must be greater than 1"),
(self.config.init_method in ["xavier", "kaiming", "normal"], "Invalid initialization method")
]
for condition, message in checks:
if not condition:
raise ValueError(message)
if self.config.num_layers < self.config.total_shards:
raise ValueError("Number of layers must be >= total shards")
def _setup_environment(self) -> None:
"""Setup random seed, device, and directories."""
if self.config.seed is not None:
torch.manual_seed(self.config.seed)
np.random.seed(self.config.seed)
self.base_path.mkdir(parents=True, exist_ok=True)
logging.info(f"Environment setup: device={self.device}, base_path={self.base_path}")
if self.device.type == "cuda":
logging.info(f"CUDA Memory: {torch.cuda.memory_available() / 1024**3:.2f} GB free")
def _calculate_sharding(self) -> None:
"""Calculate layer distribution across shards."""
self.layers_per_shard = self.config.num_layers // self.config.total_shards
self.remaining_layers = self.config.num_layers % self.config.total_shards
logging.info(f"Sharding: {self.layers_per_shard} layers/shard, {self.remaining_layers} extra")
def _initialize_tensor(self, *shape) -> torch.Tensor:
"""Initialize tensor based on configured method."""
tensor = torch.empty(*shape, dtype=self.dtype, device=self.device)
if self.config.init_method == "xavier":
if len(shape) > 1:
xavier_uniform_(tensor)
else:
tensor.normal_(0, 1.0 / self.config.hidden_size ** 0.5)
elif self.config.init_method == "kaiming":
if len(shape) > 1:
kaiming_uniform_(tensor, a=0, mode="fan_in", nonlinearity="relu")
else:
tensor.normal_(0, 1.0 / self.config.hidden_size ** 0.5)
else: # normal
tensor.normal_(0, 1.0 / self.config.hidden_size ** 0.5)
return tensor
def _create_attention_block(self, layer_idx: int) -> Dict[str, torch.Tensor]:
"""Create attention weights for a layer."""
weights = {}
prefix = f"layer_{layer_idx}.attention"
head_dim = self.config.hidden_size // self.config.heads
for name in ["query_weight", "key_weight", "value_weight", "output_weight"]:
weights[f"{prefix}.{name}"] = self._initialize_tensor(self.config.hidden_size, self.config.hidden_size)
weights[f"{prefix}.{name}_bias"] = torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device)
weights[f"{prefix}.head_scale"] = torch.ones(self.config.heads, head_dim, dtype=self.dtype, device=self.device)
return weights
def _create_ffn_block(self, layer_idx: int) -> Dict[str, torch.Tensor]:
"""Create FFN weights for a layer."""
weights = {}
prefix = f"layer_{layer_idx}.ffn"
intermediate_size = self.config.hidden_size * self.config.ffn_multiplier
weights[f"{prefix}.intermediate_weight"] = self._initialize_tensor(self.config.hidden_size, intermediate_size)
weights[f"{prefix}.intermediate_bias"] = torch.zeros(intermediate_size, dtype=self.dtype, device=self.device)
weights[f"{prefix}.output_weight"] = self._initialize_tensor(intermediate_size, self.config.hidden_size)
weights[f"{prefix}.output_bias"] = torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device)
return weights
def _create_norm_block(self, layer_idx: int) -> Dict[str, torch.Tensor]:
"""Create normalization weights."""
prefix = f"layer_{layer_idx}"
return {
f"{prefix}.norm_1_weight": torch.ones(self.config.hidden_size, dtype=self.dtype, device=self.device),
f"{prefix}.norm_1_bias": torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device),
f"{prefix}.norm_2_weight": torch.ones(self.config.hidden_size, dtype=self.dtype, device=self.device),
f"{prefix}.norm_2_bias": torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device)
}
def _create_embedding_output(self) -> Dict[str, torch.Tensor]:
"""Create embedding and output layers for first shard."""
weights = {
"embedding.word_embeddings": self._initialize_tensor(self.config.vocab_size, self.config.hidden_size),
"embedding.position_embeddings": self._initialize_tensor(self.config.seq_length, self.config.hidden_size),
"embedding.token_type_embeddings": self._initialize_tensor(self.config.seq_length, self.config.hidden_size),
"output_layer.weight": self._initialize_tensor(self.config.hidden_size, self.config.vocab_size),
"output_layer.bias": torch.zeros(self.config.vocab_size, dtype=self.dtype, device=self.device)
}
return weights
def build_shard(self, shard_idx: int) -> Dict[str, torch.Tensor]:
"""Build weights for a specific shard."""
weights = {}
start_time = time.time()
start_layer = (shard_idx - 1) * self.layers_per_shard
end_layer = start_layer + self.layers_per_shard
if shard_idx == self.config.total_shards:
end_layer += self.remaining_layers
for i in tqdm(range(start_layer, end_layer), desc=f"Shard {shard_idx} layers"):
weights.update(self._create_attention_block(i))
weights.update(self._create_ffn_block(i))
weights.update(self._create_norm_block(i))
if shard_idx == 1:
weights.update(self._create_embedding_output())
elapsed = time.time() - start_time
self.metadata[f"shard_{shard_idx}"] = {"build_time": elapsed, "num_layers": end_layer - start_layer}
logging.info(f"Shard {shard_idx} built with {len(weights)} tensors in {elapsed:.2f}s")
return weights
def save_shard(self, shard_idx: int, weights: Dict[str, torch.Tensor]) -> None:
"""Save a single shard with metadata."""
shard_path = self.base_path / f"model_{shard_idx}_of_{self.config.total_shards}.safetensors"
start_time = time.time()
try:
shard_metadata = {
"shard_idx": shard_idx,
"total_shards": self.config.total_shards,
"config": asdict(self.config),
**self.metadata.get(f"shard_{shard_idx}", {})
}
save_file(weights, str(shard_path), metadata=shard_metadata)
elapsed = time.time() - start_time
logging.info(f"Shard {shard_idx} saved to {shard_path} in {elapsed:.2f}s")
except Exception as e:
logging.error(f"Shard {shard_idx} save failed: {str(e)}")
raise RuntimeError(f"Failed to save shard {shard_idx}: {str(e)}") from e
def build_and_save_all_shards(self, parallel: bool = True) -> None:
"""Build and save all shards, optionally in parallel."""
start_time = time.time()
if parallel and mp.cpu_count() > 1:
with ThreadPoolExecutor(max_workers=min(mp.cpu_count(), self.config.total_shards)) as executor:
futures = {
executor.submit(self.build_shard, i): i
for i in range(1, self.config.total_shards + 1)
}
for future in as_completed(futures):
shard_idx = futures[future]
try:
weights = future.result()
self.save_shard(shard_idx, weights)
except Exception as e:
logging.error(f"Parallel shard {shard_idx} failed: {str(e)}")
else:
for shard_idx in tqdm(range(1, self.config.total_shards + 1), desc="Building shards"):
weights = self.build_shard(shard_idx)
self.save_shard(shard_idx, weights)
total_time = time.time() - start_time
self.metadata["total_build_time"] = total_time
logging.info(f"All {self.config.total_shards} shards completed in {total_time:.2f}s")
def validate_shard(self, shard_idx: int) -> bool:
"""Validate a shard's weights after loading."""
shard_path = self.base_path / f"model_{shard_idx}_of_{self.config.total_shards}.safetensors"
try:
weights = load_file(str(shard_path), device="cpu") # Load to CPU for validation
all_valid = True
for name, tensor in weights.items():
if torch.isnan(tensor).any() or torch.isinf(tensor).any():
logging.warning(f"Invalid values in {name} (shard {shard_idx})")
all_valid = False
elif torch.max(torch.abs(tensor)) > self.config.validation_threshold:
logging.warning(f"Large values in {name} (shard {shard_idx})")
return all_valid
except Exception as e:
logging.error(f"Validation failed for shard {shard_idx}: {str(e)}")
return False
def compute_checksum(self, shard_idx: int) -> str:
"""Compute SHA256 checksum of a shard file."""
shard_path = self.base_path / f"model_{shard_idx}_of_{self.config.total_shards}.safetensors"
sha256 = hashlib.sha256()
with open(shard_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
sha256.update(chunk)
return sha256.hexdigest()
def export_metadata(self, output_path: str | Path = "model_metadata.json") -> None:
"""Export metadata to JSON file."""
output_path = Path(output_path)
with open(output_path, "w") as f:
json.dump(self.metadata, f, indent=2)
logging.info(f"Metadata exported to {output_path}")
@classmethod
def from_yaml(cls, yaml_path: str | Path) -> "TransformerShardBuilder":
"""Initialize from YAML config file."""
with open(yaml_path, "r") as f:
config_dict = yaml.safe_load(f)
return cls(ModelConfig(**config_dict))
def estimate_model_size(config: ModelConfig) -> Tuple[int, float]:
"""Estimate total model size in parameters and GB."""
builder = TransformerShardBuilder(config)
params = 0
bytes_size = 0
for shard in range(1, config.total_shards + 1):
weights = builder.build_shard(shard)
params += sum(t.numel() for t in weights.values())
bytes_size += sum(t.element_size() * t.numel() for t in weights.values())
return params, bytes_size / 1024**3
def main():
"""Main execution flow with comprehensive functionality."""
try:
# Custom configuration
config = ModelConfig(
num_layers=48,
hidden_size=8192,
heads=64,
seq_length=4096,
vocab_size=50000,
total_shards=278,
base_path="model_shards_large"
)
builder = TransformerShardBuilder(config)
# Size estimation
num_params, size_gb = estimate_model_size(config)
logging.info(f"Estimated size: {num_params:,} parameters, {size_gb:.2f} GB")
# Build and save all shards
builder.build_and_save_all_shards(parallel=True)
# Validate all shards
logging.info("Validating shards...")
for shard in tqdm(range(1, config.total_shards + 1), desc="Validating"):
if builder.validate_shard(shard):
checksum = builder.compute_checksum(shard)
logging.info(f"Shard {shard} validated, checksum: {checksum[:8]}...")
else:
logging.warning(f"Shard {shard} validation failed")
# Export metadata
builder.export_metadata()
return 0
except Exception as e:
logging.error(f"Execution failed: {str(e)}")
return 1
if __name__ == "__main__":
sys.exit(main())
def main():
"""Main execution flow with size estimation and validation."""
try:
# Default configuration
config = ModelConfig()
builder = TransformerModelBuilder(config)
# Estimate size
num_params, size_gb = estimate_model_size(config)
logging.info(f"Estimated model size: {num_params:,} parameters, {size_gb:.2f} GB")
# Build and save
weights = builder.build_model()
if builder.validate_model(weights):
logging.info("Model validation passed")
builder.save_model()
else:
logging.warning("Model validation failed")
return 1
return 0
except Exception as e:
logging.error(f"Execution failed: {str(e)}")
return 1
if __name__ == "__main__":
sys.exit(main())