Spaces:
Paused
Paused
| import csv | |
| import datetime | |
| import os | |
| import re | |
| import time | |
| import uuid | |
| from io import StringIO | |
| import gradio as gr | |
| import nltk | |
| import numpy as np | |
| import pyrubberband | |
| import spaces | |
| import torch | |
| import torchaudio | |
| from huggingface_hub import HfApi, hf_hub_download, snapshot_download | |
| from nltk.sentiment import SentimentIntensityAnalyzer | |
| from TTS.tts.configs.xtts_config import XttsConfig | |
| from TTS.tts.models.xtts import Xtts | |
| from vinorm import TTSnorm | |
| nltk.download('vader_lexicon') | |
| os.system("python -m unidic download") | |
| os.system('nvidia-smi') | |
| HF_TOKEN = None | |
| api = HfApi(token=HF_TOKEN) | |
| checkpoint_dir = "model/" | |
| repo_id = "capleaf/viXTTS" | |
| use_deepspeed = False | |
| os.makedirs(checkpoint_dir, exist_ok=True) | |
| required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"] | |
| files_in_dir = os.listdir(checkpoint_dir) | |
| if not all(file in files_in_dir for file in required_files): | |
| snapshot_download( | |
| repo_id=repo_id, | |
| repo_type="model", | |
| local_dir=checkpoint_dir, | |
| ) | |
| hf_hub_download( | |
| repo_id="coqui/XTTS-v2", | |
| filename="speakers_xtts.pth", | |
| local_dir=checkpoint_dir, | |
| ) | |
| xtts_config = os.path.join(checkpoint_dir, "config.json") | |
| config = XttsConfig() | |
| config.load_json(xtts_config) | |
| MODEL = Xtts.init_from_config(config) | |
| MODEL.load_checkpoint( | |
| config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed | |
| ) | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| else: | |
| device = torch.device("cpu") | |
| MODEL.to(device) | |
| supported_languages = config.languages | |
| if not "vi" in supported_languages: | |
| supported_languages.append("vi") | |
| if not "es-AR" in supported_languages: | |
| supported_languages.append("es-AR") | |
| def normalize_vietnamese_text(text): | |
| text = ( | |
| TTSnorm(text, unknown=False, lower=False, rule=True) | |
| .replace("..", ".") | |
| .replace("!.", "!") | |
| .replace("?.", "?") | |
| .replace(" .", ".") | |
| .replace(" ,", ",") | |
| .replace('"', "") | |
| .replace("'", "") | |
| .replace("AI", "Ây Ai") | |
| .replace("A.I", "Ây Ai") | |
| ) | |
| return text | |
| def analyze_sentiment(text): | |
| sia = SentimentIntensityAnalyzer() | |
| scores = sia.polarity_scores(text) | |
| return scores['compound'] | |
| def change_pitch(audio_data, sampling_rate, sentiment): | |
| semitones = sentiment * 2 | |
| return pyrubberband.pitch_shift(audio_data, sampling_rate, semitones) | |
| def apply_distortion(audio_data, sentiment): | |
| distortion_factor = abs(sentiment) * 0.5 | |
| return audio_data * (1 + distortion_factor * np.random.randn(len(audio_data))) | |
| def predict( | |
| prompt, | |
| language, | |
| audio_file_pth, | |
| normalize_text=True, | |
| ): | |
| if language not in supported_languages: | |
| metrics_text = gr.Warning( | |
| f"El idioma seleccionado ({language}) no está disponible. Por favor, elige uno de la lista." | |
| ) | |
| return (None, metrics_text) | |
| speaker_wav = audio_file_pth | |
| if len(prompt) < 2: | |
| metrics_text = gr.Warning("Por favor, introduce un texto más largo.") | |
| return (None, metrics_text) | |
| try: | |
| metrics_text = "" | |
| t_latent = time.time() | |
| try: | |
| ( | |
| gpt_cond_latent, | |
| speaker_embedding, | |
| ) = MODEL.get_conditioning_latents( | |
| audio_path=speaker_wav, | |
| gpt_cond_len=30, | |
| gpt_cond_chunk_len=4, | |
| max_ref_length=60, | |
| ) | |
| except Exception as e: | |
| print("Speaker encoding error", str(e)) | |
| metrics_text = gr.Warning( | |
| "¿Has activado el micrófono? Parece que hay un problema con la referencia de audio." | |
| ) | |
| return (None, metrics_text) | |
| prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt) | |
| if normalize_text and language == "vi": | |
| prompt = normalize_vietnamese_text(prompt) | |
| sentiment = analyze_sentiment(prompt) | |
| temperature = 0.75 + sentiment * 0.2 | |
| temperature = max(0.5, min(temperature, 1.0)) | |
| t0 = time.time() | |
| out = MODEL.inference( | |
| prompt, | |
| language, | |
| gpt_cond_latent, | |
| speaker_embedding, | |
| repetition_penalty=5.0, | |
| temperature=temperature, | |
| enable_text_splitting=True, | |
| ) | |
| inference_time = time.time() - t0 | |
| metrics_text += ( | |
| f"Tiempo de generación de audio: {round(inference_time*1000)} milisegundos\n" | |
| ) | |
| real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000 | |
| metrics_text += f"Factor de tiempo real (RTF): {real_time_factor:.2f}\n" | |
| audio_data = np.array(out["wav"]) | |
| modified_audio = change_pitch(audio_data, 24000, sentiment) | |
| modified_audio = apply_distortion(modified_audio, sentiment) | |
| torchaudio.save("output.wav", torch.tensor(modified_audio).unsqueeze(0), 24000) | |
| except RuntimeError as e: | |
| if "device-side assert" in str(e): | |
| error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S") | |
| error_data = [ | |
| error_time, | |
| prompt, | |
| language, | |
| audio_file_pth, | |
| ] | |
| error_data = [str(e) if type(e) != str else e for e in error_data] | |
| write_io = StringIO() | |
| csv.writer(write_io).writerows([error_data]) | |
| csv_upload = write_io.getvalue().encode() | |
| filename = error_time + "_" + str(uuid.uuid4()) + ".csv" | |
| error_api = HfApi() | |
| error_api.upload_file( | |
| path_or_fileobj=csv_upload, | |
| path_in_repo=filename, | |
| repo_id="coqui/xtts-flagged-dataset", | |
| repo_type="dataset", | |
| ) | |
| speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav" | |
| error_api = HfApi() | |
| error_api.upload_file( | |
| path_or_fileobj=speaker_wav, | |
| path_in_repo=speaker_filename, | |
| repo_id="coqui/xtts-flagged-dataset", | |
| repo_type="dataset", | |
| ) | |
| space = api.get_space_runtime(repo_id=repo_id) | |
| if space.stage != "BUILDING": | |
| api.restart_space(repo_id=repo_id) | |
| else: | |
| if "Failed to decode" in str(e): | |
| metrics_text = gr.Warning( | |
| metrics_text="Parece que hay un problema con la referencia de audio. ¿Has activado el micrófono?" | |
| ) | |
| else: | |
| metrics_text = gr.Warning( | |
| "Se ha producido un error inesperado. Por favor, inténtalo de nuevo." | |
| ) | |
| return (None, metrics_text) | |
| return ("output.wav", metrics_text) | |
| with gr.Blocks(analytics_enabled=False) as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| # viXTTS Demo ✨ | |
| """ | |
| ) | |
| with gr.Column(): | |
| pass | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text_gr = gr.Textbox( | |
| label="Texto a convertir a voz", | |
| value="Hola, soy un modelo de texto a voz.", | |
| ) | |
| language_gr = gr.Dropdown( | |
| label="Idioma", | |
| choices=[ | |
| "es-AR", | |
| "vi", | |
| "en", | |
| "es", | |
| "fr", | |
| "de", | |
| "it", | |
| "pt", | |
| "pl", | |
| "tr", | |
| "ru", | |
| "nl", | |
| "cs", | |
| "ar", | |
| "zh-cn", | |
| "ja", | |
| "ko", | |
| "hu", | |
| "hi", | |
| ], | |
| max_choices=1, | |
| value="es-AR", | |
| ) | |
| normalize_text = gr.Checkbox( | |
| label="Normalizar texto en vietnamita", | |
| info="Solo aplicable al idioma vietnamita", | |
| value=True, | |
| ) | |
| ref_gr = gr.Audio( | |
| label="Audio de referencia (opcional)", | |
| type="filepath", | |
| value="model/samples/nu-luu-loat.wav", | |
| ) | |
| tts_button = gr.Button( | |
| "Generar voz 🗣️🔥", | |
| elem_id="send-btn", | |
| visible=True, | |
| variant="primary", | |
| ) | |
| with gr.Column(): | |
| audio_gr = gr.Audio(label="Audio generado", autoplay=True) | |
| out_text_gr = gr.Text(label="Métricas") | |
| tts_button.click( | |
| predict, | |
| [ | |
| input_text_gr, | |
| language_gr, | |
| ref_gr, | |
| normalize_text, | |
| ], | |
| outputs=[audio_gr, out_text_gr], | |
| api_name="predict", | |
| ) | |
| demo.queue() | |
| demo.launch(debug=True, show_api=True, share=True) |