DSTK / semantic_detokenizer /f5tts_npu_patch.py
gooorillax's picture
first push of codes and models for g2p, t2u, tokenizer and detokenizer
cd8454d
# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Xiao Chen)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from patch_utils import MindSpeedPatchesManager as aspm
import torchaudio
import torch
import logging
def get_vocos_mel_spectrogram_cpu(
waveform,
n_fft=1024,
n_mel_channels=100,
target_sample_rate=24000,
hop_length=256,
win_length=1024,
):
wave_device = waveform.device
waveform = waveform.cpu()
mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=target_sample_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
n_mels=n_mel_channels,
power=1,
center=True,
normalized=False,
norm=None,
).to(waveform.device)
if len(waveform.shape) == 3:
waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
assert len(waveform.shape) == 2
mel = mel_stft(waveform)
mel = mel.clamp(min=1e-5).log()
waveform = waveform.to(wave_device)
mel = mel.to(wave_device)
return mel
def load_checkpoint_npu(model, ckpt_path, device: str, dtype=None, use_ema=True):
logging.info(f"Load checkpoint {ckpt_path}")
if dtype is None:
dtype = (
torch.float16
if "cuda" in device or "npu" in device
and torch.cuda.get_device_properties(device).major >= 6
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
model = model.to(dtype)
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
from safetensors.torch import load_file
checkpoint = load_file(ckpt_path, device=device)
else:
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
if use_ema:
if ckpt_type == "safetensors":
checkpoint = {"ema_model_state_dict": checkpoint}
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
# patch for backward compatibility, 305e3ea
for key in [
"mel_spec.mel_stft.mel_scale.fb",
"mel_spec.mel_stft.spectrogram.window",
]:
if key in checkpoint["model_state_dict"]:
del checkpoint["model_state_dict"][key]
model.load_state_dict(checkpoint["model_state_dict"])
else:
if ckpt_type == "safetensors":
checkpoint = {"model_state_dict": checkpoint}
model.load_state_dict(checkpoint["model_state_dict"])
del checkpoint
torch.cuda.empty_cache()
return model.to(device)
def patch_for_npu():
# replace torch.cuda.get_device_capability with implementation from MindSpeed
aspm.register_patch('f5_tts.infer.utils_infer.load_checkpoint', load_checkpoint_npu)
aspm.register_patch('f5_tts.model.modules.get_vocos_mel_spectrogram', get_vocos_mel_spectrogram_cpu)
aspm.apply_patches()