vermouthdky commited on
Commit
a19f17d
·
verified ·
1 Parent(s): 56ccb13

Upload 2 files

Browse files
Files changed (1) hide show
  1. modeling_qwen2.py +87 -266
modeling_qwen2.py CHANGED
@@ -40,8 +40,8 @@ from transformers.utils import (add_start_docstrings,
40
  is_flash_attn_greater_or_equal_2_10, logging,
41
  replace_return_docstrings)
42
 
 
43
  from .configuration_qwen2 import QwenEnPRMConfig as Qwen2Config
44
- from .nets import EnsembleModel
45
 
46
  if is_flash_attn_2_available():
47
  from transformers.modeling_flash_attention_utils import \
@@ -92,30 +92,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
92
  # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
93
  causal_mask = attention_mask
94
  else:
95
- causal_mask = torch.full(
96
- (sequence_length, target_length),
97
- fill_value=min_dtype,
98
- dtype=dtype,
99
- device=device,
100
- )
101
  if sequence_length != 1:
102
  causal_mask = torch.triu(causal_mask, diagonal=1)
103
- causal_mask *= torch.arange(
104
- target_length, device=device
105
- ) > cache_position.reshape(-1, 1)
106
  causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
107
  if attention_mask is not None:
108
- causal_mask = (
109
- causal_mask.clone()
110
- ) # copy to contiguous memory for in-place edit
111
  mask_length = attention_mask.shape[-1]
112
- padding_mask = (
113
- causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
114
- )
115
  padding_mask = padding_mask == 0
116
- causal_mask[:, :, :, :mask_length] = causal_mask[
117
- :, :, :, :mask_length
118
- ].masked_fill(padding_mask, min_dtype)
119
 
120
  return causal_mask
121
 
@@ -149,27 +138,17 @@ class Qwen2RotaryEmbedding(nn.Module):
149
  self.dim = dim
150
  self.max_position_embeddings = max_position_embeddings
151
  self.base = base
152
- inv_freq = 1.0 / (
153
- self.base
154
- ** (
155
- torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
156
- / self.dim
157
- )
158
- )
159
  self.register_buffer("inv_freq", inv_freq, persistent=False)
160
 
161
  # Build here to make `torch.jit.trace` work.
162
  self._set_cos_sin_cache(
163
- seq_len=max_position_embeddings,
164
- device=self.inv_freq.device,
165
- dtype=torch.get_default_dtype(),
166
  )
167
 
168
  def _set_cos_sin_cache(self, seq_len, device, dtype):
169
  self.max_seq_len_cached = seq_len
170
- t = torch.arange(
171
- self.max_seq_len_cached, device=device, dtype=torch.int64
172
- ).type_as(self.inv_freq)
173
 
174
  freqs = torch.outer(t, self.inv_freq)
175
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
@@ -237,9 +216,7 @@ class Qwen2MLP(nn.Module):
237
  self.act_fn = ACT2FN[config.hidden_act]
238
 
239
  def forward(self, hidden_state):
240
- return self.down_proj(
241
- self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)
242
- )
243
 
244
 
245
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
@@ -251,9 +228,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
251
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
252
  if n_rep == 1:
253
  return hidden_states
254
- hidden_states = hidden_states[:, :, None, :, :].expand(
255
- batch, num_key_value_heads, n_rep, slen, head_dim
256
- )
257
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
258
 
259
 
@@ -289,18 +264,10 @@ class Qwen2Attention(nn.Module):
289
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
290
  f" and `num_heads`: {self.num_heads})."
291
  )
292
- self.q_proj = nn.Linear(
293
- self.hidden_size, self.num_heads * self.head_dim, bias=True
294
- )
295
- self.k_proj = nn.Linear(
296
- self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
297
- )
298
- self.v_proj = nn.Linear(
299
- self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
300
- )
301
- self.o_proj = nn.Linear(
302
- self.num_heads * self.head_dim, self.hidden_size, bias=False
303
- )
304
 
305
  self.rotary_emb = Qwen2RotaryEmbedding(
306
  self.head_dim,
@@ -324,15 +291,9 @@ class Qwen2Attention(nn.Module):
324
  key_states = self.k_proj(hidden_states)
325
  value_states = self.v_proj(hidden_states)
326
 
327
- query_states = query_states.view(
328
- bsz, q_len, self.num_heads, self.head_dim
329
- ).transpose(1, 2)
330
- key_states = key_states.view(
331
- bsz, q_len, self.num_key_value_heads, self.head_dim
332
- ).transpose(1, 2)
333
- value_states = value_states.view(
334
- bsz, q_len, self.num_key_value_heads, self.head_dim
335
- ).transpose(1, 2)
336
 
337
  kv_seq_len = key_states.shape[-2]
338
  if past_key_value is not None:
@@ -344,27 +305,17 @@ class Qwen2Attention(nn.Module):
344
  )
