#!/usr/bin/env python3 """ Inference script for N2 Schema.org Retrieval model. """ from transformers import AutoModelForCausalLM, AutoTokenizer import torch import json def select_schema_types(entities, max_new_tokens=500): """Select Schema.org types for entities using the N2 model.""" model_path = "/data/models/n2_schema_retrieval_merged" # Load model and tokenizer print("Loading N2 model...") model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained(model_path) # Create prompt for schema selection messages = [ {"role": "system", "content": "You are an expert in selecting appropriate Schema.org types for entities."}, {"role": "user", "content": f"For each entity below, select the most appropriate Schema.org type:\n\n{json.dumps(entities, ensure_ascii=False, indent=2)}\n\nReturn a JSON array with each entity and its Schema.org type."}, ] # Apply chat template prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Tokenize input inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=8192) inputs = {k: v.to(model.device) for k, v in inputs.items()} # Generate with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=0.1, # Low temperature for consistent classification do_sample=True, top_p=0.95, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) # Decode and extract assistant response full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract just the assistant's response assistant_marker = "assistant:" if assistant_marker in full_response: response = full_response.split(assistant_marker)[-1].strip() else: response = full_response # Try to parse as JSON try: schema_types = json.loads(response) return schema_types except: # Return raw response if not valid JSON return response if __name__ == "__main__": # Example usage test_entities = [ { "name": "Albert Einstein", "description": "German-born theoretical physicist who developed the theory of relativity" }, { "name": "Theory of Relativity", "description": "Scientific theory of the relationship between space and time" } ] print("Input entities:") print(json.dumps(test_entities, indent=2)) print("\nSelecting Schema.org types...") schema_types = select_schema_types(test_entities) print("\nSelected Schema.org types:") print(json.dumps(schema_types, indent=2) if isinstance(schema_types, (dict, list)) else schema_types)