Doc strings + comments
Browse files- modeling_auristream.py +28 -14
modeling_auristream.py
CHANGED
|
@@ -240,7 +240,8 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 240 |
AuriStream speech language model.
|
| 241 |
|
| 242 |
A GPT-like transformer model for cochlear token prediction with optional
|
| 243 |
-
multi-token prediction (MTP) heads for
|
|
|
|
| 244 |
|
| 245 |
Developed by Greta Tuckute and Klemen Kotar.
|
| 246 |
"""
|
|
@@ -266,7 +267,7 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 266 |
else:
|
| 267 |
self.future_heads = None
|
| 268 |
|
| 269 |
-
#
|
| 270 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 271 |
|
| 272 |
# Initialize weights
|
|
@@ -305,16 +306,27 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 305 |
Args:
|
| 306 |
input_ids: Input token IDs of shape (batch_size, seq_len)
|
| 307 |
labels: Target token IDs for computing loss
|
| 308 |
-
output_logits: Whether to return all logits (including from future heads)
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
normalize_embeddings: 'l2' or 'learned' to normalize hidden states
|
| 313 |
-
seq: Legacy argument (alias for input_ids)
|
| 314 |
-
tgt: Legacy argument (alias for labels)
|
| 315 |
|
| 316 |
Returns:
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
"""
|
| 319 |
# Handle legacy arguments
|
| 320 |
if seq is not None:
|
|
@@ -343,16 +355,18 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 343 |
# Normalize hidden states if requested
|
| 344 |
hs_to_return = all_hidden_states
|
| 345 |
if output_hidden_states and normalize_embeddings is not None:
|
| 346 |
-
if normalize_embeddings == 'l2':
|
| 347 |
-
hs_to_return = [F.normalize(h, p=2, dim=-1) for h in all_hidden_states]
|
| 348 |
-
|
|
|
|
|
|
|
| 349 |
hs_to_return = []
|
| 350 |
L = len(self.h)
|
| 351 |
for i, h in enumerate(all_hidden_states):
|
| 352 |
if i < L:
|
| 353 |
-
hs_to_return.append(self.h[i].norm1(h))
|
| 354 |
else:
|
| 355 |
-
hs_to_return.append(self.ln_f(h))
|
| 356 |
|
| 357 |
# If only hidden states requested (not logits), return early
|
| 358 |
if output_hidden_states and not output_logits and labels is None:
|
|
|
|
| 240 |
AuriStream speech language model.
|
| 241 |
|
| 242 |
A GPT-like transformer model for cochlear token prediction with optional
|
| 243 |
+
multi-token prediction (MTP) heads for improved representation learning and
|
| 244 |
+
novel inference capabilities.
|
| 245 |
|
| 246 |
Developed by Greta Tuckute and Klemen Kotar.
|
| 247 |
"""
|
|
|
|
| 267 |
else:
|
| 268 |
self.future_heads = None
|
| 269 |
|
| 270 |
+
# "Standard" LM output head
|
| 271 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 272 |
|
| 273 |
# Initialize weights
|
|
|
|
| 306 |
Args:
|
| 307 |
input_ids: Input token IDs of shape (batch_size, seq_len)
|
| 308 |
labels: Target token IDs for computing loss
|
| 309 |
+
output_logits: Whether to return all logits (including from future heads).
|
| 310 |
+
The first element corresponds to the standard next-token head (prediction of i+1);
|
| 311 |
+
subsequent elements correspond to future heads predicting tokens i+2, i+3, etc.
|
| 312 |
+
output_hidden_states: Whether to return all hidden states, including the input
|
| 313 |
+
embedding state and final pre-ln_f state. Matches HuggingFace GPT-style.
|
| 314 |
+
return_dict: Whether to return a dict or tuple. If True, return a CausalLMOutput dict,
|
| 315 |
+
otherwise return a tuple.
|
| 316 |
+
up_until_layer: If set, stop the forward pass after this transformer block
|
| 317 |
+
(inclusive) and return intermediate activations. Useful for saving compute.
|
| 318 |
normalize_embeddings: 'l2' or 'learned' to normalize hidden states
|
| 319 |
+
seq: Legacy argument (alias for input_ids for backward compatibility)
|
| 320 |
+
tgt: Legacy argument (alias for labels for backward compatibility)
|
| 321 |
|
| 322 |
Returns:
|
| 323 |
+
If return_dict is True:
|
| 324 |
+
CausalLMOutput with fields:
|
| 325 |
+
• loss (optional): Scalar training loss
|
| 326 |
+
• logits: Tensor or list of tensors of prediction logits
|
| 327 |
+
• hidden_states (optional): Tuple of hidden states
|
| 328 |
+
Otherwise:
|
| 329 |
+
Tuple of (logits or list of logits, loss).
|
| 330 |
"""
|
| 331 |
# Handle legacy arguments
|
| 332 |
if seq is not None:
|
|
|
|
| 355 |
# Normalize hidden states if requested
|
| 356 |
hs_to_return = all_hidden_states
|
| 357 |
if output_hidden_states and normalize_embeddings is not None:
|
| 358 |
+
if normalize_embeddings == 'l2': # Preserve direction, get rid of magnitude
|
| 359 |
+
hs_to_return = [F.normalize(h, p=2, dim=-1) for h in all_hidden_states] # Dim -1 is the hidden state dim;
|
| 360 |
+
# after normalization torch.norm(h_norm, p=2, dim=-1) will be 1. I.e. for every token, the hidden state dim norm is 1.
|
| 361 |
+
elif normalize_embeddings == 'learned': # We use the learned RMSNorm (first one; used to prepare embeddings for attn)
|
| 362 |
+
# I.e. these are the representations on which the model computes.
|
| 363 |
hs_to_return = []
|
| 364 |
L = len(self.h)
|
| 365 |
for i, h in enumerate(all_hidden_states):
|
| 366 |
if i < L:
|
| 367 |
+
hs_to_return.append(self.h[i].norm1(h))
|
| 368 |
else:
|
| 369 |
+
hs_to_return.append(self.ln_f(h)) # Final layer norm (after the main blocks, before LM head(s))
|
| 370 |
|
| 371 |
# If only hidden states requested (not logits), return early
|
| 372 |
if output_hidden_states and not output_logits and labels is None:
|