from __future__ import annotations import torch import torchaudio import gradio as gr import spaces import requests import json import os import re import tempfile import shutil from typing import List, Dict, Optional, Union from transformers import AutoModel, AutoModelForAudioClassification, Wav2Vec2FeatureExtractor from transformers import WhisperProcessor, WhisperForConditionalGeneration import google.generativeai as genai from pathlib import Path DESCRIPTION = "Question Generation" device = "cuda" if torch.cuda.is_available() else "cpu" # --- API Configuration --- API_BASE_URL = "https://e1389f0b40fe.ngrok-free.app" HEADERS = { "Content-Type": "application/json", "ngrok-skip-browser-warning": "true" # Skip ngrok browser warning } # Google Gemini API Configuration - Using your existing API key GOOGLE_API_KEY = "AIzaSyCISuF4HL5EAc8iqMuJKOS5GipMT_ujbzw" os.environ['GOOGLE_API_KEY'] = GOOGLE_API_KEY genai.configure(api_key=GOOGLE_API_KEY) # --- Model Loading --- print("Loading ASR model (IndicConformer)...") asr_model_id = "ai4bharat/indic-conformer-600m-multilingual" asr_model = AutoModel.from_pretrained(asr_model_id, trust_remote_code=True).to(device) asr_model.eval() print("βœ… ASR Model loaded.") print("\nLoading Whisper model for English...") model_name = "openai/whisper-small" whisper_processor = WhisperProcessor.from_pretrained(model_name) whisper_model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device) print("βœ… Whisper Model loaded.") print("\nLoading Language ID model (MMS-LID-1024)...") lid_model_id = "facebook/mms-lid-1024" lid_processor = Wav2Vec2FeatureExtractor.from_pretrained(lid_model_id) lid_model = AutoModelForAudioClassification.from_pretrained(lid_model_id).to(device) lid_model.eval() print("βœ… Language ID Model loaded.") # [Keep all your existing mappings] LID_TO_ASR_LANG_MAP = { "asm_Beng": "as", "ben_Beng": "bn", "brx_Deva": "br", "doi_Deva": "doi", "guj_Gujr": "gu", "hin_Deva": "hi", "kan_Knda": "kn", "kas_Arab": "ks", "kas_Deva": "ks", "gom_Deva": "kok", "mai_Deva": "mai", "mal_Mlym": "ml", "mni_Beng": "mni", "mar_Deva": "mr", "nep_Deva": "ne", "ory_Orya": "or", "pan_Guru": "pa", "san_Deva": "sa", "sat_Olck": "sat", "snd_Arab": "sd", "tam_Taml": "ta", "tel_Telu": "te", "urd_Arab": "ur", "asm": "as", "ben": "bn", "brx": "br", "doi": "doi", "guj": "gu", "hin": "hi", "kan": "kn", "kas": "ks", "gom": "kok", "mai": "mai", "mal": "ml", "mni": "mni", "mar": "mr", "npi": "ne", "ory": "or", "pan": "pa", "san": "sa", "sat": "sat", "snd": "sd", "tam": "ta", "tel": "te", "urd": "ur", "eng": "en" } ASR_CODE_TO_NAME = { "as": "Assamese", "bn": "Bengali", "br": "Bodo", "doi": "Dogri", "gu": "Gujarati", "hi": "Hindi", "kn": "Kannada", "ks": "Kashmiri", "kok": "Konkani", "mai": "Maithili", "ml": "Malayalam", "mni": "Manipuri", "mr": "Marathi", "ne": "Nepali", "or": "Odia", "pa": "Punjabi", "sa": "Sanskrit", "sat": "Santali", "sd": "Sindhi", "ta": "Tamil", "te": "Telugu", "ur": "Urdu", "en": "English" } ASR_TO_INDICTRANS_MAP = { "as": "asm_Beng", "bn": "ben_Beng", "br": "brx_Deva", "doi": "doi_Deva", "gu": "guj_Gujr", "hi": "hin_Deva", "kn": "kan_Knda", "ks": "kas_Deva", "kok": "gom_Deva", "mai": "mai_Deva", "ml": "mal_Mlym", "mni": "mni_Beng", "mr": "mar_Deva", "ne": "nep_Deva", "or": "ory_Orya", "pa": "pan_Guru", "sa": "san_Deva", "sat": "sat_Olck", "sd": "snd_Arab", "ta": "tam_Taml", "te": "tel_Telu", "ur": "urd_Arab", "en": "eng_Latn" } LANGUAGE_OPTIONS = { "English": "eng_Latn", "Hindi": "hin_Deva", "Bengali": "ben_Beng", "Telugu": "tel_Telu", "Tamil": "tam_Taml", "Gujarati": "guj_Gujr", "Kannada": "kan_Knda", "Malayalam": "mal_Mlym", "Marathi": "mar_Deva", "Punjabi": "pan_Guru", "Odia": "ory_Orya", "Assamese": "asm_Beng", "Urdu": "urd_Arab", "Nepali": "nep_Deva", "Sanskrit": "san_Deva", "Kashmiri": "kas_Deva", "Sindhi": "snd_Arab", "Bodo": "brx_Deva", "Dogri": "doi_Deva", "Konkani": "gom_Deva", "Maithili": "mai_Deva", "Manipuri": "mni_Beng", "Santali": "sat_Olck" } # --- ENHANCED AUDIO PREPROCESSING FUNCTIONS --- SUPPORTED_AUDIO_FORMATS = { '.wav', '.mp3', '.flac', '.opus', '.ogg', '.m4a', '.aac', '.mp4', '.wma', '.amr', '.aiff', '.au', '.3gp', '.webm' } def detect_audio_format(audio_path: str) -> str: """ Detect the audio format from file extension. Args: audio_path: Path to audio file Returns: File extension in lowercase """ return Path(audio_path).suffix.lower() def get_optimal_backend(audio_format: str) -> str: """ Get the optimal torchaudio backend for the given format. Args: audio_format: File extension (e.g., '.mp3', '.wav') Returns: Recommended backend name """ # FFmpeg backend handles more formats but may not be available everywhere ffmpeg_formats = {'.mp3', '.opus', '.m4a', '.aac', '.mp4', '.webm', '.3gp'} try: # Check if FFmpeg backend is available backends = torchaudio.list_audio_backends() if 'ffmpeg' in backends and audio_format in ffmpeg_formats: return 'ffmpeg' elif 'sox_io' in backends: return 'sox_io' elif 'soundfile' in backends: return 'soundfile' else: return None # Use default except: return None def convert_to_mono(waveform: torch.Tensor) -> torch.Tensor: """ Convert stereo/multi-channel audio to mono by averaging channels. Args: waveform: Audio tensor with shape [channels, samples] Returns: Mono audio tensor with shape [1, samples] """ if waveform.shape[0] > 1: # Average all channels to create mono waveform = torch.mean(waveform, dim=0, keepdim=True) print(f"πŸ”„ Converted from {waveform.shape[0]} channels to mono") else: print("πŸ“» Audio is already mono") return waveform def preprocess_audio(audio_path: str, target_sr: int = 16000) -> tuple: """ Comprehensive audio preprocessing: load, convert to mono, and resample. Supports multiple audio formats: WAV, MP3, FLAC, OPUS, OGG, M4A, AAC, etc. Args: audio_path: Path to audio file target_sr: Target sampling rate (default: 16000 Hz for ASR) Returns: Tuple of (waveform, sample_rate) preprocessed for ASR """ try: # Detect audio format audio_format = detect_audio_format(audio_path) print(f"🎡 Detected format: {audio_format}") # Check if format is supported if audio_format not in SUPPORTED_AUDIO_FORMATS: print(f"⚠️ Warning: {audio_format} may not be fully supported") # Get optimal backend backend = get_optimal_backend(audio_format) if backend: print(f"πŸ”§ Using {backend} backend for {audio_format}") # Load audio file with appropriate backend try: if backend: waveform, orig_sr = torchaudio.load(audio_path, backend=backend) else: waveform, orig_sr = torchaudio.load(audio_path) except Exception as load_error: print(f"⚠️ Primary load method failed: {str(load_error)}") print("πŸ”„ Trying alternative loading method...") # Fallback: try different backends for fallback_backend in ['ffmpeg', 'sox_io', 'soundfile']: try: backends = torchaudio.list_audio_backends() if fallback_backend in backends: print(f"πŸ”„ Trying {fallback_backend} backend...") waveform, orig_sr = torchaudio.load(audio_path, backend=fallback_backend) print(f"βœ… Successfully loaded with {fallback_backend} backend") break except Exception as e: continue else: # If all backends fail, try without specifying backend try: waveform, orig_sr = torchaudio.load(audio_path) print("βœ… Loaded with default backend") except Exception as final_error: raise Exception(f"Failed to load audio file with any backend: {final_error}") print(f"🎡 Loaded audio: {waveform.shape} at {orig_sr} Hz") # Convert to mono if stereo/multi-channel waveform = convert_to_mono(waveform) # Resample to target sampling rate if needed if orig_sr != target_sr: print(f"πŸ”„ Resampling from {orig_sr} Hz to {target_sr} Hz...") waveform = torchaudio.functional.resample( waveform, orig_freq=orig_sr, new_freq=target_sr ) print(f"βœ… Resampled to {target_sr} Hz") else: print(f"βœ… Audio already at target {target_sr} Hz") print(f"βœ… Final preprocessed audio: {waveform.shape} at {target_sr} Hz") return waveform, target_sr except Exception as e: error_msg = f"❌ Error in audio preprocessing: {str(e)}" print(error_msg) raise Exception(error_msg) def validate_audio_file(audio_path: str) -> bool: """ Validate if the audio file exists and has a supported format. Args: audio_path: Path to audio file Returns: True if valid, False otherwise """ if not audio_path or not os.path.exists(audio_path): return False audio_format = detect_audio_format(audio_path) return audio_format in SUPPORTED_AUDIO_FORMATS # --- TEMPORARY STORAGE FUNCTIONS --- def create_temp_audio_file(audio_path: str) -> str: """ Create a temporary copy of the audio file for processing. Returns the path to the temporary file. """ if not audio_path: return None try: # Validate input file first if not validate_audio_file(audio_path): print(f"⚠️ Warning: Audio file may not be valid or supported") # Create a unique temporary file with original extension original_ext = detect_audio_format(audio_path) temp_fd, temp_path = tempfile.mkstemp(suffix=original_ext, prefix="audio_temp_") os.close(temp_fd) # Close the file descriptor # Copy the original audio to temp location shutil.copy2(audio_path, temp_path) print(f"πŸ“ Audio temporarily stored at: {temp_path}") return temp_path except Exception as e: print(f"❌ Error creating temporary audio file: {str(e)}") return audio_path # Fall back to original path def cleanup_temp_file(temp_path: str): """ Clean up temporary audio file after processing. """ try: if temp_path and os.path.exists(temp_path) and "temp" in temp_path: os.unlink(temp_path) print(f"πŸ—‘οΈ Cleaned up temporary file: {temp_path}") except Exception as e: print(f"⚠️ Warning: Could not clean up temp file: {str(e)}") # --- ENHANCED TRANSCRIPTION FUNCTIONS --- def transcribe_with_whisper(audio_path): """Transcribe English audio using Whisper with comprehensive preprocessing.""" try: # Preprocess audio: convert to mono and resample to 16kHz waveform, sr = preprocess_audio(audio_path, target_sr=16000) # Prepare input features input_features = whisper_processor( waveform.squeeze(), sampling_rate=sr, return_tensors="pt" ).input_features.to(device) # Generate tokens with model predicted_ids = whisper_model.generate(input_features) # Decode tokens to text transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] return transcription.strip() except Exception as e: return f"Error during Whisper transcription: {str(e)}" def recommend_questions_from_text(text: str) -> Dict: """Recommend 5 questions from text using Google Gemini API.""" if not text or len(text.strip()) < 20: return { "success": False, "error": "Text too short to generate meaningful questions.", "questions": ["Text too short to generate meaningful questions."] } try: # Create a prompt for Gemini to recommend 5 questions prompt = f"""Based on the following text, recommend exactly 5 insightful and relevant questions that someone might ask about this content. Focus on comprehension, analysis, and deeper understanding. Text: {text} Please provide exactly 5 questions, formatted as: 1. [Question 1] 2. [Question 2] 3. [Question 3] 4. [Question 4] 5. [Question 5] Make the questions thoughtful and educational.""" # Initialize Gemini model - KEEPING YOUR ORIGINAL MODEL model = genai.GenerativeModel('gemini-1.5-flash') # Generate content with Gemini response = model.generate_content( prompt, generation_config=genai.types.GenerationConfig( max_output_tokens=400, temperature=0.7, ) ) # Extract the response text questions_text = response.text.strip() # Parse the numbered questions questions = [] for line in questions_text.split('\n'): line = line.strip() if line and (line[0].isdigit() or line.startswith('-')): # Remove numbering (1., 2., etc.) if '.' in line: question = line[line.find('.')+1:].strip() else: question = line[1:].strip() # Clean up brackets if present question = question.strip('[]') if question: questions.append(question) # Ensure we have exactly 5 questions if len(questions) < 5: # Add generic questions if needed while len(questions) < 5: questions.append(f"What can you learn from this text? (Question {len(questions) + 1})") questions = questions[:5] # Limit to 5 return { "success": True, "questions": questions, "total_questions": len(questions), "source": "Google Gemini API" } except Exception as e: # Fallback to simple questions if API fails fallback_questions = [ "What is the main topic discussed in this text?", "What are the key points mentioned?", "How does this information relate to the broader context?", "What questions would you ask about this content?", "What can you learn from this text?" ] return { "success": False, "error": f"Gemini API call failed: {str(e)}", "questions": fallback_questions, "source": "Fallback" } def translate_indic_to_english(text: str, source_lang: str = "hin_Deva") -> Dict: """Translate Indic language text to English.""" try: url = f"{API_BASE_URL}/translate/indic-to-en" payload = { "text": text, "source_lang": source_lang } response = requests.post(url, json=payload, headers=HEADERS, timeout=30) response.raise_for_status() result = response.json() return { "success": True, "translated_text": result.get("translated_text", ""), "source_lang": source_lang, "target_lang": "eng_Latn" } except requests.exceptions.RequestException as e: return { "success": False, "error": f"API request failed: {str(e)}", "translated_text": "" } except Exception as e: return { "success": False, "error": f"Translation error: {str(e)}", "translated_text": "" } def process_text_input(text: str, language: str) -> tuple: """Process direct text input for translation and question recommendation.""" if not text or not text.strip(): return "Please provide text input.", "", "", "" # Get language code lang_code = LANGUAGE_OPTIONS.get(language, "hin_Deva") try: # Check if input is English if language == "English" or lang_code == "eng_Latn": # For English input, skip translation and directly recommend questions translation_result = text.strip() # Recommend questions using Gemini API question_response = recommend_questions_from_text(translation_result) if question_response["success"]: questions_list = question_response["questions"] questions_result = "\n".join([f"Q{i+1}: {q}" for i, q in enumerate(questions_list)]) questions_result += f"\n\nβœ… Questions recommended via {question_response.get('source', 'API')}" else: questions_result = f"❌ Question recommendation failed: {question_response.get('error', 'Unknown error')}" else: # For Indic languages, translate to English first translation_response = translate_indic_to_english(text.strip(), lang_code) if translation_response["success"]: translation_result = translation_response["translated_text"] # Recommend questions using Gemini API question_response = recommend_questions_from_text(translation_result) if question_response["success"]: questions_list = question_response["questions"] questions_result = "\n".join([f"Q{i+1}: {q}" for i, q in enumerate(questions_list)]) questions_result += f"\n\nβœ… Questions recommended via {question_response.get('source', 'API')}" else: questions_result = f"❌ Question recommendation failed: {question_response.get('error', 'Unknown error')}" else: translation_result = f"❌ Translation failed: {translation_response['error']}" questions_result = "Cannot recommend questions without valid translation." except Exception as e: return f"Error processing text: {str(e)}", "", "", "" return ( f"Text Input Processed (Language: {language})", text.strip(), translation_result, questions_result ) @spaces.GPU def transcribe_audio_with_lid(audio_path): """Main function to transcribe audio with language detection, translation, and question recommendation.""" if not audio_path: return "Please provide an audio file.", "", "", "" try: # ENHANCED: Use comprehensive audio preprocessing with format support waveform_16k, sr = preprocess_audio(audio_path, target_sr=16000) except Exception as e: return f"Error loading/preprocessing audio: {e}", "", "", "" try: # Language detection using preprocessed audio inputs = lid_processor(waveform_16k.squeeze(), sampling_rate=16000, return_tensors="pt").to(device) with torch.no_grad(): outputs = lid_model(**inputs) logits = outputs[0] predicted_lid_id = logits.argmax(-1).item() detected_lid_code = lid_model.config.id2label[predicted_lid_id] asr_lang_code = LID_TO_ASR_LANG_MAP.get(detected_lid_code) if not asr_lang_code: detected_lang_str = f"Detected '{detected_lid_code}', which is not supported by the ASR model." return detected_lang_str, "N/A", "N/A", "N/A" detected_lang_name = ASR_CODE_TO_NAME.get(asr_lang_code, 'Unknown') detected_lang_str = f"Detected Language: {detected_lang_name} ({detected_lid_code})" # Use Whisper Transformers for English, IndicConformer for others if asr_lang_code == "en": # Use Whisper Transformers for English audio with enhanced preprocessing transcription_rnnt = transcribe_with_whisper(audio_path) else: # Use IndicConformer for Indic languages - RNNT ONLY with torch.no_grad(): transcription_rnnt = asr_model(waveform_16k.to(device), asr_lang_code, "rnnt") # Translation to English translation_result = "" translation_error = "" if transcription_rnnt.strip() and asr_lang_code != "en": # Get IndicTrans2 language code indictrans_lang_code = ASR_TO_INDICTRANS_MAP.get(asr_lang_code) if indictrans_lang_code: # Translate to English translation_response = translate_indic_to_english( transcription_rnnt.strip(), indictrans_lang_code ) if translation_response["success"]: translation_result = translation_response["translated_text"] else: translation_error = translation_response["error"] translation_result = "Translation failed" else: translation_result = "Translation not supported for this language" elif asr_lang_code == "en": translation_result = transcription_rnnt.strip() # Use original text else: translation_result = "No text to translate" # Recommend questions using Gemini API questions_result = "" if translation_result and not translation_error and translation_result not in ["No text to translate", "Translation failed"]: question_response = recommend_questions_from_text(translation_result) if question_response["success"]: questions_list = question_response["questions"] questions_result = "\n".join([f"Q{i+1}: {q}" for i, q in enumerate(questions_list)]) questions_result += f"\n\nβœ… Questions recommended via {question_response.get('source', 'API')}" else: questions_result = f"❌ Question recommendation failed: {question_response.get('error', 'Unknown error')}" else: questions_result = "Cannot recommend questions without valid translation." # Combine results if translation_error: translation_display = f"❌ {translation_result}\nError: {translation_error}" else: translation_display = translation_result except Exception as e: return f"Error during processing: {str(e)}", "", "", "" return ( detected_lang_str, transcription_rnnt.strip(), translation_display, questions_result ) # --- ENHANCED AUDIO PROCESSING WITH TEMPORARY STORAGE --- @spaces.GPU def process_audio_with_temp_storage(audio_path): """ Process audio with temporary storage for better handling of recorded audio. Supports multiple formats: WAV, MP3, FLAC, OPUS, OGG, M4A, AAC, etc. """ if not audio_path: return "Please provide an audio file.", "", "", "" # Create temporary copy of audio file temp_audio_path = create_temp_audio_file(audio_path) try: # Get format info for logging audio_format = detect_audio_format(temp_audio_path) print(f"🎡 Processing {audio_format} file: {os.path.basename(temp_audio_path)}") # Process the temporarily stored audio with enhanced preprocessing result = transcribe_audio_with_lid(temp_audio_path) print("βœ… Audio processing completed successfully") return result except Exception as e: error_msg = f"❌ Error during audio processing: {str(e)}" print(error_msg) return error_msg, "", "", "" finally: # Clean up temporary file cleanup_temp_file(temp_audio_path) # --- GRADIO UI WITH ENHANCED FORMAT SUPPORT --- with gr.Blocks(theme=gr.themes.Soft(), title="Enhanced Multilingual ASR + Translation + Question Recommendations") as demo: gr.Markdown(f"## {DESCRIPTION}") gr.Markdown(""" 🎡 **Multi-Format Audio Support**: Upload audio in WAV, MP3, FLAC, OPUS, OGG, M4A, AAC, and more! πŸ”„ **Auto Stereoβ†’Mono Conversion**: Automatically converts multi-channel audio to mono for optimal ASR performance 🌐 **22+ Indian Languages**: Supports English + 22 Indian languages with automatic detection Upload/record audio OR input text in English or any supported Indian language """) with gr.Row(): with gr.Column(scale=1): # Input method selection input_method = gr.Radio( choices=["Audio Input", "Text Input"], value="Audio Input", label="Choose Input Method" ) # Enhanced audio input with format support audio = gr.Audio( label="Upload or Record Audio (Supports: WAV, MP3, FLAC, OPUS, OGG, M4A, AAC, etc.)", sources=["upload", "microphone"], type="filepath", visible=True, interactive=True ) # Text input (hidden by default) text_input = gr.Textbox( label="Enter Text in English or Indian Language", placeholder="Type your text here in English, Hindi, Bengali, Tamil, etc...", lines=4, visible=False ) # Language selection for text input (hidden by default) language_dropdown = gr.Dropdown( choices=list(LANGUAGE_OPTIONS.keys()), value="English", label="Select Language", visible=False ) process_btn = gr.Button( "🎯 Process & Get Question Recommendations", variant="primary", scale=2 ) with gr.Column(scale=2): # Detection/Processing Result detection_output = gr.Label( label="πŸ“Š Processing Result", show_label=True ) # Input/Transcription Results with gr.Tab("🎀 Input/Transcription"): gr.Markdown("### Original Text") input_output = gr.Textbox( lines=4, label="Input/Transcription Output", placeholder="Original text will appear here..." ) # Translation Results with gr.Tab("🌐 Translation"): translation_output = gr.Textbox( lines=4, label="English Translation", placeholder="English translation will appear here..." ) # Question Recommendations with gr.Tab("❓ Question Recommendations"): questions_output = gr.Textbox( lines=8, label="Recommended Questions", placeholder="Gemini AI-recommended questions will appear here..." ) # Enhanced Audio Processing Information gr.Markdown(""" ### 🎡 Supported Audio Formats: **Primary Formats**: WAV, MP3, FLAC, OPUS, OGG **Additional Formats**: M4A, AAC, MP4, WMA, AMR, AIFF, AU, 3GP, WebM ### πŸ”§ Audio Processing Features: - **Multi-Format Support**: Handles 14+ different audio formats - **Smart Backend Selection**: Automatically chooses optimal audio backend (FFmpeg/SoX/SoundFile) - **Robust Error Handling**: Fallback mechanisms for maximum compatibility - **Automatic Stereoβ†’Mono**: Converts multi-channel to mono for ASR optimization - **Intelligent Resampling**: Resamples to 16kHz as required by ASR models - **Format Validation**: Validates input files before processing """) # Toggle input visibility based on method selection def toggle_inputs(method): if method == "Audio Input": return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) else: # Text Input return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) input_method.change( fn=toggle_inputs, inputs=[input_method], outputs=[audio, text_input, language_dropdown] ) # Enhanced processing function def process_input(method, audio_file, text, language): if method == "Audio Input": if audio_file: return process_audio_with_temp_storage(audio_file) else: return "Please upload or record an audio file.", "", "", "" else: # Text Input if text: return process_text_input(text, language) else: return "Please enter some text.", "", "", "" # Event handlers process_btn.click( fn=process_input, inputs=[input_method, audio, text_input, language_dropdown], outputs=[ detection_output, input_output, translation_output, questions_output ], api_name="process" ) if __name__ == "__main__": demo.queue(max_size=10).launch( server_name="0.0.0.0", server_port=7860, share=True )