Spaces:
Build error
Build error
| # Copyright (c) 2024 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| # Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license. | |
| # LICENSE is in incl_licenses directory. | |
| import sys | |
| sys.path.append('../') | |
| from typing import Optional | |
| from copy import deepcopy | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from ms_clap.src.CLAPWrapper import CLAPWrapper | |
| import torch | |
| from torch import nn | |
| try: | |
| from .flamingo import Flamingo | |
| from .flamingo_lm import FlamingoLMMixin | |
| from .utils import extend_instance | |
| except: | |
| from flamingo import Flamingo | |
| from flamingo_lm import FlamingoLMMixin | |
| from utils import extend_instance | |
| class CLAP(nn.Module): | |
| def __init__(self, clap_config): | |
| super(CLAP, self).__init__() | |
| self.method = clap_config["method"] | |
| if torch.cuda.is_available(): | |
| device = 'cuda:0' | |
| else: | |
| device = 'cpu' | |
| if self.method == 'laion-clap': | |
| # https://github.com/LAION-AI/CLAP | |
| if clap_config["model_name"] in ['630k-audioset-best', '630k-best', '630k-audioset-fusion-best', '630k-fusion-best']: | |
| amodel = 'HTSAT-tiny' | |
| elif clap_config["model_name"] in ['music_speech_audioset_epoch_15_esc_89.98']: | |
| amodel = 'HTSAT-base' | |
| else: | |
| raise NotImplementedError | |
| enable_fusion = 'fusion' in clap_config["model_name"].lower() | |
| self.laion_clap = CLAP_Module(enable_fusion=enable_fusion, amodel=amodel, device=device) | |
| self.laion_clap.load_ckpt(ckpt=clap_config["checkpoint"]) | |
| for param in self.laion_clap.parameters(): | |
| param.requires_grad = False | |
| self.laion_clap.eval() | |
| print('loaded laion-clap model: {}'.format(clap_config["checkpoint"])) | |
| elif self.method == 'microsoft-clap': | |
| # https://github.com/microsoft/CLAP | |
| self.ms_clap = CLAPWrapper( | |
| clap_config["checkpoint"], | |
| config_root=clap_config["config_root"], | |
| version=clap_config['model_name'], | |
| use_cuda=torch.cuda.is_available() | |
| ) | |
| if clap_config['model_name'] in ['2022', '2023']: | |
| for param in self.ms_clap.clap.parameters(): | |
| param.requires_grad = False | |
| self.ms_clap.clap.eval() | |
| else: | |
| for param in self.ms_clap.clapcap.parameters(): | |
| param.requires_grad = False | |
| self.ms_clap.clapcap.eval() | |
| print('loaded microsoft-clap model: {}'.format(clap_config["checkpoint"])) | |
| else: | |
| raise NotImplementedError | |
| def forward(self, audio_clips): | |
| if len(audio_clips.shape) == 2: | |
| audio_clips = audio_clips.unsqueeze(0) | |
| assert len(audio_clips.shape) == 3 | |
| audio_embeds = [] | |
| for x in audio_clips: | |
| if self.method == 'laion-clap': | |
| audio_embed = self.laion_clap.get_audio_embedding_from_data(x=x, use_tensor=True) | |
| elif self.method == 'microsoft-clap': | |
| audio_embed = self.ms_clap.get_audio_embeddings_from_clips(x) | |
| audio_embeds.append(audio_embed) | |
| audio_embeds = torch.stack(audio_embeds, dim=0) | |
| audio_embeds.requires_grad = False | |
| return audio_embeds | |
| def create_model_and_transforms( | |
| clap_config: dict, | |
| lang_encoder_path: str, | |
| tokenizer_path: str, | |
| audio_transformer_kwargs: dict, | |
| cross_attn_every_n_layers: int = 1, | |
| use_local_files: bool = False, | |
| decoder_layers_attr_name: str = None, | |
| freeze_lm_embeddings: bool = False, | |
| unfreeze_full_lm: bool = False, | |
| cache_dir: Optional[str] = None, | |
| **flamingo_kwargs, | |
| ): | |
| clap = CLAP(clap_config) | |
| text_tokenizer = AutoTokenizer.from_pretrained( | |
| tokenizer_path, | |
| local_files_only=use_local_files, | |
| trust_remote_code=True, | |
| cache_dir=cache_dir, | |
| ) | |
| text_tokenizer.add_special_tokens( | |
| {"additional_special_tokens": ["<audio>", "<|endofchunk|>"]} | |
| ) | |
| if text_tokenizer.pad_token is None: | |
| text_tokenizer.add_special_tokens({"pad_token": "<PAD>"}) | |
| if text_tokenizer.sep_token is None: | |
| text_tokenizer.add_special_tokens({"sep_token": "<SEP>"}) | |
| lang_encoder = AutoModelForCausalLM.from_pretrained( | |
| lang_encoder_path, | |
| local_files_only=use_local_files, | |
| trust_remote_code=True, | |
| cache_dir=cache_dir, | |
| ) | |
| extend_instance(lang_encoder, FlamingoLMMixin) | |
| if decoder_layers_attr_name is None: | |
| decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) | |
| lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) | |
| lang_encoder.resize_token_embeddings(len(text_tokenizer)) | |
| unfreeze_clap = False | |
| model = Flamingo( | |
| clap, | |
| unfreeze_clap, | |
| lang_encoder, | |
| text_tokenizer.encode("<|endofchunk|>")[-1], | |
| text_tokenizer.encode("<audio>")[-1], | |
| text_tokenizer.sep_token_id, | |
| audio_embed_dim=clap_config["audio_embed_dim"], | |
| audio_transformer_kwargs=audio_transformer_kwargs, | |
| cross_attn_every_n_layers=cross_attn_every_n_layers, | |
| **flamingo_kwargs, | |
| ) | |
| model.requires_grad_(False) | |
| assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 | |
| model.audio_transformer.requires_grad_(True) | |
| model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) | |
| if not freeze_lm_embeddings: | |
| model.lang_encoder.get_input_embeddings().requires_grad_(True) | |
| if unfreeze_full_lm: | |
| model.lang_encoder.requires_grad_(True) | |
| if unfreeze_clap: | |
| model.clap.requires_grad_(True) | |
| print("Flamingo model initialized with {:,} trainable parameters (audio transformer has {:,}, LM has {:,})".format( | |
| sum(p.numel() for p in model.parameters() if p.requires_grad), | |
| sum(p.numel() for p in model.audio_transformer.parameters() if p.requires_grad), | |
| sum(p.numel() for p in model.lang_encoder.parameters() if p.requires_grad) | |
| )) | |
| return model, text_tokenizer | |
| def _infer_decoder_layers_attr_name(model): | |
| for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES: | |
| if k.lower() in model.__class__.__name__.lower(): | |
| return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k] | |
| raise ValueError( | |
| f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually." | |
| ) | |
| __KNOWN_DECODER_LAYERS_ATTR_NAMES = { | |
| "opt": "model.decoder.layers", | |
| "gptj": "transformer.h", | |
| "gpt-j": "transformer.h", | |
| "pythia": "gpt_neox.layers", | |
| "llama": "model.layers", | |
| "gptneoxforcausallm": "gpt_neox.layers", | |
| "mpt": "transformer.blocks", | |
| "mosaicgpt": "transformer.blocks", | |
| } | |