345
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
346
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
347
- query_states, key_states = apply_rotary_pos_emb(
348
- query_states, key_states, cos, sin, position_ids
349
- )
350
 
351
  if past_key_value is not None:
352
- cache_kwargs = {
353
- "sin": sin,
354
- "cos": cos,
355
- "cache_position": cache_position,
356
- } # Specific to RoPE models
357
- key_states, value_states = past_key_value.update(
358
- key_states, value_states, self.layer_idx, cache_kwargs
359
- )
360
 
361
  # repeat k/v heads if n_kv_heads < n_heads
362
  key_states = repeat_kv(key_states, self.num_key_value_groups)
363
  value_states = repeat_kv(value_states, self.num_key_value_groups)
364
 
365
- attn_weights = torch.matmul(
366
- query_states, key_states.transpose(2, 3)
367
- ) / math.sqrt(self.head_dim)
368
 
369
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
370
  raise ValueError(
@@ -377,12 +328,8 @@ class Qwen2Attention(nn.Module):
377
  attn_weights = attn_weights + causal_mask
378
 
379
  # upcast attention to fp32
380
- attn_weights = nn.functional.softmax(
381
- attn_weights, dim=-1, dtype=torch.float32
382
- ).to(query_states.dtype)
383
- attn_weights = nn.functional.dropout(
384
- attn_weights, p=self.attention_dropout, training=self.training
385
- )
386
  attn_output = torch.matmul(attn_weights, value_states)
387
 
388
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@@ -436,15 +383,9 @@ class Qwen2FlashAttention2(Qwen2Attention):
436
  key_states = self.k_proj(hidden_states)
437
  value_states = self.v_proj(hidden_states)
438
 
439
- query_states = query_states.view(
440
- bsz, q_len, self.num_heads, self.head_dim
441
- ).transpose(1, 2)
442
- key_states = key_states.view(
443
- bsz, q_len, self.num_key_value_heads, self.head_dim
444
- ).transpose(1, 2)
445
- value_states = value_states.view(
446
- bsz, q_len, self.num_key_value_heads, self.head_dim
447
- ).transpose(1, 2)
448
 
449
  kv_seq_len = key_states.shape[-2]
450
  if past_key_value is not None:
@@ -458,16 +399,12 @@ class Qwen2FlashAttention2(Qwen2Attention):
458
 
459
  # Because the input can be padded, the absolute sequence length depends on the max position id.
460
  rotary_seq_len = (
461
- max(kv_seq_len, position_ids[:, -1].max().item() + 1)
462
- if position_ids is not None
463
- else kv_seq_len
464
  )
465
 
466
  cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
467
 
468
- query_states, key_states = apply_rotary_pos_emb(
469
- query_states, key_states, cos, sin, position_ids
470
- )
471
 
472
  if past_key_value is not None:
473
  # Activate slicing cache only if the config has a value `sliding_windows` attribute
@@ -493,19 +430,10 @@ class Qwen2FlashAttention2(Qwen2Attention):
493
 
494
  if attention_mask is not None:
495
  attention_mask = attention_mask[:, slicing_tokens:]
496
- attention_mask = torch.cat(
497
- [attention_mask, torch.ones_like(attention_mask[:, -1:])],
498
- dim=-1,
499
- )
500
 
501
- cache_kwargs = {
502
- "sin": sin,
503
- "cos": cos,
504
- "cache_position": cache_position,
505
- } # Specific to RoPE models
506
- key_states, value_states = past_key_value.update(
507
- key_states, value_states, self.layer_idx, cache_kwargs
508
- )
509
 
510
  # repeat k/v heads if n_kv_heads < n_heads
511
  key_states = repeat_kv(key_states, self.num_key_value_groups)
@@ -611,34 +539,20 @@ class Qwen2SdpaAttention(Qwen2Attention):
611
  key_states = self.k_proj(hidden_states)
612
  value_states = self.v_proj(hidden_states)
613
 
614
- query_states = query_states.view(
615
- bsz, q_len, self.num_heads, self.head_dim
616
- ).transpose(1, 2)
617
- key_states = key_states.view(
618
- bsz, q_len, self.num_key_value_heads, self.head_dim
619
- ).transpose(1, 2)
620
- value_states = value_states.view(
621
- bsz, q_len, self.num_key_value_heads, self.head_dim
622
- ).transpose(1, 2)
623
 
624
  kv_seq_len = key_states.shape[-2]
625
  if past_key_value is not None:
626
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
627
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
628
 
629
- query_states, key_states = apply_rotary_pos_emb(
630
- query_states, key_states, cos, sin, position_ids
631
- )
632
 
633
  if past_key_value is not None:
634
- cache_kwargs = {
635
- "sin": sin,
636
- "cos": cos,
637
- "cache_position": cache_position,
638
- } # Specific to RoPE models
639
- key_states, value_states = past_key_value.update(
640
- key_states, value_states, self.layer_idx, cache_kwargs
641
- )
642
 
643
  key_states = repeat_kv(key_states, self.num_key_value_groups)
644
  value_states = repeat_kv(value_states, self.num_key_value_groups)
@@ -693,15 +607,11 @@ class Qwen2DecoderLayer(nn.Module):
693
  f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
694
  "unexpected results may be encountered."
695
  )
