Spaces:
Build error
Build error
| # Copyright (c) 2024 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| import os | |
| import yaml | |
| import gradio as gr | |
| import librosa | |
| from pydub import AudioSegment | |
| import soundfile as sf | |
| import numpy as np | |
| import torch | |
| import laion_clap | |
| from inference_utils import prepare_tokenizer, prepare_model, inference | |
| from data import AudioTextDataProcessor | |
| if torch.cuda.is_available(): | |
| device = 'cuda:0' | |
| else: | |
| device = 'cpu' | |
| def load_laionclap(): | |
| model = laion_clap.CLAP_Module(enable_fusion=True, amodel='HTSAT-tiny').to(device) | |
| model.load_ckpt(ckpt='630k-audioset-fusion-best.pt') | |
| model.eval() | |
| return model | |
| def int16_to_float32(x): | |
| return (x / 32767.0).astype(np.float32) | |
| def float32_to_int16(x): | |
| x = np.clip(x, a_min=-1., a_max=1.) | |
| return (x * 32767.).astype(np.int16) | |
| def load_audio(file_path, target_sr=44100, duration=33.25, start=0.0): | |
| if file_path.endswith('.mp3'): | |
| audio = AudioSegment.from_file(file_path) | |
| if len(audio) > (start + duration) * 1000: | |
| audio = audio[start * 1000:(start + duration) * 1000] | |
| if audio.frame_rate != target_sr: | |
| audio = audio.set_frame_rate(target_sr) | |
| if audio.channels > 1: | |
| audio = audio.set_channels(1) | |
| data = np.array(audio.get_array_of_samples()) | |
| if audio.sample_width == 2: | |
| data = data.astype(np.float32) / np.iinfo(np.int16).max | |
| elif audio.sample_width == 4: | |
| data = data.astype(np.float32) / np.iinfo(np.int32).max | |
| else: | |
| raise ValueError("Unsupported bit depth: {}".format(audio.sample_width)) | |
| else: | |
| with sf.SoundFile(file_path) as audio: | |
| original_sr = audio.samplerate | |
| channels = audio.channels | |
| max_frames = int((start + duration) * original_sr) | |
| audio.seek(int(start * original_sr)) | |
| frames_to_read = min(max_frames, len(audio)) | |
| data = audio.read(frames_to_read) | |
| if data.max() > 1 or data.min() < -1: | |
| data = data / max(abs(data.max()), abs(data.min())) | |
| if original_sr != target_sr: | |
| if channels == 1: | |
| data = librosa.resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr) | |
| else: | |
| data = librosa.resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0] | |
| else: | |
| if channels != 1: | |
| data = data.T[0] | |
| if data.min() >= 0: | |
| data = 2 * data / abs(data.max()) - 1.0 | |
| else: | |
| data = data / max(abs(data.max()), abs(data.min())) | |
| return data | |
| def compute_laionclap_text_audio_sim(audio_file, laionclap_model, outputs): | |
| try: | |
| data = load_audio(audio_file, target_sr=48000) | |
| except Exception as e: | |
| print(audio_file, 'unsuccessful due to', e) | |
| return [0.0] * len(outputs) | |
| audio_data = data.reshape(1, -1) | |
| audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float().to(device) | |
| audio_embed = laionclap_model.get_audio_embedding_from_data(x=audio_data_tensor, use_tensor=True) | |
| text_embed = laionclap_model.get_text_embedding(outputs, use_tensor=True) | |
| cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6) | |
| cos_similarity = cos(audio_embed.repeat(text_embed.shape[0], 1), text_embed) | |
| return cos_similarity.squeeze().cpu().numpy() | |
| inference_kwargs = { | |
| "do_sample": True, | |
| "top_k": 50, | |
| "top_p": 0.95, | |
| "num_return_sequences": 10 | |
| } | |
| config = yaml.load(open('chat.yaml'), Loader=yaml.FullLoader) | |
| clap_config = config['clap_config'] | |
| model_config = config['model_config'] | |
| text_tokenizer = prepare_tokenizer(model_config) | |
| DataProcessor = AudioTextDataProcessor( | |
| data_root='./', | |
| clap_config=clap_config, | |
| tokenizer=text_tokenizer, | |
| max_tokens=512, | |
| ) | |
| laionclap_model = load_laionclap() | |
| model = prepare_model( | |
| model_config=model_config, | |
| clap_config=clap_config, | |
| checkpoint_path='chat.pt' | |
| ) | |
| def inference_item(name, prompt): | |
| item = { | |
| 'name': str(name), | |
| 'prefix': 'The task is dialog.', | |
| 'prompt': str(prompt) | |
| } | |
| processed_item = DataProcessor.process(item) | |
| outputs = inference( | |
| model, text_tokenizer, item, processed_item, | |
| inference_kwargs, | |
| ) | |
| laionclap_scores = compute_laionclap_text_audio_sim( | |
| item["name"], | |
| laionclap_model, | |
| outputs | |
| ) | |
| outputs_joint = [(output, score) for (output, score) in zip(outputs, laionclap_scores)] | |
| outputs_joint.sort(key=lambda x: -x[1]) | |
| return outputs_joint[0][0] | |
| with gr.Blocks(title="Audio Flamingo - Demo") as ui: | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center; max-width: 900px; margin: 0 auto;"> | |
| <div | |
| style=" | |
| display: inline-flex; | |
| align-items: center; | |
| gap: 0.8rem; | |
| font-size: 1.5rem; | |
| " | |
| > | |
| <h1 style="font-weight: 700; margin-bottom: 7px; line-height: normal;"> | |
| Audio Flamingo: A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities | |
| </h1> | |
| </div> | |
| <p style="margin-bottom: 10px; font-size: 125%"> | |
| <a href="https://arxiv.org/abs/2402.01831">[Paper]</a> <a href="https://github.com/NVIDIA/audio-flamingo">[Code]</a> <a href="https://audioflamingo.github.io/">[Demo]</a> | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| gr.HTML( | |
| """ | |
| <div> | |
| <h3>Model Overview</h3> | |
| Audio Flamingo is an audio language model that can understand sounds beyond speech. | |
| It can also answer questions about the sound in natural language. | |
| Examples of questions include: | |
| "Can you briefly describe what you hear in this audio?", | |
| "What is the emotion conveyed in this music?", | |
| "Where is this audio usually heard?", | |
| or "What place is this music usually played at?". | |
| </div> | |
| """ | |
| ) | |
| name = gr.Textbox( | |
| label="Audio file path (choose one from: audio/wav{1--6}.wav)", | |
| value="audio/wav5.wav" | |
| ) | |
| prompt = gr.Textbox( | |
| label="Instruction", | |
| value='Can you briefly describe what you hear in this audio?' | |
| ) | |
| with gr.Row(): | |
| play_audio_button = gr.Button("Play Audio") | |
| audio_output = gr.Audio(label="Playback") | |
| play_audio_button.click(fn=lambda x: x, inputs=name, outputs=audio_output) | |
| inference_button = gr.Button("Inference") | |
| output_text = gr.Textbox(label="Audio Flamingo output") | |
| inference_button.click( | |
| fn=inference_item, | |
| inputs=[name, prompt], | |
| outputs=output_text | |
| ) | |
| ui.queue() | |
| ui.launch() | |