HusseinBashir commited on
Commit
b79ebd5
·
verified ·
1 Parent(s): d898fdc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -26
app.py CHANGED
@@ -1,36 +1,33 @@
1
- import gradio as gr
2
  import torch
3
- from TTS.models.vits import VitsModel
4
- from transformers import AutoTokenizer
5
- import torchaudio
6
 
7
- # Load the model and tokenizer from Hugging Face
8
  model = VitsModel.from_pretrained("HusseinBashir/codad_tijaabo")
9
  tokenizer = AutoTokenizer.from_pretrained("HusseinBashir/codad_tijaabo")
10
-
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
- model = model.to(device).eval()
13
-
14
- def tts_infer(text):
15
- inputs = tokenizer(text, return_tensors="pt")
16
- input_ids = inputs.input_ids.to(device)
17
 
 
 
 
 
18
  with torch.no_grad():
19
- output = model(input_ids)
20
- waveform = output["waveform"]
21
 
22
- # Save or return audio
23
- sample_rate = 22050 # VITS typically uses 22.05kHz
24
- torchaudio.save("output.wav", waveform.cpu(), sample_rate)
25
- return "output.wav"
 
26
 
27
- # Create Gradio UI
28
- interface = gr.Interface(
29
- fn=tts_infer,
30
- inputs=gr.Textbox(label="Geli qoraalka aad rabto in cod laga dhigo"),
31
- outputs=gr.Audio(label="Codka la sameeyey"),
32
- title="Codad Tijaabo TTS",
33
- description="Ku qor qoraal Soomaali ah si aad cod u maqasho.",
34
- )
35
 
36
- interface.launch()
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import VitsModel, AutoTokenizer
3
+ import gradio as gr
 
4
 
5
+ # Load the fine-tuned model and tokenizer
6
  model = VitsModel.from_pretrained("HusseinBashir/codad_tijaabo")
7
  tokenizer = AutoTokenizer.from_pretrained("HusseinBashir/codad_tijaabo")
 
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ model.to(device)
 
 
 
 
10
 
11
+ # Gradio TTS function
12
+ # Gradio TTS function
13
+ def tts(text):
14
+ inputs = tokenizer(text, return_tensors="pt").to(device)
15
  with torch.no_grad():
16
+ output = model(**inputs).waveform.squeeze(1).cpu().numpy()
 
17
 
18
+ # Ensure the output is a 1D numpy array and normalized
19
+ if output.ndim > 1:
20
+ output = output.flatten()
21
+ output = output / max(abs(output)) # Normalize to [-1, 1]
22
+ return (22050, output) # Return a tuple (sample_rate, waveform)
23
 
 
 
 
 
 
 
 
 
24
 
25
+ # Gradio interface for the TTS model
26
+ iface = gr.Interface(
27
+ fn=tts,
28
+ inputs=gr.Textbox(label="Enter text"),
29
+ outputs=gr.Audio(label="Generated Speech"),
30
+ title="Fine-tuned VITS TTS",
31
+ description="Generate speech from text using the fine-tuned VITS model."
32
+ )
33
+ iface.launch()