from pathlib import Path from typing import Generator import librosa import numpy as np import torch from neucodec import NeuCodec, DistillNeuCodec from transformers import AutoTokenizer, AutoModelForCausalLM from phonemizer.backend.espeak.espeak import EspeakWrapper from phonemizer import phonemize import platform import re import os # Cấu hình espeak cho các môi trường khác nhau if platform.system() == "Windows": EspeakWrapper.set_library(r"C:\Program Files\eSpeak NG\libespeak-ng.dll") elif platform.system() == "Linux": # Thử các path phổ biến espeak_paths = [ "/usr/lib/x86_64-linux-gnu/libespeak-ng.so", "/usr/lib/libespeak-ng.so", "/usr/lib/x86_64-linux-gnu/libespeak-ng.so.1", ] for path in espeak_paths: if os.path.exists(path): EspeakWrapper.set_library(path) break else: raise ValueError(f"Please set the espeak library path for your platform.") def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray: # original impl --> https://github.com/facebookresearch/encodec/blob/main/encodec/utils.py assert len(frames) dtype = frames[0].dtype shape = frames[0].shape[:-1] total_size = 0 for i, frame in enumerate(frames): frame_end = stride * i + frame.shape[-1] total_size = max(total_size, frame_end) sum_weight = np.zeros(total_size, dtype=dtype) out = np.zeros(*shape, total_size, dtype=dtype) offset: int = 0 for frame in frames: frame_length = frame.shape[-1] t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1] weight = np.abs(0.5 - (t - 0.5)) out[..., offset : offset + frame_length] += weight * frame sum_weight[offset : offset + frame_length] += weight offset += stride assert sum_weight.min() > 0 return out / sum_weight class VieNeuTTS: def __init__( self, backbone_repo="pnnbao-ump/VieNeu-TTS", backbone_device="cpu", codec_repo="neuphonic/neucodec", codec_device="cpu", ): # Constants self.sample_rate = 24_000 self.max_context = 4096 self.hop_length = 480 self.streaming_overlap_frames = 1 self.streaming_frames_per_chunk = 25 self.streaming_lookforward = 5 self.streaming_lookback = 50 self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length # ggml & onnx flags self._is_quantized_model = False self._is_onnx_codec = False # HF tokenizer self.tokenizer = None # Load models self._load_backbone(backbone_repo, backbone_device) self._load_codec(codec_repo, codec_device) def _load_backbone(self, backbone_repo, backbone_device): print(f"Loading backbone from: {backbone_repo} on {backbone_device} ...") if backbone_repo.lower().endswith("gguf") or "gguf" in backbone_repo.lower(): try: from llama_cpp import Llama except ImportError as e: raise ImportError( "Failed to import `llama_cpp`. " "Please install it with:\n" " pip install llama-cpp-python" ) from e self.backbone = Llama.from_pretrained( repo_id=backbone_repo, filename="*.gguf", verbose=False, n_gpu_layers=-1 if backbone_device == "gpu" else 0, n_ctx=self.max_context, mlock=True, flash_attn=True if backbone_device == "gpu" else False, ) self._is_quantized_model = True else: self.tokenizer = AutoTokenizer.from_pretrained(backbone_repo) self.backbone = AutoModelForCausalLM.from_pretrained(backbone_repo).to( torch.device(backbone_device) ) def _load_codec(self, codec_repo, codec_device): print(f"Loading codec from: {codec_repo} on {codec_device} ...") match codec_repo: case "neuphonic/neucodec": self.codec = NeuCodec.from_pretrained(codec_repo) self.codec.eval().to(codec_device) case "neuphonic/distill-neucodec": self.codec = DistillNeuCodec.from_pretrained(codec_repo) self.codec.eval().to(codec_device) case "neuphonic/neucodec-onnx-decoder": if codec_device != "cpu": raise ValueError("Onnx decoder only currently runs on CPU.") try: from neucodec import NeuCodecOnnxDecoder except ImportError as e: raise ImportError( "Failed to import the onnx decoder." " Ensure you have onnxruntime installed as well as neucodec >= 0.0.4." ) from e self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo) self._is_onnx_codec = True case _: raise ValueError(f"Unsupported codec repository: {codec_repo}") def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray: """ Perform inference to generate speech from text using the TTS model and reference audio. Args: text (str): Input text to be converted to speech. ref_codes (np.ndarray | torch.tensor): Encoded reference. ref_text (str): Reference text for reference audio. Defaults to None. Returns: np.ndarray: Generated speech waveform. """ # Generate tokens if self._is_quantized_model: output_str = self._infer_ggml(ref_codes, ref_text, text) else: prompt_ids = self._apply_chat_template(ref_codes, ref_text, text) output_str = self._infer_torch(prompt_ids) # Decode wav = self._decode(output_str) return wav def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]: """ Perform streaming inference to generate speech from text using the TTS model and reference audio. Args: text (str): Input text to be converted to speech. ref_codes (np.ndarray | torch.tensor): Encoded reference. ref_text (str): Reference text for reference audio. Defaults to None. Yields: np.ndarray: Generated speech waveform. """ if self._is_quantized_model: return self._infer_stream_ggml(ref_codes, ref_text, text) else: raise NotImplementedError("Streaming is not implemented for the torch backend!") def encode_reference(self, ref_audio_path: str | Path): wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True) wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0) # [1, 1, T] with torch.no_grad(): ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0) return ref_codes def _decode(self, codes: str): """Decode speech tokens to audio waveform.""" # Extract speech token IDs using regex speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)] if len(speech_ids) == 0: raise ValueError( "No valid speech tokens found in the output. " "The model may not have generated proper speech tokens." ) # Onnx decode if self._is_onnx_codec: codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :] recon = self.codec.decode_code(codes) # Torch decode else: with torch.no_grad(): codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to( self.codec.device ) recon = self.codec.decode_code(codes).cpu().numpy() return recon[0, 0, :] def _to_phones(self, text: str) -> str: """Convert text to phonemes using phonemizer.""" phones = phonemize( text, language="vi", backend="espeak", preserve_punctuation=True, with_stress=True, language_switch="remove-flags" ) # Handle both string and list returns if isinstance(phones, list): if len(phones) == 0: raise ValueError(f"Phonemization failed for text: {text}") return phones[0] elif isinstance(phones, str): return phones else: raise TypeError(f"Unexpected phonemize return type: {type(phones)}") def _apply_chat_template(self, ref_codes: list[int], ref_text: str, input_text: str) -> list[int]: input_text = self._to_phones(ref_text) + " " + self._to_phones(input_text) speech_replace = self.tokenizer.convert_tokens_to_ids("<|SPEECH_REPLACE|>") speech_gen_start = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_START|>") text_replace = self.tokenizer.convert_tokens_to_ids("<|TEXT_REPLACE|>") text_prompt_start = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_START|>") text_prompt_end = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_END|>") input_ids = self.tokenizer.encode(input_text, add_special_tokens=False) chat = """user: Convert the text to speech:<|TEXT_REPLACE|>\nassistant:<|SPEECH_REPLACE|>""" ids = self.tokenizer.encode(chat) text_replace_idx = ids.index(text_replace) ids = ( ids[:text_replace_idx] + [text_prompt_start] + input_ids + [text_prompt_end] + ids[text_replace_idx + 1 :] # noqa ) speech_replace_idx = ids.index(speech_replace) codes_str = "".join([f"<|speech_{i}|>" for i in ref_codes]) codes = self.tokenizer.encode(codes_str, add_special_tokens=False) ids = ids[:speech_replace_idx] + [speech_gen_start] + list(codes) return ids def _infer_torch(self, prompt_ids: list[int]) -> str: prompt_tensor = torch.tensor(prompt_ids).unsqueeze(0).to(self.backbone.device) speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>") with torch.no_grad(): output_tokens = self.backbone.generate( prompt_tensor, max_length=self.max_context, eos_token_id=speech_end_id, do_sample=True, temperature=1.0, top_k=50, use_cache=True, min_new_tokens=50, ) input_length = prompt_tensor.shape[-1] output_str = self.tokenizer.decode( output_tokens[0, input_length:].cpu().numpy().tolist(), add_special_tokens=False ) return output_str def _infer_ggml(self, ref_codes: list[int], ref_text: str, input_text: str) -> str: ref_text = self._to_phones(ref_text) input_text = self._to_phones(input_text) codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes]) prompt = ( f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}" f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}" ) output = self.backbone( prompt, max_tokens=self.max_context, temperature=1.0, top_k=50, stop=["<|SPEECH_GENERATION_END|>"], ) output_str = output["choices"][0]["text"] return output_str def _infer_stream_ggml(self, ref_codes: torch.Tensor, ref_text: str, input_text: str) -> Generator[np.ndarray, None, None]: ref_text = self._to_phones(ref_text) input_text = self._to_phones(input_text) codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes]) prompt = ( f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}" f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}" ) audio_cache: list[np.ndarray] = [] token_cache: list[str] = [f"<|speech_{idx}|>" for idx in ref_codes] n_decoded_samples: int = 0 n_decoded_tokens: int = len(ref_codes) for item in self.backbone( prompt, max_tokens=self.max_context, temperature=0.2, top_k=50, stop=["<|SPEECH_GENERATION_END|>"], stream=True ): output_str = item["choices"][0]["text"] token_cache.append(output_str) if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward: # decode chunk tokens_start = max( n_decoded_tokens - self.streaming_lookback - self.streaming_overlap_frames, 0 ) tokens_end = ( n_decoded_tokens + self.streaming_frames_per_chunk + self.streaming_lookforward + self.streaming_overlap_frames ) sample_start = ( n_decoded_tokens - tokens_start ) * self.hop_length sample_end = ( sample_start + (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length ) curr_codes = token_cache[tokens_start:tokens_end] recon = self._decode("".join(curr_codes)) recon = recon[sample_start:sample_end] audio_cache.append(recon) # postprocess processed_recon = _linear_overlap_add( audio_cache, stride=self.streaming_stride_samples ) new_samples_end = len(audio_cache) * self.streaming_stride_samples processed_recon = processed_recon[ n_decoded_samples:new_samples_end ] n_decoded_samples = new_samples_end n_decoded_tokens += self.streaming_frames_per_chunk yield processed_recon # final decoding handled separately as non-constant chunk size remaining_tokens = len(token_cache) - n_decoded_tokens if len(token_cache) > n_decoded_tokens: tokens_start = max( len(token_cache) - (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens), 0 ) sample_start = ( len(token_cache) - tokens_start - remaining_tokens - self.streaming_overlap_frames ) * self.hop_length curr_codes = token_cache[tokens_start:] recon = self._decode("".join(curr_codes)) recon = recon[sample_start:] audio_cache.append(recon) processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples) processed_recon = processed_recon[n_decoded_samples:] yield processed_recon