import csv import datetime import os import re import time import uuid from io import StringIO import gradio as gr import nltk import numpy as np import pyrubberband import spaces import torch import torchaudio from huggingface_hub import HfApi, hf_hub_download, snapshot_download from nltk.sentiment import SentimentIntensityAnalyzer from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts from vinorm import TTSnorm nltk.download('vader_lexicon') os.system("python -m unidic download") HF_TOKEN = None api = HfApi(token=HF_TOKEN) checkpoint_dir = "model/" repo_id = "capleaf/viXTTS" use_deepspeed = False os.makedirs(checkpoint_dir, exist_ok=True) required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"] files_in_dir = os.listdir(checkpoint_dir) if not all(file in files_in_dir for file in required_files): snapshot_download( repo_id=repo_id, repo_type="model", local_dir=checkpoint_dir, ) hf_hub_download( repo_id="coqui/XTTS-v2", filename="speakers_xtts.pth", local_dir=checkpoint_dir, ) xtts_config = os.path.join(checkpoint_dir, "config.json") config = XttsConfig() config.load_json(xtts_config) MODEL = Xtts.init_from_config(config) MODEL.load_checkpoint( config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed ) if torch.cuda.is_available(): MODEL.cuda() supported_languages = config.languages if not "vi" in supported_languages: supported_languages.append("vi") if not "es-AR" in supported_languages: supported_languages.append("es-AR") def normalize_vietnamese_text(text): text = ( TTSnorm(text, unknown=False, lower=False, rule=True) .replace("..", ".") .replace("!.", "!") .replace("?.", "?") .replace(" .", ".") .replace(" ,", ",") .replace('"', "") .replace("'", "") .replace("AI", "Ây Ai") .replace("A.I", "Ây Ai") ) return text def calculate_keep_len(text, lang): if lang in ["ja", "zh-cn"]: return -1 word_count = len(text.split()) num_punct = text.count(".") + text.count("!") + text.count("?") + text.count(",") if word_count < 5: return 15000 * word_count + 2000 * num_punct elif word_count < 10: return 13000 * word_count + 2000 * num_punct return -1 def analyze_sentiment(text): sia = SentimentIntensityAnalyzer() scores = sia.polarity_scores(text) return scores['compound'] def change_pitch(audio_data, sampling_rate, sentiment): semitones = sentiment * 2 return pyrubberband.pitch_shift(audio_data, sampling_rate, semitones) @spaces.GPU(duration=0) def predict( prompt, language, audio_file_pth, normalize_text=True, ): if language not in supported_languages: metrics_text = gr.Warning( f"El idioma seleccionado ({language}) no está disponible. Por favor, elige uno de la lista." ) return (None, metrics_text) speaker_wav = audio_file_pth if len(prompt) < 2: metrics_text = gr.Warning("Por favor, introduce un texto más largo.") return (None, metrics_text) if len(prompt) > 250000000: metrics_text = gr.Warning( f"El texto tiene {len(prompt)} caracteres. Es demasiado largo, por favor, mantenlo por debajo de 250000000 caracteres." ) return (None, metrics_text) try: metrics_text = "" t_latent = time.time() try: ( gpt_cond_latent, speaker_embedding, ) = MODEL.get_conditioning_latents( audio_path=speaker_wav, gpt_cond_len=30, gpt_cond_chunk_len=4, max_ref_length=60, ) except Exception as e: print("Speaker encoding error", str(e)) metrics_text = gr.Warning( "¿Has activado el micrófono? Parece que hay un problema con la referencia de audio." ) return (None, metrics_text) prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt) if normalize_text and language == "vi": prompt = normalize_vietnamese_text(prompt) sentiment = analyze_sentiment(prompt) temperature = 0.75 + sentiment * 0.2 temperature = max(0.5, min(temperature, 1.0)) t0 = time.time() out = MODEL.inference( prompt, language, gpt_cond_latent, speaker_embedding, repetition_penalty=5.0, temperature=temperature, enable_text_splitting=True, ) inference_time = time.time() - t0 metrics_text += ( f"Tiempo de generación de audio: {round(inference_time*1000)} milisegundos\n" ) real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000 metrics_text += f"Factor de tiempo real (RTF): {real_time_factor:.2f}\n" keep_len = calculate_keep_len(prompt, language) out["wav"] = out["wav"][:keep_len] audio_data = np.array(out["wav"]) modified_audio = change_pitch(audio_data, 24000, sentiment) torchaudio.save("output.wav", torch.tensor(modified_audio).unsqueeze(0), 24000) except RuntimeError as e: if "device-side assert" in str(e): error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S") error_data = [ error_time, prompt, language, audio_file_pth, ] error_data = [str(e) if type(e) != str else e for e in error_data] write_io = StringIO() csv.writer(write_io).writerows([error_data]) csv_upload = write_io.getvalue().encode() filename = error_time + "_" + str(uuid.uuid4()) + ".csv" error_api = HfApi() error_api.upload_file( path_or_fileobj=csv_upload, path_in_repo=filename, repo_id="coqui/xtts-flagged-dataset", repo_type="dataset", ) speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav" error_api = HfApi() error_api.upload_file( path_or_fileobj=speaker_wav, path_in_repo=speaker_filename, repo_id="coqui/xtts-flagged-dataset", repo_type="dataset", ) space = api.get_space_runtime(repo_id=repo_id) if space.stage != "BUILDING": api.restart_space(repo_id=repo_id) else: if "Failed to decode" in str(e): metrics_text = gr.Warning( metrics_text="Parece que hay un problema con la referencia de audio. ¿Has activado el micrófono?" ) else: metrics_text = gr.Warning( "Se ha producido un error inesperado. Por favor, inténtalo de nuevo." ) return (None, metrics_text) return ("output.wav", metrics_text) with gr.Blocks(analytics_enabled=False) as demo: with gr.Row(): with gr.Column(): gr.Markdown( """ # viXTTS Demo ✨ """ ) with gr.Column(): pass with gr.Row(): with gr.Column(): input_text_gr = gr.Textbox( label="Texto a convertir a voz", info="Cada frase debe tener al menos 10 palabras. Máximo 250 caracteres (alrededor de 2-3 frases).", value="Hola, soy un modelo de texto a voz.", ) language_gr = gr.Dropdown( label="Idioma", choices=[ "es-AR", "vi", "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja", "ko", "hu", "hi", ], max_choices=1, value="es-AR", ) normalize_text = gr.Checkbox( label="Normalizar texto en vietnamita", info="Solo aplicable al idioma vietnamita", value=True, ) ref_gr = gr.Audio( label="Audio de referencia (opcional)", type="filepath", value="model/samples/nu-luu-loat.wav", ) tts_button = gr.Button( "Generar voz 🗣️🔥", elem_id="send-btn", visible=True, variant="primary", ) with gr.Column(): audio_gr = gr.Audio(label="Audio generado", autoplay=True) out_text_gr = gr.Text(label="Métricas") tts_button.click( predict, [ input_text_gr, language_gr, ref_gr, normalize_text, ], outputs=[audio_gr, out_text_gr], api_name="predict", ) demo.queue() demo.launch(debug=True, show_api=True, share=True)