Spaces:
Sleeping
Sleeping
| import logging | |
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| MAPPING_FROM_TAG_TO_CATEGORY = { | |
| "cs.AI": ["Computer Science", "Artificial Intelligence"], | |
| "cs.CL": ["Computer Science", "Computation and Language"], | |
| "cs.CV": ["Computer Science", "Computer Vision and Pattern Recognition"], | |
| "cs.NE": ["Computer Science", "Neural and Evolutionary Computing"], | |
| "stat.ML": ["Statistics", "Machine Learning"], | |
| "cs.LG": ["Computer Science", "Machine Learning"], | |
| "physics.soc-ph": ["Physics", "Physics and Society"], | |
| "stat.AP": ["Statistics", "Applications"], | |
| "cs.RO": ["Computer Science", "Robotics"], | |
| "cs.MA": ["Computer Science", "Multiagent Systems"], | |
| "math.OC": ["Mathematics", "Optimization and Control"], | |
| "cs.IR": ["Computer Science", "Information Retrieval"], | |
| "stat.ME": ["Statistics", "Methodology"], | |
| "cs.DC": ["Computer Science", "Distributed, Parallel, and Cluster Computing"], | |
| "stat.CO": ["Statistics", "Computation"], | |
| "q-bio.NC": ["Quantitative Biology", "Neurons and Cognition"], | |
| "cs.GT": ["Computer Science", "Computer Science and Game Theory"], | |
| "cs.MM": ["Computer Science", "Multimedia"], | |
| "cs.CR": ["Computer Science", "Cryptography and Security"], | |
| "cs.HC": ["Computer Science", "Human-Computer Interaction"], | |
| "cs.SD": ["Computer Science", "Sound"], | |
| "cs.GR": ["Computer Science", "Graphics"], | |
| "cs.CY": ["Computer Science", "Computers and Society"], | |
| "math.ST": ["Mathematics", "Statistics Theory"], | |
| "stat.TH": ["Statistics", "Statistics Theory"], | |
| "cs.IT": ["Computer Science", "Information Theory"], | |
| "math.IT": ["Mathematics", "Information Theory"], | |
| "cs.SI": ["Computer Science", "Social and Information Networks"], | |
| "cs.DB": ["Computer Science", "Databases"], | |
| "cs.LO": ["Computer Science", "Logic in Computer Science"], | |
| "cs.SY": ["Computer Science", "Systems and Control"], | |
| "q-bio.QM": ["Quantitative Biology", "Quantitative Methods"], | |
| "cs.DS": ["Computer Science", "Data Structures and Algorithms"], | |
| "cs.NA": ["Computer Science", "Numerical Analysis"], | |
| "cs.CE": ["Computer Science", "Computational Engineering, Finance, and Science"], | |
| } | |
| MODEL_PATH = "minemile/arxiv-tag-classifier" | |
| MODEL_MAX_LENGTH = 512 | |
| CUM_PROB_THRESHOLD = 0.95 | |
| def load_model_and_tokenizer(model_path): | |
| logger.info("Loading model and tokenizer from %s", model_path) | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_path, device_map="cpu" | |
| ) | |
| model.eval() | |
| logger.info("Model and tokenizer loaded successfully.") | |
| return model, tokenizer | |
| except Exception as e: | |
| st.error(f"An error occurred during prediction: {e}") | |
| return None, None | |
| def inference(model, tokenizer, title, summary=None): | |
| if not model or not tokenizer: | |
| st.error("Model or tokenizer not loaded. Cannot predict.") | |
| return None, None | |
| combined_text = title | |
| if summary is not None: | |
| combined_text += " " + summary | |
| logger.info("Predicting for text: %s", combined_text) | |
| try: | |
| tokenized_inputs = tokenizer( | |
| combined_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=MODEL_MAX_LENGTH, | |
| ) | |
| with torch.no_grad(): | |
| logits = model(**tokenized_inputs).logits | |
| probabilities = torch.softmax(logits, dim=-1).squeeze().cpu().numpy() | |
| return probabilities | |
| except Exception as e: | |
| st.error(f"An error occurred during prediction: {e}") | |
| return None, None | |
| def prepare_output(probabilities, id2label): | |
| top_indices = np.argsort(probabilities)[::-1] | |
| cum_sum = 0.0 | |
| top_tags = [] | |
| top_category = [] | |
| top_subcategory = [] | |
| top_probas = [] | |
| for indx in top_indices: | |
| tag = id2label[indx] | |
| if tag not in MAPPING_FROM_TAG_TO_CATEGORY: | |
| top_tags.append(f"{tag}") | |
| top_category.append("Unknown") | |
| top_subcategory.append("Unknown") | |
| logger.warning("Tag %s not found in mapping from tag to category.", tag) | |
| else: | |
| top_tags.append(f"{tag}") | |
| top_category.append(MAPPING_FROM_TAG_TO_CATEGORY[tag][0]) | |
| top_subcategory.append(MAPPING_FROM_TAG_TO_CATEGORY[tag][1]) | |
| top_probas.append(f"{probabilities[indx]*100:.2f}%") | |
| cum_sum += probabilities[indx] | |
| if cum_sum >= CUM_PROB_THRESHOLD or len(top_tags) >= 5: | |
| break | |
| return { | |
| "Tag": top_tags, | |
| "Category": top_category, | |
| "Subcategory": top_subcategory, | |
| "Probability": top_probas, | |
| } | |
| def main(): | |
| st.set_page_config(page_title="ArXiv Category Tag Classifier", layout="wide") | |
| st.title("ArXiv Category Tag Classifier") | |
| with st.spinner("Loading model and tokenizer..."): | |
| model, tokenizer = load_model_and_tokenizer(MODEL_PATH) | |
| if model is None or tokenizer is None: | |
| st.error("Failed to load model/tokenizer.") | |
| return | |
| st.markdown( | |
| f"Enter the title (required) and summary (abstract) of an ArXiv paper to predict its " | |
| f"ArXiv category using a transformer model. There are {len(model.config.id2label)} available categories." | |
| ) | |
| st.divider() | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| title = st.text_input( | |
| "Title", | |
| placeholder="Enter the title of the paper (Required)", | |
| ) | |
| with col2: | |
| paper_summary = st.text_area( | |
| "Paper Summary (Optional):", | |
| placeholder="Paste the paper's abstract here. It can increase the accuracy of the prediction.", | |
| height=180, | |
| ) | |
| predict_button = st.button("Predict Category", type="primary") | |
| st.markdown("---") | |
| if predict_button: | |
| if not title: | |
| st.error("Title of the paper is required!") | |
| else: | |
| with st.spinner("Predicting category..."): | |
| class_probabilities = inference(model, tokenizer, title, paper_summary) | |
| if class_probabilities is not None: | |
| st.subheader("Top Tags Predictions:") | |
| output = prepare_output(class_probabilities, model.config.id2label) | |
| logger.info("Output: %s", output) | |
| st.dataframe(output, use_container_width=True) | |
| st.markdown("---") | |
| if __name__ == "__main__": | |
| main() | |