Spaces:
Running
Running
| import numpy as np | |
| import gradio as gr | |
| from models.atstframe.ATSTF_wrapper import ATSTWrapper | |
| from models.beats.BEATs_wrapper import BEATsWrapper | |
| from models.frame_passt.fpasst_wrapper import FPaSSTWrapper | |
| from models.m2d.M2D_wrapper import M2DWrapper | |
| from models.asit.ASIT_wrapper import ASiTWrapper | |
| from models.frame_mn.Frame_MN_wrapper import FrameMNWrapper | |
| from models.prediction_wrapper import PredictionsWrapper | |
| from models.frame_mn.utils import NAME_TO_WIDTH | |
| import torch | |
| from torch import nn | |
| import pandas as pd | |
| class TransformerClassifier(nn.Module): | |
| def __init__(self, model, n_classes): | |
| super(TransformerClassifier, self).__init__() | |
| self.model = model | |
| self.linear = nn.Linear(model.embed_dim, n_classes) | |
| def forward(self, x): | |
| mel = self.model.mel_forward(x) | |
| features = self.model(mel).squeeze(1) | |
| return self.linear(features) | |
| def get_model(model_name): | |
| if model_name == "BEATs": | |
| beats = BEATsWrapper() | |
| model = PredictionsWrapper(beats, checkpoint=None, head_type=None, seq_len=1) | |
| elif model_name == "ATST-F": | |
| atst = ATSTWrapper() | |
| model = PredictionsWrapper(atst, checkpoint=None, head_type=None, seq_len=1) | |
| elif model_name == "fpasst": | |
| fpasst = FPaSSTWrapper() | |
| model = PredictionsWrapper(fpasst, checkpoint=None, head_type=None, seq_len=1) | |
| elif model_name == "M2D": | |
| m2d = M2DWrapper() | |
| model = PredictionsWrapper(m2d, checkpoint=None, head_type=None, seq_len=1, | |
| embed_dim=m2d.m2d.cfg.feature_d) | |
| elif model_name == "ASIT": | |
| asit = ASiTWrapper() | |
| model = PredictionsWrapper(asit, checkpoint=None, head_type=None, seq_len=1) | |
| elif model_name.startswith("frame_mn"): | |
| width = NAME_TO_WIDTH(model_name) | |
| frame_mn = FrameMNWrapper(width) | |
| embed_dim = frame_mn.state_dict()['frame_mn.features.16.1.bias'].shape[0] | |
| model = PredictionsWrapper(frame_mn, checkpoint=None, head_type=None, seq_len=1, embed_dim=embed_dim) | |
| else: | |
| raise NotImplementedError(f"Model {model_name} not (yet) implemented") | |
| main_model = TransformerClassifier(model, n_classes=200) | |
| # main_model.compile() | |
| main_model.load_state_dict(torch.load(f"resources/best_model_{model_name}.pth", map_location='cpu')) | |
| print(main_model) | |
| main_model.eval() | |
| return main_model | |
| model = get_model("BEATs") | |
| label_mapping = pd.read_csv("resources/labelvocabulary.csv", header=None, index_col=0).to_dict()[1] | |
| threshold = 0.4 | |
| def predict(input_audio): | |
| # Apply sepia effect to the audio | |
| with torch.no_grad(): | |
| waveform = torch.from_numpy(input_audio[1]).float() # Convert to tensor | |
| output = model(waveform.unsqueeze(0)).squeeze(0) # Add batch dimension | |
| output = output.sigmoid() | |
| num_labels = torch.where(output >= threshold)[0].tolist() | |
| labels = [label_mapping[str(i)] for i in num_labels] | |
| return ", ".join(labels) if labels else "No sound event detected" | |
| demo = gr.Interface(predict, gr.Audio(max_length=30,), "text",title="Freesound Sound Event Detection",) | |
| demo.launch() | |