Map-NEO / generate_text.py
Austin207's picture
Upload folder using huggingface_hub
a683148 verified
# generate_text.py - Improved text generation with advanced sampling
import torch
from transformers import AutoTokenizer
from model_neo import NeoMini, NeoMiniConfig
import json
import os
from pathlib import Path
def load_model(checkpoint_path="checkpoints/extended_context_model.pt"):
"""Load trained model and tokenizer"""
print(f"Loading model from {checkpoint_path}...")
# Check if checkpoint exists
if not os.path.exists(checkpoint_path):
print(f"Error: Checkpoint not found at {checkpoint_path}")
print("Available checkpoints:")
checkpoint_dir = Path("checkpoints")
if checkpoint_dir.exists():
for ckpt in sorted(checkpoint_dir.glob("checkpoint_step_*.pt")):
print(f" - {ckpt}")
return None, None
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location="cuda" if torch.cuda.is_available() else "cpu")
# Create model with same config
config = NeoMiniConfig()
model = NeoMini(config)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print(f"Model loaded on {device}")
# Load tokenizer
tokenizer_path = "data/tokenizer"
if os.path.exists(tokenizer_path):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
else:
print("Using GPT-2 tokenizer as fallback...")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"Tokenizer vocab size: {tokenizer.vocab_size}")
print(f"Model parameters: {model.get_num_params():,}")
return model, tokenizer
def generate_text(model, tokenizer, prompt, max_length=100,
temperature=0.7, # Lower = more focused
top_k=50, # Only consider top 50 tokens
top_p=0.9, # Nucleus sampling
repetition_penalty=1.1): # Penalize repetition
"""Generate text with advanced sampling techniques"""
device = next(model.parameters()).device
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
original_length = input_ids.size(1)
print(f"Generating with: temp={temperature}, top_k={top_k}, top_p={top_p}")
with torch.no_grad():
for step in range(max_length):
# Forward pass
logits = model(input_ids)
next_token_logits = logits[0, -1, :] / temperature
# Apply repetition penalty
if repetition_penalty != 1.0:
for token_id in set(input_ids[0].tolist()):
if next_token_logits[token_id] < 0:
next_token_logits[token_id] *= repetition_penalty
else:
next_token_logits[token_id] /= repetition_penalty
# Top-k filtering
if top_k > 0:
top_k_logits, _ = torch.topk(next_token_logits, top_k)
min_top_k = top_k_logits[-1]
next_token_logits[next_token_logits < min_top_k] = float('-inf')
# Top-p (nucleus) sampling
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Keep at least one token
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# Convert back to original indices
indices_to_remove = sorted_indices[sorted_indices_to_remove]
next_token_logits[indices_to_remove] = float('-inf')
# Sample next token
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Append to sequence
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
# Check for EOS token
if next_token.item() == tokenizer.eos_token_id:
print(f" → Stopped at EOS token (step {step+1})")
break
# Check for max context length
if input_ids.size(1) >= 1024: # Model's max context
print(f" → Stopped at max context length (step {step+1})")
break
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
def compare_generation_settings(model, tokenizer, prompt):
"""Compare different generation settings"""
print(f"\n{'='*80}")
print(f"COMPARING GENERATION SETTINGS")
print(f"Prompt: '{prompt}'")
print(f"{'='*80}")
settings = [
{"name": "Conservative", "temp": 0.5, "top_k": 20, "top_p": 0.8},
{"name": "Balanced", "temp": 0.7, "top_k": 50, "top_p": 0.9},
{"name": "Creative", "temp": 0.9, "top_k": 100, "top_p": 0.95},
{"name": "Focused", "temp": 0.3, "top_k": 10, "top_p": 0.7}
]
for setting in settings:
print(f"\n--- {setting['name']} Generation ---")
generated = generate_text(
model, tokenizer, prompt, max_length=80,
temperature=setting['temp'],
top_k=setting['top_k'],
top_p=setting['top_p']
)
# Only show the generated part (after prompt)
generated_only = generated[len(prompt):].strip()
print(f"Output: {generated_only}")
def interactive_mode(model, tokenizer):
"""Interactive text generation"""
print(f"\n{'='*60}")
print("INTERACTIVE MODE - Enter prompts (or 'quit' to exit)")
print(f"{'='*60}")
while True:
try:
prompt = input("\nEnter your prompt: ").strip()
if prompt.lower() in ['quit', 'exit', 'q']:
break
if not prompt:
continue
# Get generation parameters
try:
temp = float(input("Temperature (0.1-1.5, default 0.7): ") or "0.7")
top_k = int(input("Top-K (1-100, default 50): ") or "50")
top_p = float(input("Top-P (0.1-1.0, default 0.9): ") or "0.9")
max_len = int(input("Max length (10-200, default 100): ") or "100")
except ValueError:
print("Using default parameters...")
temp, top_k, top_p, max_len = 0.7, 50, 0.9, 100
print(f"\nGenerating...")
generated = generate_text(
model, tokenizer, prompt,
max_length=max_len, temperature=temp,
top_k=top_k, top_p=top_p
)
print(f"\nFull Output:\n{'-'*40}")
print(generated)
print(f"{'-'*40}")
except KeyboardInterrupt:
break
print("\nExiting interactive mode...")
def main():
print("MAP-NEO Mini Text Generator")
print("=" * 50)
# Load model
model, tokenizer = load_model()
if model is None or tokenizer is None:
print("Failed to load model. Exiting.")
return
# Test prompts
test_prompts = [
"The future of artificial intelligence",
"In a world where technology",
"Scientists have discovered",
"The key to success is",
"Climate change is",
"The importance of education",
"Once upon a time, there was",
"To solve this problem, we need to"
]
print(f"\n{'='*60}")
print("BASIC GENERATION TEST")
print(f"{'='*60}")
# Test basic generation
for i, prompt in enumerate(test_prompts[:3], 1):
print(f"\n--- Test {i}/3 ---")
print(f"Prompt: {prompt}")
print("-" * 50)
generated = generate_text(
model, tokenizer, prompt,
max_length=80, temperature=0.7,
top_k=50, top_p=0.9
)
# Show only generated part
generated_only = generated[len(prompt):].strip()
print(f"Generated: {generated_only}")
# Compare settings
compare_generation_settings(
model, tokenizer,
"The most important discovery in science was"
)
# Interactive mode
print(f"\n{'='*60}")
choice = input("Start interactive mode? (y/n): ").lower().strip()
if choice in ['y', 'yes']:
interactive_mode(model, tokenizer)
print("\nText generation complete!")
if __name__ == "__main__":
main()