import torch import tqdm import transformers from mergekit.moe.arch import MoEOutputArchitecture from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype from mergekit.moe.config import MoEMergeConfig from mergekit.options import MergeOptions from mergekit.architecture import arch_info_for_config class LlamaMoE(MoEOutputArchitecture): def name(self) -> str: return "LlamaMoE" def supports_config(self, config: MoEMergeConfig, explain: bool = False, trust_remote_code: bool = False) -> bool: # Ensure the base model is a Llama model model_cfg = config.base_model.config(trust_remote_code=trust_remote_code) if model_cfg.model_type != "llama": if explain: print("LlamaMoE only supports Llama base models") return False return True def write_model(self, out_path: str, config: MoEMergeConfig, merge_options: MergeOptions, router_weights: list[torch.Tensor], shared_router_weights=None): base_model = config.base_model base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code) # 1. Generate the config.json out_cfg = base_cfg.to_dict() # Note: Most Llama MoEs use the Mixtral architecture name for compatibility with loaders out_cfg["architectures"] = ["MixtralForCausalLM"] out_cfg["num_local_experts"] = len(config.experts) out_cfg["num_experts_per_tok"] = config.experts_per_token out_dtype = select_dtype(config, base_cfg) # 2. Initialize IO loaders, base_loader, writer = initialize_io(config, out_path, merge_options) # 3. Map Tensors for weight_info in tqdm.tqdm(arch_info_for_config(base_cfg).all_weights(base_cfg), desc="Weights"): tensor_name = weight_info.name if ".mlp." in tensor_name: for expert_idx, expert in enumerate(config.experts): # Map Llama's gate_proj/up_proj/down_proj to Mixtral's w1/w3/w2 expert_name = tensor_name.replace(".mlp.gate_proj", f".block_sparse_moe.experts.{expert_idx}.w1") expert_name = expert_name.replace(".mlp.down_proj", f".block_sparse_moe.experts.{expert_idx}.w2") expert_name = expert_name.replace(".mlp.up_proj", f".block_sparse_moe.experts.{expert_idx}.w3") expert_loader = loaders.get(expert.source_model) copy_tensor_out(weight_info, expert_loader, writer, expert=expert, output_name=expert_name, out_dtype=out_dtype) else: # Copy Attention and Norms from base model copy_tensor_out(weight_info, base_loader, writer, out_dtype=out_dtype) # 4. Write Router Weights for layer_idx, weight in enumerate(router_weights): writer.save_tensor(f"model.layers.{layer_idx}.block_sparse_moe.gate.weight", weight.to(dtype=out_dtype)) writer.finalize()