Spaces:
Configuration error
Configuration error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """API methods for demucs | |
| Classes | |
| ------- | |
| `demucs.api.Separator`: The base separator class | |
| Functions | |
| --------- | |
| `demucs.api.save_audio`: Save an audio | |
| `demucs.api.list_models`: Get models list | |
| Examples | |
| -------- | |
| See the end of this module (if __name__ == "__main__") | |
| """ | |
| import subprocess | |
| import torch as th | |
| import torchaudio as ta | |
| from dora.log import fatal | |
| from pathlib import Path | |
| from typing import Optional, Callable, Dict, Tuple, Union | |
| from .apply import apply_model, _replace_dict | |
| from .audio import AudioFile, convert_audio, save_audio | |
| from .pretrained import get_model, _parse_remote_files, REMOTE_ROOT | |
| from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo | |
| class LoadAudioError(Exception): | |
| pass | |
| class LoadModelError(Exception): | |
| pass | |
| class _NotProvided: | |
| pass | |
| NotProvided = _NotProvided() | |
| class Separator: | |
| def __init__( | |
| self, | |
| model: str = "htdemucs", | |
| repo: Optional[Path] = None, | |
| device: str = "cuda" if th.cuda.is_available() else "cpu", | |
| shifts: int = 1, | |
| overlap: float = 0.25, | |
| split: bool = True, | |
| segment: Optional[int] = None, | |
| jobs: int = 0, | |
| progress: bool = False, | |
| callback: Optional[Callable[[dict], None]] = None, | |
| callback_arg: Optional[dict] = None, | |
| ): | |
| """ | |
| `class Separator` | |
| ================= | |
| Parameters | |
| ---------- | |
| model: Pretrained model name or signature. Default is htdemucs. | |
| repo: Folder containing all pre-trained models for use. | |
| segment: Length (in seconds) of each segment (only available if `split` is `True`). If \ | |
| not specified, will use the command line option. | |
| shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \ | |
| apply the oppositve shift to the output. This is repeated `shifts` time and all \ | |
| predictions are averaged. This effectively makes the model time equivariant and \ | |
| improves SDR by up to 0.2 points. If not specified, will use the command line option. | |
| split: If True, the input will be broken down into small chunks (length set by `segment`) \ | |
| and predictions will be performed individually on each and concatenated. Useful for \ | |
| model with large memory footprint like Tasnet. If not specified, will use the command \ | |
| line option. | |
| overlap: The overlap between the splits. If not specified, will use the command line \ | |
| option. | |
| device (torch.device, str, or None): If provided, device on which to execute the \ | |
| computation, otherwise `wav.device` is assumed. When `device` is different from \ | |
| `wav.device`, only local computations will be on `device`, while the entire tracks \ | |
| will be stored on `wav.device`. If not specified, will use the command line option. | |
| jobs: Number of jobs. This can increase memory usage but will be much faster when \ | |
| multiple cores are available. If not specified, will use the command line option. | |
| callback: A function will be called when the separation of a chunk starts or finished. \ | |
| The argument passed to the function will be a dict. For more information, please see \ | |
| the Callback section. | |
| callback_arg: A dict containing private parameters to be passed to callback function. For \ | |
| more information, please see the Callback section. | |
| progress: If true, show a progress bar. | |
| Callback | |
| -------- | |
| The function will be called with only one positional parameter whose type is `dict`. The | |
| `callback_arg` will be combined with information of current separation progress. The | |
| progress information will override the values in `callback_arg` if same key has been used. | |
| To abort the separation, raise `KeyboardInterrupt`. | |
| Progress information contains several keys (These keys will always exist): | |
| - `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0. | |
| - `shift_idx`: The index of shifts. Starts from 0. | |
| - `segment_offset`: The offset of current segment. If the number is 441000, it doesn't | |
| mean that it is at the 441000 second of the audio, but the "frame" of the tensor. | |
| - `state`: Could be `"start"` or `"end"`. | |
| - `audio_length`: Length of the audio (in "frame" of the tensor). | |
| - `models`: Count of submodels in the model. | |
| """ | |
| self._name = model | |
| self._repo = repo | |
| self._load_model() | |
| self.update_parameter(device=device, shifts=shifts, overlap=overlap, split=split, | |
| segment=segment, jobs=jobs, progress=progress, callback=callback, | |
| callback_arg=callback_arg) | |
| def update_parameter( | |
| self, | |
| device: Union[str, _NotProvided] = NotProvided, | |
| shifts: Union[int, _NotProvided] = NotProvided, | |
| overlap: Union[float, _NotProvided] = NotProvided, | |
| split: Union[bool, _NotProvided] = NotProvided, | |
| segment: Optional[Union[int, _NotProvided]] = NotProvided, | |
| jobs: Union[int, _NotProvided] = NotProvided, | |
| progress: Union[bool, _NotProvided] = NotProvided, | |
| callback: Optional[ | |
| Union[Callable[[dict], None], _NotProvided] | |
| ] = NotProvided, | |
| callback_arg: Optional[Union[dict, _NotProvided]] = NotProvided, | |
| ): | |
| """ | |
| Update the parameters of separation. | |
| Parameters | |
| ---------- | |
| segment: Length (in seconds) of each segment (only available if `split` is `True`). If \ | |
| not specified, will use the command line option. | |
| shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \ | |
| apply the oppositve shift to the output. This is repeated `shifts` time and all \ | |
| predictions are averaged. This effectively makes the model time equivariant and \ | |
| improves SDR by up to 0.2 points. If not specified, will use the command line option. | |
| split: If True, the input will be broken down into small chunks (length set by `segment`) \ | |
| and predictions will be performed individually on each and concatenated. Useful for \ | |
| model with large memory footprint like Tasnet. If not specified, will use the command \ | |
| line option. | |
| overlap: The overlap between the splits. If not specified, will use the command line \ | |
| option. | |
| device (torch.device, str, or None): If provided, device on which to execute the \ | |
| computation, otherwise `wav.device` is assumed. When `device` is different from \ | |
| `wav.device`, only local computations will be on `device`, while the entire tracks \ | |
| will be stored on `wav.device`. If not specified, will use the command line option. | |
| jobs: Number of jobs. This can increase memory usage but will be much faster when \ | |
| multiple cores are available. If not specified, will use the command line option. | |
| callback: A function will be called when the separation of a chunk starts or finished. \ | |
| The argument passed to the function will be a dict. For more information, please see \ | |
| the Callback section. | |
| callback_arg: A dict containing private parameters to be passed to callback function. For \ | |
| more information, please see the Callback section. | |
| progress: If true, show a progress bar. | |
| Callback | |
| -------- | |
| The function will be called with only one positional parameter whose type is `dict`. The | |
| `callback_arg` will be combined with information of current separation progress. The | |
| progress information will override the values in `callback_arg` if same key has been used. | |
| To abort the separation, raise `KeyboardInterrupt`. | |
| Progress information contains several keys (These keys will always exist): | |
| - `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0. | |
| - `shift_idx`: The index of shifts. Starts from 0. | |
| - `segment_offset`: The offset of current segment. If the number is 441000, it doesn't | |
| mean that it is at the 441000 second of the audio, but the "frame" of the tensor. | |
| - `state`: Could be `"start"` or `"end"`. | |
| - `audio_length`: Length of the audio (in "frame" of the tensor). | |
| - `models`: Count of submodels in the model. | |
| """ | |
| if not isinstance(device, _NotProvided): | |
| self._device = device | |
| if not isinstance(shifts, _NotProvided): | |
| self._shifts = shifts | |
| if not isinstance(overlap, _NotProvided): | |
| self._overlap = overlap | |
| if not isinstance(split, _NotProvided): | |
| self._split = split | |
| if not isinstance(segment, _NotProvided): | |
| self._segment = segment | |
| if not isinstance(jobs, _NotProvided): | |
| self._jobs = jobs | |
| if not isinstance(progress, _NotProvided): | |
| self._progress = progress | |
| if not isinstance(callback, _NotProvided): | |
| self._callback = callback | |
| if not isinstance(callback_arg, _NotProvided): | |
| self._callback_arg = callback_arg | |
| def _load_model(self): | |
| self._model = get_model(name=self._name, repo=self._repo) | |
| if self._model is None: | |
| raise LoadModelError("Failed to load model") | |
| self._audio_channels = self._model.audio_channels | |
| self._samplerate = self._model.samplerate | |
| def _load_audio(self, track: Path): | |
| errors = {} | |
| wav = None | |
| try: | |
| wav = AudioFile(track).read(streams=0, samplerate=self._samplerate, | |
| channels=self._audio_channels) | |
| except FileNotFoundError: | |
| errors["ffmpeg"] = "FFmpeg is not installed." | |
| except subprocess.CalledProcessError: | |
| errors["ffmpeg"] = "FFmpeg could not read the file." | |
| if wav is None: | |
| try: | |
| wav, sr = ta.load(str(track)) | |
| except RuntimeError as err: | |
| errors["torchaudio"] = err.args[0] | |
| else: | |
| wav = convert_audio(wav, sr, self._samplerate, self._audio_channels) | |
| if wav is None: | |
| raise LoadAudioError( | |
| "\n".join( | |
| "When trying to load using {}, got the following error: {}".format( | |
| backend, error | |
| ) | |
| for backend, error in errors.items() | |
| ) | |
| ) | |
| return wav | |
| def separate_tensor( | |
| self, wav: th.Tensor, sr: Optional[int] = None | |
| ) -> Tuple[th.Tensor, Dict[str, th.Tensor]]: | |
| """ | |
| Separate a loaded tensor. | |
| Parameters | |
| ---------- | |
| wav: Waveform of the audio. Should have 2 dimensions, the first is each audio channel, \ | |
| while the second is the waveform of each channel. Type should be float32. \ | |
| e.g. `tuple(wav.shape) == (2, 884000)` means the audio has 2 channels. | |
| sr: Sample rate of the original audio, the wave will be resampled if it doesn't match the \ | |
| model. | |
| Returns | |
| ------- | |
| A tuple, whose first element is the original wave and second element is a dict, whose keys | |
| are the name of stems and values are separated waves. The original wave will have already | |
| been resampled. | |
| Notes | |
| ----- | |
| Use this function with cautiousness. This function does not provide data verifying. | |
| """ | |
| if sr is not None and sr != self.samplerate: | |
| wav = convert_audio(wav, sr, self._samplerate, self._audio_channels) | |
| ref = wav.mean(0) | |
| wav -= ref.mean() | |
| wav /= ref.std() + 1e-8 | |
| out = apply_model( | |
| self._model, | |
| wav[None], | |
| segment=self._segment, | |
| shifts=self._shifts, | |
| split=self._split, | |
| overlap=self._overlap, | |
| device=self._device, | |
| num_workers=self._jobs, | |
| callback=self._callback, | |
| callback_arg=_replace_dict( | |
| self._callback_arg, ("audio_length", wav.shape[1]) | |
| ), | |
| progress=self._progress, | |
| ) | |
| if out is None: | |
| raise KeyboardInterrupt | |
| out *= ref.std() + 1e-8 | |
| out += ref.mean() | |
| wav *= ref.std() + 1e-8 | |
| wav += ref.mean() | |
| return (wav, dict(zip(self._model.sources, out[0]))) | |
| def separate_audio_file(self, file: Path): | |
| """ | |
| Separate an audio file. The method will automatically read the file. | |
| Parameters | |
| ---------- | |
| wav: Path of the file to be separated. | |
| Returns | |
| ------- | |
| A tuple, whose first element is the original wave and second element is a dict, whose keys | |
| are the name of stems and values are separated waves. The original wave will have already | |
| been resampled. | |
| """ | |
| return self.separate_tensor(self._load_audio(file), self.samplerate) | |
| def samplerate(self): | |
| return self._samplerate | |
| def audio_channels(self): | |
| return self._audio_channels | |
| def model(self): | |
| return self._model | |
| def list_models(repo: Optional[Path] = None) -> Dict[str, Dict[str, Union[str, Path]]]: | |
| """ | |
| List the available models. Please remember that not all the returned models can be | |
| successfully loaded. | |
| Parameters | |
| ---------- | |
| repo: The repo whose models are to be listed. | |
| Returns | |
| ------- | |
| A dict with two keys ("single" for single models and "bag" for bag of models). The values are | |
| lists whose components are strs. | |
| """ | |
| model_repo: ModelOnlyRepo | |
| if repo is None: | |
| models = _parse_remote_files(REMOTE_ROOT / 'files.txt') | |
| model_repo = RemoteRepo(models) | |
| bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo) | |
| else: | |
| if not repo.is_dir(): | |
| fatal(f"{repo} must exist and be a directory.") | |
| model_repo = LocalRepo(repo) | |
| bag_repo = BagOnlyRepo(repo, model_repo) | |
| return {"single": model_repo.list_model(), "bag": bag_repo.list_model()} | |
| if __name__ == "__main__": | |
| # Test API functions | |
| # two-stem not supported | |
| from .separate import get_parser | |
| args = get_parser().parse_args() | |
| separator = Separator( | |
| model=args.name, | |
| repo=args.repo, | |
| device=args.device, | |
| shifts=args.shifts, | |
| overlap=args.overlap, | |
| split=args.split, | |
| segment=args.segment, | |
| jobs=args.jobs, | |
| callback=print | |
| ) | |
| out = args.out / args.name | |
| out.mkdir(parents=True, exist_ok=True) | |
| for file in args.tracks: | |
| separated = separator.separate_audio_file(file)[1] | |
| if args.mp3: | |
| ext = "mp3" | |
| elif args.flac: | |
| ext = "flac" | |
| else: | |
| ext = "wav" | |
| kwargs = { | |
| "samplerate": separator.samplerate, | |
| "bitrate": args.mp3_bitrate, | |
| "clip": args.clip_mode, | |
| "as_float": args.float32, | |
| "bits_per_sample": 24 if args.int24 else 16, | |
| } | |
| for stem, source in separated.items(): | |
| stem = out / args.filename.format( | |
| track=Path(file).name.rsplit(".", 1)[0], | |
| trackext=Path(file).name.rsplit(".", 1)[-1], | |
| stem=stem, | |
| ext=ext, | |
| ) | |
| stem.parent.mkdir(parents=True, exist_ok=True) | |
| save_audio(source, str(stem), **kwargs) | |