WeDLM-CFM-11B-Base / modeling_wedlm.py
win10's picture
Upload folder using huggingface_hub
866fe09 verified
# coding=utf-8
# Copyright 2024 The WeDLM team and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch WeDLM model."""
from typing import Optional, Tuple, Union, Dict, List, Callable
import math
import torch
from torch import nn
import torch.nn.functional as F
from transformers import PreTrainedModel, GenerationMixin
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
from transformers.utils.generic import check_model_inputs
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
# Import attention-related utilities
try:
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
except ImportError:
FlashAttentionKwargs = dict
from .configuration_wedlm import WeDLMConfig
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
# ============================================================================
# Flow Matching / Rectified Flow helpers
# ============================================================================
class WeDLMFlowTimeEmbedding(nn.Module):
"""Sinusoidal timestep embedding + MLP, used to condition Flow Matching / Rectified Flow.
The module is intentionally lightweight and conditions the velocity field on continuous timesteps.
Timesteps are assumed to be normalized to [0, 1] (float). Internally, a configurable scale is applied
before sinusoidal features are computed.
"""
def __init__(self, config: WeDLMConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.time_embed_dim = int(getattr(config, "flow_time_embedding_dim", 256))
self.max_period = int(getattr(config, "flow_time_embedding_max_period", 10000))
self.time_scale = float(getattr(config, "flow_time_scale", 1000.0))
self.linear_1 = nn.Linear(self.time_embed_dim, self.hidden_size)
self.act = nn.SiLU()
self.linear_2 = nn.Linear(self.hidden_size, self.hidden_size)
@staticmethod
def _sinusoidal_embedding(timesteps: torch.Tensor, dim: int, max_period: int) -> torch.Tensor:
"""Create sinusoidal timestep embeddings.
timesteps: (batch,) float tensor.
Returns: (batch, dim) float tensor.
"""
if timesteps.ndim != 1:
timesteps = timesteps.view(-1)
half = dim // 2
device = timesteps.device
dtype = torch.float32
freqs = torch.exp(
-math.log(max_period) * torch.arange(0, half, device=device, dtype=dtype) / max(half, 1)
)
args = timesteps.to(dtype)[:, None] * freqs[None]
emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2 == 1:
emb = torch.cat([emb, torch.zeros((emb.shape[0], 1), device=device, dtype=dtype)], dim=-1)
return emb
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
# timesteps expected in [0,1]; scale to a more typical diffusion timestep range.
t = timesteps.to(dtype=torch.float32)
if t.ndim == 0:
t = t[None]
if t.ndim != 1:
t = t.view(-1)
t = t.clamp(0.0, 1.0) * self.time_scale
emb = self._sinusoidal_embedding(t, self.time_embed_dim, self.max_period)
# NOTE: In bf16/fp16 training (and especially under PEFT/LoRA), the wrapped Linear's base weights
# can be low-precision (e.g. bfloat16) while the sinusoidal features are float32.
# Torch's F.linear requires the input and weight dtypes to match, so we explicitly cast here
# to the *base* layer's weight dtype.
base_linear_1 = getattr(self.linear_1, "base_layer", self.linear_1)
w1 = getattr(base_linear_1, "weight", None)
if w1 is not None:
emb = emb.to(dtype=w1.dtype)
emb = self.linear_1(emb)
emb = self.act(emb)
base_linear_2 = getattr(self.linear_2, "base_layer", self.linear_2)
w2 = getattr(base_linear_2, "weight", None)
if w2 is not None:
emb = emb.to(dtype=w2.dtype)
emb = self.linear_2(emb)
return emb
# ============================================================================
# ============================================================================
# Core Components (self-contained, no Qwen2 dependency)
# ============================================================================
class WeDLMMLP(nn.Module):
"""WeDLM MLP module with SwiGLU activation."""
def __init__(self, config: WeDLMConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x: torch.Tensor) -> torch.Tensor:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class WeDLMRMSNorm(nn.Module):
"""WeDLM RMSNorm, equivalent to T5LayerNorm."""
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self) -> str:
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class WeDLMRotaryEmbedding(nn.Module):
"""WeDLM Rotary Position Embedding."""
def __init__(self, config: WeDLMConfig, device=None):
super().__init__()
# Determine rope_type from config
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type", "default"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
# Get initialization function
if self.rope_type == "default":
inv_freq, self.attention_scaling = self._compute_default_rope_parameters(config, device)
else:
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@staticmethod
def _compute_default_rope_parameters(
config: WeDLMConfig,
device: Optional[torch.device] = None,
) -> Tuple[torch.Tensor, float]:
"""
Computes the inverse frequencies for default RoPE.
Args:
config: Model configuration
device: Device to place the tensors on
Returns:
Tuple of (inv_freq tensor, attention_scaling factor)
"""
base = config.rope_theta
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
# Compute the inverse frequencies
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
attention_factor = 1.0
return inv_freq, attention_factor
@torch.no_grad()
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute rotary position embeddings.
Args:
x: Input tensor, used for dtype and device
position_ids: Position indices
Returns:
Tuple of (cos, sin) tensors
"""
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
# Force float32 computation for numerical stability
with torch.amp.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# ============================================================================
# Attention Utilities
# ============================================================================
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
unsqueeze_dim: int = 1
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Applies Rotary Position Embedding to the query and key tensors."""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
Repeats key/value heads to match the number of query heads (for GQA).
Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Eager (standard) attention implementation."""
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# ============================================================================
# Attention Layer
# ============================================================================
class WeDLMAttention(nn.Module):
"""
WeDLM Attention module.
Supports both:
- Qwen2.5 style: with QKV bias, no QK Norm
- Qwen3 style: configurable QKV bias, with QK Norm
"""
def __init__(self, config: WeDLMConfig, layer_idx: int):
super().__init__()
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim ** -0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
# Support configurable attention_bias (Qwen2.5: True, Qwen3: False by default)
attention_bias = getattr(config, "attention_bias", True)
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=attention_bias)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=attention_bias)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=attention_bias)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
# Support optional QK Norm (Qwen3 feature)
self.qk_norm = getattr(config, "qk_norm", False)
if self.qk_norm:
self.q_norm = WeDLMRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = WeDLMRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
if self.qk_norm:
# Qwen3 style: apply norm after projection, before transpose
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
else:
# Qwen2 style: no norm
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
# Select attention implementation
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager" and self.config._attn_implementation in ALL_ATTENTION_FUNCTIONS:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
# ============================================================================
# Decoder Layer
# ============================================================================
class WeDLMDecoderLayer(GradientCheckpointingLayer):
"""WeDLM Decoder Layer with pre-norm architecture."""
def __init__(self, config: WeDLMConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = WeDLMAttention(config=config, layer_idx=layer_idx)
self.mlp = WeDLMMLP(config)
self.input_layernorm = WeDLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = WeDLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attention_type = config.layer_types[layer_idx]
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states: Input tensor of shape `(batch, seq_len, embed_dim)`
attention_mask: Attention mask of size `(batch, sequence_length)`
position_ids: Position indices
past_key_values: Cached past key and value projection states
output_attentions: Whether to return attention weights
use_cache: Whether to use KV cache
cache_position: Position in the cache
position_embeddings: Tuple of (cos, sin) for rotary embeddings
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
# Feed Forward
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
# ============================================================================
# Model Classes
# ============================================================================
@auto_docstring
class WeDLMPreTrainedModel(PreTrainedModel):
"""Base class for WeDLM models."""
config_class = WeDLMConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["WeDLMDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": WeDLMDecoderLayer,
"attentions": WeDLMAttention,
}
@auto_docstring
class WeDLMModel(WeDLMPreTrainedModel):
"""
WeDLM base model outputting raw hidden states.
"""
def __init__(self, config: WeDLMConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[WeDLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = WeDLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = WeDLMRotaryEmbedding(config=config)
self.gradient_checkpointing = False
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@check_model_inputs
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# Prepare attention masks
if not isinstance(causal_mask_mapping := attention_mask, dict):
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
}
if self.has_sliding_layers:
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
hidden_states = inputs_embeds
# Create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# Decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
@auto_docstring
class WeDLMForCausalLM(WeDLMPreTrainedModel, GenerationMixin):
"""
WeDLM Model for Flow-Matching language modeling (Rectified Flow in token-embedding space).
- Training (`labels` provided): optimizes a Flow Matching objective on a selected subset of token positions.
Large-vocabulary projection (`lm_head`) is skipped by default during training for lower cost.
- Inference (no `labels`): behaves like a standard causal LM (returns logits).
- Fast decoding: use `generate_wedlm` (Flow-Matching block decoding).
"""
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: WeDLMConfig):
super().__init__(config)
self.model = WeDLMModel(config)
self.vocab_size = config.vocab_size
# Token discretization head (used for inference / evaluation / discretization at the end of each flow block)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Flow Matching modules
self.flow_time_embed = WeDLMFlowTimeEmbedding(config)
self.flow_head = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
# ----------------------------
# Embedding plumbing
# ----------------------------
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# ----------------------------
# Flow Matching utilities
# ----------------------------
def _select_flow_targets(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor],
labels: Optional[torch.LongTensor],
flow_target_mask: Optional[torch.BoolTensor],
) -> torch.BoolTensor:
"""
Determine which token positions participate in Flow Matching loss.
Priority:
1) `flow_target_mask` argument
2) `config.flow_train_strategy`
"""
bsz, seq_len = input_ids.shape
device = input_ids.device
if attention_mask is not None:
valid = attention_mask.to(dtype=torch.bool, device=device)
else:
# Best-effort fallback: treat non-pad as valid.
pad_id = getattr(self.config, "pad_token_id", None)
if pad_id is None:
valid = torch.ones((bsz, seq_len), dtype=torch.bool, device=device)
else:
valid = input_ids.ne(pad_id)
if labels is not None:
valid = valid & labels.ne(-100)
if flow_target_mask is not None:
target = flow_target_mask.to(dtype=torch.bool, device=device)
return target & valid
strategy = str(getattr(self.config, "flow_train_strategy", "suffix_block")).lower()
min_targets = int(getattr(self.config, "flow_train_min_target_tokens", 1))
if strategy == "random":
ratio = float(getattr(self.config, "flow_train_mask_ratio", 0.15))
# Sample only from valid positions; guarantee at least `min_targets` if possible.
rand = torch.rand((bsz, seq_len), device=device)
target = (rand < ratio) & valid
if min_targets > 0:
for b in range(bsz):
if valid[b].any() and target[b].sum().item() < min_targets:
valid_idx = valid[b].nonzero(as_tuple=True)[0]
# Select the last positions (deterministic tie-break) to fill up.
need = min(min_targets - target[b].sum().item(), valid_idx.numel())
if need > 0:
target[b, valid_idx[-need:]] = True
return target
# Default: suffix_block
block_size = int(getattr(self.config, "flow_train_block_size", 64))
target = torch.zeros((bsz, seq_len), dtype=torch.bool, device=device)
for b in range(bsz):
valid_idx = valid[b].nonzero(as_tuple=True)[0]
if valid_idx.numel() == 0:
continue
# Do not target the very first valid token by default (no context); if needed, user can pass flow_target_mask.
if valid_idx.numel() == 1:
continue
# Suffix contiguous block among valid positions.
k = min(block_size, valid_idx.numel() - 1)
k = max(k, min_targets)
k = min(k, valid_idx.numel() - 1) # ensure at least one context token remains
if k <= 0:
continue
target_pos = valid_idx[-k:]
target[b, target_pos] = True
return target
def _normalize_timesteps(
self,
timesteps: Optional[torch.FloatTensor],
target_mask: torch.BoolTensor,
) -> torch.FloatTensor:
"""
Returns per-token normalized timesteps t in [0, 1], shape (bsz, seq_len).
"""
device = target_mask.device
bsz, seq_len = target_mask.shape
if timesteps is None:
t = torch.rand((bsz, seq_len), device=device, dtype=torch.float32)
# Non-target positions: t=1.0 (data endpoint) so that any accidental use is benign.
t = torch.where(target_mask, t, torch.ones_like(t))
return t
t_in = timesteps.to(device=device, dtype=torch.float32)
if t_in.ndim == 0:
t = t_in.view(1, 1).expand(bsz, seq_len)
elif t_in.ndim == 1:
if t_in.shape[0] == 1 and bsz > 1:
t = t_in.view(1, 1).expand(bsz, seq_len)
elif t_in.shape[0] == bsz:
t = t_in.view(bsz, 1).expand(bsz, seq_len)
else:
raise ValueError(
f"flow_timesteps must be scalar, shape (bsz,), or shape (bsz, seq_len); got {tuple(t_in.shape)}"
)
elif t_in.ndim == 2:
if t_in.shape != (bsz, seq_len):
raise ValueError(
f"flow_timesteps must have shape (bsz, seq_len) == {(bsz, seq_len)}; got {tuple(t_in.shape)}"
)
t = t_in
else:
raise ValueError(
f"flow_timesteps must be scalar, 1D, or 2D; got ndim={t_in.ndim} with shape {tuple(t_in.shape)}"
)
# Clamp into [0,1] for numerical safety.
t = torch.clamp(t, 0.0, 1.0)
t = torch.where(target_mask, t, torch.ones_like(t))
return t
def _build_flow_inputs(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor],
labels: Optional[torch.LongTensor],
flow_target_mask: Optional[torch.BoolTensor],
flow_timesteps: Optional[torch.FloatTensor],
flow_noise: Optional[torch.FloatTensor],
) -> Tuple[torch.FloatTensor, torch.BoolTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""
Prepare (inputs_embeds, target_mask, t, noise, clean_embeds) for Flow Matching training.
"""
clean_embeds = self.model.embed_tokens(input_ids)
bsz, seq_len, hidden = clean_embeds.shape
device = clean_embeds.device
target_mask = self._select_flow_targets(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
flow_target_mask=flow_target_mask,
)
t = self._normalize_timesteps(flow_timesteps, target_mask=target_mask)
sigma = float(getattr(self.config, "flow_init_sigma", 1.0))
if flow_noise is None:
noise = torch.randn_like(clean_embeds, dtype=torch.float32) * sigma
else:
noise = flow_noise.to(device=device, dtype=torch.float32)
if noise.shape != clean_embeds.shape:
raise ValueError(
f"flow_noise must have the same shape as token embeddings {tuple(clean_embeds.shape)}; got {tuple(noise.shape)}"
)
# Rectified Flow path: X_t = (1 - t) * X_0 + t * X_1
# Here X_0 is noise, X_1 is data (token embeddings).
t_exp = t.unsqueeze(-1)
x_t = (1.0 - t_exp) * noise + t_exp * clean_embeds.to(dtype=torch.float32)
# Time conditioning is provided by adding a learned timestep embedding to the *input embeddings*
# for flow-target positions (so the Transformer can use t).
inputs_embeds = clean_embeds.to(dtype=torch.float32)
if target_mask.any():
# Compute time embedding only for targets to reduce overhead.
t_flat = t[target_mask].reshape(-1)
time_cond = self.flow_time_embed(t_flat).to(dtype=inputs_embeds.dtype)
inputs_embeds = inputs_embeds.clone()
inputs_embeds[target_mask] = x_t[target_mask] + time_cond
else:
inputs_embeds = inputs_embeds.clone()
return inputs_embeds.to(dtype=clean_embeds.dtype), target_mask, t, noise, clean_embeds
# ----------------------------
# Decoding: Flow Matching block decoding
# ----------------------------
def _top_k_top_p_filtering(
self,
logits: torch.Tensor,
top_k: int = 0,
top_p: float = 1.0,
filter_value: float = -float("inf"),
) -> torch.Tensor:
"""Apply top-k and/or nucleus (top-p) filtering to logits."""
if top_k is not None and top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1).values[..., -1, None]
logits = logits.masked_fill(indices_to_remove, filter_value)
if top_p is not None and 0.0 < top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, filter_value)
return logits
def _sample_from_logits(
self,
logits: torch.Tensor,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 0,
) -> torch.Tensor:
"""Sample token IDs from logits with temperature + (top-k, top-p) filtering."""
if temperature is None or temperature <= 0:
temperature = 1.0
logits = logits / float(temperature)
logits = self._top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1).squeeze(-1)
@torch.no_grad()
def generate_wedlm(
self,
input_ids: torch.LongTensor,
max_new_tokens: int,
block_size: Optional[int] = None,
num_steps: Optional[int] = None,
flow_init_sigma: Optional[float] = None,
discretization: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
return_stats: bool = True,
**kwargs,
) -> Union[torch.LongTensor, Dict]:
"""
Flow-Matching block decoding.
Generates `block_size` tokens per block using `num_steps` Euler steps, and runs the vocabulary projection
(`lm_head`) only once per block (final discretization).
"""
device = input_ids.device
if pad_token_id is None:
pad_token_id = self.config.pad_token_id
if eos_token_id is None:
eos_token_id = getattr(self.config, "eos_token_id", None)
if block_size is None:
block_size = int(getattr(self.config, "flow_block_size", 64))
if num_steps is None:
num_steps = int(getattr(self.config, "flow_inference_steps", 8))
if flow_init_sigma is None:
flow_init_sigma = float(getattr(self.config, "flow_init_sigma", 1.0))
if discretization is None:
discretization = str(getattr(self.config, "flow_discretization", "argmax")).lower()
if temperature is None:
temperature = float(getattr(self.config, "flow_temperature", 1.0))
if top_p is None:
top_p = float(getattr(self.config, "flow_top_p", 1.0))
if top_k is None:
top_k = int(getattr(self.config, "flow_top_k", 0))
batch_size = input_ids.shape[0]
all_generated: List[torch.Tensor] = []
all_sample_stats: List[Dict] = []
num_blocks = (max_new_tokens + block_size - 1) // block_size
for batch_idx in range(batch_size):
sample_ids = input_ids[batch_idx]
if pad_token_id is not None:
pad_mask = sample_ids.ne(pad_token_id)
if pad_mask.any():
valid_length = int(pad_mask.sum().item())
prefix_ids = sample_ids[:valid_length]
else:
prefix_ids = sample_ids
else:
prefix_ids = sample_ids
prefix_ids = prefix_ids.clone()
prefix_length = prefix_ids.shape[0]
sample_stats = {
"input_length": prefix_length,
"num_blocks": num_blocks,
"block_size": block_size,
"num_steps": num_steps,
"sigma": float(flow_init_sigma),
"generated_tokens": 0,
"blocks": [],
}
current_ids = prefix_ids
for block_idx in range(num_blocks):
remaining = max_new_tokens - block_idx * block_size
cur_block = min(block_size, remaining)
if cur_block <= 0:
break
# State variable: current embedding estimates for the block (initialized from Gaussian noise)
x = torch.randn((cur_block, self.config.hidden_size), device=device, dtype=torch.float32) * float(flow_init_sigma)
dt = 1.0 / float(num_steps)
# Euler integration from t=0 (noise) to t=1 (data)
for step in range(num_steps):
t = float(step) / float(num_steps)
# Create per-token timesteps matching training behavior (all same t for this step)
t_tensor = torch.full((cur_block,), t, device=device, dtype=torch.float32)
# Build embeddings in *model dtype* to avoid dtype mismatch in Linear layers when AMP is off.
context_embeds = self.model.embed_tokens(current_ids.unsqueeze(0))
ctx_dtype = context_embeds.dtype
# Per-token time conditioning (cur_block,) -> (cur_block, H)
t_cond = self.flow_time_embed(t_tensor).to(dtype=ctx_dtype) # (cur_block, H)
# Context tokens are discrete; block tokens are continuous (x) + time conditioning.
block_embeds = x.to(dtype=ctx_dtype) + t_cond # (cur_block, H)
block_embeds = block_embeds.view(1, cur_block, -1) # (1, cur_block, H)
inputs_embeds = torch.cat([context_embeds, block_embeds], dim=1)
seq_len = inputs_embeds.shape[1]
attention_mask = torch.ones((1, seq_len), dtype=torch.long, device=device)
position_ids = torch.arange(seq_len, dtype=torch.long, device=device).unsqueeze(0)
outputs = self.model(
input_ids=None,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False,
return_dict=True,
)
h_new = outputs.last_hidden_state[:, -cur_block:, :] # (1, cur_block, H)
# Predict velocity in a dtype-compatible way, then accumulate in fp32.
h_in = h_new
base_flow = getattr(self.flow_head, "base_layer", self.flow_head)
w_flow = getattr(base_flow, "weight", None)
if w_flow is not None:
h_in = h_in.to(dtype=w_flow.dtype)
v = self.flow_head(h_in).to(dtype=torch.float32).squeeze(0) # (cur_block, H)
x = x + dt * v
# Discretize final embeddings into token IDs
base_lm = getattr(self.lm_head, "base_layer", self.lm_head)
w_lm = getattr(base_lm, "weight", None)
x_lm = x
if w_lm is not None:
x_lm = x_lm.to(dtype=w_lm.dtype)
logits = self.lm_head(x_lm).to(dtype=torch.float32) # (cur_block, vocab)
if discretization == "sample":
next_ids = self._sample_from_logits(logits, temperature=temperature, top_p=top_p, top_k=top_k)
else:
next_ids = torch.argmax(logits, dim=-1)
# Optional early stop on EOS within the block
if eos_token_id is not None:
eos_positions = (next_ids == eos_token_id).nonzero(as_tuple=True)[0]
if eos_positions.numel() > 0:
cut = int(eos_positions[0].item()) + 1
next_ids = next_ids[:cut]
current_ids = torch.cat([current_ids, next_ids.to(dtype=torch.long)], dim=0)
sample_stats["generated_tokens"] += int(next_ids.numel())
sample_stats["blocks"].append(
{
"block_idx": block_idx,
"target_block_size": int(cur_block),
"actual_block_tokens": int(next_ids.numel()),
}
)
if eos_token_id is not None and next_ids.numel() > 0 and next_ids[-1].item() == int(eos_token_id):
break
sample_stats["output_length"] = int(current_ids.numel())
all_generated.append(current_ids)
all_sample_stats.append(sample_stats)
# Pad to max length
max_len = max(seq.numel() for seq in all_generated) if all_generated else 0
padded = []
for seq in all_generated:
if seq.numel() < max_len:
pad = torch.full(
(max_len - seq.numel(),),
int(pad_token_id) if pad_token_id is not None else 0,
dtype=torch.long,
device=device,
)
seq = torch.cat([seq, pad], dim=0)
padded.append(seq)
sequences = torch.stack(padded, dim=0) if padded else torch.empty((0, 0), dtype=torch.long, device=device)
if not return_stats:
return sequences
total_steps = int(num_steps) * int(num_blocks) * int(batch_size)
return {
"sequences": sequences,
"stats": {
"batch_size": int(batch_size),
"max_new_tokens": int(max_new_tokens),
"block_size": int(block_size),
"num_steps": int(num_steps),
"discretization": discretization,
"temperature": float(temperature),
"top_p": float(top_p),
"top_k": int(top_k),
"total_flow_evals": total_steps,
"per_sample_stats": all_sample_stats,
},
}
# ----------------------------
# Generate (override to use Flow Matching by default)
# ----------------------------
@torch.no_grad()
def generate(
self,
input_ids: Optional[torch.LongTensor] = None,
generation_config=None,
max_new_tokens: Optional[int] = None,
max_length: Optional[int] = None,
do_sample: Optional[bool] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
use_flow_matching: bool = True,
block_size: Optional[int] = None,
num_steps: Optional[int] = None,
flow_init_sigma: Optional[float] = None,
streamer=None,
**kwargs,
):
"""
Override generate() to use Flow Matching decoding by default.
Set `use_flow_matching=False` to fall back to standard AR generation.
"""
# Fall back to standard AR generation if requested
if not use_flow_matching:
return super().generate(
input_ids=input_ids,
generation_config=generation_config,
max_new_tokens=max_new_tokens,
max_length=max_length,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
top_k=top_k,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
streamer=streamer,
**kwargs,
)
# Extract parameters from generation_config if provided
if generation_config is not None:
if max_new_tokens is None:
max_new_tokens = getattr(generation_config, "max_new_tokens", None)
if max_length is None:
max_length = getattr(generation_config, "max_length", None)
if do_sample is None:
do_sample = getattr(generation_config, "do_sample", None)
if temperature is None:
temperature = getattr(generation_config, "temperature", None)
if top_p is None:
top_p = getattr(generation_config, "top_p", None)
if top_k is None:
top_k = getattr(generation_config, "top_k", None)
if pad_token_id is None:
pad_token_id = getattr(generation_config, "pad_token_id", None)
if eos_token_id is None:
eos_token_id = getattr(generation_config, "eos_token_id", None)
# Determine max_new_tokens
if max_new_tokens is None:
if max_length is not None and input_ids is not None:
max_new_tokens = max_length - input_ids.shape[1]
else:
max_new_tokens = 256 # Default
max_new_tokens = max(1, max_new_tokens)
# Map do_sample to discretization
discretization = "sample" if do_sample else "argmax"
# Use config defaults for flow parameters
if block_size is None:
block_size = getattr(self.config, "flow_block_size", 64)
if num_steps is None:
num_steps = getattr(self.config, "flow_inference_steps", 8)
if flow_init_sigma is None:
flow_init_sigma = getattr(self.config, "flow_init_sigma", 1.0)
if temperature is None:
temperature = getattr(self.config, "flow_temperature", 1.0)
if top_p is None:
top_p = getattr(self.config, "flow_top_p", 1.0)
if top_k is None:
top_k = getattr(self.config, "flow_top_k", 0)
if pad_token_id is None:
pad_token_id = self.config.pad_token_id
if eos_token_id is None:
eos_token_id = self.config.eos_token_id
# Call Flow Matching generation
result = self.generate_wedlm(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
block_size=block_size,
num_steps=num_steps,
flow_init_sigma=flow_init_sigma,
discretization=discretization,
temperature=temperature,
top_p=top_p,
top_k=top_k,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
return_stats=False,
)
# Handle streamer if provided (basic support)
if streamer is not None:
for token_id in result[0, input_ids.shape[1]:]:
streamer.put(token_id.unsqueeze(0).unsqueeze(0))
streamer.end()
return result
# ----------------------------
# Forward
# ----------------------------
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
# Flow Matching controls (training-time)
flow_target_mask: Optional[torch.BoolTensor] = None,
flow_timesteps: Optional[torch.FloatTensor] = None,
flow_noise: Optional[torch.FloatTensor] = None,
return_logits: bool = False,
**kwargs: Unpack[TransformersKwargs],
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
When `labels` is provided: computes Flow Matching loss.
Otherwise: returns logits like a standard causal LM.
Args:
flow_target_mask (`torch.BoolTensor`, *optional*):
Boolean mask indicating which positions are flow targets.
flow_timesteps (`torch.FloatTensor`, *optional*):
Timesteps for flow matching, in range [0, 1].
flow_noise (`torch.FloatTensor`, *optional*):
Noise tensor for flow matching interpolation.
return_logits (`bool`, defaults to `False`):
If True, also compute vocabulary logits during Flow-Matching training.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# ------------------------------------------------------------
# Flow Matching training path (primary)
# ------------------------------------------------------------
if labels is not None:
if input_ids is None:
raise ValueError("Flow-Matching training requires input_ids (token IDs) when labels is provided.")
if inputs_embeds is not None:
raise ValueError("Do not pass inputs_embeds when training with labels; Flow-Matching builds embeds internally.")
inputs_embeds, target_mask, t, noise, clean_embeds = self._build_flow_inputs(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
flow_target_mask=flow_target_mask,
flow_timesteps=flow_timesteps,
flow_noise=flow_noise,
)
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
loss = None
logits = None
if target_mask.any():
# Target velocity: d/dt X_t = X_1 - X_0 (data - noise) on straight-line path.
v_target = (clean_embeds.to(dtype=torch.float32) - noise.to(dtype=torch.float32))[target_mask]
# PEFT/LoRA can wrap `flow_head` and keep base weights in bf16/fp16.
# Ensure dtype alignment for the Linear matmul.
hs = hidden_states[target_mask]
base_flow = getattr(self.flow_head, "base_layer", self.flow_head)
w_flow = getattr(base_flow, "weight", None)
if w_flow is not None:
hs = hs.to(dtype=w_flow.dtype)
v_pred = self.flow_head(hs).to(dtype=torch.float32)
flow_loss = F.mse_loss(v_pred, v_target, reduction="mean")
w = float(getattr(self.config, "flow_loss_weight", 1.0))
loss = w * flow_loss
else:
# No valid targets -> zero loss (avoid NaNs).
loss = hidden_states.new_tensor(0.0)
if return_logits:
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
if not return_dict:
output = (logits,) + (outputs.past_key_values, outputs.hidden_states, outputs.attentions)
return (loss,) + output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# ------------------------------------------------------------
# Standard causal LM path (no labels): logits for evaluation / AR generation
# ------------------------------------------------------------
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
if not return_dict:
output = (logits,) + (outputs.past_key_values, outputs.hidden_states, outputs.attentions)
return output
return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
**kwargs,
):
if past_key_values is not None:
if inputs_embeds is not None:
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if inputs_embeds is not None and cache_position is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, DynamicCache) and attention_mask is not None and attention_mask.ndim == 2:
model_inputs["cache_position"] = cache_position
model_inputs["past_key_values"] = past_key_values
model_inputs["use_cache"] = use_cache
model_inputs["position_ids"] = position_ids
model_inputs["attention_mask"] = attention_mask
return model_inputs
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs
__all__ = [
"WeDLMConfig",
"WeDLMPreTrainedModel",
"WeDLMModel",
"WeDLMForCausalLM",
]