|
|
from __future__ import annotations |
|
|
|
|
|
import torch |
|
|
|
|
|
import asyncio |
|
|
from queue import Queue |
|
|
from typing import TYPE_CHECKING, Optional |
|
|
|
|
|
|
|
|
from transformers.generation import BaseStreamer |
|
|
|
|
|
|
|
|
class AudioStreamer(BaseStreamer): |
|
|
""" |
|
|
Audio streamer that stores audio chunks in queues for each sample in the batch. |
|
|
This allows streaming audio generation for multiple samples simultaneously. |
|
|
|
|
|
Parameters: |
|
|
batch_size (`int`): |
|
|
The batch size for generation |
|
|
stop_signal (`any`, *optional*): |
|
|
The signal to put in the queue when generation ends. Defaults to None. |
|
|
timeout (`float`, *optional*): |
|
|
The timeout for the audio queue. If `None`, the queue will block indefinitely. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
batch_size: int, |
|
|
stop_signal: Optional[any] = None, |
|
|
timeout: Optional[float] = None, |
|
|
): |
|
|
self.batch_size = batch_size |
|
|
self.stop_signal = stop_signal |
|
|
self.timeout = timeout |
|
|
|
|
|
|
|
|
self.audio_queues = [Queue() for _ in range(batch_size)] |
|
|
self.finished_flags = [False for _ in range(batch_size)] |
|
|
self.sample_indices_map = {} |
|
|
|
|
|
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor): |
|
|
""" |
|
|
Receives audio chunks and puts them in the appropriate queues. |
|
|
|
|
|
Args: |
|
|
audio_chunks: Tensor of shape (num_samples, ...) containing audio chunks |
|
|
sample_indices: Tensor indicating which samples these chunks belong to |
|
|
""" |
|
|
for i, sample_idx in enumerate(sample_indices): |
|
|
idx = sample_idx.item() |
|
|
if idx < self.batch_size and not self.finished_flags[idx]: |
|
|
|
|
|
audio_chunk = audio_chunks[i].detach().cpu() |
|
|
self.audio_queues[idx].put(audio_chunk, timeout=self.timeout) |
|
|
|
|
|
def end(self, sample_indices: Optional[torch.Tensor] = None): |
|
|
""" |
|
|
Signals the end of generation for specified samples or all samples. |
|
|
|
|
|
Args: |
|
|
sample_indices: Optional tensor of sample indices to end. If None, ends all. |
|
|
""" |
|
|
if sample_indices is None: |
|
|
|
|
|
for idx in range(self.batch_size): |
|
|
if not self.finished_flags[idx]: |
|
|
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout) |
|
|
self.finished_flags[idx] = True |
|
|
else: |
|
|
|
|
|
for sample_idx in sample_indices: |
|
|
idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx |
|
|
if idx < self.batch_size and not self.finished_flags[idx]: |
|
|
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout) |
|
|
self.finished_flags[idx] = True |
|
|
|
|
|
def __iter__(self): |
|
|
"""Returns an iterator over the batch of audio streams.""" |
|
|
return AudioBatchIterator(self) |
|
|
|
|
|
def get_stream(self, sample_idx: int): |
|
|
"""Get the audio stream for a specific sample.""" |
|
|
if sample_idx >= self.batch_size: |
|
|
raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}") |
|
|
return AudioSampleIterator(self, sample_idx) |
|
|
|
|
|
|
|
|
class AudioSampleIterator: |
|
|
"""Iterator for a single audio stream from the batch.""" |
|
|
|
|
|
def __init__(self, streamer: AudioStreamer, sample_idx: int): |
|
|
self.streamer = streamer |
|
|
self.sample_idx = sample_idx |
|
|
|
|
|
def __iter__(self): |
|
|
return self |
|
|
|
|
|
def __next__(self): |
|
|
value = self.streamer.audio_queues[self.sample_idx].get(timeout=self.streamer.timeout) |
|
|
if value == self.streamer.stop_signal: |
|
|
raise StopIteration() |
|
|
return value |
|
|
|
|
|
|
|
|
class AudioBatchIterator: |
|
|
"""Iterator that yields audio chunks for all samples in the batch.""" |
|
|
|
|
|
def __init__(self, streamer: AudioStreamer): |
|
|
self.streamer = streamer |
|
|
self.active_samples = set(range(streamer.batch_size)) |
|
|
|
|
|
def __iter__(self): |
|
|
return self |
|
|
|
|
|
def __next__(self): |
|
|
if not self.active_samples: |
|
|
raise StopIteration() |
|
|
|
|
|
batch_chunks = {} |
|
|
samples_to_remove = set() |
|
|
|
|
|
|
|
|
for idx in self.active_samples: |
|
|
try: |
|
|
value = self.streamer.audio_queues[idx].get(block=False) |
|
|
if value == self.streamer.stop_signal: |
|
|
samples_to_remove.add(idx) |
|
|
else: |
|
|
batch_chunks[idx] = value |
|
|
except: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
self.active_samples -= samples_to_remove |
|
|
|
|
|
if batch_chunks: |
|
|
return batch_chunks |
|
|
elif self.active_samples: |
|
|
|
|
|
|
|
|
import time |
|
|
time.sleep(0.01) |
|
|
return self.__next__() |
|
|
else: |
|
|
raise StopIteration() |
|
|
|
|
|
|
|
|
class AsyncAudioStreamer(AudioStreamer): |
|
|
""" |
|
|
Async version of AudioStreamer for use in async contexts. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
batch_size: int, |
|
|
stop_signal: Optional[any] = None, |
|
|
timeout: Optional[float] = None, |
|
|
): |
|
|
super().__init__(batch_size, stop_signal, timeout) |
|
|
|
|
|
self.audio_queues = [asyncio.Queue() for _ in range(batch_size)] |
|
|
self.loop = asyncio.get_running_loop() |
|
|
|
|
|
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor): |
|
|
"""Put audio chunks in the appropriate async queues.""" |
|
|
for i, sample_idx in enumerate(sample_indices): |
|
|
idx = sample_idx.item() |
|
|
if idx < self.batch_size and not self.finished_flags[idx]: |
|
|
audio_chunk = audio_chunks[i].detach().cpu() |
|
|
self.loop.call_soon_threadsafe( |
|
|
self.audio_queues[idx].put_nowait, audio_chunk |
|
|
) |
|
|
|
|
|
def end(self, sample_indices: Optional[torch.Tensor] = None): |
|
|
"""Signal the end of generation for specified samples.""" |
|
|
if sample_indices is None: |
|
|
indices_to_end = range(self.batch_size) |
|
|
else: |
|
|
indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices] |
|
|
|
|
|
for idx in indices_to_end: |
|
|
if idx < self.batch_size and not self.finished_flags[idx]: |
|
|
self.loop.call_soon_threadsafe( |
|
|
self.audio_queues[idx].put_nowait, self.stop_signal |
|
|
) |
|
|
self.finished_flags[idx] = True |
|
|
|
|
|
async def get_stream(self, sample_idx: int): |
|
|
"""Get async iterator for a specific sample's audio stream.""" |
|
|
if sample_idx >= self.batch_size: |
|
|
raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}") |
|
|
|
|
|
while True: |
|
|
value = await self.audio_queues[sample_idx].get() |
|
|
if value == self.stop_signal: |
|
|
break |
|
|
yield value |
|
|
|
|
|
def __aiter__(self): |
|
|
"""Returns an async iterator over all audio streams.""" |
|
|
return AsyncAudioBatchIterator(self) |
|
|
|
|
|
|
|
|
class AsyncAudioBatchIterator: |
|
|
"""Async iterator for batch audio streaming.""" |
|
|
|
|
|
def __init__(self, streamer: AsyncAudioStreamer): |
|
|
self.streamer = streamer |
|
|
self.active_samples = set(range(streamer.batch_size)) |
|
|
|
|
|
def __aiter__(self): |
|
|
return self |
|
|
|
|
|
async def __anext__(self): |
|
|
if not self.active_samples: |
|
|
raise StopAsyncIteration() |
|
|
|
|
|
batch_chunks = {} |
|
|
samples_to_remove = set() |
|
|
|
|
|
|
|
|
tasks = { |
|
|
idx: asyncio.create_task(self._get_chunk(idx)) |
|
|
for idx in self.active_samples |
|
|
} |
|
|
|
|
|
|
|
|
done, pending = await asyncio.wait( |
|
|
tasks.values(), |
|
|
return_when=asyncio.FIRST_COMPLETED, |
|
|
timeout=self.streamer.timeout |
|
|
) |
|
|
|
|
|
|
|
|
for task in pending: |
|
|
task.cancel() |
|
|
|
|
|
|
|
|
for idx, task in tasks.items(): |
|
|
if task in done: |
|
|
try: |
|
|
value = await task |
|
|
if value == self.streamer.stop_signal: |
|
|
samples_to_remove.add(idx) |
|
|
else: |
|
|
batch_chunks[idx] = value |
|
|
except asyncio.CancelledError: |
|
|
pass |
|
|
|
|
|
self.active_samples -= samples_to_remove |
|
|
|
|
|
if batch_chunks: |
|
|
return batch_chunks |
|
|
elif self.active_samples: |
|
|
|
|
|
return await self.__anext__() |
|
|
else: |
|
|
raise StopAsyncIteration() |
|
|
|
|
|
async def _get_chunk(self, idx): |
|
|
"""Helper to get a chunk from a specific queue.""" |
|
|
return await self.streamer.audio_queues[idx].get() |