Spaces:
Running
Running
| import os | |
| import torch | |
| import gradio as gr | |
| import logging | |
| import subprocess | |
| from pydub import AudioSegment | |
| from pydub.exceptions import CouldntDecodeError | |
| from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor | |
| from pathlib import Path | |
| from tempfile import NamedTemporaryFile | |
| from datetime import timedelta | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # Configuration | |
| MODEL_ID = "KBLab/kb-whisper-large" | |
| CHUNK_DURATION_MS = 10000 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| SUPPORTED_FORMATS = {".wav", ".mp3", ".m4a"} | |
| # Check for ffmpeg availability | |
| def check_ffmpeg(): | |
| try: | |
| subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) | |
| logger.info("ffmpeg is installed and accessible.") | |
| return True | |
| except (subprocess.CalledProcessError, FileNotFoundError): | |
| logger.error("ffmpeg is not installed or not found in PATH.") | |
| return False | |
| # Initialize model and pipeline | |
| def initialize_pipeline(): | |
| try: | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=TORCH_DTYPE, | |
| low_cpu_mem_usage=True | |
| ).to(DEVICE) | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| return pipeline( | |
| "automatic-speech-recognition", | |
| model=model, | |
| tokenizer=processor.tokenizer, | |
| feature_extractor=processor.feature_extractor, | |
| device=DEVICE, | |
| torch_dtype=TORCH_DTYPE, | |
| model_kwargs={"use_flash_attention_2": torch.cuda.is_available()} | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to initialize pipeline: {str(e)}") | |
| raise RuntimeError("Unable to load transcription model. Please check your network connection or model ID.") | |
| # Convert audio if needed | |
| def convert_to_wav(audio_path: str) -> str: | |
| try: | |
| if not check_ffmpeg(): | |
| raise RuntimeError("ffmpeg is required to process .m4a files. Please install ffmpeg and ensure it's in your PATH.") | |
| ext = str(Path(audio_path).suffix).lower() | |
| if ext not in SUPPORTED_FORMATS: | |
| raise ValueError(f"Unsupported audio format: {ext}. Supported formats: {', '.join(SUPPORTED_FORMATS)}") | |
| if ext != ".wav": | |
| logger.info(f"Converting {ext} file to WAV: {audio_path}") | |
| audio = AudioSegment.from_file(audio_path) | |
| wav_path = str(Path(audio_path).with_suffix(".converted.wav")) | |
| audio.export(wav_path, format="wav") | |
| logger.info(f"Conversion successful: {wav_path}") | |
| return wav_path | |
| return audio_path | |
| except CouldntDecodeError: | |
| logger.error(f"Failed to decode .m4a file: {audio_path}") | |
| raise ValueError("The .m4a file is corrupted or not supported. Ensure it's a valid iPhone recording and ffmpeg is installed.") | |
| except OSError as e: | |
| logger.error(f"OS error during audio conversion: {str(e)}") | |
| raise ValueError("Failed to process the .m4a file due to a system error. Check file permissions or disk space.") | |
| except Exception as e: | |
| logger.error(f"Unexpected error during .m4a conversion: {str(e)}") | |
| raise ValueError(f"An unexpected error occurred while converting the .m4a file: {str(e)}") | |
| # Split audio into chunks | |
| def split_audio(audio_path: str) -> list: | |
| try: | |
| audio = AudioSegment.from_file(audio_path) | |
| if len(audio) == 0: | |
| raise ValueError("The .m4a file is empty or invalid.") | |
| logger.info(f"Splitting audio into {CHUNK_DURATION_MS/1000}-second chunks: {audio_path}") | |
| return [audio[i:i + CHUNK_DURATION_MS] for i in range(0, len(audio), CHUNK_DURATION_MS)] | |
| except CouldntDecodeError: | |
| logger.error(f"Failed to decode audio for splitting: {audio_path}") | |
| raise ValueError("The .m4a file is corrupted or not supported. Ensure it's a valid iPhone recording.") | |
| except Exception as e: | |
| logger.error(f"Failed to split audio: {str(e)}") | |
| raise ValueError(f"Failed to process the .m4a file: {str(e)}") | |
| # Helper to compute chunk start time | |
| def get_chunk_time(index: int, chunk_duration_ms: int) -> str: | |
| start_ms = index * chunk_duration_ms | |
| return str(timedelta(milliseconds=start_ms)) | |
| # Transcribe audio with progress and timestamps | |
| def transcribe(audio_path: str, include_timestamps: bool = False, progress=gr.Progress()): | |
| try: | |
| if not audio_path or not os.path.exists(audio_path): | |
| logger.warning("Invalid or missing audio file path.") | |
| return "Please upload a valid .m4a file.", None | |
| # Convert to WAV if needed | |
| wav_path = convert_to_wav(audio_path) | |
| # Split and process | |
| chunks = split_audio(wav_path) | |
| total_chunks = len(chunks) | |
| transcript = [] | |
| timestamped_transcript = [] | |
| failed_chunks = 0 | |
| for i, chunk in enumerate(chunks): | |
| temp_file_path = None | |
| try: | |
| with NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: | |
| temp_file_path = temp_file.name | |
| chunk.export(temp_file.name, format="wav") | |
| result = PIPELINE(temp_file.name, | |
| generate_kwargs={"task": "transcribe", "language": "sv"}) | |
| text = result["text"].strip() | |
| if text: | |
| transcript.append(text) | |
| if include_timestamps: | |
| timestamp = get_chunk_time(i, CHUNK_DURATION_MS) | |
| timestamped_transcript.append(f"[{timestamp}] {text}") | |
| except RuntimeError as e: | |
| logger.warning(f"Failed to transcribe chunk {i+1}/{total_chunks}: {str(e)}") | |
| failed_chunks += 1 | |
| transcript.append("[Transcription failed for this segment]") | |
| if include_timestamps: | |
| timestamp = get_chunk_time(i, CHUNK_DURATION_MS) | |
| timestamped_transcript.append(f"[{timestamp}] [Transcription failed]") | |
| except Exception as e: | |
| logger.error(f"Unexpected error in chunk {i+1}/{total_chunks}: {str(e)}") | |
| failed_chunks += 1 | |
| transcript.append("[Transcription failed for this segment]") | |
| if include_timestamps: | |
| timestamp = get_chunk_time(i, CHUNK_DURATION_MS) | |
| timestamped_transcript.append(f"[{timestamp}] [Transcription failed]") | |
| finally: | |
| if temp_file_path and os.path.exists(temp_file_path): | |
| try: | |
| os.remove(temp_file_path) | |
| except OSError as e: | |
| logger.warning(f"Failed to delete temporary file {temp_file_path}: {str(e)}") | |
| progress((i + 1) / total_chunks) | |
| yield " ".join(transcript), None | |
| # Clean up converted file if created | |
| if wav_path != audio_path and os.path.exists(wav_path): | |
| try: | |
| os.remove(wav_path) | |
| except OSError as e: | |
| logger.warning(f"Failed to delete converted WAV file {wav_path}: {str(e)}") | |
| # Prepare final transcript and downloadable file | |
| final_transcript = " ".join(transcript) | |
| if failed_chunks > 0: | |
| final_transcript = f"Warning: {failed_chunks}/{total_chunks} chunks failed to transcribe.\n{final_transcript}" | |
| download_content = "\n".join(timestamped_transcript) if include_timestamps else final_transcript | |
| download_path = None | |
| try: | |
| with NamedTemporaryFile(suffix=".txt", delete=False, mode='w', encoding='utf-8') as temp_file: | |
| temp_file.write(download_content) | |
| download_path = temp_file.name | |
| except OSError as e: | |
| logger.error(f"Failed to create downloadable transcript: {str(e)}") | |
| final_transcript = f"{final_transcript}\nNote: Could not generate downloadable transcript due to a file error." | |
| return final_transcript, download_path | |
| except ValueError as e: | |
| logger.error(f"Value error during transcription: {str(e)}") | |
| return str(e), None | |
| except Exception as e: | |
| logger.error(f"Unexpected error during transcription: {str(e)}") | |
| return f"An unexpected error occurred while processing the .m4a file: {str(e)}. Please ensure the file is a valid iPhone recording and try again.", None | |
| # Initialize pipeline globally | |
| try: | |
| PIPELINE = initialize_pipeline() | |
| except RuntimeError as e: | |
| logger.critical(f"Pipeline initialization failed: {str(e)}") | |
| raise | |
| # Gradio Interface with Blocks | |
| def create_interface(): | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Swedish Whisper Transcriber") | |
| gr.Markdown("Upload an .m4a file from your iPhone for real-time Swedish speech transcription.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio(type="filepath", label="Upload .m4a Audio") | |
| timestamp_toggle = gr.Checkbox(label="Include Timestamps in Download", value=False) | |
| transcribe_btn = gr.Button("Transcribe") | |
| with gr.Column(): | |
| transcript_output = gr.Textbox(label="Live Transcription", lines=10) | |
| download_output = gr.File(label="Download Transcript") | |
| transcribe_btn.click( | |
| fn=transcribe, | |
| inputs=[audio_input, timestamp_toggle], | |
| outputs=[transcript_output, download_output] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| try: | |
| if not check_ffmpeg(): | |
| print("Error: ffmpeg is required to process .m4a files. Please install ffmpeg and ensure it's in your PATH.") | |
| exit(1) | |
| create_interface().launch() | |
| except Exception as e: | |
| logger.critical(f"Failed to launch Gradio interface: {str(e)}") | |
| print(f"Error: Could not start the application. Please check the logs for details.") |