| import torch |
| from safetensors.torch import save_file, load_file |
| from typing import Dict, Optional, Tuple, List, Union, Any |
| import logging |
| import time |
| import json |
| import yaml |
| import os |
| from pathlib import Path |
| import sys |
| import shutil |
| from dataclasses import dataclass, asdict |
| import numpy as np |
| from tqdm import tqdm |
| import multiprocessing as mp |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| import hashlib |
| from torch.nn.init import xavier_uniform_, kaiming_uniform_ |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s - %(levelname)s - [%(processName)s:%(threadName)s] - %(message)s", |
| handlers=[ |
| logging.StreamHandler(sys.stdout), |
| logging.FileHandler("transformer_shard_builder.log", mode="a") |
| ] |
| ) |
|
|
| @dataclass |
| class ModelConfig: |
| """Configuration for transformer model parameters and sharding.""" |
| num_layers: int = 48 |
| hidden_size: int = 8192 |
| heads: int = 64 |
| seq_length: int = 4096 |
| vocab_size: int = 50000 |
| dtype: str = "float16" |
| ffn_multiplier: int = 4 |
| total_shards: int = 278 |
| base_path: str = "model_shards" |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" |
| seed: Optional[int] = 42 |
| init_method: str = "xavier" |
| shard_compression: bool = True |
| validation_threshold: float = 1e-5 |
|
|
| class TransformerShardBuilder: |
| """Advanced class to build, shard, validate, and save a large transformer model.""" |
|
|
| def __init__(self, config: Optional[ModelConfig] = None): |
| """Initialize with configuration and setup environment.""" |
| self.config = config or ModelConfig() |
| self.dtype = getattr(torch, self.config.dtype) |
| self.device = torch.device(self.config.device) |
| self.base_path = Path(self.config.base_path) |
| self.weights: Dict[int, Dict[str, torch.Tensor]] = {} |
| self.metadata: Dict[str, Any] = {} |
| |
| self._validate_config() |
| self._setup_environment() |
| self._calculate_sharding() |
|
|
| def _validate_config(self) -> None: |
| """Validate configuration parameters.""" |
| checks = [ |
| (self.config.num_layers > 0, "Number of layers must be positive"), |
| (self.config.hidden_size % self.config.heads == 0, "Hidden size must be divisible by heads"), |
| (self.config.seq_length > 0, "Sequence length must be positive"), |
| (self.config.vocab_size > 0, "Vocab size must be positive"), |
| (self.config.total_shards > 0, "Total shards must be positive"), |
| (self.config.ffn_multiplier > 1, "FFN multiplier must be greater than 1"), |
| (self.config.init_method in ["xavier", "kaiming", "normal"], "Invalid initialization method") |
| ] |
| for condition, message in checks: |
| if not condition: |
| raise ValueError(message) |
| if self.config.num_layers < self.config.total_shards: |
| raise ValueError("Number of layers must be >= total shards") |
|
|
| def _setup_environment(self) -> None: |
| """Setup random seed, device, and directories.""" |
| if self.config.seed is not None: |
| torch.manual_seed(self.config.seed) |
| np.random.seed(self.config.seed) |
| self.base_path.mkdir(parents=True, exist_ok=True) |
| logging.info(f"Environment setup: device={self.device}, base_path={self.base_path}") |
| if self.device.type == "cuda": |
| logging.info(f"CUDA Memory: {torch.cuda.memory_available() / 1024**3:.2f} GB free") |
|
|
| def _calculate_sharding(self) -> None: |
| """Calculate layer distribution across shards.""" |
| self.layers_per_shard = self.config.num_layers // self.config.total_shards |
| self.remaining_layers = self.config.num_layers % self.config.total_shards |
| logging.info(f"Sharding: {self.layers_per_shard} layers/shard, {self.remaining_layers} extra") |
|
|
| def _initialize_tensor(self, *shape) -> torch.Tensor: |
| """Initialize tensor based on configured method.""" |
| tensor = torch.empty(*shape, dtype=self.dtype, device=self.device) |
| if self.config.init_method == "xavier": |
| if len(shape) > 1: |
| xavier_uniform_(tensor) |
| else: |
| tensor.normal_(0, 1.0 / self.config.hidden_size ** 0.5) |
| elif self.config.init_method == "kaiming": |
| if len(shape) > 1: |
| kaiming_uniform_(tensor, a=0, mode="fan_in", nonlinearity="relu") |
| else: |
| tensor.normal_(0, 1.0 / self.config.hidden_size ** 0.5) |
| else: |
| tensor.normal_(0, 1.0 / self.config.hidden_size ** 0.5) |
| return tensor |
|
|
| def _create_attention_block(self, layer_idx: int) -> Dict[str, torch.Tensor]: |
| """Create attention weights for a layer.""" |
| weights = {} |
| prefix = f"layer_{layer_idx}.attention" |
| head_dim = self.config.hidden_size // self.config.heads |
| |
| for name in ["query_weight", "key_weight", "value_weight", "output_weight"]: |
| weights[f"{prefix}.{name}"] = self._initialize_tensor(self.config.hidden_size, self.config.hidden_size) |
| weights[f"{prefix}.{name}_bias"] = torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device) |
| weights[f"{prefix}.head_scale"] = torch.ones(self.config.heads, head_dim, dtype=self.dtype, device=self.device) |
| return weights |
|
|
| def _create_ffn_block(self, layer_idx: int) -> Dict[str, torch.Tensor]: |
| """Create FFN weights for a layer.""" |
| weights = {} |
| prefix = f"layer_{layer_idx}.ffn" |
| intermediate_size = self.config.hidden_size * self.config.ffn_multiplier |
| |
| weights[f"{prefix}.intermediate_weight"] = self._initialize_tensor(self.config.hidden_size, intermediate_size) |
| weights[f"{prefix}.intermediate_bias"] = torch.zeros(intermediate_size, dtype=self.dtype, device=self.device) |
| weights[f"{prefix}.output_weight"] = self._initialize_tensor(intermediate_size, self.config.hidden_size) |
| weights[f"{prefix}.output_bias"] = torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device) |
| return weights |
|
|
| def _create_norm_block(self, layer_idx: int) -> Dict[str, torch.Tensor]: |
| """Create normalization weights.""" |
| prefix = f"layer_{layer_idx}" |
| return { |
| f"{prefix}.norm_1_weight": torch.ones(self.config.hidden_size, dtype=self.dtype, device=self.device), |
| f"{prefix}.norm_1_bias": torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device), |
| f"{prefix}.norm_2_weight": torch.ones(self.config.hidden_size, dtype=self.dtype, device=self.device), |
| f"{prefix}.norm_2_bias": torch.zeros(self.config.hidden_size, dtype=self.dtype, device=self.device) |
| } |
|
|
| def _create_embedding_output(self) -> Dict[str, torch.Tensor]: |
| """Create embedding and output layers for first shard.""" |
| weights = { |
| "embedding.word_embeddings": self._initialize_tensor(self.config.vocab_size, self.config.hidden_size), |
| "embedding.position_embeddings": self._initialize_tensor(self.config.seq_length, self.config.hidden_size), |
| "embedding.token_type_embeddings": self._initialize_tensor(self.config.seq_length, self.config.hidden_size), |
| "output_layer.weight": self._initialize_tensor(self.config.hidden_size, self.config.vocab_size), |
| "output_layer.bias": torch.zeros(self.config.vocab_size, dtype=self.dtype, device=self.device) |
| } |
| return weights |
|
|
| def build_shard(self, shard_idx: int) -> Dict[str, torch.Tensor]: |
| """Build weights for a specific shard.""" |
| weights = {} |
| start_time = time.time() |
| |
| start_layer = (shard_idx - 1) * self.layers_per_shard |
| end_layer = start_layer + self.layers_per_shard |
| if shard_idx == self.config.total_shards: |
| end_layer += self.remaining_layers |
|
|
| for i in tqdm(range(start_layer, end_layer), desc=f"Shard {shard_idx} layers"): |
| weights.update(self._create_attention_block(i)) |
| weights.update(self._create_ffn_block(i)) |
| weights.update(self._create_norm_block(i)) |
|
|
| if shard_idx == 1: |
| weights.update(self._create_embedding_output()) |
|
|
| elapsed = time.time() - start_time |
| self.metadata[f"shard_{shard_idx}"] = {"build_time": elapsed, "num_layers": end_layer - start_layer} |
| logging.info(f"Shard {shard_idx} built with {len(weights)} tensors in {elapsed:.2f}s") |
| return weights |
|
|
| def save_shard(self, shard_idx: int, weights: Dict[str, torch.Tensor]) -> None: |
| """Save a single shard with metadata.""" |
| shard_path = self.base_path / f"model_{shard_idx}_of_{self.config.total_shards}.safetensors" |
| start_time = time.time() |
| |
| try: |
| shard_metadata = { |
| "shard_idx": shard_idx, |
| "total_shards": self.config.total_shards, |
| "config": asdict(self.config), |
| **self.metadata.get(f"shard_{shard_idx}", {}) |
| } |
| save_file(weights, str(shard_path), metadata=shard_metadata) |
| elapsed = time.time() - start_time |
| logging.info(f"Shard {shard_idx} saved to {shard_path} in {elapsed:.2f}s") |
| except Exception as e: |
| logging.error(f"Shard {shard_idx} save failed: {str(e)}") |
| raise RuntimeError(f"Failed to save shard {shard_idx}: {str(e)}") from e |
|
|
| def build_and_save_all_shards(self, parallel: bool = True) -> None: |
| """Build and save all shards, optionally in parallel.""" |
| start_time = time.time() |
| |
| if parallel and mp.cpu_count() > 1: |
| with ThreadPoolExecutor(max_workers=min(mp.cpu_count(), self.config.total_shards)) as executor: |
| futures = { |
| executor.submit(self.build_shard, i): i |
| for i in range(1, self.config.total_shards + 1) |
| } |
| for future in as_completed(futures): |
| shard_idx = futures[future] |
| try: |
| weights = future.result() |
| self.save_shard(shard_idx, weights) |
| except Exception as e: |
| logging.error(f"Parallel shard {shard_idx} failed: {str(e)}") |
| else: |
| for shard_idx in tqdm(range(1, self.config.total_shards + 1), desc="Building shards"): |
| weights = self.build_shard(shard_idx) |
| self.save_shard(shard_idx, weights) |
|
|
| total_time = time.time() - start_time |
| self.metadata["total_build_time"] = total_time |
| logging.info(f"All {self.config.total_shards} shards completed in {total_time:.2f}s") |
|
|
| def validate_shard(self, shard_idx: int) -> bool: |
| """Validate a shard's weights after loading.""" |
| shard_path = self.base_path / f"model_{shard_idx}_of_{self.config.total_shards}.safetensors" |
| try: |
| weights = load_file(str(shard_path), device="cpu") |
| all_valid = True |
| for name, tensor in weights.items(): |
| if torch.isnan(tensor).any() or torch.isinf(tensor).any(): |
| logging.warning(f"Invalid values in {name} (shard {shard_idx})") |
| all_valid = False |
| elif torch.max(torch.abs(tensor)) > self.config.validation_threshold: |
| logging.warning(f"Large values in {name} (shard {shard_idx})") |
| return all_valid |
| except Exception as e: |
| logging.error(f"Validation failed for shard {shard_idx}: {str(e)}") |
| return False |
|
|
| def compute_checksum(self, shard_idx: int) -> str: |
| """Compute SHA256 checksum of a shard file.""" |
| shard_path = self.base_path / f"model_{shard_idx}_of_{self.config.total_shards}.safetensors" |
| sha256 = hashlib.sha256() |
| with open(shard_path, "rb") as f: |
| for chunk in iter(lambda: f.read(4096), b""): |
| sha256.update(chunk) |
| return sha256.hexdigest() |
|
|
| def export_metadata(self, output_path: str | Path = "model_metadata.json") -> None: |
| """Export metadata to JSON file.""" |
| output_path = Path(output_path) |
| with open(output_path, "w") as f: |
| json.dump(self.metadata, f, indent=2) |
| logging.info(f"Metadata exported to {output_path}") |
|
|
| @classmethod |
| def from_yaml(cls, yaml_path: str | Path) -> "TransformerShardBuilder": |
| """Initialize from YAML config file.""" |
| with open(yaml_path, "r") as f: |
| config_dict = yaml.safe_load(f) |
| return cls(ModelConfig(**config_dict)) |
|
|
| def estimate_model_size(config: ModelConfig) -> Tuple[int, float]: |
| """Estimate total model size in parameters and GB.""" |
| builder = TransformerShardBuilder(config) |
| params = 0 |
| bytes_size = 0 |
| for shard in range(1, config.total_shards + 1): |
| weights = builder.build_shard(shard) |
| params += sum(t.numel() for t in weights.values()) |
| bytes_size += sum(t.element_size() * t.numel() for t in weights.values()) |
| return params, bytes_size / 1024**3 |
|
|
| def main(): |
| """Main execution flow with comprehensive functionality.""" |
| try: |
| |
| config = ModelConfig( |
| num_layers=48, |
| hidden_size=8192, |
| heads=64, |
| seq_length=4096, |
| vocab_size=50000, |
| total_shards=278, |
| base_path="model_shards_large" |
| ) |
| builder = TransformerShardBuilder(config) |
|
|
| |
| num_params, size_gb = estimate_model_size(config) |
| logging.info(f"Estimated size: {num_params:,} parameters, {size_gb:.2f} GB") |
|
|
| |
| builder.build_and_save_all_shards(parallel=True) |
|
|
| |
| logging.info("Validating shards...") |
| for shard in tqdm(range(1, config.total_shards + 1), desc="Validating"): |
| if builder.validate_shard(shard): |
| checksum = builder.compute_checksum(shard) |
| logging.info(f"Shard {shard} validated, checksum: {checksum[:8]}...") |
| else: |
| logging.warning(f"Shard {shard} validation failed") |
|
|
| |
| builder.export_metadata() |
|
|
| return 0 |
| except Exception as e: |
| logging.error(f"Execution failed: {str(e)}") |
| return 1 |
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
| def main(): |
| """Main execution flow with size estimation and validation.""" |
| try: |
| |
| config = ModelConfig() |
| builder = TransformerModelBuilder(config) |
| |
| |
| num_params, size_gb = estimate_model_size(config) |
| logging.info(f"Estimated model size: {num_params:,} parameters, {size_gb:.2f} GB") |
| |
| |
| weights = builder.build_model() |
| if builder.validate_model(weights): |
| logging.info("Model validation passed") |
| builder.save_model() |
| else: |
| logging.warning("Model validation failed") |
| return 1 |
| |
| return 0 |
| |
| except Exception as e: |
| logging.error(f"Execution failed: {str(e)}") |
| return 1 |
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |