|
|
""" |
|
|
CUPE: Easy usage with automatic downloading from Hugging Face Hub |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torchaudio |
|
|
from huggingface_hub import hf_hub_download |
|
|
import importlib.util |
|
|
import sys |
|
|
import os |
|
|
|
|
|
def load_cupe_model(model_name="english", device="auto"): |
|
|
""" |
|
|
Load CUPE model with automatic downloading from Hugging Face Hub |
|
|
|
|
|
Args: |
|
|
model_name: "english", "multilingual-mls", or "multilingual-mswc" |
|
|
device: "auto", "cpu", or "cuda" |
|
|
|
|
|
Returns: |
|
|
Tuple of (extractor, windowing_module) |
|
|
""" |
|
|
|
|
|
|
|
|
model_files = { |
|
|
"english": "en_libri1000_uj01d_e199_val_GER=0.2307.ckpt", |
|
|
"multilingual-mls": "multi_MLS8_uh02_e36_val_GER=0.2334.ckpt", |
|
|
"multilingual-mswc": "multi_mswc38_ug20_e59_val_GER=0.5611.ckpt" |
|
|
} |
|
|
|
|
|
if model_name not in model_files: |
|
|
raise ValueError(f"Model {model_name} not available. Choose from: {list(model_files.keys())}") |
|
|
|
|
|
if device == "auto": |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
print(f"Loading CUPE {model_name} model...") |
|
|
|
|
|
|
|
|
repo_id = "Tabahi/CUPE-2i" |
|
|
|
|
|
model_file = hf_hub_download(repo_id=repo_id, filename="model2i.py") |
|
|
windowing_file = hf_hub_download(repo_id=repo_id, filename="windowing.py") |
|
|
mapper_file = hf_hub_download(repo_id=repo_id, filename="mapper.py") |
|
|
model_utils_file = hf_hub_download(repo_id=repo_id, filename="model_utils.py") |
|
|
checkpoint_file = hf_hub_download(repo_id=repo_id, filename=f"ckpt/{model_files[model_name]}") |
|
|
|
|
|
|
|
|
def import_module_from_file(module_name, file_path): |
|
|
spec = importlib.util.spec_from_file_location(module_name, file_path) |
|
|
module = importlib.util.module_from_spec(spec) |
|
|
sys.modules[module_name] = module |
|
|
spec.loader.exec_module(module) |
|
|
return module |
|
|
|
|
|
_ = import_module_from_file("model_utils", model_utils_file) |
|
|
model2i = import_module_from_file("model2i", model_file) |
|
|
windowing = import_module_from_file("windowing", windowing_file) |
|
|
mapper = import_module_from_file("mapper", mapper_file) |
|
|
|
|
|
phoneme_to_token = mapper.phoneme_mapped_index |
|
|
token_to_phoneme = {v: k for k, v in phoneme_to_token.items()} |
|
|
group_to_token = mapper.phoneme_groups_index |
|
|
token_to_group = {v: k for k, v in group_to_token.items()} |
|
|
|
|
|
|
|
|
extractor = model2i.CUPEEmbeddingsExtractor(checkpoint_file, device=device) |
|
|
|
|
|
print(f"Model loaded on {device}") |
|
|
return extractor, windowing, token_to_phoneme, token_to_group |
|
|
|
|
|
def predict_phonemes(audio_path, model_name="english", device="auto"): |
|
|
""" |
|
|
Predict phonemes from audio file |
|
|
|
|
|
Args: |
|
|
audio_path: Path to audio file |
|
|
model_name: CUPE model variant to use |
|
|
device: Device to run inference on |
|
|
|
|
|
Returns: |
|
|
Dictionary with predictions and metadata |
|
|
""" |
|
|
|
|
|
|
|
|
extractor, windowing, token_to_phoneme, token_to_group = load_cupe_model(model_name, device) |
|
|
|
|
|
|
|
|
sample_rate = 16000 |
|
|
window_size_ms = 120 |
|
|
stride_ms = 80 |
|
|
|
|
|
|
|
|
audio, orig_sr = torchaudio.load(audio_path) |
|
|
|
|
|
|
|
|
if audio.shape[0] > 1: |
|
|
audio = audio.mean(dim=0, keepdim=True) |
|
|
|
|
|
resampler = torchaudio.transforms.Resample( |
|
|
sample_rate, |
|
|
lowpass_filter_width=64, |
|
|
rolloff=0.9475937167399596, |
|
|
resampling_method="sinc_interp_kaiser", |
|
|
beta=14.769656459379492, |
|
|
) |
|
|
|
|
|
|
|
|
audio = resampler(audio) |
|
|
|
|
|
|
|
|
audio = audio.to(device) |
|
|
audio_batch = audio.unsqueeze(0) |
|
|
|
|
|
print(f"Processing audio: {audio.shape[1]/sample_rate:.2f}s duration") |
|
|
|
|
|
|
|
|
windowed_audio = windowing.slice_windows( |
|
|
audio_batch, |
|
|
sample_rate, |
|
|
window_size_ms, |
|
|
stride_ms |
|
|
) |
|
|
|
|
|
batch_size, num_windows, window_size = windowed_audio.shape |
|
|
windows_flat = windowed_audio.reshape(-1, window_size) |
|
|
|
|
|
|
|
|
logits_phonemes, logits_groups = extractor.predict( |
|
|
windows_flat, |
|
|
return_embeddings=False, |
|
|
groups_only=False |
|
|
) |
|
|
|
|
|
|
|
|
frames_per_window = logits_phonemes.shape[1] |
|
|
|
|
|
logits_phonemes = logits_phonemes.reshape(batch_size, num_windows, frames_per_window, -1) |
|
|
logits_groups = logits_groups.reshape(batch_size, num_windows, frames_per_window, -1) |
|
|
|
|
|
phoneme_logits = windowing.stich_window_predictions( |
|
|
logits_phonemes, |
|
|
original_audio_length=audio_batch.size(2), |
|
|
cnn_output_size=frames_per_window, |
|
|
sample_rate=sample_rate, |
|
|
window_size_ms=window_size_ms, |
|
|
stride_ms=stride_ms |
|
|
) |
|
|
|
|
|
group_logits = windowing.stich_window_predictions( |
|
|
logits_groups, |
|
|
original_audio_length=audio_batch.size(2), |
|
|
cnn_output_size=frames_per_window, |
|
|
sample_rate=sample_rate, |
|
|
window_size_ms=window_size_ms, |
|
|
stride_ms=stride_ms |
|
|
) |
|
|
|
|
|
|
|
|
phoneme_probs = torch.softmax(phoneme_logits.squeeze(0), dim=-1) |
|
|
group_probs = torch.softmax(group_logits.squeeze(0), dim=-1) |
|
|
|
|
|
phoneme_preds = torch.argmax(phoneme_probs, dim=-1) |
|
|
group_preds = torch.argmax(group_probs, dim=-1) |
|
|
|
|
|
phonemes_sequence = [token_to_phoneme[int(p)] for p in phoneme_preds.cpu().numpy()] |
|
|
groups_sequence = [token_to_group[int(g)] for g in group_preds.cpu().numpy()] |
|
|
|
|
|
phonemes_sequence = [p for p in phonemes_sequence if p != 'noise'] |
|
|
groups_sequence = [g for g in groups_sequence if g != 'noise'] |
|
|
|
|
|
|
|
|
|
|
|
num_frames = phoneme_probs.shape[0] |
|
|
|
|
|
print(f"Processed {num_frames} frames ({num_frames*16}ms total)") |
|
|
|
|
|
return { |
|
|
'phoneme_probabilities': phoneme_probs.cpu().numpy(), |
|
|
'phoneme_predictions': phoneme_preds.cpu().numpy(), |
|
|
'group_probabilities': group_probs.cpu().numpy(), |
|
|
'group_predictions': group_preds.cpu().numpy(), |
|
|
'phonemes_sequence': phonemes_sequence, |
|
|
'groups_sequence': groups_sequence, |
|
|
'model_info': { |
|
|
'model_name': model_name, |
|
|
'sample_rate': sample_rate, |
|
|
'frames_per_second': 1000/16, |
|
|
'num_phoneme_classes': phoneme_probs.shape[-1], |
|
|
'num_group_classes': group_probs.shape[-1] |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
audio_file = "samples/109867__timkahn__butterfly.wav.wav" |
|
|
|
|
|
|
|
|
if not os.path.exists(audio_file): |
|
|
print(f"Audio file {audio_file} does not exist. Please provide a valid path.") |
|
|
sys.exit(1) |
|
|
|
|
|
torch.manual_seed(42) |
|
|
|
|
|
results = predict_phonemes( |
|
|
audio_path=audio_file, |
|
|
model_name="english", |
|
|
device="cpu" |
|
|
) |
|
|
|
|
|
print(f"\nResults:") |
|
|
print(f"Phoneme predictions shape: {results['phoneme_predictions'].shape}") |
|
|
print(f"Group predictions shape: {results['group_predictions'].shape}") |
|
|
print(f"Model info: {results['model_info']}") |
|
|
|
|
|
|
|
|
print(f"\nFirst 10 frame predictions:") |
|
|
for i in range(min(10, len(results['phoneme_predictions']))): |
|
|
print(f"Frame {i}: phoneme={results['phoneme_predictions'][i]}, " |
|
|
f"group={results['group_predictions'][i]}") |
|
|
|
|
|
print(f"\nPhonemes sequence: {results['phonemes_sequence'][:10]}...") |
|
|
print(f"Groups sequence: {results['groups_sequence'][:10]}...") |
|
|
|
|
|
''' output: |
|
|
Loading CUPE english model... |
|
|
Model loaded on cpu |
|
|
Processing audio: 1.26s duration |
|
|
Processed 75 frames (1200ms total) |
|
|
|
|
|
Results: |
|
|
Phoneme predictions shape: (75,) |
|
|
Group predictions shape: (75,) |
|
|
Model info: {'model_name': 'english', 'sample_rate': 16000, 'frames_per_second': 62.5, 'num_phoneme_classes': 67, 'num_group_classes': 17} |
|
|
|
|
|
First 10 frame predictions: |
|
|
Frame 0: phoneme=66, group=16 |
|
|
Frame 1: phoneme=66, group=16 |
|
|
Frame 2: phoneme=29, group=7 |
|
|
Frame 3: phoneme=66, group=16 |
|
|
Frame 4: phoneme=66, group=16 |
|
|
Frame 5: phoneme=66, group=16 |
|
|
Frame 6: phoneme=10, group=2 |
|
|
Frame 7: phoneme=66, group=16 |
|
|
Frame 8: phoneme=66, group=16 |
|
|
Frame 9: phoneme=66, group=16 |
|
|
|
|
|
Phonemes sequence: ['b', 'ʌ', 't', 'h', 'ʌ', 'f', 'l', 'æ']... |
|
|
Groups sequence: ['voiced_stops', 'central_vowels', 'voiceless_stops', 'voiceless_fricatives', 'central_vowels', 'voiceless_fricatives', 'laterals', 'low_vowels']... |
|
|
''' |
|
|
|