Spaces:
Build error
Build error
| # Copyright (c) 2024 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| import os | |
| import yaml | |
| # import spaces | |
| import gradio as gr | |
| import librosa | |
| from pydub import AudioSegment | |
| import soundfile as sf | |
| import numpy as np | |
| import torch | |
| import laion_clap | |
| from inference_utils import prepare_tokenizer, prepare_model, inference | |
| from data import AudioTextDataProcessor | |
| if torch.cuda.is_available(): | |
| device = 'cuda:0' | |
| else: | |
| device = 'cpu' | |
| # @spaces.GPU | |
| def load_laionclap(): | |
| model = laion_clap.CLAP_Module(enable_fusion=True, amodel='HTSAT-tiny').to(device) | |
| model.load_ckpt(ckpt='630k-audioset-fusion-best.pt') | |
| model.eval() | |
| return model | |
| def int16_to_float32(x): | |
| return (x / 32767.0).astype(np.float32) | |
| def float32_to_int16(x): | |
| x = np.clip(x, a_min=-1., a_max=1.) | |
| return (x * 32767.).astype(np.int16) | |
| def load_audio(file_path, target_sr=44100, duration=33.25, start=0.0): | |
| if file_path.endswith('.mp3'): | |
| audio = AudioSegment.from_file(file_path) | |
| if len(audio) > (start + duration) * 1000: | |
| audio = audio[start * 1000:(start + duration) * 1000] | |
| if audio.frame_rate != target_sr: | |
| audio = audio.set_frame_rate(target_sr) | |
| if audio.channels > 1: | |
| audio = audio.set_channels(1) | |
| data = np.array(audio.get_array_of_samples()) | |
| if audio.sample_width == 2: | |
| data = data.astype(np.float32) / np.iinfo(np.int16).max | |
| elif audio.sample_width == 4: | |
| data = data.astype(np.float32) / np.iinfo(np.int32).max | |
| else: | |
| raise ValueError("Unsupported bit depth: {}".format(audio.sample_width)) | |
| else: | |
| with sf.SoundFile(file_path) as audio: | |
| original_sr = audio.samplerate | |
| channels = audio.channels | |
| max_frames = int((start + duration) * original_sr) | |
| audio.seek(int(start * original_sr)) | |
| frames_to_read = min(max_frames, len(audio)) | |
| data = audio.read(frames_to_read) | |
| if data.max() > 1 or data.min() < -1: | |
| data = data / max(abs(data.max()), abs(data.min())) | |
| if original_sr != target_sr: | |
| if channels == 1: | |
| data = librosa.resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr) | |
| else: | |
| data = librosa.resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0] | |
| else: | |
| if channels != 1: | |
| data = data.T[0] | |
| if data.min() >= 0: | |
| data = 2 * data / abs(data.max()) - 1.0 | |
| else: | |
| data = data / max(abs(data.max()), abs(data.min())) | |
| return data | |
| # @spaces.GPU | |
| def compute_laionclap_text_audio_sim(audio_file, laionclap_model, outputs): | |
| try: | |
| data = load_audio(audio_file, target_sr=48000) | |
| except Exception as e: | |
| print(audio_file, 'unsuccessful due to', e) | |
| return [0.0] * len(outputs) | |
| audio_data = data.reshape(1, -1) | |
| audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float().to(device) | |
| audio_embed = laionclap_model.get_audio_embedding_from_data(x=audio_data_tensor, use_tensor=True) | |
| text_embed = laionclap_model.get_text_embedding(outputs, use_tensor=True) | |
| cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6) | |
| cos_similarity = cos(audio_embed.repeat(text_embed.shape[0], 1), text_embed) | |
| return cos_similarity.squeeze().cpu().numpy() | |
| inference_kwargs = { | |
| "do_sample": True, | |
| "top_k": 50, | |
| "top_p": 0.95, | |
| "num_return_sequences": 20 | |
| } | |
| config = yaml.load(open('chat.yaml'), Loader=yaml.FullLoader) | |
| clap_config = config['clap_config'] | |
| model_config = config['model_config'] | |
| text_tokenizer = prepare_tokenizer(model_config) | |
| DataProcessor = AudioTextDataProcessor( | |
| data_root='./', | |
| clap_config=clap_config, | |
| tokenizer=text_tokenizer, | |
| max_tokens=512, | |
| ) | |
| laionclap_model = load_laionclap() | |
| model = prepare_model( | |
| model_config=model_config, | |
| clap_config=clap_config, | |
| checkpoint_path='chat.pt', | |
| device=device | |
| ) | |
| # @spaces.GPU | |
| def inference_item(name, prompt): | |
| item = { | |
| 'name': str(name), | |
| 'prefix': 'The task is dialog.', | |
| 'prompt': str(prompt) | |
| } | |
| processed_item = DataProcessor.process(item) | |
| outputs = inference( | |
| model, text_tokenizer, item, processed_item, | |
| inference_kwargs, | |
| device=device | |
| ) | |
| laionclap_scores = compute_laionclap_text_audio_sim( | |
| item["name"], | |
| laionclap_model, | |
| outputs | |
| ) | |
| outputs_joint = [(output, score) for (output, score) in zip(outputs, laionclap_scores)] | |
| outputs_joint.sort(key=lambda x: -x[1]) | |
| return outputs_joint[0][0] | |
| css = """ | |
| a { | |
| color: inherit; | |
| text-decoration: underline; | |
| } | |
| .gradio-container { | |
| font-family: 'IBM Plex Sans', sans-serif; | |
| } | |
| .gr-button { | |
| color: white; | |
| border-color: #000000; | |
| background: #000000; | |
| } | |
| input[type='range'] { | |
| accent-color: #000000; | |
| } | |
| .dark input[type='range'] { | |
| accent-color: #dfdfdf; | |
| } | |
| .container { | |
| max-width: 730px; | |
| margin: auto; | |
| padding-top: 1.5rem; | |
| } | |
| #gallery { | |
| min-height: 22rem; | |
| margin-bottom: 15px; | |
| margin-left: auto; | |
| margin-right: auto; | |
| border-bottom-right-radius: .5rem !important; | |
| border-bottom-left-radius: .5rem !important; | |
| } | |
| #gallery>div>.h-full { | |
| min-height: 20rem; | |
| } | |
| .details:hover { | |
| text-decoration: underline; | |
| } | |
| .gr-button { | |
| white-space: nowrap; | |
| } | |
| .gr-button:focus { | |
| border-color: rgb(147 197 253 / var(--tw-border-opacity)); | |
| outline: none; | |
| box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); | |
| --tw-border-opacity: 1; | |
| --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); | |
| --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color); | |
| --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity)); | |
| --tw-ring-opacity: .5; | |
| } | |
| #advanced-btn { | |
| font-size: .7rem !important; | |
| line-height: 19px; | |
| margin-top: 12px; | |
| margin-bottom: 12px; | |
| padding: 2px 8px; | |
| border-radius: 14px !important; | |
| } | |
| #advanced-options { | |
| margin-bottom: 20px; | |
| } | |
| .footer { | |
| margin-bottom: 45px; | |
| margin-top: 35px; | |
| text-align: center; | |
| border-bottom: 1px solid #e5e5e5; | |
| } | |
| .footer>p { | |
| font-size: .8rem; | |
| display: inline-block; | |
| padding: 0 10px; | |
| transform: translateY(10px); | |
| background: white; | |
| } | |
| .dark .footer { | |
| border-color: #303030; | |
| } | |
| .dark .footer>p { | |
| background: #0b0f19; | |
| } | |
| .acknowledgments h4{ | |
| margin: 1.25em 0 .25em 0; | |
| font-weight: bold; | |
| font-size: 115%; | |
| } | |
| #container-advanced-btns{ | |
| display: flex; | |
| flex-wrap: wrap; | |
| justify-content: space-between; | |
| align-items: center; | |
| } | |
| .animate-spin { | |
| animation: spin 1s linear infinite; | |
| } | |
| @keyframes spin { | |
| from { | |
| transform: rotate(0deg); | |
| } | |
| to { | |
| transform: rotate(360deg); | |
| } | |
| } | |
| #share-btn-container { | |
| display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; | |
| margin-top: 10px; | |
| margin-left: auto; | |
| } | |
| #share-btn { | |
| all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0; | |
| } | |
| #share-btn * { | |
| all: unset; | |
| } | |
| #share-btn-container div:nth-child(-n+2){ | |
| width: auto !important; | |
| min-height: 0px !important; | |
| } | |
| #share-btn-container .wrap { | |
| display: none !important; | |
| } | |
| .gr-form{ | |
| flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0; | |
| } | |
| #prompt-container{ | |
| gap: 0; | |
| } | |
| #generated_id{ | |
| min-height: 700px | |
| } | |
| #setting_id{ | |
| margin-bottom: 12px; | |
| text-align: center; | |
| font-weight: 900; | |
| } | |
| """ | |
| ui = gr.Blocks(css=css, title="Audio Flamingo - Demo") | |
| with ui: | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center; max-width: 900px; margin: 0 auto;"> | |
| <div | |
| style=" | |
| display: inline-flex; | |
| align-items: center; | |
| gap: 0.8rem; | |
| font-size: 1.5rem; | |
| " | |
| > | |
| <h1 style="font-weight: 700; margin-bottom: 7px; line-height: normal;"> | |
| Audio Flamingo: A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities | |
| </h1> | |
| </div> | |
| <p style="margin-bottom: 10px; font-size: 125%"> | |
| <a href="https://arxiv.org/abs/2402.01831">[Paper]</a> <a href="https://github.com/NVIDIA/audio-flamingo">[Code]</a> <a href="https://audioflamingo.github.io/">[Demo Website]</a> <a href="https://www.youtube.com/watch?v=ucttuS28RVE">[Demo Video]</a> | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| gr.HTML( | |
| """ | |
| <div> | |
| <h3>Overview</h3> | |
| Audio Flamingo is an audio language model that can understand sounds beyond speech. | |
| It can also answer questions about the sound in natural language. <br> | |
| Examples of questions include: <br> | |
| - Can you briefly describe what you hear in this audio? <br> | |
| - What is the emotion conveyed in this music? <br> | |
| - Where is this audio usually heard? <br> | |
| - What place is this music usually played at? <br> | |
| </div> | |
| """ | |
| ) | |
| name = gr.Textbox( | |
| label="Audio file path (choose one from: audio/wav{1--6}.wav)", | |
| value="audio/wav1.wav" | |
| ) | |
| prompt = gr.Textbox( | |
| label="Instruction", | |
| value='Can you briefly describe what you hear in this audio?' | |
| ) | |
| with gr.Row(): | |
| play_audio_button = gr.Button("Play Audio") | |
| audio_output = gr.Audio(label="Playback") | |
| play_audio_button.click(fn=lambda x: x, inputs=name, outputs=audio_output) | |
| inference_button = gr.Button("Inference") | |
| output_text = gr.Textbox(label="Audio Flamingo output") | |
| inference_button.click( | |
| fn=inference_item, | |
| inputs=[name, prompt], | |
| outputs=output_text | |
| ) | |
| ui.queue() | |
| ui.launch() | |