Phonepadith commited on
Commit
60a3a10
·
verified ·
1 Parent(s): 020ac09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -43
app.py CHANGED
@@ -3,6 +3,7 @@ import gradio as gr
3
  import torch
4
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
5
  import numpy as np
 
6
 
7
  # Load model and processor
8
  model_id = "Phonepadith/whisper-3-large-lao-finetuned-v1"
@@ -13,67 +14,77 @@ model = WhisperForConditionalGeneration.from_pretrained(model_id)
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  model.to(device)
15
 
 
 
16
  def transcribe_audio(audio):
17
  """
18
  Transcribe audio to Lao text
19
  Args:
20
- audio: tuple (sample_rate, audio_array) from Gradio
21
  Returns:
22
  transcription: Lao text
23
  """
24
  if audio is None:
25
  return "Please upload or record audio."
26
 
27
- # Get sample rate and audio array
28
- sample_rate, audio_array = audio
29
-
30
- # Convert to float32 and normalize
31
- if audio_array.dtype != np.float32:
32
- # If integer type, normalize to [-1, 1]
33
- if np.issubdtype(audio_array.dtype, np.integer):
34
- max_val = np.iinfo(audio_array.dtype).max
35
- audio_array = audio_array.astype(np.float32) / max_val
36
  else:
37
- audio_array = audio_array.astype(np.float32)
38
-
39
- # Ensure audio is in [-1, 1] range
40
- if np.abs(audio_array).max() > 1.0:
41
- audio_array = audio_array / np.abs(audio_array).max()
42
-
43
- # Resample to 16kHz if needed
44
- if sample_rate != 16000:
45
- import librosa
46
- audio_array = librosa.resample(
47
- audio_array,
48
- orig_sr=sample_rate,
49
- target_sr=16000
50
- )
51
-
52
- # Process audio
53
- input_features = processor(
54
- audio_array,
55
- sampling_rate=16000,
56
- return_tensors="pt"
57
- ).input_features.to(device)
58
-
59
- # Generate transcription
60
- with torch.no_grad():
61
- predicted_ids = model.generate(input_features)
62
-
63
- # Decode transcription
64
- transcription = processor.batch_decode(
65
- predicted_ids,
66
- skip_special_tokens=True
67
- )[0]
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- return transcription
 
70
 
71
  # Create Gradio interface
72
  demo = gr.Interface(
73
  fn=transcribe_audio,
74
  inputs=gr.Audio(
75
  sources=["microphone", "upload"],
76
- type="numpy",
77
  label="Record or Upload Lao Audio"
78
  ),
79
  outputs=gr.Textbox(
 
3
  import torch
4
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
5
  import numpy as np
6
+ import librosa
7
 
8
  # Load model and processor
9
  model_id = "Phonepadith/whisper-3-large-lao-finetuned-v1"
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  model.to(device)
16
 
17
+ print(f"Model loaded on: {device}")
18
+
19
  def transcribe_audio(audio):
20
  """
21
  Transcribe audio to Lao text
22
  Args:
23
+ audio: Audio file path (string) or tuple (sample_rate, audio_array) from Gradio
24
  Returns:
25
  transcription: Lao text
26
  """
27
  if audio is None:
28
  return "Please upload or record audio."
29
 
30
+ try:
31
+ # Handle both file paths and numpy arrays
32
+ if isinstance(audio, str):
33
+ # Audio is a file path - use librosa to load it
34
+ audio_array, sample_rate = librosa.load(audio, sr=16000, mono=True)
 
 
 
 
35
  else:
36
+ # Audio is a tuple (sample_rate, audio_array)
37
+ sample_rate, audio_array = audio
38
+
39
+ # Convert to float32 and normalize
40
+ if audio_array.dtype != np.float32:
41
+ # If integer type, normalize to [-1, 1]
42
+ if np.issubdtype(audio_array.dtype, np.integer):
43
+ max_val = np.iinfo(audio_array.dtype).max
44
+ audio_array = audio_array.astype(np.float32) / max_val
45
+ else:
46
+ audio_array = audio_array.astype(np.float32)
47
+
48
+ # Ensure audio is in [-1, 1] range
49
+ if np.abs(audio_array).max() > 1.0:
50
+ audio_array = audio_array / np.abs(audio_array).max()
51
+
52
+ # Resample to 16kHz if needed
53
+ if sample_rate != 16000:
54
+ audio_array = librosa.resample(
55
+ audio_array,
56
+ orig_sr=sample_rate,
57
+ target_sr=16000
58
+ )
59
+
60
+ # Process audio
61
+ input_features = processor(
62
+ audio_array,
63
+ sampling_rate=16000,
64
+ return_tensors="pt"
65
+ ).input_features.to(device)
66
+
67
+ # Generate transcription
68
+ with torch.no_grad():
69
+ predicted_ids = model.generate(input_features)
70
+
71
+ # Decode transcription
72
+ transcription = processor.batch_decode(
73
+ predicted_ids,
74
+ skip_special_tokens=True
75
+ )[0]
76
+
77
+ return transcription
78
 
79
+ except Exception as e:
80
+ return f"Error processing audio: {str(e)}"
81
 
82
  # Create Gradio interface
83
  demo = gr.Interface(
84
  fn=transcribe_audio,
85
  inputs=gr.Audio(
86
  sources=["microphone", "upload"],
87
+ type="filepath", # Changed to filepath to handle various formats
88
  label="Record or Upload Lao Audio"
89
  ),
90
  outputs=gr.Textbox(