Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| from pydub import AudioSegment | |
| from pyannote.audio import Pipeline | |
| from huggingface_hub import login | |
| import numpy as np | |
| import json | |
| # Authenticate with Huggingface | |
| AUTH_TOKEN = os.getenv("HF_TOKEN") | |
| # Load the diarization pipeline | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.0", use_auth_token = AUTH_TOKEN).to(device) | |
| def preprocess_audio(audio_path): | |
| """Convert audio to mono, 16kHz WAV format suitable for pyannote.""" | |
| try: | |
| # Load audio with pydub | |
| audio = AudioSegment.from_file(audio_path) | |
| # Convert to mono and set sample rate to 16kHz | |
| audio = audio.set_channels(1).set_frame_rate(16000) | |
| # Export to temporary WAV file | |
| temp_wav = "temp_audio.wav" | |
| audio.export(temp_wav, format="wav") | |
| return temp_wav | |
| except Exception as e: | |
| raise ValueError(f"Error preprocessing audio: {str(e)}") | |
| def diarize_audio(audio_path, num_speakers): | |
| """Perform speaker diarization and return formatted results.""" | |
| try: | |
| # Validate inputs | |
| if not os.path.exists(audio_path): | |
| raise ValueError("Audio file not found.") | |
| if not isinstance(num_speakers, int) or num_speakers < 1: | |
| raise ValueError("Number of speakers must be a positive integer.") | |
| # Preprocess audio | |
| wav_path = preprocess_audio(audio_path) | |
| # Load audio for pyannote | |
| waveform, sample_rate = torchaudio.load(wav_path) | |
| audio_dict = {"waveform": waveform, "sample_rate": sample_rate} | |
| # Configure pipeline with number of speakers | |
| pipeline_params = {"num_speakers": num_speakers} | |
| diarization = pipeline(audio_dict, **pipeline_params) | |
| # Format results | |
| results = [] | |
| text_output = "" | |
| for turn, _, speaker in diarization.itertracks(yield_label=True): | |
| result = { | |
| "start": round(turn.start, 3), | |
| "end": round(turn.end, 3), | |
| "speaker_id": speaker | |
| } | |
| results.append(result) | |
| text_output += f"Speaker {speaker}: {result['start']}s - {result['end']}s\n" | |
| # Clean up temporary file | |
| if os.path.exists(wav_path): | |
| os.remove(wav_path) | |
| # Return text and JSON results | |
| json_output = json.dumps(results, indent=2) | |
| return text_output, json_output | |
| except Exception as e: | |
| return f"Error: {str(e)}", "" | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Speaker Diarization with Pyannote 3.0") | |
| gr.Markdown("Upload an audio file and specify the number of speakers to diarize the audio.") | |
| with gr.Row(): | |
| audio_input = gr.Audio(label="Upload Audio File", type="filepath") | |
| num_speakers = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Speakers", value=2) | |
| submit_btn = gr.Button("Diarize") | |
| with gr.Row(): | |
| text_output = gr.Textbox(label="Diarization Results (Text)") | |
| json_output = gr.Textbox(label="Diarization Results (JSON)") | |
| submit_btn.click( | |
| fn=diarize_audio, | |
| inputs=[audio_input, num_speakers], | |
| outputs=[text_output, json_output] | |
| ) | |
| # Launch the Gradio app | |
| demo.launch() | |