# generate_text.py - Improved text generation with advanced sampling import torch from transformers import AutoTokenizer from model_neo import NeoMini, NeoMiniConfig import json import os from pathlib import Path def load_model(checkpoint_path="checkpoints/extended_context_model.pt"): """Load trained model and tokenizer""" print(f"Loading model from {checkpoint_path}...") # Check if checkpoint exists if not os.path.exists(checkpoint_path): print(f"Error: Checkpoint not found at {checkpoint_path}") print("Available checkpoints:") checkpoint_dir = Path("checkpoints") if checkpoint_dir.exists(): for ckpt in sorted(checkpoint_dir.glob("checkpoint_step_*.pt")): print(f" - {ckpt}") return None, None # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location="cuda" if torch.cuda.is_available() else "cpu") # Create model with same config config = NeoMiniConfig() model = NeoMini(config) model.load_state_dict(checkpoint['model_state_dict']) model.eval() # Move to GPU if available device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) print(f"Model loaded on {device}") # Load tokenizer tokenizer_path = "data/tokenizer" if os.path.exists(tokenizer_path): tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) else: print("Using GPT-2 tokenizer as fallback...") tokenizer = AutoTokenizer.from_pretrained("gpt2") if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print(f"Tokenizer vocab size: {tokenizer.vocab_size}") print(f"Model parameters: {model.get_num_params():,}") return model, tokenizer def generate_text(model, tokenizer, prompt, max_length=100, temperature=0.7, # Lower = more focused top_k=50, # Only consider top 50 tokens top_p=0.9, # Nucleus sampling repetition_penalty=1.1): # Penalize repetition """Generate text with advanced sampling techniques""" device = next(model.parameters()).device input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) original_length = input_ids.size(1) print(f"Generating with: temp={temperature}, top_k={top_k}, top_p={top_p}") with torch.no_grad(): for step in range(max_length): # Forward pass logits = model(input_ids) next_token_logits = logits[0, -1, :] / temperature # Apply repetition penalty if repetition_penalty != 1.0: for token_id in set(input_ids[0].tolist()): if next_token_logits[token_id] < 0: next_token_logits[token_id] *= repetition_penalty else: next_token_logits[token_id] /= repetition_penalty # Top-k filtering if top_k > 0: top_k_logits, _ = torch.topk(next_token_logits, top_k) min_top_k = top_k_logits[-1] next_token_logits[next_token_logits < min_top_k] = float('-inf') # Top-p (nucleus) sampling if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Keep at least one token sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 # Convert back to original indices indices_to_remove = sorted_indices[sorted_indices_to_remove] next_token_logits[indices_to_remove] = float('-inf') # Sample next token probs = torch.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # Append to sequence input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) # Check for EOS token if next_token.item() == tokenizer.eos_token_id: print(f" → Stopped at EOS token (step {step+1})") break # Check for max context length if input_ids.size(1) >= 1024: # Model's max context print(f" → Stopped at max context length (step {step+1})") break return tokenizer.decode(input_ids[0], skip_special_tokens=True) def compare_generation_settings(model, tokenizer, prompt): """Compare different generation settings""" print(f"\n{'='*80}") print(f"COMPARING GENERATION SETTINGS") print(f"Prompt: '{prompt}'") print(f"{'='*80}") settings = [ {"name": "Conservative", "temp": 0.5, "top_k": 20, "top_p": 0.8}, {"name": "Balanced", "temp": 0.7, "top_k": 50, "top_p": 0.9}, {"name": "Creative", "temp": 0.9, "top_k": 100, "top_p": 0.95}, {"name": "Focused", "temp": 0.3, "top_k": 10, "top_p": 0.7} ] for setting in settings: print(f"\n--- {setting['name']} Generation ---") generated = generate_text( model, tokenizer, prompt, max_length=80, temperature=setting['temp'], top_k=setting['top_k'], top_p=setting['top_p'] ) # Only show the generated part (after prompt) generated_only = generated[len(prompt):].strip() print(f"Output: {generated_only}") def interactive_mode(model, tokenizer): """Interactive text generation""" print(f"\n{'='*60}") print("INTERACTIVE MODE - Enter prompts (or 'quit' to exit)") print(f"{'='*60}") while True: try: prompt = input("\nEnter your prompt: ").strip() if prompt.lower() in ['quit', 'exit', 'q']: break if not prompt: continue # Get generation parameters try: temp = float(input("Temperature (0.1-1.5, default 0.7): ") or "0.7") top_k = int(input("Top-K (1-100, default 50): ") or "50") top_p = float(input("Top-P (0.1-1.0, default 0.9): ") or "0.9") max_len = int(input("Max length (10-200, default 100): ") or "100") except ValueError: print("Using default parameters...") temp, top_k, top_p, max_len = 0.7, 50, 0.9, 100 print(f"\nGenerating...") generated = generate_text( model, tokenizer, prompt, max_length=max_len, temperature=temp, top_k=top_k, top_p=top_p ) print(f"\nFull Output:\n{'-'*40}") print(generated) print(f"{'-'*40}") except KeyboardInterrupt: break print("\nExiting interactive mode...") def main(): print("MAP-NEO Mini Text Generator") print("=" * 50) # Load model model, tokenizer = load_model() if model is None or tokenizer is None: print("Failed to load model. Exiting.") return # Test prompts test_prompts = [ "The future of artificial intelligence", "In a world where technology", "Scientists have discovered", "The key to success is", "Climate change is", "The importance of education", "Once upon a time, there was", "To solve this problem, we need to" ] print(f"\n{'='*60}") print("BASIC GENERATION TEST") print(f"{'='*60}") # Test basic generation for i, prompt in enumerate(test_prompts[:3], 1): print(f"\n--- Test {i}/3 ---") print(f"Prompt: {prompt}") print("-" * 50) generated = generate_text( model, tokenizer, prompt, max_length=80, temperature=0.7, top_k=50, top_p=0.9 ) # Show only generated part generated_only = generated[len(prompt):].strip() print(f"Generated: {generated_only}") # Compare settings compare_generation_settings( model, tokenizer, "The most important discovery in science was" ) # Interactive mode print(f"\n{'='*60}") choice = input("Start interactive mode? (y/n): ").lower().strip() if choice in ['y', 'yes']: interactive_mode(model, tokenizer) print("\nText generation complete!") if __name__ == "__main__": main()