File size: 3,641 Bytes
719134e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the NVIDIA Open Model License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
 Load one of the NeMo speaker diarization models:
 [Streaming Sortformer Diarizer v2](https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2.1),
 [Streaming Sortformer Diarizer v2.1](https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2.1)
""" 
```python
from nemo.collections.asr.models import SortformerEncLabelModel, ASRModel
import torch
# A speaker diarization model is needed for tracking the speech activity of each speaker.
diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2.1").eval().to(torch.device("cuda"))
asr_model = ASRModel.from_pretrained("nvidia/multitalker-parakeet-streaming-0.6b-v1.nemo").eval().to(torch.device("cuda"))

# Use the pre-defined dataclass template `MultitalkerTranscriptionConfig` from `multitalker_transcript_config.py`. 
# Configure the diarization model using streaming parameters:
from multitalker_transcript_config import MultitalkerTranscriptionConfig
from omegaconf import OmegaConf
cfg = OmegaConf.structured(MultitalkerTranscriptionConfig())
cfg.audio_file = "/path/to/your/audio.wav"
cfg.output_path = "/path/to/output_transcription.json"

diar_model = MultitalkerTranscriptionConfig.init_diar_model(cfg, diar_model)

# Load your audio file into a streaming audio buffer to simulate a real-time audio session.
from nemo.collections.asr.parts.utils.streaming_utils import CacheAwareStreamingAudioBuffer

samples = [{'audio_filepath': cfg.audio_file}]
streaming_buffer = CacheAwareStreamingAudioBuffer(
    model=asr_model,
    online_normalization=cfg.online_normalization,
    pad_and_drop_preencoded=cfg.pad_and_drop_preencoded,
)
streaming_buffer.append_audio_file(audio_filepath=cfg.audio_file, stream_id=-1)
streaming_buffer_iter = iter(streaming_buffer)

# Use the helper class `SpeakerTaggedASR`, which handles all ASR and diarization cache data for streaming.
from nemo.collections.asr.parts.utils.multispk_transcribe_utils import SpeakerTaggedASR
multispk_asr_streamer = SpeakerTaggedASR(cfg, asr_model, diar_model)

for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter):
    drop_extra_pre_encoded = (
        0
        if step_num == 0 and not cfg.pad_and_drop_preencoded
        else asr_model.encoder.streaming_cfg.drop_extra_pre_encoded
    )
    with torch.inference_mode():
        with torch.amp.autocast(diar_model.device.type, enabled=True):
            with torch.no_grad():
                multispk_asr_streamer.perform_parallel_streaming_stt_spk(
                    step_num=step_num,
                    chunk_audio=chunk_audio,
                    chunk_lengths=chunk_lengths,
                    is_buffer_empty=streaming_buffer.is_buffer_empty(),
                    drop_extra_pre_encoded=drop_extra_pre_encoded,
                )

# Generate the speaker-tagged transcript and print it.
multispk_asr_streamer.generate_seglst_dicts_from_parallel_streaming(samples=samples)
print(multispk_asr_streamer.instance_manager.seglst_dict_list)