Spaces:
Build error
Build error
Commit
·
0195d32
1
Parent(s):
64fc4c7
update
Browse files- app.py +3 -1
- inference_utils.py +6 -6
app.py
CHANGED
|
@@ -132,7 +132,8 @@ laionclap_model = load_laionclap()
|
|
| 132 |
model = prepare_model(
|
| 133 |
model_config=model_config,
|
| 134 |
clap_config=clap_config,
|
| 135 |
-
checkpoint_path='chat.pt'
|
|
|
|
| 136 |
)
|
| 137 |
|
| 138 |
|
|
@@ -147,6 +148,7 @@ def inference_item(name, prompt):
|
|
| 147 |
outputs = inference(
|
| 148 |
model, text_tokenizer, item, processed_item,
|
| 149 |
inference_kwargs,
|
|
|
|
| 150 |
)
|
| 151 |
|
| 152 |
laionclap_scores = compute_laionclap_text_audio_sim(
|
|
|
|
| 132 |
model = prepare_model(
|
| 133 |
model_config=model_config,
|
| 134 |
clap_config=clap_config,
|
| 135 |
+
checkpoint_path='chat.pt',
|
| 136 |
+
device=device
|
| 137 |
)
|
| 138 |
|
| 139 |
|
|
|
|
| 148 |
outputs = inference(
|
| 149 |
model, text_tokenizer, item, processed_item,
|
| 150 |
inference_kwargs,
|
| 151 |
+
device=device
|
| 152 |
)
|
| 153 |
|
| 154 |
laionclap_scores = compute_laionclap_text_audio_sim(
|
inference_utils.py
CHANGED
|
@@ -33,7 +33,7 @@ def prepare_tokenizer(model_config):
|
|
| 33 |
return text_tokenizer
|
| 34 |
|
| 35 |
|
| 36 |
-
def prepare_model(model_config, clap_config, checkpoint_path,
|
| 37 |
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning
|
| 38 |
model, tokenizer = create_model_and_transforms(
|
| 39 |
**model_config,
|
|
@@ -43,7 +43,7 @@ def prepare_model(model_config, clap_config, checkpoint_path, device_id=0):
|
|
| 43 |
freeze_lm_embeddings=False,
|
| 44 |
)
|
| 45 |
model.eval()
|
| 46 |
-
model = model.to(
|
| 47 |
|
| 48 |
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 49 |
model_state_dict = checkpoint["model_state_dict"]
|
|
@@ -53,11 +53,11 @@ def prepare_model(model_config, clap_config, checkpoint_path, device_id=0):
|
|
| 53 |
return model
|
| 54 |
|
| 55 |
|
| 56 |
-
def inference(model, tokenizer, item, processed_item, inference_kwargs,
|
| 57 |
filename, audio_clips, audio_embed_mask, input_ids, attention_mask = processed_item
|
| 58 |
-
audio_clips = audio_clips.to(
|
| 59 |
-
audio_embed_mask = audio_embed_mask.to(
|
| 60 |
-
input_ids = input_ids.to(
|
| 61 |
|
| 62 |
media_token_id = tokenizer.encode("<audio>")[-1]
|
| 63 |
eoc_token_id = tokenizer.encode("<|endofchunk|>")[-1]
|
|
|
|
| 33 |
return text_tokenizer
|
| 34 |
|
| 35 |
|
| 36 |
+
def prepare_model(model_config, clap_config, checkpoint_path, device=0):
|
| 37 |
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning
|
| 38 |
model, tokenizer = create_model_and_transforms(
|
| 39 |
**model_config,
|
|
|
|
| 43 |
freeze_lm_embeddings=False,
|
| 44 |
)
|
| 45 |
model.eval()
|
| 46 |
+
model = model.to(device)
|
| 47 |
|
| 48 |
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 49 |
model_state_dict = checkpoint["model_state_dict"]
|
|
|
|
| 53 |
return model
|
| 54 |
|
| 55 |
|
| 56 |
+
def inference(model, tokenizer, item, processed_item, inference_kwargs, device=0):
|
| 57 |
filename, audio_clips, audio_embed_mask, input_ids, attention_mask = processed_item
|
| 58 |
+
audio_clips = audio_clips.to(device, dtype=None, non_blocking=True)
|
| 59 |
+
audio_embed_mask = audio_embed_mask.to(device, dtype=None, non_blocking=True)
|
| 60 |
+
input_ids = input_ids.to(device, dtype=None, non_blocking=True).squeeze()
|
| 61 |
|
| 62 |
media_token_id = tokenizer.encode("<audio>")[-1]
|
| 63 |
eoc_token_id = tokenizer.encode("<|endofchunk|>")[-1]
|