|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from patch_utils import MindSpeedPatchesManager as aspm |
|
|
import torchaudio |
|
|
import torch |
|
|
import logging |
|
|
|
|
|
def get_vocos_mel_spectrogram_cpu( |
|
|
waveform, |
|
|
n_fft=1024, |
|
|
n_mel_channels=100, |
|
|
target_sample_rate=24000, |
|
|
hop_length=256, |
|
|
win_length=1024, |
|
|
): |
|
|
wave_device = waveform.device |
|
|
waveform = waveform.cpu() |
|
|
mel_stft = torchaudio.transforms.MelSpectrogram( |
|
|
sample_rate=target_sample_rate, |
|
|
n_fft=n_fft, |
|
|
win_length=win_length, |
|
|
hop_length=hop_length, |
|
|
n_mels=n_mel_channels, |
|
|
power=1, |
|
|
center=True, |
|
|
normalized=False, |
|
|
norm=None, |
|
|
).to(waveform.device) |
|
|
if len(waveform.shape) == 3: |
|
|
waveform = waveform.squeeze(1) |
|
|
|
|
|
assert len(waveform.shape) == 2 |
|
|
|
|
|
mel = mel_stft(waveform) |
|
|
mel = mel.clamp(min=1e-5).log() |
|
|
waveform = waveform.to(wave_device) |
|
|
mel = mel.to(wave_device) |
|
|
return mel |
|
|
|
|
|
|
|
|
def load_checkpoint_npu(model, ckpt_path, device: str, dtype=None, use_ema=True): |
|
|
logging.info(f"Load checkpoint {ckpt_path}") |
|
|
if dtype is None: |
|
|
dtype = ( |
|
|
torch.float16 |
|
|
if "cuda" in device or "npu" in device |
|
|
and torch.cuda.get_device_properties(device).major >= 6 |
|
|
and not torch.cuda.get_device_name().endswith("[ZLUDA]") |
|
|
else torch.float32 |
|
|
) |
|
|
model = model.to(dtype) |
|
|
|
|
|
ckpt_type = ckpt_path.split(".")[-1] |
|
|
if ckpt_type == "safetensors": |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
checkpoint = load_file(ckpt_path, device=device) |
|
|
else: |
|
|
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) |
|
|
|
|
|
if use_ema: |
|
|
if ckpt_type == "safetensors": |
|
|
checkpoint = {"ema_model_state_dict": checkpoint} |
|
|
checkpoint["model_state_dict"] = { |
|
|
k.replace("ema_model.", ""): v |
|
|
for k, v in checkpoint["ema_model_state_dict"].items() |
|
|
if k not in ["initted", "step"] |
|
|
} |
|
|
|
|
|
|
|
|
for key in [ |
|
|
"mel_spec.mel_stft.mel_scale.fb", |
|
|
"mel_spec.mel_stft.spectrogram.window", |
|
|
]: |
|
|
if key in checkpoint["model_state_dict"]: |
|
|
del checkpoint["model_state_dict"][key] |
|
|
|
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
else: |
|
|
if ckpt_type == "safetensors": |
|
|
checkpoint = {"model_state_dict": checkpoint} |
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
|
|
|
del checkpoint |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return model.to(device) |
|
|
|
|
|
def patch_for_npu(): |
|
|
|
|
|
aspm.register_patch('f5_tts.infer.utils_infer.load_checkpoint', load_checkpoint_npu) |
|
|
aspm.register_patch('f5_tts.model.modules.get_vocos_mel_spectrogram', get_vocos_mel_spectrogram_cpu) |
|
|
aspm.apply_patches() |