Spaces:
Build error
Build error
| # Copyright (c) 2024 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| import functools | |
| import io | |
| import json | |
| import math | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning | |
| import random | |
| import re | |
| import string | |
| import subprocess | |
| import sys | |
| import yaml | |
| import numpy as np | |
| from collections import defaultdict | |
| from copy import deepcopy | |
| from dataclasses import dataclass | |
| from functools import partial | |
| from pydub import AudioSegment | |
| from tqdm import tqdm | |
| import torch | |
| import torchvision | |
| from torch.utils.data import DataLoader, Dataset, get_worker_info | |
| from torch.utils.data.distributed import DistributedSampler | |
| from transformers import AutoTokenizer | |
| import librosa | |
| import soundfile as sf | |
| def int16_to_float32(x): | |
| return (x / 32767.0).astype(np.float32) | |
| def float32_to_int16(x): | |
| x = np.clip(x, a_min=-1., a_max=1.) | |
| return (x * 32767.).astype(np.int16) | |
| class AudioTextDataProcessor: | |
| def __init__( | |
| self, | |
| data_root: str, | |
| clap_config: dict, | |
| tokenizer, | |
| max_tokens: int, | |
| **kwargs | |
| ): | |
| self.data_root = data_root | |
| self.clap_config = clap_config | |
| self.tokenizer = tokenizer | |
| self.tokenizer.padding_side = "right" | |
| self.max_tokens = max_tokens | |
| def get_num_windows(self, T, sr): | |
| clap_config = self.clap_config | |
| window_length = int(float(clap_config["window_length"]) * sr) | |
| window_overlap = int(float(clap_config["window_overlap"]) * sr) | |
| max_num_window = int(clap_config["max_num_window"]) | |
| num_windows = 1 | |
| if T <= window_length: | |
| num_windows = 1 | |
| full_length = window_length | |
| elif T >= (max_num_window * window_length - (max_num_window - 1) * window_overlap): | |
| num_windows = max_num_window | |
| full_length = (max_num_window * window_length - (max_num_window - 1) * window_overlap) | |
| else: | |
| num_windows = 1 + int(np.ceil((T - window_length) / float(window_length - window_overlap))) | |
| full_length = num_windows * window_length - (num_windows - 1) * window_overlap | |
| return num_windows, full_length | |
| def load_audio(self, file_path, target_sr=44100, duration=30.0, start=0.0): | |
| if file_path.endswith('.mp3'): | |
| audio = AudioSegment.from_file(file_path) | |
| if len(audio) > (start + duration) * 1000: | |
| audio = audio[start * 1000:(start + duration) * 1000] | |
| if audio.frame_rate != target_sr: | |
| audio = audio.set_frame_rate(target_sr) | |
| if audio.channels > 1: | |
| audio = audio.set_channels(1) | |
| data = np.array(audio.get_array_of_samples()) | |
| if audio.sample_width == 2: | |
| data = data.astype(np.float32) / np.iinfo(np.int16).max | |
| elif audio.sample_width == 4: | |
| data = data.astype(np.float32) / np.iinfo(np.int32).max | |
| else: | |
| raise ValueError("Unsupported bit depth: {}".format(audio.sample_width)) | |
| else: | |
| with sf.SoundFile(file_path) as audio: | |
| original_sr = audio.samplerate | |
| channels = audio.channels | |
| max_frames = int((start + duration) * original_sr) | |
| audio.seek(int(start * original_sr)) | |
| frames_to_read = min(max_frames, len(audio)) | |
| data = audio.read(frames_to_read) | |
| if data.max() > 1 or data.min() < -1: | |
| data = data / max(abs(data.max()), abs(data.min())) | |
| if original_sr != target_sr: | |
| if channels == 1: | |
| data = librosa.resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr) | |
| else: | |
| data = librosa.resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0] | |
| else: | |
| if channels != 1: | |
| data = data.T[0] | |
| if data.min() >= 0: | |
| data = 2 * data / abs(data.max()) - 1.0 | |
| else: | |
| data = data / max(abs(data.max()), abs(data.min())) | |
| assert len(data.shape) == 1, data.shape | |
| return data | |
| def compute_sliding_window(self, audio_file, audio_start=0.0): | |
| if type(audio_start) == str: | |
| audio_start = float(audio_start) | |
| clap_config = self.clap_config | |
| if clap_config["method"] == 'laion-clap': | |
| sr = 48000 | |
| elif clap_config["method"] == 'microsoft-clap': | |
| sr = 44100 | |
| else: | |
| raise NotImplementedError | |
| window_length = int(float(clap_config["window_length"]) * sr) | |
| window_overlap = int(float(clap_config["window_overlap"]) * sr) | |
| max_num_window = int(clap_config["max_num_window"]) | |
| duration = max_num_window * (clap_config["window_length"] - clap_config["window_overlap"]) + clap_config["window_overlap"] | |
| audio_data = self.load_audio(audio_file, sr, duration, audio_start) | |
| T = len(audio_data) | |
| num_windows, full_length = self.get_num_windows(T, sr) | |
| if full_length > T: | |
| audio_data = np.append(audio_data, np.zeros(full_length - T)) | |
| audio_data = audio_data.reshape(1, -1) | |
| audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float() | |
| audio_clips = [] | |
| audio_embed_mask = torch.zeros(max_num_window) | |
| for i in range(num_windows): | |
| start = i * (window_length - window_overlap) | |
| audio_clips.append(audio_data_tensor[:, start:start+window_length]) | |
| audio_embed_mask[i] = 1 | |
| assert sum(audio_embed_mask) == num_windows | |
| if num_windows < max_num_window: | |
| for _ in range(max_num_window - num_windows): | |
| audio_clips.append(torch.zeros_like(audio_clips[-1])) | |
| audio_clips = torch.cat(audio_clips) # (max_num_window, window_length * sr) cuda tensor | |
| return audio_clips, audio_embed_mask | |
| def preprocess_string_for_eval(self, x): | |
| x = x.rstrip().lstrip() | |
| x = x.lower() | |
| return x | |
| def process(self, item): | |
| if type(item['name']) is str: | |
| audio_files = [os.path.join(self.data_root, item['name'])] | |
| audio_starts = [0 if 'audio_start' not in item else float(item['audio_start'])] | |
| else: | |
| audio_files = [os.path.join(self.data_root, name) for name in item['name']] | |
| audio_starts = [0] * len(audio_files) if 'audio_start' not in item else item['audio_start'] | |
| audio_clips, audio_embed_mask = [], [] | |
| for audio_file, audio_start in zip(audio_files, audio_starts): | |
| this_audio_clips, this_audio_embed_mask = self.compute_sliding_window(audio_file, audio_start) | |
| audio_clips.append(this_audio_clips) | |
| audio_embed_mask.append(this_audio_embed_mask) | |
| audio_clips = torch.cat(audio_clips) | |
| audio_embed_mask = torch.cat(audio_embed_mask) | |
| correct_num_windows = int(self.clap_config["max_num_window"]) * int(self.clap_config["max_num_fewshot"]) | |
| if len(audio_clips) < correct_num_windows: | |
| audio_clips = torch.cat([ | |
| audio_clips, | |
| torch.zeros(correct_num_windows - len(audio_clips), audio_clips.shape[1]) | |
| ]) | |
| audio_embed_mask = torch.cat([ | |
| audio_embed_mask, | |
| torch.zeros(correct_num_windows - len(audio_embed_mask)) | |
| ]) | |
| audio_clips.requires_grad = False | |
| audio_embed_mask.requires_grad = False | |
| assert type(item['name']) is str | |
| # simple data - 1 audio, 1 text | |
| if 'prompt' in item: | |
| text_prompt = item['prompt'].lower() | |
| prefix = item['prefix'].lower() # the task is xxx. | |
| sample = "{}{} <audio>{}\nanswer:{}".format( | |
| self.tokenizer.bos_token, | |
| self.preprocess_string_for_eval(prefix), | |
| self.preprocess_string_for_eval(text_prompt), | |
| self.tokenizer.sep_token | |
| ) | |
| # dialog data - 1 audio, multiple text | |
| elif 'dialogue' in item: | |
| dialogue = item['dialogue'] | |
| prefix = item['prefix'].lower() # the task is dialog. | |
| sample = f"{self.tokenizer.bos_token}{prefix}<audio>" | |
| for each_round in dialogue: | |
| sample = sample + f"user: {each_round['user']} \nassistant: {self.tokenizer.sep_token}" | |
| if 'assistant' in each_round: | |
| sample = sample + f"{each_round['assistant']}<|endofchunk|>{self.tokenizer.eos_token}\n" | |
| text = self.tokenizer( | |
| sample, | |
| max_length=self.max_tokens*5, | |
| padding="longest", | |
| truncation="only_first", | |
| return_tensors="pt" | |
| ) | |
| return (item['name'], audio_clips, audio_embed_mask, text["input_ids"], text["attention_mask"]) | |