Translsis commited on
Commit
6f66838
·
verified ·
1 Parent(s): 214cc33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -19
app.py CHANGED
@@ -93,34 +93,57 @@ class VoiceMapper:
93
  return default_voice
94
 
95
 
96
- # Monkey patch the _update_model_kwargs_for_generation function to handle dict outputs
97
- def _update_model_kwargs_for_generation_fixed(
 
98
  outputs,
99
  model_kwargs,
100
  is_encoder_decoder=False,
101
- **kwargs,
 
102
  ):
103
- """Fixed version that handles both dict and object-like outputs"""
104
- # Update past_key_values - handle both dict and object-like outputs
 
 
105
  if isinstance(outputs, dict):
106
- model_kwargs["past_key_values"] = outputs.get("past_key_values")
 
107
  else:
108
- model_kwargs["past_key_values"] = getattr(outputs, "past_key_values", None)
109
-
110
- # Update attention mask
111
- if "attention_mask" in model_kwargs:
112
- attention_mask = model_kwargs["attention_mask"]
113
- model_kwargs["attention_mask"] = torch.cat(
114
- [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
115
- dim=-1,
116
- )
117
 
118
- return model_kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
 
 
120
 
121
- # Apply the monkey patch
122
- import vibevoice.modular.modeling_vibevoice_streaming_inference as inference_module
123
- inference_module._update_model_kwargs_for_generation = _update_model_kwargs_for_generation_fixed
124
 
125
 
126
  # Check if CUDA is available
@@ -147,6 +170,9 @@ MODEL = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
147
  attn_implementation="sdpa",
148
  )
149
 
 
 
 
150
  MODEL.eval()
151
  MODEL.set_ddpm_inference_steps(num_steps=5)
152
 
 
93
  return default_voice
94
 
95
 
96
+ # Patch the _update_model_kwargs_for_generation method
97
+ def patched_update_model_kwargs_for_generation(
98
+ self,
99
  outputs,
100
  model_kwargs,
101
  is_encoder_decoder=False,
102
+ model_inputs=None,
103
+ num_new_tokens=1,
104
  ):
105
+ """Patched version that handles both dict and object-like outputs"""
106
+ # Handle both dict and object-like outputs for cache
107
+ cache_name = "past_key_values"
108
+
109
  if isinstance(outputs, dict):
110
+ # For dict outputs, use .get() method
111
+ model_kwargs[cache_name] = outputs.get(cache_name)
112
  else:
113
+ # For object outputs, try to get the attribute
114
+ model_kwargs[cache_name] = getattr(outputs, cache_name, None)
115
+
116
+ if getattr(self, "config", None) is not None:
117
+ if "token_type_ids" in model_kwargs and model_kwargs["token_type_ids"] is not None:
118
+ token_type_ids = model_kwargs["token_type_ids"]
119
+ model_kwargs["token_type_ids"] = torch.cat(
120
+ [token_type_ids, token_type_ids[:, -1:]], dim=-1
121
+ )
122
 
123
+ if not is_encoder_decoder:
124
+ # update attention mask
125
+ if "attention_mask" in model_kwargs and model_kwargs["attention_mask"] is not None:
126
+ attention_mask = model_kwargs["attention_mask"]
127
+ model_kwargs["attention_mask"] = torch.cat(
128
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
129
+ dim=-1,
130
+ )
131
+ else:
132
+ # update decoder attention mask
133
+ if "decoder_attention_mask" in model_kwargs and model_kwargs["decoder_attention_mask"] is not None:
134
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
135
+ model_kwargs["decoder_attention_mask"] = torch.cat(
136
+ [
137
+ decoder_attention_mask,
138
+ decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1)),
139
+ ],
140
+ dim=-1,
141
+ )
142
 
143
+ if model_inputs is not None and "cache_position" in model_inputs:
144
+ model_kwargs["cache_position"] = model_inputs["cache_position"][-1:] + num_new_tokens
145
 
146
+ return model_kwargs
 
 
147
 
148
 
149
  # Check if CUDA is available
 
170
  attn_implementation="sdpa",
171
  )
172
 
173
+ # Apply the patch to the model instance
174
+ MODEL._update_model_kwargs_for_generation = patched_update_model_kwargs_for_generation.__get__(MODEL, type(MODEL))
175
+
176
  MODEL.eval()
177
  MODEL.set_ddpm_inference_steps(num_steps=5)
178