import math from typing import Tuple import torch import torch.nn as nn from transformers import GenerationMixin, PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from .configuration_qmoe import QMoEConfig class RMSNorm(nn.Module): def __init__(self, d_model: int, eps: float = 1e-6): super().__init__() self.eps = eps self.scale = nn.Parameter(torch.ones(d_model, dtype=torch.float32)) def forward(self, x: torch.Tensor) -> torch.Tensor: rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt() return (x / rms) * self.scale class DenseNoBias(nn.Module): def __init__(self, in_features: int, out_features: int): super().__init__() self.kernel = nn.Parameter(torch.empty(in_features, out_features, dtype=torch.float32)) nn.init.normal_(self.kernel, std=0.02) def forward(self, x: torch.Tensor) -> torch.Tensor: return x @ self.kernel def causal_mask(t: int, *, device: torch.device) -> torch.Tensor: return torch.tril(torch.ones((t, t), dtype=torch.bool, device=device)) class MultiHeadAttention(nn.Module): def __init__(self, d_model: int, num_heads: int): super().__init__() if d_model % num_heads != 0: raise ValueError('d_model must be divisible by num_heads') self.d_model = d_model self.num_heads = num_heads self.head_dim = d_model // num_heads self.q_proj = DenseNoBias(d_model, d_model) self.k_proj = DenseNoBias(d_model, d_model) self.v_proj = DenseNoBias(d_model, d_model) self.out_proj = DenseNoBias(d_model, d_model) def forward(self, x: torch.Tensor, *, attn_mask: torch.Tensor) -> torch.Tensor: b, t, d = x.shape q = self.q_proj(x).view(b, t, self.num_heads, self.head_dim) k = self.k_proj(x).view(b, t, self.num_heads, self.head_dim) v = self.v_proj(x).view(b, t, self.num_heads, self.head_dim) scale = 1.0 / math.sqrt(self.head_dim) att = torch.einsum('bthd,bshd->bhts', q, k) * scale att = att.masked_fill(~attn_mask.view(1, 1, t, t), -1e30) att = torch.softmax(att, dim=-1) out = torch.einsum('bhts,bshd->bthd', att, v).contiguous() out = out.view(b, t, d) return self.out_proj(out) class Router(nn.Module): def __init__(self, d_model: int, num_experts: int, top_k: int): super().__init__() self.num_experts = num_experts self.top_k = top_k self.gate = DenseNoBias(d_model, num_experts) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: logits = self.gate(x) probs = torch.softmax(logits, dim=-1) topk_vals, topk_idx = torch.topk(probs, k=self.top_k, dim=-1) denom = topk_vals.sum(dim=-1, keepdim=True).clamp_min(1e-6) gates = topk_vals / denom return topk_idx, gates class ExpertMLPBank(nn.Module): def __init__(self, d_model: int, hidden_dim: int, num_experts: int): super().__init__() self.w1 = nn.Parameter(torch.empty(num_experts, d_model, hidden_dim, dtype=torch.float32)) self.b1 = nn.Parameter(torch.zeros(num_experts, hidden_dim, dtype=torch.float32)) self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, d_model, dtype=torch.float32)) self.b2 = nn.Parameter(torch.zeros(num_experts, d_model, dtype=torch.float32)) nn.init.normal_(self.w1, std=0.02) nn.init.normal_(self.w2, std=0.02) def forward(self, x: torch.Tensor, expert_idx: torch.Tensor) -> torch.Tensor: w1 = self.w1.index_select(0, expert_idx) b1 = self.b1.index_select(0, expert_idx) w2 = self.w2.index_select(0, expert_idx) b2 = self.b2.index_select(0, expert_idx) h = torch.einsum('nd,ndh->nh', x, w1) + b1 h = torch.nn.functional.silu(h) y = torch.einsum('nh,nhd->nd', h, w2) + b2 return y class MoEFeedForward(nn.Module): def __init__(self, d_model: int, hidden_dim: int, num_experts: int, top_k: int): super().__init__() self.router = Router(d_model=d_model, num_experts=num_experts, top_k=top_k) self.experts = ExpertMLPBank(d_model=d_model, hidden_dim=hidden_dim, num_experts=num_experts) self.top_k = top_k def forward(self, x: torch.Tensor) -> torch.Tensor: b, t, d = x.shape topk_idx, gates = self.router(x) x_flat = x.reshape(b * t, d) idx_flat = topk_idx.reshape(b * t, self.top_k) gates_flat = gates.reshape(b * t, self.top_k) y = torch.zeros_like(x_flat) for j in range(self.top_k): e_idx = idx_flat[:, j] y_j = self.experts(x_flat, e_idx) y = y + y_j * gates_flat[:, j : j + 1] return y.reshape(b, t, d) class Block(nn.Module): def __init__(self, d_model: int, num_heads: int, hidden_dim: int, num_experts: int, top_k: int): super().__init__() self.rmsnorm_0 = RMSNorm(d_model) self.attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads) self.rmsnorm_1 = RMSNorm(d_model) self.moe = MoEFeedForward(d_model=d_model, hidden_dim=hidden_dim, num_experts=num_experts, top_k=top_k) def forward(self, x: torch.Tensor, *, attn_mask: torch.Tensor) -> torch.Tensor: h = self.rmsnorm_0(x) x = x + self.attn(h, attn_mask=attn_mask) h = self.rmsnorm_1(x) x = x + self.moe(h) return x class QMoEForCausalLM(PreTrainedModel, GenerationMixin): config_class = QMoEConfig main_input_name = 'input_ids' def __init__(self, config: QMoEConfig): super().__init__(config) self.tok_emb = nn.Embedding(config.vocab_size, config.d_model) self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model) self.blocks = nn.ModuleList([Block(config.d_model, config.num_heads, config.ffn_dim, config.num_experts, config.moe_top_k) for _ in range(config.num_layers)]) self.rmsnorm_f = RMSNorm(config.d_model) self.lm_head = DenseNoBias(config.d_model, config.vocab_size) self.post_init() def get_input_embeddings(self): return self.tok_emb def set_input_embeddings(self, value): self.tok_emb = value def get_output_embeddings(self): return self.lm_head def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs): return {'input_ids': input_ids, 'attention_mask': attention_mask} def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): if input_ids is None: raise ValueError('input_ids is required') b, t = input_ids.shape device = input_ids.device tok = self.tok_emb(input_ids) pos_idx = torch.arange(t, device=device).unsqueeze(0) pos = self.pos_emb(pos_idx) x = tok + pos attn_mask = causal_mask(t, device=device) for blk in self.blocks: x = blk(x, attn_mask=attn_mask) x = self.rmsnorm_f(x) logits = self.lm_head(x) loss = None if labels is not None: shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous() loss = torch.nn.functional.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100) return CausalLMOutputWithCrossAttentions(logits=logits, loss=loss)