696
- self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](
697
- config, layer_idx
698
- )
699
 
700
  self.mlp = Qwen2MLP(config)
701
  self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
702
- self.post_attention_layernorm = Qwen2RMSNorm(
703
- config.hidden_size, eps=config.rms_norm_eps
704
- )
705
 
706
  def forward(
707
  self,
@@ -713,9 +623,7 @@ class Qwen2DecoderLayer(nn.Module):
713
  use_cache: Optional[bool] = False,
714
  cache_position: Optional[torch.LongTensor] = None,
715
  **kwargs,
716
- ) -> Tuple[
717
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
718
- ]:
719
  """
720
  Args:
721
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@@ -902,14 +810,9 @@ class Qwen2Model(Qwen2PreTrainedModel):
902
  self.padding_idx = config.pad_token_id
903
  self.vocab_size = config.vocab_size
904
 
905
- self.embed_tokens = nn.Embedding(
906
- config.vocab_size, config.hidden_size, self.padding_idx
907
- )
908
  self.layers = nn.ModuleList(
909
- [
910
- Qwen2DecoderLayer(config, layer_idx)
911
- for layer_idx in range(config.num_hidden_layers)
912
- ]
913
  )
914
  self._attn_implementation = config._attn_implementation
915
  self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -938,21 +841,13 @@ class Qwen2Model(Qwen2PreTrainedModel):
938
  return_dict: Optional[bool] = None,
939
  cache_position: Optional[torch.LongTensor] = None,
940
  ) -> Union[Tuple, BaseModelOutputWithPast]:
941
- output_attentions = (
942
- output_attentions
943
- if output_attentions is not None
944
- else self.config.output_attentions
945
- )
946
  output_hidden_states = (
947
- output_hidden_states
948
- if output_hidden_states is not None
949
- else self.config.output_hidden_states
950
  )
951
  use_cache = use_cache if use_cache is not None else self.config.use_cache
952
 
953
- return_dict = (
954
- return_dict if return_dict is not None else self.config.use_return_dict
955
- )
956
 
957
  if (input_ids is None) ^ (inputs_embeds is not None):
958
  raise ValueError(
@@ -979,23 +874,15 @@ class Qwen2Model(Qwen2PreTrainedModel):
979
  inputs_embeds = self.embed_tokens(input_ids)
980
 
981
  if cache_position is None:
982
- past_seen_tokens = (
983
- past_key_values.get_seq_length() if past_key_values is not None else 0
984
- )
985
  cache_position = torch.arange(
986
- past_seen_tokens,
987
- past_seen_tokens + inputs_embeds.shape[1],
988
- device=inputs_embeds.device,
989
  )
990
  if position_ids is None:
991
  position_ids = cache_position.unsqueeze(0)
992
 
993
  causal_mask = self._update_causal_mask(
994
- attention_mask,
995
- inputs_embeds,
996
- cache_position,
997
- past_key_values,
998
- output_attentions,
999
  )
1000
 
1001
  hidden_states = inputs_embeds
@@ -1047,18 +934,10 @@ class Qwen2Model(Qwen2PreTrainedModel):
1047
 
1048
  next_cache = None
1049
  if use_cache:
1050
- next_cache = (
1051
- next_decoder_cache.to_legacy_cache()
1052
- if use_legacy_cache
1053
- else next_decoder_cache
1054
- )
1055
 
1056
  if not return_dict:
1057
- return tuple(
1058
- v
1059
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1060
- if v is not None
1061
- )
1062
  return BaseModelOutputWithPast(
1063
  last_hidden_state=hidden_states,
1064
  past_key_values=next_cache,
@@ -1088,17 +967,11 @@ class Qwen2Model(Qwen2PreTrainedModel):
1088
  # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1089
  # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1090
  # to infer the attention mask.
1091
- past_seen_tokens = (
1092
- past_key_values.get_seq_length() if past_key_values is not None else 0
1093
- )
1094
  using_static_cache = False # isinstance(past_key_values, StaticCache)
1095
 
1096
  # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1097
- if (
1098
- self.config._attn_implementation == "sdpa"
1099
- and not using_static_cache
1100
- and not output_attentions
1101
- ):
1102
  if AttentionMaskConverter._ignore_causal_mask_sdpa(
1103
  attention_mask,
1104
  inputs_embeds=input_tensor,
@@ -1140,9 +1013,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
1140
  # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1141
  # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1142
  # Details: https://github.com/pytorch/pytorch/issues/110213
1143
- causal_mask = AttentionMaskConverter._unmask_unattended(
1144
- causal_mask, min_dtype
1145
- )
1146
 
1147
  return causal_mask
1148
 
@@ -1178,9 +1049,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1178
  return self.model
1179
 
1180
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1181
- @replace_return_docstrings(
1182
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1183
- )
1184
  def forward(
1185
  self,
1186
  input_ids: torch.LongTensor = None,
@@ -1221,19 +1090,11 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1221
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1222
  ```"""
