Spaces:
Runtime error
Runtime error
ASG Models
commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,6 +7,29 @@ import requests
|
|
| 7 |
from genai_chat_ai import AI,create_chat_session
|
| 8 |
api_key = os.environ.get("Id_mode_vits")
|
| 9 |
headers = {"Authorization": f"Bearer {api_key}"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
def remove_extra_spaces(text):
|
| 12 |
"""
|
|
@@ -69,9 +92,9 @@ with gr.Blocks() as demo: # Use gr.Blocks to wrap the entire interface
|
|
| 69 |
API_URL = f"https://api-inference.huggingface.co/models/{model_choice}"
|
| 70 |
text_answer = get_answer_ai(text)
|
| 71 |
text_answer = remove_extra_spaces(text_answer)
|
| 72 |
-
data_ai = query(text_answer, API_URL)
|
| 73 |
if generate_user_audio: # Generate user audio if needed
|
| 74 |
-
data_user = query(text, API_URL)
|
| 75 |
return data_user, data_ai, text_answer
|
| 76 |
else:
|
| 77 |
return data_ai # Return None for user_audio
|
|
|
|
| 7 |
from genai_chat_ai import AI,create_chat_session
|
| 8 |
api_key = os.environ.get("Id_mode_vits")
|
| 9 |
headers = {"Authorization": f"Bearer {api_key}"}
|
| 10 |
+
from transformers import pipeline
|
| 11 |
+
from transformers import AutoTokenizer,VitsModel
|
| 12 |
+
import torch
|
| 13 |
+
models= {}
|
| 14 |
+
tokenizer = AutoTokenizer.from_pretrained("asg2024/vits-ar-sa-huba",token=api_key)
|
| 15 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
+
def get_model(name_model):
|
| 17 |
+
global models
|
| 18 |
+
if name_model in not models:
|
| 19 |
+
models[name_model]=VitsModel.from_pretrained(name_model,token=api_key).to(device)
|
| 20 |
+
return models[name_model]
|
| 21 |
+
|
| 22 |
+
def genrate_speech(text,name_model):
|
| 23 |
+
inputs=tokenizer(text,return_tensors="pt")
|
| 24 |
+
model=get_model(name_model)
|
| 25 |
+
with torch.no_grad():
|
| 26 |
+
wav=model(
|
| 27 |
+
input_ids= input_ids.input_ids.to(device),
|
| 28 |
+
attention_mask=input_ids.attention_mask.to(device),
|
| 29 |
+
speaker_id=0
|
| 30 |
+
).waveform.cpu().numpy().reshape(-1)
|
| 31 |
+
return model.config.sampling_rate,wav
|
| 32 |
+
|
| 33 |
|
| 34 |
def remove_extra_spaces(text):
|
| 35 |
"""
|
|
|
|
| 92 |
API_URL = f"https://api-inference.huggingface.co/models/{model_choice}"
|
| 93 |
text_answer = get_answer_ai(text)
|
| 94 |
text_answer = remove_extra_spaces(text_answer)
|
| 95 |
+
data_ai = genrate_speech(text_answer,model_choice)#query(text_answer, API_URL)
|
| 96 |
if generate_user_audio: # Generate user audio if needed
|
| 97 |
+
data_user =genrate_speech(text_answer,model_choice)# query(text, API_URL)
|
| 98 |
return data_user, data_ai, text_answer
|
| 99 |
else:
|
| 100 |
return data_ai # Return None for user_audio
|