jellecali8 commited on
Commit
0adbc6d
Β·
verified Β·
1 Parent(s): e72ad8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -4
app.py CHANGED
@@ -2,32 +2,56 @@ import gradio as gr
2
  import torch
3
  import soundfile as sf
4
  import tempfile
 
5
  from transformers import AutoTokenizer, VitsModel
6
  import numpy as np
7
 
 
8
  repo_id = "jellecali8/Somali_tts_model"
9
 
 
10
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
11
  model = VitsModel.from_pretrained(repo_id)
12
  model.eval()
13
 
 
14
  try:
15
- custom_embedding_np = np.load("somali_speaker_embedding.npy")
16
  custom_embedding = torch.tensor(custom_embedding_np, dtype=torch.float32).unsqueeze(0)
17
  except Exception:
 
18
  custom_embedding = torch.randn(1, 256)
19
 
20
  def tts(text):
21
  try:
22
  if not text.strip():
23
- return None
24
- inputs = tokenizer(text, return_tensors="pt").to(next(model.parameters()).device)
25
- custom_embedding_ = custom_embedding.to(next(model.parameters()).device)
 
 
26
  with torch.no_grad():
27
  outputs = model(**inputs, speaker_embeddings=custom_embedding_)
28
  waveform = outputs.waveform.squeeze().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
29
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
30
  sf.write(tmp.name, waveform, 16000)
 
 
 
 
 
 
31
  return tmp.name
32
  except Exception as e:
33
  return f"Error: {str(e)}"
 
2
  import torch
3
  import soundfile as sf
4
  import tempfile
5
+ import os
6
  from transformers import AutoTokenizer, VitsModel
7
  import numpy as np
8
 
9
+ # Bedel username/repo-gaaga saxda ah
10
  repo_id = "jellecali8/Somali_tts_model"
11
 
12
+ # Load tokenizer iyo model
13
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
14
  model = VitsModel.from_pretrained(repo_id)
15
  model.eval()
16
 
17
+ # Load custom speaker embedding (.npy) file path
18
  try:
19
+ custom_embedding_np = np.load("somali_speaker_embedding.npy") # Ku dar faylka Space folder-ka
20
  custom_embedding = torch.tensor(custom_embedding_np, dtype=torch.float32).unsqueeze(0)
21
  except Exception:
22
+ # Haddii embedding file ma jiro, isticmaal random tensor (kaliya tijaabo)
23
  custom_embedding = torch.randn(1, 256)
24
 
25
  def tts(text):
26
  try:
27
  if not text.strip():
28
+ return None # Qoraal madhan ha soo gelin
29
+ # U gudbi inputs iyo embedding device-ka model-ka (CPU/GPU)
30
+ device = next(model.parameters()).device
31
+ inputs = tokenizer(text, return_tensors="pt").to(device)
32
+ custom_embedding_ = custom_embedding.to(device)
33
  with torch.no_grad():
34
  outputs = model(**inputs, speaker_embeddings=custom_embedding_)
35
  waveform = outputs.waveform.squeeze().cpu().numpy()
36
+
37
+ # Normalize waveform si uu u dhex maro -1.0 ilaa 1.0
38
+ max_val = max(abs(waveform.max()), abs(waveform.min()))
39
+ if max_val > 0:
40
+ waveform = waveform / max_val
41
+ else:
42
+ # Haddii waveform dhan yahay eber, soo celi error ama None
43
+ print("Warning: Waveform is all zeros")
44
+ return None
45
+
46
+ # Kaydi waveform file .wav ah
47
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
48
  sf.write(tmp.name, waveform, 16000)
49
+
50
+ # Hubi file size si loo xaqiijiyo inuu sax yahay
51
+ if os.path.getsize(tmp.name) == 0:
52
+ print("Warning: Generated WAV file is empty")
53
+ return None
54
+
55
  return tmp.name
56
  except Exception as e:
57
  return f"Error: {str(e)}"