1223
 
1224
- output_attentions = (
1225
- output_attentions
1226
- if output_attentions is not None
1227
- else self.config.output_attentions
1228
- )
1229
  output_hidden_states = (
1230
- output_hidden_states
1231
- if output_hidden_states is not None
1232
- else self.config.output_hidden_states
1233
- )
1234
- return_dict = (
1235
- return_dict if return_dict is not None else self.config.use_return_dict
1236
  )
 
1237
 
1238
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1239
  outputs = self.model(
@@ -1296,9 +1157,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1296
  if past_key_values is not None:
1297
  if inputs_embeds is not None: # Exception 1
1298
  input_ids = input_ids[:, -cache_position.shape[0] :]
1299
- elif (
1300
- input_ids.shape[1] != cache_position.shape[0]
1301
- ): # Default case (the "else", a no op, is Exception 2)
1302
  input_ids = input_ids[:, cache_position]
1303
 
1304
  if attention_mask is not None and position_ids is None:
@@ -1317,11 +1176,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1317
  else:
1318
  model_inputs = {"input_ids": input_ids}
1319
 
1320
- if (
1321
- False
1322
- and isinstance(past_key_values, StaticCache)
1323
- and attention_mask.ndim == 2
1324
- ):
1325
  if inputs_embeds is not None:
1326
  batch_size, sequence_length = inputs_embeds.shape
1327
  device = inputs_embeds.device
@@ -1406,9 +1261,7 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
1406
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1407
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1408
  """
1409
- return_dict = (
1410
- return_dict if return_dict is not None else self.config.use_return_dict
1411
- )
1412
 
1413
  transformer_outputs = self.model(
1414
  input_ids,
@@ -1430,25 +1283,19 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
1430
  batch_size = inputs_embeds.shape[0]
1431
 
1432
  if self.config.pad_token_id is None and batch_size != 1:
1433
- raise ValueError(
1434
- "Cannot handle batch sizes > 1 if no padding token is defined."
1435
- )
1436
  if self.config.pad_token_id is None:
1437
  sequence_lengths = -1
1438
  else:
1439
  if input_ids is not None:
1440
  # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1441
- sequence_lengths = (
1442
- torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1443
- )
1444
  sequence_lengths = sequence_lengths % input_ids.shape[-1]
1445
  sequence_lengths = sequence_lengths.to(logits.device)
1446
  else:
1447
  sequence_lengths = -1
1448
 
1449
- pooled_logits = logits[
1450
- torch.arange(batch_size, device=logits.device), sequence_lengths
1451
- ]
1452
 
1453
  loss = None
1454
  if labels is not None:
@@ -1456,9 +1303,7 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
1456
  if self.config.problem_type is None:
1457
  if self.num_labels == 1:
1458
  self.config.problem_type = "regression"
1459
- elif self.num_labels > 1 and (
1460
- labels.dtype == torch.long or labels.dtype == torch.int
1461
- ):
1462
  self.config.problem_type = "single_label_classification"
1463
  else:
1464
  self.config.problem_type = "multi_label_classification"
@@ -1471,9 +1316,7 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
1471
  loss = loss_fct(pooled_logits, labels)
1472
  elif self.config.problem_type == "single_label_classification":
1473
  loss_fct = CrossEntropyLoss()
1474
- loss = loss_fct(
1475
- pooled_logits.view(-1, self.num_labels), labels.view(-1)
1476
- )
1477
  elif self.config.problem_type == "multi_label_classification":
1478
  loss_fct = BCEWithLogitsLoss()
1479
  loss = loss_fct(pooled_logits, labels)
@@ -1541,9 +1384,7 @@ class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
1541
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1542
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1543
  """
1544
- return_dict = (
1545
- return_dict if return_dict is not None else self.config.use_return_dict
1546
- )
1547
 
1548
  outputs = self.model(
1549
  input_ids,
@@ -1604,7 +1445,7 @@ class Qwen2ForEnsemblePRM(Qwen2PreTrainedModel):
1604
  encoding_dim=config.hidden_size,
1605
  num_ensemble=config.num_ensemble,
1606
  )
1607
- self.score.init()
1608
  # Initialize weights and apply final processing
1609
  self.post_init()
1610
 
@@ -1621,7 +1462,7 @@ class Qwen2ForEnsemblePRM(Qwen2PreTrainedModel):
1621
  outputs.logits = torch.nn.functional.sigmoid(outputs.logits)
1622
  return outputs
1623
 
1624
- def _compute_loss(self, logits, labels):
1625
  # NOTE: we only compute the loss for specific position (labels != -100)
1626
  logits = logits.float()
1627
  loss = None
@@ -1630,23 +1471,21 @@ class Qwen2ForEnsemblePRM(Qwen2PreTrainedModel):
1630
  # only support hard labels; not need for soft labels
1631
  loss_fct = BCEWithLogitsLoss(reduction="none")
1632
 
1633
- loss = loss_fct(
1634
- logits, labels[None].repeat([logits.size(0), 1, 1]).to(logits.dtype)
1635
- )
1636
  # select loss for specific position
1637
  mask = (labels != -100)[None].repeat([logits.size(0), 1, 1])
1638
  # and random mask instance for differnet ensemble model
1639
- data_aloc_mask = (
1640
- torch.rand(mask.size(0), mask.size(1)) < self.learning_probability
1641
- )
1642
  mask = mask & data_aloc_mask[:, :, None].to(mask.device)
1643
 
1644
  loss = torch.masked_select(loss, mask)
1645
  loss = loss.mean()
1646
- loss += (
1647
- self.regularization_lambda * labels.size(0) * self.score.regularization()
1648
- )
1649
- return loss
 
 
1650
 
1651
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1652
  def forward(
@@ -1662,9 +1501,7 @@ class Qwen2ForEnsemblePRM(Qwen2PreTrainedModel):
1662
  output_hidden_states: Optional[bool] = None,
1663
  return_dict: Optional[bool] = None,
1664
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1665
- return_dict = (
1666
- return_dict if return_dict is not None else self.config.use_return_dict
1667
- )
1668
 
1669
  transformer_outputs = self.model(
1670
  input_ids,
@@ -1678,9 +1515,7 @@ class Qwen2ForEnsemblePRM(Qwen2PreTrainedModel):
1678
  return_dict=return_dict,
1679
  )
1680
  hidden_states = transformer_outputs[0] # (b, l, h)
1681
- hidden_states = hidden_states[None, :, :, :].repeat(
1682
- self.score.num_ensemble, 1, 1, 1
1683
- ) # (e, l, h)
1684
  logits = self.score(hidden_states)
1685
 
1686
  if input_ids is not None:
@@ -1689,17 +1524,13 @@ class Qwen2ForEnsemblePRM(Qwen2PreTrainedModel):
1689
  batch_size = inputs_embeds.shape[0]
1690
 
1691
  if self.config.pad_token_id is None and batch_size != 1:
1692
- raise ValueError(
1693
- "Cannot handle batch sizes > 1 if no padding token is defined."
1694
- )
1695
  if self.config.pad_token_id is None:
1696
  sequence_lengths = -1
1697
  else:
1698
  if input_ids is not None:
1699
  # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1700
- sequence_lengths = (
1701
- torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1702
- )
1703
  sequence_lengths = sequence_lengths % input_ids.shape[-1]
1704
  sequence_lengths = sequence_lengths.to(logits.device)
1705
  else:
@@ -1707,9 +1538,7 @@ class Qwen2ForEnsemblePRM(Qwen2PreTrainedModel):
1707
 
1708
  logits = logits.float()
1709
  loss = None
1710
- logits = logits.squeeze(
1711
- -1
1712
- ) # (ensemble, batch_size, seq_len, 1) -> (ensemble, batch_size, seq_len)
1713
  if labels is not None:
1714
  if self.config.problem_type is None: # NOTE: no use
1715
  if labels.dtype is not torch.long:
@@ -1721,24 +1550,16 @@ class Qwen2ForEnsemblePRM(Qwen2PreTrainedModel):
1721
  # only support hard labels; not need for soft labels
1722
  loss_fct = BCEWithLogitsLoss(reduction="none")
1723
 
1724
- loss = loss_fct(
1725
- logits, labels[None].repeat([logits.size(0), 1, 1]).to(logits.dtype)
1726
- )
1727
  # select loss for specific position
1728
  mask = (labels != -100)[None].repeat([logits.size(0), 1, 1])
1729
  # and random mask instance for differnet ensemble model
1730
- data_aloc_mask = (
1731
- torch.rand(mask.size(0), mask.size(1)) < self.learning_probability
1732
- )
1733
  mask = mask & data_aloc_mask[:, :, None].to(mask.device)
1734
 
1735
  loss = torch.masked_select(loss, mask)
1736
  loss = loss.mean()
1737
- loss += (
1738
- self.regularization_lambda
1739
- * labels.size(0)
1740
- * self.score.regularization()
1741
- )
1742
 
1743
  if not return_dict:
1744
  output = (logits,) + transformer_outputs[1:]
 
40
  is_flash_attn_greater_or_equal_2_10, logging,
41
  replace_return_docstrings)
42
 
43
+ from ..nets import EnsembleModel
44
  from .configuration_qwen2 import QwenEnPRMConfig as Qwen2Config
 
45
 
46
  if is_flash_attn_2_available():
47
  from transformers.modeling_flash_attention_utils import \
 
92
  # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
93
  causal_mask = attention_mask
94
  else:
95
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
 
 
 
 
 
96
  if sequence_length != 1:
97
  causal_mask = torch.triu(causal_mask, diagonal=1)
98
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
 
 
99
  causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
100
  if attention_mask is not None:
101
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
 
 
102
  mask_length = attention_mask.shape[-1]
103
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
 
 
104
  padding_mask = padding_mask == 0
105
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
106
+ padding_mask, min_dtype
107
+ )
108
 
