"""Backbone components for Mimi models - shared attention transformers.""" import math from typing import Optional, Union import torch from torch import nn from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.masking_utils import create_causal_mask from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from transformers.modeling_layers import GradientCheckpointingLayer from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.utils import logging try: from .configuration_mimi import MimiConfig from .modeling_mimi_clean import ( MimiAttention, MimiMLP, MimiLayerScale, MimiRotaryEmbedding, apply_rotary_pos_emb, MIMI_ATTENTION_CLASSES ) except ImportError: from configuration_mimi import MimiConfig from modeling_mimi_clean import ( MimiAttention, MimiMLP, MimiLayerScale, MimiRotaryEmbedding, apply_rotary_pos_emb, MIMI_ATTENTION_CLASSES ) logger = logging.get_logger(__name__) class CausalAttentionTransformer(nn.Module): """ Standard causal attention transformer (decoder-only) consisting of *config.num_hidden_layers* layers. Each layer is a [`MimiTransformerLayer`] with self-attention only. This is a standard decoder-only transformer architecture for causal language modeling. Args: config: MimiConfig """ def __init__(self, config: MimiConfig): super().__init__() self.layers = nn.ModuleList( [MimiTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self._attn_implementation = config._attn_implementation self.gradient_checkpointing = False self.config = config def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[tuple, BaseModelOutputWithPast]: """ Args: hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Input embeddings or hidden states from previous layer attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - a [`~cache_utils.Cache`] instance; - Tuple of `tuple(torch.FloatTensor)` of length `config.num_hidden_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the legacy cache format will be returned. If `past_key_values` are used, the user can optionally input only the last `hidden_states` of shape `(batch_size, 1, hidden_size)` instead of all `hidden_states` of shape `(batch_size, sequence_length, hidden_size)`. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False if use_cache and not isinstance(past_key_values, Cache): if past_key_values is None: past_key_values = DynamicCache() else: past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" ) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) # Create causal mask for self-attention causal_mask = create_causal_mask( config=self.config, input_embeds=hidden_states, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, position_ids=position_ids, ) # Initialize output containers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) # Add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) class MimiTransformerLayer(GradientCheckpointingLayer): def __init__(self, config: MimiConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = MIMI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.mlp = MimiMLP(config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) self.self_attn_layer_scale = MimiLayerScale(config) self.mlp_layer_scale = MimiLayerScale(config) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence kwargs (`dict`, *optional*): Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = residual + self.self_attn_layer_scale(hidden_states) # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + self.mlp_layer_scale(hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs class CrossAttention(nn.Module): """ Cross-attention layer with monotonic masking for decoder queries attending to encoder outputs. Queries come from decoder, keys and values come from encoder. Supports monotonic attention where each query can only attend to a progressive subset of keys. """ def __init__(self, config: MimiConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = config.head_dim self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True # Causal for queries, but can attend to all encoder positions self.scaling = 1 / math.sqrt(config.head_dim) if self.hidden_size % self.num_heads != 0: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) # Query projection for decoder hidden states self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) # Key and value projections for encoder hidden states self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) # Rotary embeddings only for queries (decoder positions) self.rotary_emb = MimiRotaryEmbedding(config) def forward( self, hidden_states: torch.Tensor, # Decoder hidden states (queries) encoder_hidden_states: torch.Tensor, # Encoder hidden states (keys, values) attention_mask: Optional[torch.Tensor] = None, # Mask for encoder positions position_ids: Optional[torch.LongTensor] = None, # Decoder position IDs past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, alignment_chunk_sizes: Optional[torch.Tensor] = None, # Monotonic attention chunk sizes ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() _, kv_len, _ = encoder_hidden_states.size() # Queries from decoder query_states = self.q_proj(hidden_states) # Keys and values from encoder key_states = self.k_proj(encoder_hidden_states) value_states = self.v_proj(encoder_hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, kv_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, kv_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # Apply rotary embeddings only to queries (decoder positions) if position_ids is not None: cos, sin = self.rotary_emb(value_states, position_ids) query_states, _ = apply_rotary_pos_emb(query_states, query_states, cos, sin) if past_key_value is not None: # For cross attention, we typically cache encoder keys/values cache_kwargs = {"sin": sin if position_ids is not None else None, "cos": cos if position_ids is not None else None, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling # Apply monotonic attention mask if alignment_chunk_sizes is provided if alignment_chunk_sizes is not None: monotonic_mask = _create_monotonic_attention_mask( alignment_chunk_sizes=alignment_chunk_sizes, query_length=q_len, key_length=kv_len, device=attn_weights.device, dtype=attn_weights.dtype, ) attn_weights = attn_weights + monotonic_mask # Apply additional attention mask for encoder positions (if provided) if attention_mask is not None: # attention_mask should mask invalid encoder positions # Shape: [batch_size, 1, 1, encoder_seq_len] or [batch_size, 1, decoder_seq_len, encoder_seq_len] attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class CrossAttentionLayer(GradientCheckpointingLayer): """ Cross-attention transformer layer with layer normalization and MLP. Includes self-attention on decoder, cross-attention to encoder, and feed-forward. """ def __init__(self, config: MimiConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size # Self-attention for decoder self.self_attn = MIMI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) # Cross-attention to encoder self.cross_attn = CrossAttention(config=config, layer_idx=layer_idx) self.mlp = MimiMLP(config) # Layer norms self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) self.post_cross_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) # Layer scales self.self_attn_layer_scale = MimiLayerScale(config) self.cross_attn_layer_scale = MimiLayerScale(config) self.mlp_layer_scale = MimiLayerScale(config) def forward( self, hidden_states: torch.Tensor, # Decoder hidden states encoder_hidden_states: torch.Tensor, # Encoder hidden states attention_mask: Optional[torch.Tensor] = None, # Causal mask for self-attention encoder_attention_mask: Optional[torch.Tensor] = None, # Mask for encoder positions position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, cross_past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, alignment_chunk_sizes: Optional[torch.Tensor] = None, # Monotonic attention chunk sizes **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): decoder input of shape `(batch, seq_len, embed_dim)` encoder_hidden_states (`torch.FloatTensor`): encoder output of shape `(batch, encoder_seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): causal attention mask for self-attention encoder_attention_mask (`torch.FloatTensor`, *optional*): mask for encoder positions position_ids (`torch.LongTensor`, *optional*): position IDs for decoder past_key_value (`Cache`, *optional*): cached self-attention states cross_past_key_value (`Cache`, *optional*): cached cross-attention states output_attentions (`bool`, *optional*): whether to return attention weights use_cache (`bool`, *optional*): whether to use caching cache_position (`torch.LongTensor`, *optional*): cache positions """ residual = hidden_states # Pre-norm for self-attention hidden_states = self.input_layernorm(hidden_states) # Self-attention on decoder hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = residual + self.self_attn_layer_scale(hidden_states) # Cross-attention to encoder residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, cross_attn_weights, cross_present_key_value = self.cross_attn( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_ids=position_ids, past_key_value=cross_past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, alignment_chunk_sizes=alignment_chunk_sizes, ) hidden_states = residual + self.cross_attn_layer_scale(hidden_states) # Feed Forward Network residual = hidden_states hidden_states = self.post_cross_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + self.mlp_layer_scale(hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) if use_cache: outputs += (present_key_value, cross_present_key_value) return outputs class CrossAttentionTransformer(nn.Module): """ Cross-attention transformer consisting of N cross-attention layers. Each layer performs self-attention on decoder and cross-attention to encoder. Args: config: MimiConfig """ def __init__(self, config: MimiConfig): super().__init__() self.layers = nn.ModuleList( [CrossAttentionLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self._attn_implementation = config._attn_implementation self.gradient_checkpointing = False self.config = config def forward( self, hidden_states: torch.Tensor, # Decoder hidden states encoder_hidden_states: torch.Tensor, # Encoder hidden states attention_mask: Optional[torch.Tensor] = None, # Causal mask for decoder encoder_attention_mask: Optional[torch.Tensor] = None, # Mask for encoder position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, cross_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, alignment_chunk_sizes: Optional[torch.Tensor] = None, # Monotonic attention chunk sizes ) -> Union[tuple, BaseModelOutputWithPast]: """ Args: hidden_states (`torch.FloatTensor`): decoder input of shape `(batch_size, decoder_sequence_length, hidden_size)` encoder_hidden_states (`torch.FloatTensor`): encoder output of shape `(batch_size, encoder_sequence_length, hidden_size)` attention_mask (`torch.Tensor`, *optional*): causal attention mask for decoder self-attention encoder_attention_mask (`torch.Tensor`, *optional*): attention mask for encoder positions position_ids (`torch.LongTensor`, *optional*): position IDs for decoder past_key_values (`Cache` or `list`, *optional*): cached self-attention states cross_past_key_values (`Cache` or `list`, *optional*): cached cross-attention states use_cache (`bool`, *optional*): whether to use caching output_attentions (`bool`, *optional*): whether to return attention weights output_hidden_states (`bool`, *optional*): whether to return hidden states return_dict (`bool`, *optional*): whether to return ModelOutput cache_position (`torch.LongTensor`, *optional*): cache positions alignment_chunk_sizes (`torch.Tensor`, *optional*): tensor of shape `(decoder_sequence_length,)` specifying how many encoder positions each decoder position can attend to cumulatively. Enables monotonic attention where decoder position i can attend to encoder positions 0 through sum(alignment_chunk_sizes[:i+1])-1. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if use_cache and past_key_values is None: logger.warning_once("use_cache=True was passed, but no past_key_values were given. Creating new cache.") past_key_values = DynamicCache() if use_cache and cross_past_key_values is None: cross_past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) # Create causal mask for decoder self-attention causal_mask = create_causal_mask( config=self.config, input_embeds=hidden_states, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, position_ids=position_ids, ) # Initialize output containers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attns = () if output_attentions else None next_decoder_cache = None next_cross_cache = None for layer_idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) # Get past key values for this layer layer_past_key_value = past_key_values[layer_idx] if past_key_values is not None else None layer_cross_past_key_value = cross_past_key_values[layer_idx] if cross_past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, encoder_hidden_states, causal_mask, encoder_attention_mask, position_ids, layer_past_key_value, layer_cross_past_key_value, output_attentions, use_cache, cache_position, alignment_chunk_sizes, ) else: layer_outputs = decoder_layer( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=causal_mask, encoder_attention_mask=encoder_attention_mask, position_ids=position_ids, past_key_value=layer_past_key_value, cross_past_key_value=layer_cross_past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, alignment_chunk_sizes=alignment_chunk_sizes, ) hidden_states = layer_outputs[0] if use_cache: # Extract the cached states if output_attentions: next_decoder_cache = layer_outputs[3] # self attn cache next_cross_cache = layer_outputs[4] # cross attn cache else: next_decoder_cache = layer_outputs[1] # self attn cache next_cross_cache = layer_outputs[2] # cross attn cache if output_attentions: all_self_attns += (layer_outputs[1],) # self attention weights all_cross_attns += (layer_outputs[2],) # cross attention weights # Add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None next_cross_cache = next_cross_cache if use_cache else None if not return_dict: return tuple(v for v in [hidden_states, next_cache, next_cross_cache, all_hidden_states, all_self_attns, all_cross_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def _create_monotonic_attention_mask( alignment_chunk_sizes: torch.Tensor, query_length: int, key_length: int, device: torch.device, dtype: torch.dtype, ) -> torch.Tensor: """ Create a monotonic attention mask where each query can only attend to a progressive subset of keys. Args: alignment_chunk_sizes: Tensor of shape (batch_size, query_length) where each element represents how many keys the corresponding query can attend to cumulatively. query_length: Number of queries (text tokens) key_length: Number of keys (speech features) device: Device to create the mask on dtype: Data type for the mask Returns: Attention mask of shape (batch_size, 1, query_length, key_length) where -inf masks out invalid positions, 0.0 allows attention. """ batch_size = alignment_chunk_sizes.shape[0] # Create cumulative positions that each query can attend up to cumulative_positions = torch.cumsum(alignment_chunk_sizes, dim=1) # [batch_size, query_length] # Ensure we don't exceed the key length cumulative_positions = torch.clamp(cumulative_positions, max=key_length) # Create position indices for keys key_positions = torch.arange(key_length, device=device).unsqueeze(0).unsqueeze(0) # [1, 1, key_length] # Expand cumulative positions for broadcasting cumulative_positions = cumulative_positions.unsqueeze(2) # [batch_size, query_length, 1] # Create mask: query i can attend to keys 0 to cumulative_positions[i] mask = key_positions < cumulative_positions # [batch_size, query_length, key_length] # Convert to attention mask format: True -> 0.0 (attend), False -> -inf (mask out) attention_mask = torch.where(mask, 0.0, float('-inf')) # Add head dimension: [batch_size, 1, query_length, key_length] attention_mask = attention_mask.unsqueeze(1) return attention_mask.to(dtype) __all__ = [ "CausalAttentionTransformer", "MimiTransformerLayer", "CrossAttention", "CrossAttentionLayer", "CrossAttentionTransformer", "_create_monotonic_attention_mask", ]