# Copyright (C) 2025 Arcee AI # SPDX-License-Identifier: LGPL-3.0-only from typing import ClassVar, List, Optional from pydantic import BaseModel from transformers import PretrainedConfig from mergekit.architecture.base import ( ModuleArchitecture, WeightInfo, ) from mergekit.architecture.json_definitions import NAME_TO_ARCH MISTRAL_INFO = NAME_TO_ARCH["MistralForCausalLM"][0] MISTRAL_MODULE_ARCH = MISTRAL_INFO.modules["default"].architecture class MixtralModuleArchitecture(ModuleArchitecture, BaseModel): ARCHITECTURE_NAME: ClassVar[str] = "MixtralForCausalLM" num_local_experts: int def name(self) -> str: return "mixtral" @classmethod def from_config(cls, config: PretrainedConfig): return MixtralModuleArchitecture(num_local_experts=config.num_local_experts) def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: return MISTRAL_MODULE_ARCH.pre_weights(config) def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: return MISTRAL_MODULE_ARCH.post_weights(config) def num_layers_config_key(self) -> str: return MISTRAL_MODULE_ARCH.num_layers_config_key() def layer_weights( self, index: int, config: PretrainedConfig ) -> Optional[List[WeightInfo]]: num_experts = self.num_local_experts prefix = f"model.layers.{index}" tensor_names = [] for expert_idx in range(num_experts): for param in ("w1", "w2", "w3"): tensor_names.append( prefix + f".block_sparse_moe.experts.{expert_idx}.{param}.weight" ) tensor_names.append(prefix + ".block_sparse_moe.gate.weight") res = [] for name in tensor_names: res.append(WeightInfo(name=name)) for weight_info in MISTRAL_MODULE_ARCH.layer_weights(index, config): if ".mlp." in weight_info.name: continue res.append(weight_info) return res QWEN3_INFO = NAME_TO_ARCH["Qwen3ForCausalLM"][0] QWEN3_MODULE_ARCH = QWEN3_INFO.modules["default"].architecture class Qwen3MoeModuleArchitecture(ModuleArchitecture, BaseModel): ARCHITECTURE_NAME: ClassVar[str] = "Qwen3MoeForCausalLM" num_experts: int def name(self) -> str: return "qwen3_moe" @classmethod def from_config(cls, config: PretrainedConfig): return Qwen3MoeModuleArchitecture(num_experts=config.num_experts) def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: return QWEN3_MODULE_ARCH.pre_weights(config) def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: return QWEN3_MODULE_ARCH.post_weights(config) def num_layers_config_key(self) -> str: return QWEN3_MODULE_ARCH.num_layers_config_key() def layer_weights( self, index: int, config: PretrainedConfig ) -> Optional[List[WeightInfo]]: prefix = f"model.layers.{index}" tensor_names = [] for expert_idx in range(self.num_experts): for param in ("up_proj", "gate_proj", "down_proj"): tensor_names.append( prefix + f".mlp.experts.{expert_idx}.{param}.weight" ) tensor_names.append(prefix + ".mlp.gate.weight") res = [] for name in tensor_names: res.append(WeightInfo(name=name)) for weight_info in QWEN3_MODULE_ARCH.layer_weights(index, config): if ".mlp." in weight_info.name: continue res.append(weight_info) return res AFMOE_PARTIAL_INFO = NAME_TO_ARCH["_AfmoePartialForCausalLM"][0] AFMOE_PARTIAL_MODULE_ARCH = AFMOE_PARTIAL_INFO.modules["default"].architecture class AfmoeModuleArchitecture(ModuleArchitecture, BaseModel): ARCHITECTURE_NAME: ClassVar[str] = "AfmoeForCausalLM" num_experts: int def name(self) -> str: return "afmoe" @classmethod def from_config(cls, config: PretrainedConfig): return AfmoeModuleArchitecture(num_experts=config.num_experts) def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: return AFMOE_PARTIAL_MODULE_ARCH.pre_weights(config) def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: return AFMOE_PARTIAL_MODULE_ARCH.post_weights(config) def num_layers_config_key(self) -> str: return AFMOE_PARTIAL_MODULE_ARCH.num_layers_config_key() def layer_weights( self, index: int, config: PretrainedConfig ) -> Optional[List[WeightInfo]]: res = AFMOE_PARTIAL_MODULE_ARCH.layer_weights(index, config) or [] prefix = f"model.layers.{index}" for expert_idx in range(self.num_experts): for param in ("up_proj", "gate_proj", "down_proj"): res.append( WeightInfo( name=prefix + f".mlp.experts.{expert_idx}.{param}.weight", optional=True, ) ) return res # Add this to moe_defs.py # 1. Get the base Llama info from the registry LLAMA_INFO = NAME_TO_ARCH["LlamaForCausalLM"][0] LLAMA_MODULE_ARCH = LLAMA_INFO.modules["default"].architecture class LlamaMoeModuleArchitecture(ModuleArchitecture, BaseModel): # This is the name that will appear in the output config.json ARCHITECTURE_NAME: ClassVar[str] = "LlamaMoeForCausalLM" num_experts: int def name(self) -> str: return "llama_moe" @classmethod def from_config(cls, config: PretrainedConfig): # This looks for the 'num_experts' key in the model's config return LlamaMoeModuleArchitecture(num_experts=getattr(config, "num_experts", 8)) def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: # Uses standard Llama embeddings/norms return LLAMA_MODULE_ARCH.pre_weights(config) def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: # Uses standard Llama final norm/head return LLAMA_MODULE_ARCH.post_weights(config) def num_layers_config_key(self) -> str: return LLAMA_MODULE_ARCH.num_layers_config_key() def layer_weights(self, index: int, config: PretrainedConfig) -> Optional[List[WeightInfo]]: prefix = f"model.layers.{index}" res = [] # 2. Define the Expert weights # We map the dense MLP layers into an expert array for expert_idx in range(self.num_experts): for param in ("gate_proj", "up_proj", "down_proj"): res.append( WeightInfo(name=prefix + f".block_sparse_moe.experts.{expert_idx}.{param}.weight") ) # 3. Define the Router (Gate) weight res.append(WeightInfo(name=prefix + ".block_sparse_moe.gate.weight")) # 4. Add the non-MLP weights (Attention layers, Input Norms) # We skip the original .mlp. weights because we replaced them with experts for weight_info in LLAMA_MODULE_ARCH.layer_weights(index, config): if ".mlp." not in weight_info.name: res.append(weight_info) return res