Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import numpy as np | |
| import inference | |
| from utils import get_poem_embeddings | |
| import config as CFG | |
| #for running this script as main | |
| from utils import get_datasets, build_loaders | |
| from models import PoemTextModel | |
| from train import train, test | |
| import json | |
| import os | |
| def calc_metrics(test_dataset, model): | |
| """ | |
| compute ranks of the test_dataset (and mean rank and MRR) | |
| Parameters: | |
| ----------- | |
| test_dataset: list of dict | |
| dataset containing text and poem beyts to compute metrics from | |
| model: PoemTextModel | |
| The PoemTextModel model to get poem embeddings from and predict poems for each text | |
| """ | |
| # computing all poems embeddings once (to avoid computing them for each test text) | |
| m , embedding = get_poem_embeddings(test_dataset, model) | |
| # adding poems and texts | |
| poems = [] | |
| meanings = [] | |
| for p in np.array(test_dataset): | |
| poems.append(p['beyt']) | |
| meanings.append(p['text']) | |
| # instantiating a text tokenizer to encode texts | |
| text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer) | |
| rank = [] | |
| for i, meaning in enumerate(meanings): | |
| # predict most similar poem beyts for each text | |
| sorted_pred = inference.predict_poems_from_text(model, embedding, meaning, poems, text_tokenizer, n=len(test_dataset)) | |
| # find index of this text's true beyt in the sorted predictions | |
| idx = sorted_pred.index(poems[i]) | |
| rank.append(idx+1) | |
| rank = np.array(rank) | |
| metrics = { | |
| "mean_rank": np.mean(rank), | |
| "mean_reciprocal_rank_(MRR)":np.mean(np.reciprocal(rank.astype(float))), | |
| "rank": rank.tolist() | |
| } | |
| return metrics | |
| if __name__ == "__main__": | |
| """ | |
| Creates a PoemTextModel based on configs, and computes its metrics. | |
| """ | |
| # get dataset from dataset_path (the same datasets as the train, val and test dataset files in the data directory is made) | |
| train_dataset, val_dataset, test_dataset = get_datasets() | |
| model = PoemTextModel(poem_encoder_pretrained=True, text_encoder_pretrained=True).to(CFG.device) | |
| model.eval() | |
| # compute accuracy, mean rank and MRR using test set and write them in a file | |
| print("Accuracy on test set: ", test(model, test_dataset)) | |
| metrics = calc_metrics(test_dataset, model) | |
| print('mean rank: ', metrics["mean_rank"]) | |
| print('mean reciprocal rank (MRR)', metrics["mean_reciprocal_rank_(MRR)"]) | |
| with open('test_metrics_{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f: | |
| f.write(json.dumps(metrics, indent= 4)) |