Spaces:
Build error
Build error
| # Copyright (c) 2024 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| import os | |
| import string | |
| import yaml | |
| from copy import deepcopy | |
| import torch | |
| from transformers import AutoTokenizer, set_seed | |
| set_seed(0) | |
| from data import AudioTextDataProcessor | |
| from src.factory import create_model_and_transforms | |
| def prepare_tokenizer(model_config): | |
| tokenizer_path = model_config['tokenizer_path'] | |
| cache_dir = model_config['cache_dir'] | |
| text_tokenizer = AutoTokenizer.from_pretrained( | |
| tokenizer_path, | |
| local_files_only=False, | |
| trust_remote_code=True, | |
| cache_dir=cache_dir, | |
| ) | |
| text_tokenizer.add_special_tokens( | |
| {"additional_special_tokens": ["<audio>", "<|endofchunk|>"]} | |
| ) | |
| if text_tokenizer.pad_token is None: | |
| text_tokenizer.add_special_tokens({"pad_token": "<PAD>"}) | |
| if text_tokenizer.sep_token is None: | |
| text_tokenizer.add_special_tokens({"sep_token": "<SEP>"}) | |
| return text_tokenizer | |
| def prepare_model(model_config, clap_config, checkpoint_path, device=0): | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning | |
| model, tokenizer = create_model_and_transforms( | |
| **model_config, | |
| clap_config=clap_config, | |
| use_local_files=False, | |
| gradient_checkpointing=False, | |
| freeze_lm_embeddings=False, | |
| ) | |
| model.eval() | |
| model = model.to(device) | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
| model_state_dict = checkpoint["model_state_dict"] | |
| model_state_dict = {k.replace("module.", ""): v for k, v in model_state_dict.items()} | |
| model.load_state_dict(model_state_dict, False) | |
| return model | |
| def inference(model, tokenizer, item, processed_item, inference_kwargs, device=0): | |
| filename, audio_clips, audio_embed_mask, input_ids, attention_mask = processed_item | |
| audio_clips = audio_clips.to(device, dtype=None, non_blocking=True) | |
| audio_embed_mask = audio_embed_mask.to(device, dtype=None, non_blocking=True) | |
| input_ids = input_ids.to(device, dtype=None, non_blocking=True).squeeze() | |
| media_token_id = tokenizer.encode("<audio>")[-1] | |
| eoc_token_id = tokenizer.encode("<|endofchunk|>")[-1] | |
| sep_token_id = tokenizer.sep_token_id | |
| eos_token_id = tokenizer.eos_token_id | |
| outputs = model.generate( | |
| audio_x=audio_clips.unsqueeze(0), | |
| audio_x_mask=audio_embed_mask.unsqueeze(0), | |
| lang_x=input_ids.unsqueeze(0), | |
| eos_token_id=eos_token_id, | |
| max_new_tokens=128, | |
| **inference_kwargs, | |
| ) | |
| outputs_decoded = [ | |
| tokenizer.decode(output).split(tokenizer.sep_token)[-1].replace(tokenizer.eos_token, '').replace(tokenizer.pad_token, '').replace('<|endofchunk|>', '') for output in outputs | |
| ] | |
| return outputs_decoded | |