109
  return causal_mask
110
 
 
138
  self.dim = dim
139
  self.max_position_embeddings = max_position_embeddings
140
  self.base = base
141
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
 
 
 
 
 
 
142
  self.register_buffer("inv_freq", inv_freq, persistent=False)
143
 
144
  # Build here to make `torch.jit.trace` work.
145
  self._set_cos_sin_cache(
146
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
 
 
147
  )
148
 
149
  def _set_cos_sin_cache(self, seq_len, device, dtype):
150
  self.max_seq_len_cached = seq_len
151
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
 
 
152
 
153
  freqs = torch.outer(t, self.inv_freq)
154
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
 
216
  self.act_fn = ACT2FN[config.hidden_act]
217
 
218
  def forward(self, hidden_state):
219
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
 
 
220
 
221
 
222
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
 
228
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
229
  if n_rep == 1:
230
  return hidden_states
231
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
 
 
232
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
233
 
234
 
 
264
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
265
  f" and `num_heads`: {self.num_heads})."
266
  )
267
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
268
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
269
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
270
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
 
 
 
 
 
 
 
 
271
 
272
  self.rotary_emb = Qwen2RotaryEmbedding(
273
  self.head_dim,
 
291
  key_states = self.k_proj(hidden_states)
292
  value_states = self.v_proj(hidden_states)
293
 
294
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
295
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
296
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
297
 
298
  kv_seq_len = key_states.shape[-2]
299
  if past_key_value is not None:
 
305
  )
