gretatuckute commited on
Commit
df1b42e
·
verified ·
1 Parent(s): ffc5144

Doc strings + comments

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +28 -14
modeling_auristream.py CHANGED
@@ -240,7 +240,8 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
240
  AuriStream speech language model.
241
 
242
  A GPT-like transformer model for cochlear token prediction with optional
243
- multi-token prediction (MTP) heads for speculative decoding.
 
244
 
245
  Developed by Greta Tuckute and Klemen Kotar.
246
  """
@@ -266,7 +267,7 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
266
  else:
267
  self.future_heads = None
268
 
269
- # Output head
270
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
271
 
272
  # Initialize weights
@@ -305,16 +306,27 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
305
  Args:
306
  input_ids: Input token IDs of shape (batch_size, seq_len)
307
  labels: Target token IDs for computing loss
308
- output_logits: Whether to return all logits (including from future heads)
309
- output_hidden_states: Whether to return all hidden states
310
- return_dict: Whether to return a dict or tuple
311
- up_until_layer: Stop forward pass at this layer index
 
 
 
 
 
312
  normalize_embeddings: 'l2' or 'learned' to normalize hidden states
313
- seq: Legacy argument (alias for input_ids)
314
- tgt: Legacy argument (alias for labels)
315
 
316
  Returns:
317
- CausalLMOutput with logits and optional loss, or tuple
 
 
 
 
 
 
318
  """
319
  # Handle legacy arguments
320
  if seq is not None:
@@ -343,16 +355,18 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
343
  # Normalize hidden states if requested
344
  hs_to_return = all_hidden_states
345
  if output_hidden_states and normalize_embeddings is not None:
346
- if normalize_embeddings == 'l2':
347
- hs_to_return = [F.normalize(h, p=2, dim=-1) for h in all_hidden_states]
348
- elif normalize_embeddings == 'learned':
 
 
349
  hs_to_return = []
350
  L = len(self.h)
351
  for i, h in enumerate(all_hidden_states):
352
  if i < L:
353
- hs_to_return.append(self.h[i].norm1(h))
354
  else:
355
- hs_to_return.append(self.ln_f(h))
356
 
357
  # If only hidden states requested (not logits), return early
358
  if output_hidden_states and not output_logits and labels is None:
 
240
  AuriStream speech language model.
241
 
242
  A GPT-like transformer model for cochlear token prediction with optional
243
+ multi-token prediction (MTP) heads for improved representation learning and
244
+ novel inference capabilities.
245
 
246
  Developed by Greta Tuckute and Klemen Kotar.
247
  """
 
267
  else:
268
  self.future_heads = None
269
 
270
+ # "Standard" LM output head
271
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
272
 
273
  # Initialize weights
 
306
  Args:
307
  input_ids: Input token IDs of shape (batch_size, seq_len)
308
  labels: Target token IDs for computing loss
309
+ output_logits: Whether to return all logits (including from future heads).
310
+ The first element corresponds to the standard next-token head (prediction of i+1);
311
+ subsequent elements correspond to future heads predicting tokens i+2, i+3, etc.
312
+ output_hidden_states: Whether to return all hidden states, including the input
313
+ embedding state and final pre-ln_f state. Matches HuggingFace GPT-style.
314
+ return_dict: Whether to return a dict or tuple. If True, return a CausalLMOutput dict,
315
+ otherwise return a tuple.
316
+ up_until_layer: If set, stop the forward pass after this transformer block
317
+ (inclusive) and return intermediate activations. Useful for saving compute.
318
  normalize_embeddings: 'l2' or 'learned' to normalize hidden states
319
+ seq: Legacy argument (alias for input_ids for backward compatibility)
320
+ tgt: Legacy argument (alias for labels for backward compatibility)
321
 
322
  Returns:
323
+ If return_dict is True:
324
+ CausalLMOutput with fields:
325
+ • loss (optional): Scalar training loss
326
+ • logits: Tensor or list of tensors of prediction logits
327
+ • hidden_states (optional): Tuple of hidden states
328
+ Otherwise:
329
+ Tuple of (logits or list of logits, loss).
330
  """
331
  # Handle legacy arguments
332
  if seq is not None:
 
355
  # Normalize hidden states if requested
356
  hs_to_return = all_hidden_states
357
  if output_hidden_states and normalize_embeddings is not None:
358
+ if normalize_embeddings == 'l2': # Preserve direction, get rid of magnitude
359
+ hs_to_return = [F.normalize(h, p=2, dim=-1) for h in all_hidden_states] # Dim -1 is the hidden state dim;
360
+ # after normalization torch.norm(h_norm, p=2, dim=-1) will be 1. I.e. for every token, the hidden state dim norm is 1.
361
+ elif normalize_embeddings == 'learned': # We use the learned RMSNorm (first one; used to prepare embeddings for attn)
362
+ # I.e. these are the representations on which the model computes.
363
  hs_to_return = []
364
  L = len(self.h)
365
  for i, h in enumerate(all_hidden_states):
366
  if i < L:
367
+ hs_to_return.append(self.h[i].norm1(h))
368
  else:
369
+ hs_to_return.append(self.ln_f(h)) # Final layer norm (after the main blocks, before LM head(s))
370
 
371
  # If only hidden states requested (not logits), return early
372
  if output_hidden_states and not output_logits and labels is None: