| | import torch |
| | import safetensors.torch |
| | import concurrent.futures |
| | import zlib |
| | import logging |
| | from typing import Dict, Tuple |
| | from pathlib import Path |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format="%(asctime)s - %(levelname)s - %(message)s", |
| | handlers=[logging.StreamHandler()] |
| | ) |
| |
|
| | class AdvancedModelParameters: |
| | def __init__(self, num_shards=2089, base_filename="charm15", hidden_size=16384, layers_per_shard=100): |
| | """Initialize model parameters for a massive transformer model.""" |
| | self.num_shards = num_shards |
| | self.base_filename = base_filename |
| | self.hidden_size = hidden_size |
| | self.layers_per_shard = layers_per_shard |
| | self.ffn_multiplier = 4 |
| | self.shape = (hidden_size, hidden_size) |
| | self.dtype = torch.float16 |
| | self.base_path = Path("model_shards") |
| | self.base_path.mkdir(parents=True, exist_ok=True) |
| |
|
| | def generate_layer_parameters(self, layer_idx: int) -> Dict[str, torch.Tensor]: |
| | """Generate parameters for a single transformer layer.""" |
| | params = {} |
| | prefix = f"layer_{layer_idx}" |
| | |
| | |
| | for name in ["query_weight", "key_weight", "value_weight", "output_weight"]: |
| | params[f"{prefix}.attention.{name}"] = torch.randn( |
| | self.shape, dtype=self.dtype |
| | ) * (1.0 / self.hidden_size ** 0.5) |
| | |
| | |
| | intermediate_size = self.hidden_size * self.ffn_multiplier |
| | params[f"{prefix}.ffn.intermediate_weight"] = torch.randn( |
| | self.hidden_size, intermediate_size, dtype=self.dtype |
| | ) * (1.0 / self.hidden_size ** 0.5) |
| | params[f"{prefix}.ffn.output_weight"] = torch.randn( |
| | intermediate_size, self.hidden_size, dtype=self.dtype |
| | ) * (1.0 / intermediate_size ** 0.5) |
| | |
| | return params |
| |
|
| | def generate_shard_parameters(self, shard_index: int) -> Dict[str, torch.Tensor]: |
| | """Generate parameters for a single shard.""" |
| | params = {} |
| | start_layer = (shard_index - 1) * self.layers_per_shard |
| | end_layer = start_layer + self.layers_per_shard |
| | |
| | |
| | for layer_idx in range(start_layer, end_layer): |
| | params.update(self.generate_layer_parameters(layer_idx)) |
| | |
| | |
| | if shard_index == 1: |
| | params["embedding.word_embeddings"] = torch.randn( |
| | 50000, self.hidden_size, dtype=self.dtype |
| | ) * (1.0 / self.hidden_size ** 0.5) |
| | params["embedding.position_embeddings"] = torch.randn( |
| | 4096, self.hidden_size, dtype=self.dtype |
| | ) * (1.0 / self.hidden_size ** 0.5) |
| | params["output_layer"] = torch.randn( |
| | self.hidden_size, 50000, dtype=self.dtype |
| | ) * (1.0 / self.hidden_size ** 0.5) |
| | |
| | return params |
| |
|
| | def compress_tensor(self, tensor: torch.Tensor) -> bytes: |
| | """Apply zlib compression to tensor data.""" |
| | tensor_bytes = tensor.numpy().tobytes() |
| | return zlib.compress(tensor_bytes, level=9) |
| |
|
| | def save_single_shard(self, shard_index: int) -> None: |
| | """Save a single model shard with compression.""" |
| | params = self.generate_shard_parameters(shard_index) |
| | filename = self.base_path / f"{self.base_filename}_{shard_index}_of_{self.num_shards}.safetensors" |
| | |
| | |
| | compressed_data = {key: self.compress_tensor(value) for key, value in params.items()} |
| | |
| | |
| | metadata = { |
| | "shard_index": shard_index, |
| | "total_shards": self.num_shards, |
| | "layers": self.layers_per_shard, |
| | "hidden_size": self.hidden_size |
| | } |
| | safetensors.torch.save_file(compressed_data, str(filename), metadata=metadata) |
| | logging.info(f"[✔] Shard {shard_index}/{self.num_shards} saved: {filename}") |
| |
|
| | def save_sharded_parameters(self) -> None: |
| | """Save all shards in parallel.""" |
| | logging.info(f"Starting to save {self.num_shards} shards...") |
| | with concurrent.futures.ThreadPoolExecutor() as executor: |
| | executor.map(self.save_single_shard, range(1, self.num_shards + 1)) |
| | logging.info("All shards saved successfully.") |
| |
|
| | def estimate_parameters(self) -> Tuple[int, float]: |
| | """Estimate total parameters and memory usage.""" |
| | params_per_layer = ( |
| | 4 * (self.hidden_size * self.hidden_size) + |
| | self.hidden_size * (self.hidden_size * self.ffn_multiplier) + |
| | (self.hidden_size * self.ffn_multiplier) * self.hidden_size |
| | ) |
| | params_per_shard = params_per_layer * self.layers_per_shard |
| | total_params = params_per_shard * self.num_shards |
| | |
| | |
| | total_params += ( |
| | 50000 * self.hidden_size + |
| | 4096 * self.hidden_size + |
| | self.hidden_size * 50000 |
| | ) |
| | |
| | memory_gb = (total_params * 2) / 1024**3 |
| | return total_params, memory_gb |
| |
|
| | def main(): |
| | """Main execution flow.""" |
| | model_storage = AdvancedModelParameters( |
| | num_shards=2089, |
| | base_filename="charm15", |
| | hidden_size=16384, |
| | layers_per_shard=100 |
| | ) |
| | |
| | |
| | total_params, memory_gb = model_storage.estimate_parameters() |
| | logging.info(f"Estimated total parameters: {total_params:,}") |
| | logging.info(f"Estimated memory usage: {memory_gb:.2f} GB") |
| | |
| | |
| | model_storage.save_sharded_parameters() |
| |
|
| | if __name__ == "__main__": |
| | main() |