import torch import torch.nn as nn import math import torch import torchaudio from models.ecapa_tdnn import ECAPA_TDNN_SMALL import torch.nn.functional as F score_fn = nn.CosineSimilarity() def load_model(checkpoint): model = ECAPA_TDNN_SMALL( feat_dim=1024, feat_type="wavlm_large", config_path=None ) state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage) model.load_state_dict(state_dict, strict=False) return model def inference_kathbadh( wav1, wav2): checkpoint = r"./wavlm_large_kathbadh_finetune.pth" model = load_model(checkpoint) model.eval() wav1, sr = torchaudio.load(wav1) wav2, sr = torchaudio.load(wav2) # input = torch.cat([wav1, wav2], dim=0) with torch.no_grad(): embedding1 = model(wav1) embedding2 = model(wav2) score = score_fn(embedding1, embedding2) return score.item()