Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchaudio | |
| import wget | |
| import json | |
| import os | |
| TTS_FOLDER = "./TTSModel" | |
| TTS_MODEL_NAME = "vits" | |
| TTS_MODEL_CONFIG = "config.json" | |
| TTS_MODEL_WEIGHTS = "pytorch_model.bin" | |
| TTS_VOCAB = "vocab.json" | |
| TTS_CONFIG_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/config.json" | |
| TTS_MODEL_WEIGHTS_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/pytorch_model.bin" | |
| TTS_VOCAB_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/vocab.json" | |
| TTS_FILES_URLS = [ | |
| (TTS_CONFIG_URL, TTS_MODEL_CONFIG), | |
| (TTS_MODEL_WEIGHTS_URL, TTS_MODEL_WEIGHTS), | |
| (TTS_VOCAB_URL, TTS_VOCAB), | |
| ] | |
| def ensure_tts_files_exist(): | |
| os.makedirs(TTS_FOLDER, exist_ok=True) | |
| for url, filename in TTS_FILES_URLS: | |
| filepath = os.path.join(TTS_FOLDER, filename) | |
| if not os.path.exists(filepath): | |
| wget.download(url, out=filepath) | |
| class VITS(nn.Module): | |
| def __init__(self, spec_channels, segment_size, num_speakers, num_languages, num_symbols): | |
| super().__init__() | |
| self.spec_channels = spec_channels | |
| self.segment_size = segment_size | |
| self.num_speakers = num_speakers | |
| self.num_languages = num_languages | |
| self.num_symbols = num_symbols | |
| self.embedding = nn.Embedding(num_symbols, 192) | |
| self.decoder = Generator(spec_channels) | |
| def forward(self, text): | |
| x = self.embedding(text) | |
| audio = self.decoder(x) | |
| return audio | |
| class Generator(nn.Module): | |
| def __init__(self, spec_channels): | |
| super().__init__() | |
| self.spec_channels = spec_channels | |
| self.initial_conv = nn.ConvTranspose2d(192, spec_channels, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) | |
| self.final_conv = nn.Conv2d(spec_channels, 1, kernel_size=(7, 7), padding=(3, 3)) | |
| def forward(self, encoder_outputs): | |
| x = encoder_outputs.unsqueeze(2) | |
| x = self.initial_conv(x) | |
| x = self.final_conv(x) | |
| return x.squeeze(1) |