HusseinBashir commited on
Commit
055ad76
·
verified ·
1 Parent(s): 1b4a453

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -18
app.py CHANGED
@@ -2,32 +2,29 @@ import torch
2
  from transformers import VitsModel, AutoTokenizer
3
  import gradio as gr
4
 
5
- # Load the fine-tuned model and tokenizer
6
  model = VitsModel.from_pretrained("HusseinBashir/codad_tijaabo")
7
  tokenizer = AutoTokenizer.from_pretrained("HusseinBashir/codad_tijaabo")
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
- model.to(device)
10
 
11
- # Gradio TTS function
12
- # Gradio TTS function
13
  def tts(text):
14
  inputs = tokenizer(text, return_tensors="pt").to(device)
15
  with torch.no_grad():
16
- output = model(**inputs).waveform.squeeze(1).cpu().numpy()
 
17
 
18
- # Ensure the output is a 1D numpy array and normalized
19
- if output.ndim > 1:
20
- output = output.flatten()
21
- output = output / max(abs(output)) # Normalize to [-1, 1]
22
- return (22050, output) # Return a tuple (sample_rate, waveform)
23
 
 
24
 
25
- # Gradio interface for the TTS model
26
- iface = gr.Interface(
27
  fn=tts,
28
- inputs=gr.Textbox(label="Enter text"),
29
- outputs=gr.Audio(label="Generated Speech"),
30
- title="Fine-tuned VITS TTS",
31
- description="Generate speech from text using the fine-tuned VITS model."
32
- )
33
- iface.launch()
 
2
  from transformers import VitsModel, AutoTokenizer
3
  import gradio as gr
4
 
5
+ # Load fine-tuned model and tokenizer
6
  model = VitsModel.from_pretrained("HusseinBashir/codad_tijaabo")
7
  tokenizer = AutoTokenizer.from_pretrained("HusseinBashir/codad_tijaabo")
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ model.to(device).eval()
10
 
 
 
11
  def tts(text):
12
  inputs = tokenizer(text, return_tensors="pt").to(device)
13
  with torch.no_grad():
14
+ output = model(**inputs)
15
+ waveform = output["waveform"].squeeze(1).cpu().numpy()
16
 
17
+ # Normalize waveform
18
+ waveform = waveform.flatten()
19
+ waveform = waveform / max(abs(waveform))
 
 
20
 
21
+ return (22050, waveform) # 22.05 kHz sample rate typical for VITS
22
 
23
+ # Gradio interface
24
+ gr.Interface(
25
  fn=tts,
26
+ inputs=gr.Textbox(label="Geli qoraal Soomaali ah"),
27
+ outputs=gr.Audio(label="Codka la sameeyey"),
28
+ title="Codad Tijaabo TTS",
29
+ description="Ku qor qoraal Soomaali ah si aad cod u maqasho iyadoo la adeegsanayo VITS."
30
+ ).launch()