Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, File, Form | |
| import datetime | |
| import time | |
| import torch | |
| from typing import Optional | |
| import os | |
| import numpy as np | |
| from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM, AutoConfig | |
| from huggingface_hub import hf_hub_download | |
| from fuzzywuzzy import fuzz | |
| from utils import ffmpeg_read, query_dummy, query_raw, find_different | |
| ## config | |
| API_TOKEN = os.environ["API_TOKEN"] | |
| MODEL_PATH = os.environ["MODEL_PATH"] | |
| PITCH_PATH = os.environ["PITCH_PATH"] | |
| QUANTIZED_MODEL_PATH = hf_hub_download(repo_id=MODEL_PATH, filename='quantized_model.pt', token=API_TOKEN) | |
| QUANTIZED_PITCH_MODEL_PATH = hf_hub_download(repo_id=PITCH_PATH, filename='quantized_model.pt', token=API_TOKEN) | |
| ## word preprocessor | |
| processor_with_lm = Wav2Vec2ProcessorWithLM.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN) | |
| processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN) | |
| ### quantized model | |
| config = AutoConfig.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN) | |
| dummy_model = Wav2Vec2ForCTC(config) | |
| quantized_model = torch.quantization.quantize_dynamic(dummy_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True) | |
| quantized_model.load_state_dict(torch.load(QUANTIZED_MODEL_PATH)) | |
| ## pitch preprocessor | |
| processor_pitch = Wav2Vec2Processor.from_pretrained(PITCH_PATH, use_auth_token=API_TOKEN) | |
| ### quantized pitch mode | |
| config = AutoConfig.from_pretrained(PITCH_PATH, use_auth_token=API_TOKEN) | |
| dummy_pitch_model = Wav2Vec2ForCTC(config) | |
| quantized_pitch_model = torch.quantization.quantize_dynamic(dummy_pitch_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True) | |
| quantized_pitch_model.load_state_dict(torch.load(QUANTIZED_PITCH_MODEL_PATH)) | |
| app = FastAPI() | |
| def read_root(): | |
| return {"Message": "Application startup complete"} | |
| async def predict( | |
| file: bytes = File(...), | |
| word: str = Form(...), | |
| pitch: Optional[str] = Form(None), | |
| temperature: int = Form(...), | |
| ): | |
| """ Transform input audio, get text and pitch from Huggingface api and calculate score by Levenshtein Distance Score | |
| Parameters: | |
| ---------- | |
| file : bytes | |
| input audio file | |
| word : strings | |
| true hiragana word to calculate word score | |
| pitch : strings | |
| true pitch to calculate pitch score | |
| temperature: integer | |
| the difficulty of AI model | |
| Returns: | |
| ------- | |
| timestamp: strings | |
| current time Year-Month-Day-Hours:Minutes:Second | |
| running_time : strings | |
| running time second | |
| error message : strings | |
| error message from api | |
| audio duration: integer | |
| durations of source audio | |
| target : integer | |
| durations of target audio | |
| method : string | |
| method applied to transform source audio | |
| word predict : strings | |
| text from api | |
| pitch predict : strings | |
| pitch from api | |
| wrong word index: strings (ex: 100) | |
| wrong word compare to target word | |
| wrong pitch index: strings (ex: 100) | |
| wrong word compare to target word | |
| score: integer | |
| Levenshtein Distance Score from pitch and word | |
| """ | |
| upload_audio = ffmpeg_read(file, sampling_rate=16000) | |
| audio_duration = len(upload_audio) / 16000 | |
| current_time = datetime.datetime.now().strftime("%Y-%h-%d-%H:%M:%S") | |
| start_time = time.time() | |
| error_message, score, word_preds, pitch_preds = None, None, None, None | |
| word_preds = query_raw(upload_audio, word, processor, processor_with_lm, quantized_model, temperature=temperature) | |
| if pitch is not None: | |
| if len(word) != len(pitch): | |
| error_message = "Length of word and pitch input is not equal" | |
| pitch_preds = query_dummy(upload_audio, processor_pitch, quantized_pitch_model) | |
| # find best word | |
| word_score_list = [] | |
| for word_predict in word_preds: | |
| word_score_list.append(fuzz.ratio(word, word_predict[0])) | |
| word_score = max(word_score_list) | |
| best_word_predict = word_preds[word_score_list.index(word_score)][0] | |
| wrong_word = find_different(word, best_word_predict) # get wrong word | |
| # find best pitch | |
| if pitch_preds is not None: | |
| best_pitch_predict = pitch_preds.replace(" ", "") | |
| if len(best_pitch_predict) < len(best_word_predict): | |
| best_pitch_predict = best_pitch_predict + "1" * (len(best_word_predict) - len(best_pitch_predict)) | |
| else: | |
| best_pitch_predict = best_pitch_predict[:len(best_word_predict)] # truncate to max len | |
| pitch_score = fuzz.ratio(pitch, best_pitch_predict) | |
| score = int((word_score * 2 + pitch_score) / 3) | |
| wrong_pitch = find_different(pitch, best_pitch_predict) # get wrong pitch | |
| else: | |
| score = int(word_score) | |
| best_pitch_predict = None | |
| wrong_pitch = None | |
| return {"timestamp": current_time, | |
| "running_time": f"{round(time.time() - start_time, 4)} s", | |
| "error_message": error_message, | |
| "audio_duration": audio_duration, | |
| "word_predict": best_word_predict, | |
| "pitch_predict": best_pitch_predict, | |
| "wrong_word_index": wrong_word, | |
| "wrong_pitch_index": wrong_pitch, | |
| "score": score, | |
| } |