306
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
307
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
308
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
 
309
 
310
  if past_key_value is not None:
311
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
312
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
 
 
 
 
313
 
314
  # repeat k/v heads if n_kv_heads < n_heads
315
  key_states = repeat_kv(key_states, self.num_key_value_groups)
316
  value_states = repeat_kv(value_states, self.num_key_value_groups)
317
 
318
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
 
319
 
320
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
321
  raise ValueError(
 
328
  attn_weights = attn_weights + causal_mask
329
 
330
  # upcast attention to fp32
331
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
332
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
 
 
 
 
333
  attn_output = torch.matmul(attn_weights, value_states)
334
 
335
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
383
  key_states = self.k_proj(hidden_states)
384
  value_states = self.v_proj(hidden_states)
385
 
386
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
387
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
388
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
389
 
390
  kv_seq_len = key_states.shape[-2]
391
  if past_key_value is not None:
 
399
 
400
  # Because the input can be padded, the absolute sequence length depends on the max position id.
401
  rotary_seq_len = (
402
+ max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
 
 
403
  )
404
 
405
  cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
406
 
407
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
 
408
 
409
  if past_key_value is not None:
410
  # Activate slicing cache only if the config has a value `sliding_windows` attribute
 
430
 
431
  if attention_mask is not None:
432
  attention_mask = attention_mask[:, slicing_tokens:]
433
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
 
 
 
434
 
435
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
436
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
 
 
 
 
437
 
438
  # repeat k/v heads if n_kv_heads < n_heads
439
  key_states = repeat_kv(key_states, self.num_key_value_groups)
 
539
  key_states = self.k_proj(hidden_states)
540
  value_states = self.v_proj(hidden_states)
541
 
542
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
543
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
544
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
545
 
546
  kv_seq_len = key_states.shape[-2]
547
  if past_key_value is not None:
548
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
549
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
550
 
551
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
 
552
 
553
  if past_key_value is not None:
554
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
555
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
 
 
 
 
556
 
557
  key_states = repeat_kv(key_states, self.num_key_value_groups)
558
  value_states = repeat_kv(value_states, self.num_key_value_groups)
 
607
  f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
608
  "unexpected results may be encountered."
609
  )
