Feature Extraction
Transformers
Safetensors
moss-audio-tokenizer
audio
audio-tokenizer
neural-codec
moss-tts-family
MOSS Audio Tokenizer
speech-tokenizer
trust-remote-code
custom_code
Instructions to use OpenMOSS-Team/MOSS-Audio-Tokenizer with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use OpenMOSS-Team/MOSS-Audio-Tokenizer with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="OpenMOSS-Team/MOSS-Audio-Tokenizer", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("OpenMOSS-Team/MOSS-Audio-Tokenizer", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """PyTorch MossAudioTokenizer model.""" | |
| from __future__ import annotations | |
| import copy | |
| import math | |
| from contextlib import ExitStack, contextmanager | |
| from dataclasses import dataclass | |
| from typing import cast | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers.modeling_utils import PreTrainedAudioTokenizerBase | |
| from transformers.utils import ModelOutput, auto_docstring, logging | |
| from .configuration_moss_audio_tokenizer import MossAudioTokenizerConfig | |
| logger = logging.get_logger(__name__) | |
| # ============================================================================= | |
| # Output Classes | |
| # ============================================================================= | |
| class MossAudioTokenizerEncoderOutput(ModelOutput): | |
| r""" | |
| audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`, *optional*): | |
| Discrete audio codes computed using the encoder and quantizer. | |
| audio_codes_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
| Valid lengths for each sample's audio codes. | |
| encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, hidden_size, sequence_length)`, *optional*): | |
| Hidden states from the encoder before quantization. | |
| """ | |
| audio_codes: torch.Tensor | None = None | |
| audio_codes_lengths: torch.Tensor | None = None | |
| encoder_hidden_states: torch.Tensor | None = None | |
| class MossAudioTokenizerDecoderOutput(ModelOutput): | |
| r""" | |
| audio (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*): | |
| Decoded audio waveform. | |
| audio_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
| Valid lengths for each sample's audio. | |
| """ | |
| audio: torch.Tensor | None = None | |
| audio_lengths: torch.Tensor | None = None | |
| class MossAudioTokenizerOutput(ModelOutput): | |
| r""" | |
| audio (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*): | |
| Decoded audio waveform. | |
| audio_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
| Valid lengths for each sample's audio. | |
| audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`, *optional*): | |
| Discrete audio codes computed using the encoder and quantizer. | |
| audio_codes_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
| Valid lengths for each sample's audio codes. | |
| """ | |
| audio: torch.Tensor | None = None | |
| audio_lengths: torch.Tensor | None = None | |
| audio_codes: torch.Tensor | None = None | |
| audio_codes_lengths: torch.Tensor | None = None | |
| # ============================================================================= | |
| # Streaming Module Base Classes | |
| # ============================================================================= | |
| class StreamingState: | |
| """Base state for streaming modules.""" | |
| batch_size: int | |
| device: torch.device | |
| def __post_init__(self): | |
| self.exec_mask = torch.ones(self.batch_size, dtype=torch.bool, device=self.device) | |
| def set_exec_mask(self, exec_mask: torch.Tensor): | |
| self.exec_mask[:] = exec_mask | |
| def reset(self, reset_mask: torch.Tensor) -> None: | |
| self.exec_mask[:] = torch.where(reset_mask, torch.ones_like(self.exec_mask), self.exec_mask) | |
| def __enter__(self): | |
| # ExitStack expects a context manager; returning self is conventional and useful for debugging. | |
| return self | |
| def __exit__(self, exc_type, exc_value, traceback) -> None: | |
| pass | |
| class StreamingModule(nn.Module): | |
| """Base class for streaming components.""" | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self._streaming_state: StreamingState | None = None | |
| self._streaming_detached: bool = False | |
| self._cached_children: list[tuple[str, StreamingModule]] | None = None | |
| def is_streaming(self): | |
| return self._streaming_state is not None | |
| def _apply_named_streaming(self, fn): | |
| def _handle_module(prefix: str, module: nn.Module): | |
| if isinstance(module, StreamingModule): | |
| if module._streaming_detached and prefix != "": | |
| return | |
| if self._cached_children is None: | |
| raise RuntimeError("Internal error: _cached_children should be initialized before traversal.") | |
| self._cached_children.append((prefix, module)) | |
| for name, child in module.named_children(): | |
| new_prefix = f"{prefix}.{name}" if prefix else name | |
| _handle_module(new_prefix, child) | |
| if self._cached_children is None: | |
| self._cached_children = [] | |
| _handle_module("", self) | |
| for name, child in self._cached_children: | |
| fn(name, child) | |
| def _start_streaming(self, batch_size: int, exit_stack: ExitStack): | |
| def _start_streaming_fn(name: str, module: StreamingModule): | |
| if module._streaming_state is not None: | |
| raise RuntimeError(f"{name} is already streaming!") | |
| state = module._init_streaming_state(batch_size) | |
| exit_stack.enter_context(state) | |
| module._streaming_state = state | |
| self._apply_named_streaming(_start_streaming_fn) | |
| def _stop_streaming(self) -> None: | |
| def _stop_streaming_fn(name: str, module: StreamingModule): | |
| module._streaming_state = None | |
| self._apply_named_streaming(_stop_streaming_fn) | |
| def _init_streaming_state(self, batch_size: int) -> StreamingState: | |
| device = next(iter(self.parameters())).device | |
| return StreamingState(batch_size, device) | |
| def streaming(self, batch_size: int) -> ExitStack: | |
| """Context manager to enter streaming mode.""" | |
| exit_stack = ExitStack() | |
| self._start_streaming(batch_size, exit_stack) | |
| exit_stack.callback(self._stop_streaming) | |
| return exit_stack | |
| class StreamingContainer(StreamingModule): | |
| """Container for streaming modules.""" | |
| pass | |
| # ============================================================================= | |
| # Normalization Layers | |
| # ============================================================================= | |
| class MossAudioTokenizerRMSNorm(nn.Module): | |
| """Root Mean Square Layer Normalization.""" | |
| def __init__( | |
| self, | |
| dim: int, | |
| eps: float = 1e-5, | |
| dtype: torch.dtype | None = None, | |
| device=None, | |
| ): | |
| super().__init__() | |
| self.eps = eps | |
| self.dtype = dtype | |
| self.alpha = nn.Parameter(torch.full((1, 1, dim), 1.0, requires_grad=True, device=device, dtype=dtype)) | |
| def forward(self, x: torch.Tensor): | |
| x_dtype = x.dtype | |
| if self.dtype is not None: | |
| x = x.to(self.dtype) | |
| var = self.eps + torch.mean(x**2, dim=2, keepdim=True) | |
| y = (x * (self.alpha.to(var) * torch.rsqrt(var))).to(x_dtype) | |
| return y | |
| class MossAudioTokenizerLayerScale(nn.Module): | |
| """Layer scale from Touvron et al. 2021.""" | |
| def __init__( | |
| self, | |
| channels: int, | |
| init: float = 1e-4, | |
| channel_last: bool = True, | |
| device=None, | |
| dtype=None, | |
| ): | |
| super().__init__() | |
| self.channel_last = channel_last | |
| self.scale = nn.Parameter(torch.full((channels,), init, requires_grad=True, device=device, dtype=dtype)) | |
| def forward(self, x: torch.Tensor): | |
| if self.channel_last: | |
| return self.scale * x | |
| else: | |
| return self.scale[:, None] * x | |
| def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module: | |
| """Create normalization module.""" | |
| if norm_type == "layer_norm": | |
| return nn.LayerNorm(dim, eps=1e-5, **kwargs) | |
| elif norm_type in {"rms_norm"}: | |
| return MossAudioTokenizerRMSNorm(dim, eps=1e-5, **kwargs) | |
| elif norm_type in {"rms_norm_f32"}: | |
| kwargs.pop("dtype", None) | |
| return MossAudioTokenizerRMSNorm(dim, eps=1e-8, dtype=torch.float, **kwargs) | |
| else: | |
| raise ValueError(f"Unknown norm type: {norm_type}") | |
| # ============================================================================= | |
| # Rotary Position Embedding | |
| # ============================================================================= | |
| def apply_rope( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| offset: torch.Tensor, | |
| max_period: float = 10_000, | |
| time_before_heads: bool = False, | |
| ): | |
| """Apply rotary position embedding.""" | |
| if time_before_heads: | |
| B, T, H, D = q.shape | |
| else: | |
| B, H, T, D = q.shape | |
| if k.shape != q.shape: | |
| raise ValueError(f"Expected k.shape == q.shape, got k={tuple(k.shape)} q={tuple(q.shape)}") | |
| if D <= 0 or (D % 2) != 0: | |
| raise ValueError(f"RoPE requires an even last dimension, got D={D}") | |
| ds = torch.arange(D // 2, device=q.device, dtype=torch.float32) | |
| freqs = torch.exp(ds * (-math.log(max_period) * 2 / D)) | |
| ts = offset.float().view(-1, 1) + torch.arange(T, device=q.device, dtype=torch.float32) | |
| if time_before_heads: | |
| ts = ts.view(B, -1, 1, 1) | |
| else: | |
| ts = ts.view(B, 1, -1, 1) | |
| dims = q.shape[:-1] | |
| q = q.view(*dims, D // 2, 2) | |
| k = k.view(*dims, D // 2, 2) | |
| qr, qi = q[..., 0].float(), q[..., 1].float() | |
| kr, ki = k[..., 0].float(), k[..., 1].float() | |
| rotr = torch.cos(freqs * ts) | |
| roti = torch.sin(freqs * ts) | |
| qor = qr * rotr - qi * roti | |
| qoi = qr * roti + qi * rotr | |
| kor = kr * rotr - ki * roti | |
| koi = kr * roti + ki * rotr | |
| dtype = q.dtype | |
| qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1) | |
| ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1) | |
| return qo.view(*dims, D), ko.view(*dims, D) | |
| class MossAudioTokenizerRotaryEmbedding(nn.Module): | |
| """Rotary positional embedding (RoPE).""" | |
| def __init__(self, max_period: float = 10000.0): | |
| super().__init__() | |
| self.max_period = max_period | |
| def forward( | |
| self, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| offset: torch.Tensor, | |
| time_before_heads: bool = False, | |
| ): | |
| return apply_rope(q, k, offset, self.max_period, time_before_heads) | |
| # ============================================================================= | |
| # Gating Modules | |
| # ============================================================================= | |
| class MossAudioTokenizerActivationGating(nn.Module): | |
| """Gating FFN layer with activation.""" | |
| def __init__(self, dim: int, dim_feedforward: int, activation, **factory_kwargs): | |
| super().__init__() | |
| if dim_feedforward == 4 * dim: | |
| hidden = (21 * dim) // 8 | |
| else: | |
| hidden = (2 * dim_feedforward) // 3 | |
| self.linear_in = nn.Linear(dim, 2 * hidden, bias=False, **factory_kwargs) | |
| self.linear_out = nn.Linear(hidden, dim, bias=False, **factory_kwargs) | |
| self.activation = activation | |
| def forward(self, x: torch.Tensor): | |
| x = self.linear_in(x) | |
| B, T, _ = x.shape | |
| x = x.view(B, T, 2, -1) | |
| x = self.activation(x[..., 0, :]) * x[..., 1, :] | |
| x = self.linear_out(x) | |
| return x | |
| def _get_activation(name: str): | |
| if name in ["sigmoid", "tanh", "relu"]: | |
| return getattr(torch, name) | |
| elif name in ["leaky_relu", "elu", "gelu", "silu", "mish", "softsign"]: | |
| return getattr(F, name) | |
| elif name == "identity": | |
| return nn.Identity() | |
| else: | |
| raise ValueError(f"Unknown activation {name}") | |
| def make_gating(name: str, dim: int, dim_feedforward: int, **factory_kwargs) -> nn.Module: | |
| return MossAudioTokenizerActivationGating(dim, dim_feedforward, _get_activation(name), **factory_kwargs) | |
| # ============================================================================= | |
| # Positional Embeddings | |
| # ============================================================================= | |
| def create_sin_embedding( | |
| positions: torch.Tensor, | |
| dim: int, | |
| max_period: float = 10000, | |
| dtype: torch.dtype = torch.float32, | |
| ) -> torch.Tensor: | |
| """Create sinusoidal positional embedding with shape [B, T, C].""" | |
| if dim % 2 != 0: | |
| raise ValueError(f"Sinusoidal embedding requires even dim, got dim={dim}") | |
| half_dim = dim // 2 | |
| if half_dim <= 1: | |
| raise ValueError(f"Sinusoidal embedding requires dim >= 4, got dim={dim}") | |
| positions = positions.to(dtype) | |
| adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1) | |
| max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) | |
| phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) | |
| return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) | |
| # ============================================================================= | |
| # KV Cache for Attention | |
| # ============================================================================= | |
| class KVCacheResult: | |
| """Container for KV cache results that supports tuple unpacking.""" | |
| __slots__ = ("keys", "values", "positions") | |
| def __init__(self, keys: torch.Tensor, values: torch.Tensor, positions: torch.Tensor): | |
| self.keys = keys | |
| self.values = values | |
| self.positions = positions | |
| def __iter__(self): | |
| """Allow unpacking as (keys, values, positions).""" | |
| return iter((self.keys, self.values, self.positions)) | |
| def from_kv(keys: torch.Tensor, values: torch.Tensor) -> KVCacheResult: | |
| B, H, T, D = keys.shape | |
| positions = torch.arange(T, device=keys.device, dtype=torch.long) | |
| return KVCacheResult(keys, values, positions.expand(B, -1)) | |
| class RingKVCache: | |
| """Efficient streaming KVCache compatible with CUDA Graph.""" | |
| def __init__( | |
| self, | |
| batch_size: int, | |
| num_heads: int, | |
| dim_per_head: int, | |
| capacity: int, | |
| respect_exec_mask: bool = True, | |
| device: torch.device = torch.device("cuda"), | |
| dtype: torch.dtype = torch.bfloat16, | |
| ): | |
| self.capacity = capacity | |
| self.cache = torch.zeros( | |
| (2, batch_size, num_heads, capacity, dim_per_head), | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| self.respect_exec_mask = respect_exec_mask | |
| if self.respect_exec_mask: | |
| self.end_offset = torch.zeros(batch_size, device=device, dtype=torch.long) | |
| else: | |
| self.end_offset = torch.zeros(1, device=device, dtype=torch.long) | |
| def reset(self, reset_mask: torch.Tensor) -> None: | |
| self.end_offset[:] = torch.where(reset_mask, torch.zeros_like(self.end_offset), self.end_offset) | |
| def complete(self, k: torch.Tensor, v: torch.Tensor, exec_mask: torch.Tensor) -> KVCacheResult: | |
| B, H, T, D = k.shape | |
| if T <= 0: | |
| raise ValueError(f"Expected T > 0, got T={T}") | |
| indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) | |
| indexes = indexes + self.end_offset.view(-1, 1) | |
| indexes = indexes % self.capacity | |
| if self.respect_exec_mask: | |
| this_indexes = indexes.view(B, 1, T, 1).expand(-1, H, T, D) | |
| self.cache[0].scatter_(2, this_indexes, k) | |
| self.cache[1].scatter_(2, this_indexes, v) | |
| else: | |
| self.cache[0].index_copy_(2, indexes[0], k) | |
| self.cache[1].index_copy_(2, indexes[0], v) | |
| keys = self.cache[0] | |
| values = self.cache[1] | |
| indexes = torch.arange(self.capacity, device=self.end_offset.device, dtype=torch.long) | |
| last_offset = self.end_offset.view(-1, 1) + T - 1 | |
| end_index = last_offset % self.capacity | |
| delta = indexes - end_index | |
| positions = torch.where( | |
| delta <= 0, | |
| last_offset + delta, | |
| last_offset + delta - self.capacity, | |
| ) | |
| if self.respect_exec_mask: | |
| self.end_offset[:] = torch.where(exec_mask, self.end_offset + T, self.end_offset) | |
| else: | |
| self.end_offset.add_(T) | |
| invalid = indexes >= self.end_offset.view(-1, 1) | |
| positions = torch.where(invalid, torch.full_like(positions, -1), positions) | |
| return KVCacheResult(keys, values, positions) | |
| # ============================================================================= | |
| # Multi-Head Attention | |
| # ============================================================================= | |
| class MHAState(StreamingState): | |
| kv_cache: RingKVCache | None | |
| offset: torch.Tensor | |
| offset_cpu: int | |
| def reset(self, reset_mask: torch.Tensor): | |
| super().reset(reset_mask) | |
| self.offset[:] = torch.where(reset_mask, torch.zeros_like(self.offset), self.offset) | |
| if self.kv_cache is not None: | |
| self.kv_cache.reset(reset_mask) | |
| self.offset_cpu = 0 | |
| def apply_weights_per_step( | |
| modules: nn.ModuleList, | |
| schedule: list[int] | None, | |
| x: torch.Tensor, | |
| offset: int | None, | |
| ) -> torch.Tensor: | |
| """Apply different weights for each time step.""" | |
| if len(modules) == 1: | |
| return modules[0](x) | |
| if offset is None: | |
| raise ValueError("offset must be provided when using per-step weights (len(modules) > 1).") | |
| ys = [] | |
| B, T, C = x.shape | |
| for t in range(T): | |
| module_index = t + offset | |
| if schedule is not None: | |
| if module_index >= len(schedule) or module_index < 0: | |
| raise ValueError( | |
| f"weights_per_step_schedule is too short for module_index={module_index} (len={len(schedule)})." | |
| ) | |
| module_index = schedule[module_index] | |
| if module_index >= len(modules) or module_index < 0: | |
| raise ValueError(f"module_index={module_index} out of range for len(modules)={len(modules)}.") | |
| y = modules[module_index](x[:, t : t + 1]) | |
| ys.append(y) | |
| return torch.cat(ys, 1) | |
| class MossAudioTokenizerMultiheadAttention(StreamingModule): | |
| """Multi-head attention with streaming support.""" | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| num_heads: int, | |
| causal: bool = False, | |
| context: int | None = None, | |
| rope: MossAudioTokenizerRotaryEmbedding | None = None, | |
| weights_per_step: int = 0, | |
| weights_per_step_schedule: list[int] | None = None, | |
| device=None, | |
| dtype=None, | |
| ): | |
| super().__init__() | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| self.embed_dim = embed_dim | |
| self.causal = causal | |
| self.context = context | |
| self.rope = rope | |
| self.num_heads = num_heads | |
| self.weights_per_step = weights_per_step | |
| self.weights_per_step_schedule = weights_per_step_schedule | |
| out_dim = 3 * embed_dim | |
| mult = 1 | |
| if weights_per_step: | |
| mult = max(weights_per_step_schedule) + 1 if weights_per_step_schedule else weights_per_step | |
| self.mult = mult | |
| self.out_projs = nn.ModuleList( | |
| [nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs) for _ in range(mult)] | |
| ) | |
| self.in_projs = nn.ModuleList( | |
| [nn.Linear(embed_dim, out_dim, bias=False, **factory_kwargs) for _ in range(mult)] | |
| ) | |
| self._register_load_state_dict_pre_hook(self._load_hook, with_module=True) | |
| def _load_hook(module, state_dict, prefix, *_): | |
| mappings = { | |
| "in_proj_weight": "in_projs.{i}.weight", | |
| "in_proj.weight": "in_projs.{i}.weight", | |
| "out_proj.weight": "out_projs.{i}.weight", | |
| } | |
| mult = module.mult | |
| for suffix in ["", "_scb"]: | |
| for source, target in mappings.items(): | |
| this_source = prefix + source + suffix | |
| if this_source in state_dict: | |
| weight = state_dict[this_source] | |
| _, *OD = weight.shape | |
| weight = weight.view(mult, -1, *OD) | |
| for i in range(mult): | |
| state_dict[prefix + target.format(i=i) + suffix] = weight[i] | |
| state_dict.pop(this_source) | |
| def _init_streaming_state(self, batch_size: int) -> MHAState: | |
| in_proj = cast(nn.Linear, self.in_projs[0]) | |
| device = cast(torch.device, in_proj.weight.device) | |
| dtype = cast(torch.dtype, in_proj.weight.dtype) | |
| dim_per_head = self.embed_dim // self.num_heads | |
| if self.context is None: | |
| capacity = self.weights_per_step if self.weights_per_step else 1024 | |
| else: | |
| capacity = self.context | |
| kv_cache = RingKVCache( | |
| batch_size, | |
| self.num_heads, | |
| dim_per_head, | |
| capacity, | |
| respect_exec_mask=not self.weights_per_step, | |
| device=cast(torch.device, device), | |
| dtype=cast(torch.dtype, dtype), | |
| ) | |
| return MHAState( | |
| batch_size, | |
| cast(torch.device, device), | |
| kv_cache, | |
| offset=torch.zeros(batch_size, device=cast(torch.device, device), dtype=torch.long), | |
| offset_cpu=0, | |
| ) | |
| def _complete_kv(self, k, v) -> KVCacheResult: | |
| state = cast(MHAState | None, self._streaming_state) | |
| if state is None: | |
| return KVCacheResult.from_kv(k, v) | |
| if state.kv_cache is None: | |
| return KVCacheResult.from_kv(k, v) | |
| return state.kv_cache.complete(k, v, state.exec_mask) | |
| def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): | |
| state = cast(MHAState | None, self._streaming_state) | |
| B, T = query.shape[:2] | |
| if state is None: | |
| offset = torch.zeros(B, device=query.device, dtype=torch.long) | |
| offset_cpu = 0 | |
| else: | |
| offset = state.offset | |
| offset_cpu = state.offset_cpu | |
| projected = apply_weights_per_step(self.in_projs, self.weights_per_step_schedule, query, offset_cpu) | |
| dim_per_head = self.embed_dim // self.num_heads | |
| projected = projected.reshape(B, T, 3, self.num_heads, dim_per_head).permute(2, 0, 3, 1, 4) | |
| q, k, v = projected[0], projected[1], projected[2] | |
| if self.rope: | |
| q, k = self.rope(q, k, offset, time_before_heads=False) | |
| k, v, pos_k = self._complete_kv(k, v) | |
| pos_k = pos_k[:, None] | |
| if self.causal: | |
| pos_q = offset.view(-1, 1, 1) + torch.arange(T, device=q.device, dtype=torch.long).view(-1, 1) | |
| delta = pos_q - pos_k | |
| attn_bias = (pos_k >= 0) & (delta >= 0) | |
| if self.context is not None: | |
| attn_bias = attn_bias & (delta < self.context) | |
| attn_bias = attn_bias[:, None] | |
| else: | |
| attn_bias = None | |
| x = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0) | |
| x = x.transpose(1, 2).reshape(B, T, self.embed_dim) | |
| x = apply_weights_per_step(self.out_projs, self.weights_per_step_schedule, x, offset_cpu) | |
| if state is not None: | |
| state.offset[:] = torch.where(state.exec_mask, state.offset + T, state.offset) | |
| state.offset_cpu += T | |
| return x | |
| # ============================================================================= | |
| # Transformer Layer | |
| # ============================================================================= | |
| class LayerState(StreamingState): | |
| offset_cpu: int = 0 | |
| def reset(self, reset_mask: torch.Tensor): | |
| super().reset(reset_mask) | |
| self.offset_cpu = 0 | |
| class MossAudioTokenizerTransformerLayer(StreamingModule): | |
| """Transformer layer with streaming support.""" | |
| def __init__( | |
| self, | |
| d_model: int, | |
| num_heads: int, | |
| dim_feedforward: int = 2048, | |
| causal: bool = False, | |
| context: int | None = None, | |
| rope: MossAudioTokenizerRotaryEmbedding | None = None, | |
| norm: str = "layer_norm", | |
| layer_scale: float | None = None, | |
| gating: str = "none", | |
| weights_per_step: int = 0, | |
| weights_per_step_schedule: list[int] | None = None, | |
| activation=F.gelu, | |
| device=None, | |
| dtype=None, | |
| ): | |
| super().__init__() | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| self.self_attn = MossAudioTokenizerMultiheadAttention( | |
| embed_dim=d_model, | |
| num_heads=num_heads, | |
| causal=causal, | |
| context=context, | |
| rope=rope, | |
| weights_per_step=weights_per_step, | |
| weights_per_step_schedule=weights_per_step_schedule, | |
| **factory_kwargs, | |
| ) | |
| self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) | |
| self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) | |
| self.weights_per_step = weights_per_step | |
| self.weights_per_step_schedule = weights_per_step_schedule | |
| self.gating: nn.Module | nn.ModuleList | None = None | |
| self.linear1: nn.Module | None = None | |
| self.linear2: nn.Module | None = None | |
| self.activation = activation | |
| num_weights = 1 | |
| if weights_per_step: | |
| num_weights = max(weights_per_step_schedule) + 1 if weights_per_step_schedule else weights_per_step | |
| if gating == "none": | |
| self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False, **factory_kwargs) | |
| self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False, **factory_kwargs) | |
| else: | |
| if weights_per_step: | |
| dim_ff_list = [dim_feedforward] * num_weights if isinstance(dim_feedforward, int) else dim_feedforward | |
| self.gating = nn.ModuleList( | |
| [make_gating(gating, d_model, dim, **factory_kwargs) for dim in dim_ff_list] | |
| ) | |
| else: | |
| self.gating = make_gating(gating, d_model, dim_feedforward, **factory_kwargs) | |
| if layer_scale is None: | |
| self.layer_scale_1 = nn.Identity() | |
| self.layer_scale_2 = nn.Identity() | |
| else: | |
| self.layer_scale_1 = MossAudioTokenizerLayerScale( | |
| channels=d_model, init=layer_scale, channel_last=True, **cast(dict[str, object], factory_kwargs) | |
| ) | |
| self.layer_scale_2 = MossAudioTokenizerLayerScale( | |
| channels=d_model, init=layer_scale, channel_last=True, **cast(dict[str, object], factory_kwargs) | |
| ) | |
| def _init_streaming_state(self, batch_size: int) -> LayerState: | |
| device = next(iter(self.parameters())).device | |
| return LayerState(batch_size, device, offset_cpu=0) | |
| def _ff_block(self, x: torch.Tensor) -> torch.Tensor: | |
| state = self._streaming_state | |
| offset = state.offset_cpu if isinstance(state, LayerState) else 0 | |
| x_orig = x | |
| x = self.norm2(x) | |
| if self.gating is None: | |
| assert self.linear1 is not None | |
| assert self.linear2 is not None | |
| update = self.linear2(self.activation(self.linear1(x))) | |
| else: | |
| if self.weights_per_step: | |
| assert isinstance(self.gating, nn.ModuleList) | |
| update = apply_weights_per_step(self.gating, self.weights_per_step_schedule, x, offset) | |
| else: | |
| update = self.gating(x) | |
| return x_orig.to(update) + self.layer_scale_2(update) | |
| def _sa_block(self, x: torch.Tensor): | |
| x_orig = x | |
| x = self.norm1(x) | |
| update = self.self_attn(x, x, x) | |
| return x_orig.to(update) + self.layer_scale_1(update) | |
| def forward(self, x: torch.Tensor): | |
| x = self._sa_block(x) | |
| x = self._ff_block(x) | |
| state = self._streaming_state | |
| if state is not None: | |
| assert isinstance(state, LayerState) | |
| state.offset_cpu += x.shape[1] | |
| return x | |
| # ============================================================================= | |
| # Streaming Transformer | |
| # ============================================================================= | |
| class TransformerState(StreamingState): | |
| offsets: torch.Tensor | |
| def reset(self, reset_mask: torch.Tensor): | |
| super().reset(reset_mask) | |
| self.offsets[:] = torch.where(reset_mask, torch.zeros_like(self.offsets), self.offsets) | |
| class MossAudioTokenizerTransformer(StreamingModule): | |
| """Transformer with streaming/causal support.""" | |
| def __init__( | |
| self, | |
| d_model: int, | |
| num_heads: int, | |
| num_layers: int, | |
| dim_feedforward: int = 2048, | |
| causal: bool = False, | |
| context: int | None = None, | |
| positional_embedding: str = "sin", | |
| max_period: float = 10_000, | |
| positional_scale: float = 1.0, | |
| device=None, | |
| dtype=None, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| if d_model % num_heads != 0: | |
| raise ValueError(f"d_model must be divisible by num_heads, got d_model={d_model}, num_heads={num_heads}") | |
| self.positional_embedding = positional_embedding | |
| self.max_period = max_period | |
| self.positional_scale = positional_scale | |
| self.rope: MossAudioTokenizerRotaryEmbedding | None = None | |
| if positional_embedding in {"rope", "sin_rope"}: | |
| self.rope = MossAudioTokenizerRotaryEmbedding(max_period=max_period) | |
| self.layers = nn.ModuleList() | |
| for _ in range(num_layers): | |
| self.layers.append( | |
| MossAudioTokenizerTransformerLayer( | |
| d_model=d_model, | |
| num_heads=num_heads, | |
| dim_feedforward=dim_feedforward, | |
| causal=causal, | |
| context=context, | |
| rope=self.rope, | |
| device=device, | |
| dtype=dtype, | |
| **kwargs, | |
| ) | |
| ) | |
| def _init_streaming_state(self, batch_size: int) -> TransformerState: | |
| device = next(self.parameters()).device | |
| return TransformerState( | |
| batch_size, | |
| device, | |
| offsets=torch.zeros(batch_size, device=device, dtype=torch.long), | |
| ) | |
| def forward(self, x: torch.Tensor, *args, **kwargs): | |
| B, T, C = x.shape | |
| state = self._streaming_state | |
| offsets = ( | |
| torch.zeros(1, dtype=torch.long, device=x.device) | |
| if state is None | |
| else ( | |
| state.offsets | |
| if isinstance(state, TransformerState) | |
| else torch.zeros(1, dtype=torch.long, device=x.device) | |
| ) | |
| ) | |
| if self.positional_embedding in {"sin", "sin_rope"}: | |
| positions = torch.arange(T, device=x.device).view(1, -1, 1) | |
| positions = positions + offsets.view(-1, 1, 1) | |
| pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype) | |
| x = x + self.positional_scale * pos_emb | |
| for layer in self.layers: | |
| x = layer(x, *args, **kwargs) | |
| if state is not None: | |
| assert isinstance(state, TransformerState) | |
| state.offsets[:] = torch.where(state.exec_mask, state.offsets + T, state.offsets) | |
| return x | |
| class MossAudioTokenizerProjectedTransformer(StreamingContainer): | |
| """Transformer with input/output projections.""" | |
| def __init__( | |
| self, | |
| input_dimension: int, | |
| output_dimension: int, | |
| d_model: int, | |
| *, | |
| conv_layout: bool = False, | |
| module_type: str, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.module_type = module_type | |
| self.downsample_ratio: int = 1 | |
| self.input_dimension = input_dimension | |
| self.output_dimension = output_dimension | |
| self.input_proj = ( | |
| nn.Linear(input_dimension, d_model, bias=False) if d_model != input_dimension else nn.Identity() | |
| ) | |
| self.transformer = MossAudioTokenizerTransformer(d_model=d_model, **kwargs) | |
| self.conv_layout = conv_layout | |
| self.output_proj = ( | |
| nn.Linear(d_model, output_dimension, bias=False) if d_model != output_dimension else nn.Identity() | |
| ) | |
| def forward(self, x, input_lengths, *args, **kwargs): | |
| x = self.input_proj(x.transpose(1, 2)) # (B, D, T) -> (B, T, D) | |
| x = self.transformer(x, *args, **kwargs) | |
| x = self.output_proj(x).transpose(1, 2) # (B, T, D) -> (B, D, T) | |
| return x, input_lengths | |
| # ============================================================================= | |
| # Patched Pretransform Module | |
| # ============================================================================= | |
| class MossAudioTokenizerPatchedPretransform(nn.Module): | |
| """Patching module for downsampling/upsampling.""" | |
| def __init__(self, patch_size: int, is_downsample: bool, module_type: str, **kwargs): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.downsample_ratio: int = patch_size | |
| self.is_downsample = is_downsample | |
| self.module_type = module_type | |
| def encode(self, x, input_lengths): | |
| b, d, _ = x.shape | |
| h = self.patch_size | |
| x = x.reshape(b, d, -1, h).permute(0, 1, 3, 2).reshape(b, d * h, -1) | |
| # We pad the input waveform to a multiple of `downsample_rate` before applying the encoder. | |
| # Use a ceil division to match that padding and avoid dropping the last (partially padded) frame. | |
| output_lengths = input_lengths // self.patch_size | |
| return x, output_lengths | |
| def decode(self, x, input_lengths): | |
| b, dh, l = x.shape | |
| h = self.patch_size | |
| d = dh // h | |
| x = x.reshape(b, d, h, l).permute(0, 1, 3, 2).reshape(b, d, l * h) | |
| output_lengths = input_lengths * self.patch_size | |
| return x, output_lengths | |
| def forward(self, x, input_lengths): | |
| if self.is_downsample: | |
| return self.encode(x, input_lengths) | |
| else: | |
| return self.decode(x, input_lengths) | |
| # ============================================================================= | |
| # Vector Quantization | |
| # ============================================================================= | |
| def WNConv1d(*args, **kwargs): | |
| """Weight-normalized Conv1d.""" | |
| return nn.utils.parametrizations.weight_norm(nn.Conv1d(*args, **kwargs)) | |
| class MossAudioTokenizerVectorQuantize(nn.Module): | |
| """Single codebook vector quantization (inference only).""" | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| codebook_size: int, | |
| codebook_dim: int, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.codebook_size = codebook_size | |
| self.codebook_dim = codebook_dim | |
| if input_dim != codebook_dim: | |
| self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) | |
| self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) | |
| else: | |
| self.in_proj = nn.Identity() | |
| self.out_proj = nn.Identity() | |
| self.codebook = nn.Embedding(codebook_size, codebook_dim) | |
| def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Args: | |
| z: Input tensor of shape (B, D, T) | |
| Returns: | |
| z_q: Quantized tensor of shape (B, D, T) | |
| indices: Code indices of shape (B, T) | |
| z_e: Encoded tensor before quantization | |
| """ | |
| z = z.float() | |
| z_e = self.in_proj(z).float() | |
| encodings = z_e.transpose(1, 2).reshape(-1, z_e.shape[1]) | |
| codebook_weight = self.codebook.weight | |
| dist = ( | |
| encodings.pow(2).sum(1, keepdim=True) | |
| - 2 * encodings @ codebook_weight.float().t() | |
| + codebook_weight.float().pow(2).sum(1, keepdim=True).t() | |
| ) | |
| indices = (-dist).max(1)[1] | |
| indices = indices.reshape(z.size(0), -1) | |
| z_q = self.decode_code(indices) | |
| z_q = self.out_proj(z_q).float() | |
| return z_q, indices, z_e | |
| def decode_code(self, embed_id: torch.Tensor) -> torch.Tensor: | |
| """Decode code indices to embeddings.""" | |
| return self.codebook(embed_id).transpose(1, 2).float() | |
| class MossAudioTokenizerLFQ(nn.Module): | |
| """LFQ (inference-only) used by ResidualLFQ.""" | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| codebook_size: int, | |
| codebook_dim: int, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.codebook_size = codebook_size | |
| self.codebook_dim = codebook_dim | |
| if self.input_dim != self.codebook_dim: | |
| self.in_proj = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1) | |
| self.out_proj = WNConv1d(self.codebook_dim, self.input_dim, kernel_size=1) | |
| else: | |
| self.in_proj = nn.Identity() | |
| self.out_proj = nn.Identity() | |
| self.codebook = nn.Embedding(codebook_size, codebook_dim) | |
| def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Quantize z into codebook vectors.""" | |
| z = z.float() | |
| z_e = self.in_proj(z).float() | |
| z_q, indices = self.decode_latents(z_e) | |
| z_q = (z_e + (z_q - z_e).detach()).float() | |
| z_q = self.out_proj(z_q).float() | |
| return z_q, indices, z_e | |
| def embed_code(self, embed_id: torch.Tensor) -> torch.Tensor: | |
| return F.embedding(embed_id, self.codebook.weight) | |
| def decode_code_wo_out_proj(self, embed_id: torch.Tensor) -> torch.Tensor: | |
| return self.embed_code(embed_id).transpose(1, 2) | |
| def decode_code(self, embed_id: torch.Tensor) -> torch.Tensor: | |
| z_q = self.decode_code_wo_out_proj(embed_id).float() | |
| z_q = self.out_proj(z_q).float() | |
| return z_q | |
| def decode_latents(self, latents: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Match training LFQ: L2-normalize then argmin squared distance.""" | |
| encodings = latents.transpose(1, 2).reshape(-1, latents.shape[1]).float() | |
| codebook = self.codebook.weight.float() | |
| encodings = F.normalize(encodings) | |
| codebook = F.normalize(codebook) | |
| dist = ( | |
| encodings.pow(2).sum(1, keepdim=True) | |
| - 2 * encodings @ codebook.t() | |
| + codebook.pow(2).sum(1, keepdim=True).t() | |
| ) | |
| indices = (-dist).max(1)[1] | |
| indices = indices.reshape(latents.size(0), -1) | |
| z_q = self.decode_code_wo_out_proj(indices).float() | |
| return z_q, indices | |
| class MossAudioTokenizerResidualVQ(nn.Module): | |
| """Residual Vector Quantization (inference only).""" | |
| def __init__( | |
| self, | |
| input_dim: int = 1024, | |
| rvq_dim: int | None = None, | |
| output_dim: int | None = None, | |
| num_quantizers: int = 32, | |
| codebook_size: int = 1024, | |
| codebook_dim: int = 8, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.rvq_dim = rvq_dim or input_dim | |
| self.output_dim = output_dim or input_dim | |
| self.num_quantizers = num_quantizers | |
| self.codebook_size = codebook_size | |
| self.codebook_dim = codebook_dim | |
| self.input_proj = ( | |
| WNConv1d(input_dim, self.rvq_dim, kernel_size=1) if input_dim != self.rvq_dim else nn.Identity() | |
| ) | |
| self.output_proj = ( | |
| WNConv1d(self.rvq_dim, self.output_dim, kernel_size=1) | |
| if self.rvq_dim != self.output_dim | |
| else nn.Identity() | |
| ) | |
| self.quantizers = nn.ModuleList( | |
| [ | |
| MossAudioTokenizerVectorQuantize( | |
| input_dim=self.rvq_dim, | |
| codebook_size=codebook_size, | |
| codebook_dim=codebook_dim, | |
| **kwargs, | |
| ) | |
| for _ in range(num_quantizers) | |
| ] | |
| ) | |
| def forward( | |
| self, | |
| z: torch.Tensor, | |
| input_length: torch.Tensor, | |
| n_quantizers: int | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Args: | |
| z: Input tensor of shape (B, D, T) | |
| input_length: Valid lengths for each sample (B,) | |
| n_quantizers: Number of quantizers to use | |
| Returns: | |
| quantized_out: Quantized output (B, D, T) | |
| all_indices: All code indices (N, B, T) | |
| output_length: Output lengths (B,) | |
| """ | |
| z = self.input_proj(z) | |
| batch_size, _, max_time = z.shape | |
| mask = torch.arange(max_time, device=z.device).expand(batch_size, max_time) < input_length.unsqueeze(1) | |
| quantized_out = torch.zeros_like(z, dtype=torch.float32) | |
| residual = z.clone().float() | |
| all_indices = [] | |
| n_quantizers = n_quantizers or self.num_quantizers | |
| for i, quantizer in enumerate(self.quantizers): | |
| if i >= n_quantizers: | |
| break | |
| masked_residual = residual * mask.unsqueeze(1) | |
| z_q_i, indices_i, _ = quantizer(masked_residual) | |
| update_mask = mask.unsqueeze(1) | |
| quantized_out = quantized_out + z_q_i * update_mask | |
| residual = residual - z_q_i * update_mask | |
| all_indices.append(indices_i) | |
| all_indices = torch.stack(all_indices) # (N, B, T) | |
| quantized_out = self.output_proj(quantized_out) | |
| return quantized_out, all_indices, input_length | |
| def decode_codes(self, codes: torch.Tensor) -> torch.Tensor: | |
| """Decode codes from multiple quantizers to embeddings.""" | |
| nq, B, T = codes.shape | |
| emb = torch.zeros(B, self.rvq_dim, T, device=codes.device, dtype=torch.float32) | |
| for i, quantizer in enumerate(self.quantizers[:nq]): | |
| quantizer = cast(MossAudioTokenizerVectorQuantize, quantizer) | |
| quantized_i = quantizer.decode_code(codes[i]) | |
| emb += quantized_i | |
| emb = self.output_proj(emb) | |
| return emb | |
| class MossAudioTokenizerResidualLFQ(nn.Module): | |
| """Residual LFQ (inference only).""" | |
| def __init__( | |
| self, | |
| input_dim: int = 1024, | |
| rvq_dim: int | None = None, | |
| output_dim: int | None = None, | |
| num_quantizers: int = 32, | |
| codebook_size: int = 1024, | |
| codebook_dim: int = 8, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.rvq_dim = rvq_dim or input_dim | |
| self.output_dim = output_dim or input_dim | |
| self.num_quantizers = num_quantizers | |
| self.codebook_size = codebook_size | |
| self.codebook_dim = codebook_dim | |
| self.input_proj = ( | |
| WNConv1d(input_dim, self.rvq_dim, kernel_size=1) if input_dim != self.rvq_dim else nn.Identity() | |
| ) | |
| self.output_proj = ( | |
| WNConv1d(self.rvq_dim, self.output_dim, kernel_size=1) | |
| if self.rvq_dim != self.output_dim | |
| else nn.Identity() | |
| ) | |
| self.quantizers = nn.ModuleList( | |
| [ | |
| MossAudioTokenizerLFQ( | |
| input_dim=self.rvq_dim, | |
| codebook_size=codebook_size, | |
| codebook_dim=codebook_dim, | |
| **kwargs, | |
| ) | |
| for _ in range(num_quantizers) | |
| ] | |
| ) | |
| def forward( | |
| self, | |
| z: torch.Tensor, | |
| input_length: torch.Tensor, | |
| n_quantizers: int | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Inference quantization.""" | |
| z = self.input_proj(z).float() | |
| batch_size, _, max_time = z.shape | |
| mask = torch.arange(max_time, device=z.device).expand(batch_size, max_time) < input_length.unsqueeze(1) | |
| quantized_out = torch.zeros_like(z, dtype=torch.float32) | |
| residual = z.clone().float() | |
| all_indices = [] | |
| n_quantizers = n_quantizers or self.num_quantizers | |
| for i, quantizer in enumerate(self.quantizers): | |
| if i >= n_quantizers: | |
| break | |
| masked_residual = residual * mask.unsqueeze(1) | |
| z_q_i, indices_i, _ = quantizer(masked_residual) | |
| update_mask = mask.unsqueeze(1) | |
| quantized_out = quantized_out + z_q_i * update_mask | |
| residual = residual - z_q_i * update_mask | |
| all_indices.append(indices_i) | |
| all_indices = ( | |
| torch.stack(all_indices) | |
| if all_indices | |
| else torch.empty(0, batch_size, max_time, device=z.device, dtype=torch.long) | |
| ) | |
| quantized_out = self.output_proj(quantized_out) | |
| return quantized_out, all_indices, input_length | |
| def decode_codes(self, codes: torch.Tensor) -> torch.Tensor: | |
| nq, B, T = codes.shape | |
| emb = torch.zeros(B, self.rvq_dim, T, device=codes.device, dtype=torch.float32) | |
| for i, quantizer in enumerate(self.quantizers[:nq]): | |
| quantizer = cast(MossAudioTokenizerLFQ, quantizer) | |
| emb += quantizer.decode_code(codes[i]).float() | |
| emb = self.output_proj(emb) | |
| return emb | |
| # ============================================================================= | |
| # Main Model Classes | |
| # ============================================================================= | |
| class MossAudioTokenizerPreTrainedModel(PreTrainedAudioTokenizerBase): | |
| """Base class for MossAudioTokenizer models.""" | |
| config_class = MossAudioTokenizerConfig | |
| base_model_prefix = "" | |
| main_input_name = "input_values" | |
| input_modalities = "audio" | |
| supports_gradient_checkpointing = False | |
| _no_split_modules = [ | |
| "MossAudioTokenizerTransformerLayer", | |
| "MossAudioTokenizerResidualVQ", | |
| "MossAudioTokenizerResidualLFQ", | |
| ] | |
| class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel): | |
| """ | |
| MossAudioTokenizer model for audio tokenization and synthesis. | |
| This model can encode audio waveforms into discrete tokens and decode | |
| tokens back into audio waveforms. | |
| """ | |
| def __init__(self, config: MossAudioTokenizerConfig): | |
| super().__init__(config) | |
| self.config = config | |
| _ = config.version | |
| self.sampling_rate = config.sampling_rate | |
| self.downsample_rate = config.downsample_rate | |
| self.causal_transformer_context_duration = config.causal_transformer_context_duration | |
| # Build encoder | |
| current_frame_rate: float = float(self.sampling_rate) | |
| self.encoder = nn.ModuleList() | |
| for encoder_kwargs_i in config.encoder_kwargs: | |
| encoder_kwargs_i = dict(encoder_kwargs_i) # Make a copy | |
| if encoder_kwargs_i["module_type"] == "PatchedPretransform": | |
| self.encoder.append(MossAudioTokenizerPatchedPretransform(**encoder_kwargs_i, is_downsample=True)) | |
| elif encoder_kwargs_i["module_type"] == "Transformer": | |
| self.encoder.append( | |
| MossAudioTokenizerProjectedTransformer( | |
| **encoder_kwargs_i, | |
| context=int(current_frame_rate * self.causal_transformer_context_duration), | |
| ) | |
| ) | |
| current_frame_rate /= self.encoder[-1].downsample_ratio | |
| # Build quantizer | |
| quantizer_kwargs = dict(config.quantizer_kwargs) | |
| quantizer_type = quantizer_kwargs.get("quantizer_type", getattr(config, "quantizer_type", "rvq")) | |
| if quantizer_type in {"rvq", "spec_rvq"}: | |
| self.quantizer = MossAudioTokenizerResidualVQ(**quantizer_kwargs) | |
| elif quantizer_type in {"rlfq", "random_prefix_rlfq"}: | |
| self.quantizer = MossAudioTokenizerResidualLFQ(**quantizer_kwargs) | |
| else: | |
| raise ValueError(f"Unsupported quantizer_type: {quantizer_type}") | |
| # Build decoder | |
| decoder_kwargs_list = copy.deepcopy(config.decoder_kwargs) | |
| self.decoder = nn.ModuleList() | |
| for decoder_kwargs_i in decoder_kwargs_list: | |
| decoder_kwargs_i = dict(decoder_kwargs_i) | |
| if decoder_kwargs_i["module_type"] == "PatchedPretransform": | |
| self.decoder.append(MossAudioTokenizerPatchedPretransform(**decoder_kwargs_i, is_downsample=False)) | |
| elif decoder_kwargs_i["module_type"] == "Transformer": | |
| self.decoder.append( | |
| MossAudioTokenizerProjectedTransformer( | |
| **decoder_kwargs_i, | |
| context=int(current_frame_rate * self.causal_transformer_context_duration), | |
| ) | |
| ) | |
| current_frame_rate *= self.decoder[-1].downsample_ratio | |
| self.post_init() | |
| def _start_streaming(self, batch_size: int): | |
| """Start streaming mode for all modules.""" | |
| def _start(module): | |
| if isinstance(module, StreamingModule): | |
| module._streaming_state = module._init_streaming_state(batch_size) | |
| self.apply(_start) | |
| def _stop_streaming(self): | |
| """Stop streaming mode for all modules.""" | |
| def _stop(module): | |
| if isinstance(module, StreamingModule): | |
| module._streaming_state = None | |
| self.apply(_stop) | |
| def streaming(self, batch_size: int = 1): | |
| """Context manager for streaming mode.""" | |
| self._start_streaming(batch_size) | |
| try: | |
| yield | |
| finally: | |
| self._stop_streaming() | |
| def batch_encode( | |
| self, wav_list: list[torch.Tensor], num_quantizers: int | None = None | |
| ) -> MossAudioTokenizerEncoderOutput: | |
| """Batch encode a list of audio waveforms. | |
| Args: | |
| wav_list: List of audio tensors, each of shape `(num_samples,)`. | |
| num_quantizers: Number of quantizers to use. By default, all quantizers are used. | |
| Returns: | |
| [`MossAudioTokenizerEncoderOutput`] with `audio_codes` and `audio_codes_lengths`. | |
| """ | |
| if len(wav_list) == 0: | |
| raise ValueError("`wav_list` must contain at least one waveform.") | |
| device = wav_list[0].device | |
| batch_size = len(wav_list) | |
| max_length = max(wav.shape[-1] for wav in wav_list) | |
| input_values = torch.zeros(batch_size, 1, max_length, device=device) | |
| input_lengths = torch.zeros(batch_size, device=device, dtype=torch.long) | |
| for i, wav in enumerate(wav_list): | |
| input_values[i, 0, : wav.shape[-1]] = wav | |
| input_lengths[i] = wav.shape[-1] | |
| return self._encode_frame(input_values, input_lengths, n_quantizers=num_quantizers) | |
| def batch_decode( | |
| self, codes_list: list[torch.Tensor], num_quantizers: int | None = None | |
| ) -> MossAudioTokenizerDecoderOutput: | |
| """Batch decode a list of audio codes. | |
| Args: | |
| codes_list: List of audio code tensors, each of shape `(num_quantizers, codes_length)`. | |
| num_quantizers: If provided, decode only the first `num_quantizers` quantizers from each element in | |
| `codes_list`. If omitted, all elements in `codes_list` must have the same number of quantizers. | |
| Returns: | |
| [`MossAudioTokenizerDecoderOutput`] with `audio` and `audio_lengths`. | |
| """ | |
| if len(codes_list) == 0: | |
| raise ValueError("`codes_list` must contain at least one code tensor.") | |
| batch_size = len(codes_list) | |
| device = codes_list[0].device | |
| nqs = [codes.shape[0] for codes in codes_list] | |
| if num_quantizers is None: | |
| num_quantizers = nqs[0] | |
| if any(nq != num_quantizers for nq in nqs): | |
| raise ValueError( | |
| "All elements in `codes_list` must have the same number of quantizers when `num_quantizers` is None. " | |
| "Pass `num_quantizers=...` to decode a common prefix." | |
| ) | |
| else: | |
| min_nq = min(nqs) | |
| if min_nq < num_quantizers: | |
| raise ValueError( | |
| "`num_quantizers` must be <= the number of quantizers for every element in `codes_list`. " | |
| f"Got num_quantizers={num_quantizers}, min(codes.shape[0])={min_nq}." | |
| ) | |
| max_length = max(codes.shape[-1] for codes in codes_list) | |
| audio_codes = torch.zeros(num_quantizers, batch_size, max_length, device=device, dtype=torch.long) | |
| audio_codes_lengths = torch.zeros(batch_size, device=device, dtype=torch.long) | |
| for i, codes in enumerate(codes_list): | |
| codes = codes[:num_quantizers] | |
| audio_codes[:, i, : codes.shape[-1]] = codes | |
| audio_codes_lengths[i] = codes.shape[-1] | |
| return self._decode_frame(audio_codes, audio_codes_lengths) | |
| def _encode_frame( | |
| self, | |
| input_values: torch.Tensor, | |
| input_lengths: torch.Tensor | None = None, | |
| n_quantizers: int | None = None, | |
| ) -> MossAudioTokenizerEncoderOutput: | |
| """Tokenize audio waveform into discrete tokens.""" | |
| # Handle input shape | |
| if input_values.dim() == 2: | |
| input_values = input_values.unsqueeze(1) | |
| B, _, T = input_values.shape | |
| device = input_values.device | |
| if input_lengths is None: | |
| input_lengths = torch.full((B,), T, device=device, dtype=torch.long) | |
| # Pad to multiple of downsample_rate | |
| if T % self.downsample_rate != 0: | |
| pad_length = self.downsample_rate - (T % self.downsample_rate) | |
| input_values = F.pad(input_values, (0, pad_length)) | |
| # Encode | |
| e, e_lengths = input_values, input_lengths | |
| for encoder_module in self.encoder: | |
| e, e_lengths = encoder_module(e, e_lengths) | |
| # Quantize | |
| quantizer = cast(MossAudioTokenizerResidualVQ | MossAudioTokenizerResidualLFQ, self.quantizer) | |
| zq, audio_codes, audio_codes_lengths = quantizer(e, e_lengths, n_quantizers) | |
| return MossAudioTokenizerEncoderOutput( | |
| audio_codes=audio_codes, audio_codes_lengths=audio_codes_lengths, encoder_hidden_states=e | |
| ) | |
| def _decode_frame( | |
| self, | |
| codes: torch.Tensor, | |
| codes_lengths: torch.Tensor | None = None, | |
| ) -> MossAudioTokenizerDecoderOutput: | |
| """Detokenize discrete tokens into audio waveform.""" | |
| nq, B, T = codes.shape | |
| device = codes.device | |
| if codes_lengths is None: | |
| codes_lengths = torch.full((B,), T, device=device, dtype=torch.long) | |
| # Decode from codes | |
| quantizer = cast(MossAudioTokenizerResidualVQ | MossAudioTokenizerResidualLFQ, self.quantizer) | |
| zq = quantizer.decode_codes(codes) | |
| d, d_lengths = zq, codes_lengths | |
| for decoder_module in self.decoder: | |
| d, d_lengths = decoder_module(d, d_lengths) | |
| return MossAudioTokenizerDecoderOutput(audio=d, audio_lengths=d_lengths) | |
| def encode( # type: ignore[override] | |
| self, | |
| input_values: torch.Tensor, | |
| padding_mask: torch.Tensor | None = None, | |
| num_quantizers: int | None = None, | |
| return_dict: bool | None = None, | |
| chunk_duration: float | None = None, | |
| ): | |
| """ | |
| Encodes the input audio waveform into discrete codes. | |
| Args: | |
| input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): | |
| Float values of the input audio waveform. | |
| padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Mask to indicate valid audio samples. | |
| num_quantizers (`int`, *optional*): | |
| Number of quantizers to use. By default, all quantizers are used. | |
| return_dict (`bool`, *optional*): | |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
| chunk_duration (`float`, *optional*): | |
| If provided, encode the input waveform in successive chunks of `chunk_duration` seconds while keeping a | |
| streaming KV cache for the causal transformers. | |
| `chunk_duration` must be <= `config.causal_transformer_context_duration`, and | |
| `chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`. | |
| Returns: | |
| `MossAudioTokenizerEncoderOutput` or tuple containing audio codes and lengths. | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.return_dict | |
| # Handle input shape | |
| if input_values.dim() == 2: | |
| input_values = input_values.unsqueeze(1) | |
| B, _, T = input_values.shape | |
| device = input_values.device | |
| if padding_mask is not None: | |
| input_lengths = padding_mask.sum(dim=-1).long() | |
| else: | |
| input_lengths = torch.full((B,), T, device=device, dtype=torch.long) | |
| if chunk_duration is None: | |
| encoder_output = self._encode_frame(input_values, input_lengths, num_quantizers) | |
| else: | |
| if chunk_duration <= 0: | |
| raise ValueError("`chunk_duration` must be > 0 when provided.") | |
| if chunk_duration > self.causal_transformer_context_duration: | |
| raise ValueError( | |
| "`chunk_duration` must be <= `config.causal_transformer_context_duration` " | |
| f"({self.causal_transformer_context_duration}), got {chunk_duration}." | |
| ) | |
| if B != 1: | |
| raise ValueError("Streaming encode via `chunk_duration` currently only supports batch_size=1.") | |
| chunk_length = int(round(chunk_duration * self.sampling_rate)) | |
| if chunk_length <= 0: | |
| raise ValueError("`chunk_duration` is too small and results in chunk_length <= 0.") | |
| if chunk_length % self.downsample_rate != 0: | |
| raise ValueError( | |
| "`chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`. " | |
| f"Got chunk_length={chunk_length}, downsample_rate={self.downsample_rate}." | |
| ) | |
| input_length = int(input_lengths[0].item()) | |
| if input_length <= chunk_length: | |
| encoder_output = self._encode_frame(input_values[..., :input_length], input_lengths, num_quantizers) | |
| else: | |
| codes_chunks: list[torch.Tensor] = [] | |
| hidden_chunks: list[torch.Tensor] = [] | |
| with ExitStack() as exit_stack: | |
| for encoder_module in self.encoder: | |
| if isinstance(encoder_module, StreamingModule): | |
| exit_stack.enter_context(encoder_module.streaming(batch_size=B)) | |
| for start_idx in range(0, input_length, chunk_length): | |
| input_length_i = min(chunk_length, input_length - start_idx) | |
| if input_length_i <= 0: | |
| break | |
| input_lengths_i = torch.tensor([input_length_i], device=device, dtype=torch.long) | |
| input_values_i = input_values[..., start_idx : start_idx + input_length_i] | |
| result_i = self._encode_frame(input_values_i, input_lengths_i, num_quantizers) | |
| if result_i.audio_codes is None or result_i.audio_codes_lengths is None: | |
| raise RuntimeError("Internal error: `_encode_frame` returned empty audio codes.") | |
| if result_i.encoder_hidden_states is None: | |
| raise RuntimeError("Internal error: `_encode_frame` returned empty encoder hidden states.") | |
| codes_length_i = result_i.audio_codes_lengths | |
| codes_chunks.append(result_i.audio_codes[:, :, : codes_length_i[0]]) | |
| hidden_chunks.append(result_i.encoder_hidden_states[:, :, : codes_length_i[0]]) | |
| audio_codes = torch.cat(codes_chunks, dim=-1) | |
| encoder_hidden_states = torch.cat(hidden_chunks, dim=-1) | |
| audio_codes_lengths = torch.tensor([audio_codes.shape[-1]], device=device, dtype=torch.long) | |
| encoder_output = MossAudioTokenizerEncoderOutput( | |
| audio_codes=audio_codes, | |
| audio_codes_lengths=audio_codes_lengths, | |
| encoder_hidden_states=encoder_hidden_states, | |
| ) | |
| if not return_dict: | |
| assert encoder_output.audio_codes is not None | |
| assert encoder_output.audio_codes_lengths is not None | |
| return ( | |
| cast(torch.Tensor, encoder_output.audio_codes), | |
| cast(torch.Tensor, encoder_output.audio_codes_lengths), | |
| ) | |
| return encoder_output | |
| def decode( # type: ignore[override] | |
| self, | |
| audio_codes: torch.Tensor, | |
| padding_mask: torch.Tensor | None = None, | |
| return_dict: bool | None = None, | |
| chunk_duration: float | None = None, | |
| num_quantizers: int | None = None, | |
| ): | |
| """ | |
| Decodes the given codes into an output audio waveform. | |
| Args: | |
| audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`): | |
| Discrete code embeddings computed using `model.encode`. | |
| padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Mask to indicate valid code positions. | |
| return_dict (`bool`, *optional*): | |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
| chunk_duration (`float`, *optional*): | |
| If provided, decode the input codes in successive chunks of `chunk_duration` seconds while keeping a | |
| streaming KV cache for the causal transformers. | |
| num_quantizers (`int`, *optional*): | |
| Number of quantizers to use. By default, all quantizers in `audio_codes` are used. | |
| `chunk_duration` must be <= `config.causal_transformer_context_duration`, and | |
| `chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`. | |
| Returns: | |
| `MossAudioTokenizerDecoderOutput` or tuple containing decoded audio. | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.return_dict | |
| if audio_codes.dim() == 2: | |
| audio_codes = audio_codes.unsqueeze(1) # nq, T -> nq, B=1, T | |
| if num_quantizers is not None: | |
| if num_quantizers > audio_codes.shape[0]: | |
| raise ValueError( | |
| f"`num_quantizers` ({num_quantizers}) must be <= audio_codes.shape[0] ({audio_codes.shape[0]})." | |
| ) | |
| audio_codes = audio_codes[:num_quantizers] | |
| _, B, T = audio_codes.shape | |
| device = audio_codes.device | |
| if padding_mask is not None: | |
| codes_lengths = padding_mask.sum(dim=-1).long() | |
| else: | |
| codes_lengths = torch.full((B,), T, device=device, dtype=torch.long) | |
| if chunk_duration is None: | |
| decoder_output = self._decode_frame(audio_codes, codes_lengths) | |
| else: | |
| if chunk_duration <= 0: | |
| raise ValueError("`chunk_duration` must be > 0 when provided.") | |
| if chunk_duration > self.causal_transformer_context_duration: | |
| raise ValueError( | |
| "`chunk_duration` must be <= `config.causal_transformer_context_duration` " | |
| f"({self.causal_transformer_context_duration}), got {chunk_duration}." | |
| ) | |
| if B != 1: | |
| raise ValueError("Streaming decode via `chunk_duration` currently only supports batch_size=1.") | |
| chunk_length = int(round(chunk_duration * self.sampling_rate)) | |
| if chunk_length <= 0: | |
| raise ValueError("`chunk_duration` is too small and results in chunk_length <= 0.") | |
| if chunk_length % self.downsample_rate != 0: | |
| raise ValueError( | |
| "`chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`. " | |
| f"Got chunk_length={chunk_length}, downsample_rate={self.downsample_rate}." | |
| ) | |
| chunk_frame_length = chunk_length // self.downsample_rate | |
| codes_length = int(codes_lengths[0].item()) | |
| if codes_length <= chunk_frame_length: | |
| decoder_output = self._decode_frame(audio_codes[..., :codes_length], codes_lengths) | |
| else: | |
| wav_chunks: list[torch.Tensor] = [] | |
| with ExitStack() as exit_stack: | |
| for decoder_module in self.decoder: | |
| if isinstance(decoder_module, StreamingModule): | |
| exit_stack.enter_context(decoder_module.streaming(batch_size=B)) | |
| for start_idx in range(0, codes_length, chunk_frame_length): | |
| codes_length_i = min(chunk_frame_length, codes_length - start_idx) | |
| if codes_length_i <= 0: | |
| break | |
| codes_lengths_i = torch.tensor([codes_length_i], device=device, dtype=torch.long) | |
| codes_i = audio_codes[:, :, start_idx : start_idx + codes_length_i] | |
| result_i = self._decode_frame(codes_i, codes_lengths_i) | |
| if result_i.audio is None or result_i.audio_lengths is None: | |
| raise RuntimeError("Internal error: `_decode_frame` returned empty audio.") | |
| wav_chunks.append(result_i.audio[:, :, : result_i.audio_lengths[0]]) | |
| wav = torch.cat(wav_chunks, dim=-1) | |
| audio_lengths = torch.tensor([wav.shape[-1]], device=device, dtype=torch.long) | |
| decoder_output = MossAudioTokenizerDecoderOutput(audio=wav, audio_lengths=audio_lengths) | |
| if not return_dict: | |
| assert decoder_output.audio is not None | |
| return (cast(torch.Tensor, decoder_output.audio),) | |
| return decoder_output | |
| def forward( | |
| self, | |
| input_values: torch.FloatTensor | None = None, | |
| padding_mask: torch.BoolTensor | None = None, | |
| audio_codes: torch.Tensor | None = None, | |
| num_quantizers: int | None = None, | |
| return_dict: bool | None = None, | |
| ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | MossAudioTokenizerOutput: # type: ignore[override] | |
| r""" | |
| input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*): | |
| Raw audio input converted to Float. | |
| padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Mask to avoid computing on padding token indices. Mask values selected in `[0, 1]`: | |
| - 1 for tokens that are **not masked**, | |
| - 0 for tokens that are **masked**. | |
| audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`, *optional*): | |
| Discrete code embeddings computed using `model.encode`. | |
| num_quantizers (`int`, *optional*): | |
| Number of quantizers (codebooks) to use. By default, all quantizers are used. | |
| return_dict (`bool`, *optional*): | |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
| Examples: | |
| ```python | |
| >>> import torch | |
| >>> from transformers import MossAudioTokenizerModel | |
| >>> model = MossAudioTokenizerModel.from_pretrained("moss_audio_tokenizer-model") | |
| >>> # Create dummy audio input | |
| >>> audio = torch.randn(1, 1, 24000) # 1 second of audio at 24kHz | |
| >>> outputs = model(input_values=audio) | |
| >>> audio_codes = outputs.audio_codes | |
| >>> audio_values = outputs.audio | |
| ``` | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.return_dict | |
| output_audio_codes: torch.Tensor | None = None | |
| output_audio_codes_lengths: torch.Tensor | None = None | |
| output_audio: torch.Tensor | None = None | |
| output_audio_lengths: torch.Tensor | None = None | |
| decoded_from_encoded_codes = False | |
| # Encode if input_values provided | |
| if input_values is not None: | |
| encoder_output = self.encode(input_values, padding_mask, num_quantizers, return_dict=True) | |
| encoder_output = cast(MossAudioTokenizerEncoderOutput, encoder_output) | |
| output_audio_codes = encoder_output.audio_codes | |
| output_audio_codes_lengths = encoder_output.audio_codes_lengths | |
| # If codes not provided separately, use encoded codes for decoding | |
| if audio_codes is None: | |
| audio_codes = output_audio_codes | |
| decoded_from_encoded_codes = True | |
| # Decode if codes available | |
| if audio_codes is not None: | |
| # If we're decoding the codes we just produced, use the computed lengths so we don't decode padded garbage. | |
| if decoded_from_encoded_codes and output_audio_codes_lengths is not None: | |
| decoder_output = self._decode_frame(audio_codes, output_audio_codes_lengths) | |
| else: | |
| decoder_output = self.decode( | |
| audio_codes, | |
| padding_mask=padding_mask, | |
| return_dict=True, | |
| num_quantizers=num_quantizers, | |
| ) | |
| decoder_output = cast(MossAudioTokenizerDecoderOutput, decoder_output) | |
| output_audio = decoder_output.audio | |
| output_audio_lengths = decoder_output.audio_lengths | |
| if not return_dict: | |
| return (output_audio_codes, output_audio, output_audio_lengths) | |
| return MossAudioTokenizerOutput( | |
| audio=output_audio, | |
| audio_lengths=output_audio_lengths, | |
| audio_codes=output_audio_codes, | |
| audio_codes_lengths=output_audio_codes_lengths, | |
| ) | |
| __all__ = ["MossAudioTokenizerModel", "MossAudioTokenizerPreTrainedModel"] | |