610
+ self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
 
 
611
 
612
  self.mlp = Qwen2MLP(config)
613
  self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
614
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
615
 
616
  def forward(
617
  self,
 
623
  use_cache: Optional[bool] = False,
624
  cache_position: Optional[torch.LongTensor] = None,
625
  **kwargs,
626
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
627
  """
628
  Args:
629
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
 
810
  self.padding_idx = config.pad_token_id
811
  self.vocab_size = config.vocab_size
812
 
813
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
 
814
  self.layers = nn.ModuleList(
815
+ [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
 
 
 
816
  )
817
  self._attn_implementation = config._attn_implementation
818
  self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
841
  return_dict: Optional[bool] = None,
842
  cache_position: Optional[torch.LongTensor] = None,
843
  ) -> Union[Tuple, BaseModelOutputWithPast]:
844
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
845
  output_hidden_states = (
846
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
847
  )
848
  use_cache = use_cache if use_cache is not None else self.config.use_cache
849
 
850
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
851
 
852
  if (input_ids is None) ^ (inputs_embeds is not None):
853
  raise ValueError(
 
874
  inputs_embeds = self.embed_tokens(input_ids)
875
 
876
  if cache_position is None:
877
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
 
878
  cache_position = torch.arange(
879
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
 
 
880
  )
881
  if position_ids is None:
882
  position_ids = cache_position.unsqueeze(0)
883
 
884
  causal_mask = self._update_causal_mask(
885
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
 
 
 
 
886
  )
887
 
888
  hidden_states = inputs_embeds
 
934
 
935
  next_cache = None
936
  if use_cache:
937
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
 
 
 
 
938
 
939
  if not return_dict:
940
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
 
 
 
941
  return BaseModelOutputWithPast(
942
  last_hidden_state=hidden_states,
943
  past_key_values=next_cache,
 
967
  # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
968
  # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
969
  # to infer the attention mask.
970
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
 
971
  using_static_cache = False # isinstance(past_key_values, StaticCache)
972
 
973
  # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
974
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
 
 
 
 
975
  if AttentionMaskConverter._ignore_causal_mask_sdpa(
976
  attention_mask,
977
  inputs_embeds=input_tensor,
 
1013
  # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1014
  # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1015
  # Details: https://github.com/pytorch/pytorch/issues/110213
1016
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
 
 
1017
 
1018
  return causal_mask
1019
 
 
1049
  return self.model
1050
 
1051
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1052
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
 
1053
  def forward(
1054
  self,
1055
  input_ids: torch.LongTensor = None,
 
1090
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1091
  ```"""
1092
 
1093
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
1094
  output_hidden_states = (
1095
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
1096
  )
1097
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1098
 
1099
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1100
  outputs = self.model(
 
1157
  if past_key_values is not None:
1158
  if inputs_embeds is not None: # Exception 1
1159
  input_ids = input_ids[:, -cache_position.shape[0] :]
1160
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
 
 
1161
  input_ids = input_ids[:, cache_position]
1162
 
1163
  if attention_mask is not None and position_ids is None:
 
1176
  else:
1177
  model_inputs = {"input_ids": input_ids}
1178
 
1179
+ if False and isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
 
 
 
 
1180
  if inputs_embeds is not None:
1181
  batch_size, sequence_length = inputs_embeds.shape
1182
  device = inputs_embeds.device
 
1261
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1262
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1263
  """
1264
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1265
 
1266
  transformer_outputs = self.model(
1267
  input_ids,
 
1283
  batch_size = inputs_embeds.shape[0]
1284
 
1285
  if self.config.pad_token_id is None and batch_size != 1:
1286
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
 
 
1287
  if self.config.pad_token_id is None:
1288
  sequence_lengths = -1
1289
  else:
1290
  if input_ids is not None:
1291
  # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1292
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
 
 
1293
  sequence_lengths = sequence_lengths % input_ids.shape[-1]
1294
  sequence_lengths = sequence_lengths.to(logits.device)
1295
  else:
1296
  sequence_lengths = -1
1297
 
1298
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
 
 
1299
 
1300
  loss = None
1301
  if labels is not None:
 
1303
  if self.config.problem_type is None:
1304
  if self.num_labels == 1:
1305
  self.config.problem_type = "regression"
1306
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
 
 
1307
  self.config.problem_type = "single_label_classification"
1308
  else:
1309
  self.config.problem_type = "multi_label_classification"
 
1316
  loss = loss_fct(pooled_logits, labels)
1317
  elif self.config.problem_type == "single_label_classification":
1318
  loss_fct = CrossEntropyLoss()
1319
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
 
 
1320
  elif self.config.problem_type == "multi_label_classification":
1321
  loss_fct = BCEWithLogitsLoss()
1322
  loss = loss_fct(pooled_logits, labels)
 
1384
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1385
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1386
  """
1387
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1388
 
1389
  outputs = self.model(
1390
  input_ids,
 
1445
  encoding_dim=config.hidden_size,
1446
  num_ensemble=config.num_ensemble,
1447
  )
1448
+ # self.score.init()
1449
  # Initialize weights and apply final processing
1450
  self.post_init()
1451
 
 
1462
  outputs.logits = torch.nn.functional.sigmoid(outputs.logits)
1463
  return outputs
1464
 
1465
+ def _compute_loss(self, logits, labels, return_reg_loss=False):
1466
  # NOTE: we only compute the loss for specific position (labels != -100)
1467
  logits = logits.float()
1468
  loss = None
 
1471
  # only support hard labels; not need for soft labels
1472
  loss_fct = BCEWithLogitsLoss(reduction="none")
1473
 
1474
+ loss = loss_fct(logits, labels[None].repeat([logits.size(0), 1, 1]).to(logits.dtype))
 
 
1475
  # select loss for specific position
1476
  mask = (labels != -100)[None].repeat([logits.size(0), 1, 1])
1477
  # and random mask instance for differnet ensemble model
1478
+ data_aloc_mask = torch.rand(mask.size(0), mask.size(1)) < self.learning_probability
 
 
1479
  mask = mask & data_aloc_mask[:, :, None].to(mask.device)
1480
 
1481
  loss = torch.masked_select(loss, mask)
1482
  loss = loss.mean()
1483
+ reg_loss = self.regularization_lambda * self.score.regularization()
1484
+ loss += reg_loss
1485
+ if not return_reg_loss:
1486
+ return loss
1487
+ else:
1488
+ return (loss, reg_loss)
1489
 
1490
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1491
  def forward(
 
1501
  output_hidden_states: Optional[bool] = None,
1502
  return_dict: Optional[bool] = None,
1503
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1504
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1505
 
1506
  transformer_outputs = self.model(
1507
  input_ids,
 
1515
  return_dict=return_dict,
1516
  )
1517
  hidden_states = transformer_outputs[0] # (b, l, h)
1518
+ hidden_states = hidden_states[None, :, :, :].repeat(self.score.num_ensemble, 1, 1, 1) # (e, l, h)
 
 
1519
  logits = self.score(hidden_states)
1520
 
1521
  if input_ids is not None:
 
1524
  batch_size = inputs_embeds.shape[0]
1525
 
1526
  if self.config.pad_token_id is None and batch_size != 1:
1527
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
 
 
1528
  if self.config.pad_token_id is None:
1529
  sequence_lengths = -1
1530
  else:
1531
  if input_ids is not None:
1532
  # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1533
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
 
 
1534
  sequence_lengths = sequence_lengths % input_ids.shape[-1]
1535
  sequence_lengths = sequence_lengths.to(logits.device)
1536
  else:
 
1538
 
1539
  logits = logits.float()
1540
  loss = None
1541
+ logits = logits.squeeze(-1) # (ensemble, batch_size, seq_len, 1) -> (ensemble, batch_size, seq_len)
 
 
1542
  if labels is not None:
1543
  if self.config.problem_type is None: # NOTE: no use
1544
  if labels.dtype is not torch.long:
 
1550
  # only support hard labels; not need for soft labels
1551
  loss_fct = BCEWithLogitsLoss(reduction="none")
1552
 
1553
+ loss = loss_fct(logits, labels[None].repeat([logits.size(0), 1, 1]).to(logits.dtype))
 
 
1554
  # select loss for specific position
1555
  mask = (labels != -100)[None].repeat([logits.size(0), 1, 1])
1556
  # and random mask instance for differnet ensemble model
1557
+ data_aloc_mask = torch.rand(mask.size(0), mask.size(1)) < self.learning_probability
 
 
1558
  mask = mask & data_aloc_mask[:, :, None].to(mask.device)
1559
 
1560
  loss = torch.masked_select(loss, mask)
1561
  loss = loss.mean()
1562
+ loss += self.regularization_lambda * labels.size(0) * self.score.regularization()
 
 
 
 
1563
 
1564
  if not return_dict:
1565
  output = (logits,) + transformer_outputs[1:]