|
|
|
|
|
import torch
|
|
|
from transformers import AutoTokenizer
|
|
|
from model_neo import NeoMini, NeoMiniConfig
|
|
|
import json
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
|
|
|
def load_model(checkpoint_path="checkpoints/extended_context_model.pt"):
|
|
|
"""Load trained model and tokenizer"""
|
|
|
print(f"Loading model from {checkpoint_path}...")
|
|
|
|
|
|
|
|
|
if not os.path.exists(checkpoint_path):
|
|
|
print(f"Error: Checkpoint not found at {checkpoint_path}")
|
|
|
print("Available checkpoints:")
|
|
|
checkpoint_dir = Path("checkpoints")
|
|
|
if checkpoint_dir.exists():
|
|
|
for ckpt in sorted(checkpoint_dir.glob("checkpoint_step_*.pt")):
|
|
|
print(f" - {ckpt}")
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location="cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
config = NeoMiniConfig()
|
|
|
model = NeoMini(config)
|
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
model = model.to(device)
|
|
|
print(f"Model loaded on {device}")
|
|
|
|
|
|
|
|
|
tokenizer_path = "data/tokenizer"
|
|
|
if os.path.exists(tokenizer_path):
|
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
|
|
else:
|
|
|
print("Using GPT-2 tokenizer as fallback...")
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
|
if tokenizer.pad_token is None:
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
print(f"Tokenizer vocab size: {tokenizer.vocab_size}")
|
|
|
print(f"Model parameters: {model.get_num_params():,}")
|
|
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
def generate_text(model, tokenizer, prompt, max_length=100,
|
|
|
temperature=0.7,
|
|
|
top_k=50,
|
|
|
top_p=0.9,
|
|
|
repetition_penalty=1.1):
|
|
|
"""Generate text with advanced sampling techniques"""
|
|
|
|
|
|
device = next(model.parameters()).device
|
|
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
|
|
original_length = input_ids.size(1)
|
|
|
|
|
|
print(f"Generating with: temp={temperature}, top_k={top_k}, top_p={top_p}")
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for step in range(max_length):
|
|
|
|
|
|
logits = model(input_ids)
|
|
|
next_token_logits = logits[0, -1, :] / temperature
|
|
|
|
|
|
|
|
|
if repetition_penalty != 1.0:
|
|
|
for token_id in set(input_ids[0].tolist()):
|
|
|
if next_token_logits[token_id] < 0:
|
|
|
next_token_logits[token_id] *= repetition_penalty
|
|
|
else:
|
|
|
next_token_logits[token_id] /= repetition_penalty
|
|
|
|
|
|
|
|
|
if top_k > 0:
|
|
|
top_k_logits, _ = torch.topk(next_token_logits, top_k)
|
|
|
min_top_k = top_k_logits[-1]
|
|
|
next_token_logits[next_token_logits < min_top_k] = float('-inf')
|
|
|
|
|
|
|
|
|
if top_p < 1.0:
|
|
|
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
|
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p
|
|
|
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
|
sorted_indices_to_remove[..., 0] = 0
|
|
|
|
|
|
|
|
|
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
|
|
next_token_logits[indices_to_remove] = float('-inf')
|
|
|
|
|
|
|
|
|
probs = torch.softmax(next_token_logits, dim=-1)
|
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
|
|
|
|
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
|
|
|
|
|
|
|
|
|
if next_token.item() == tokenizer.eos_token_id:
|
|
|
print(f" → Stopped at EOS token (step {step+1})")
|
|
|
break
|
|
|
|
|
|
|
|
|
if input_ids.size(1) >= 1024:
|
|
|
print(f" → Stopped at max context length (step {step+1})")
|
|
|
break
|
|
|
|
|
|
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
|
|
|
|
|
def compare_generation_settings(model, tokenizer, prompt):
|
|
|
"""Compare different generation settings"""
|
|
|
print(f"\n{'='*80}")
|
|
|
print(f"COMPARING GENERATION SETTINGS")
|
|
|
print(f"Prompt: '{prompt}'")
|
|
|
print(f"{'='*80}")
|
|
|
|
|
|
settings = [
|
|
|
{"name": "Conservative", "temp": 0.5, "top_k": 20, "top_p": 0.8},
|
|
|
{"name": "Balanced", "temp": 0.7, "top_k": 50, "top_p": 0.9},
|
|
|
{"name": "Creative", "temp": 0.9, "top_k": 100, "top_p": 0.95},
|
|
|
{"name": "Focused", "temp": 0.3, "top_k": 10, "top_p": 0.7}
|
|
|
]
|
|
|
|
|
|
for setting in settings:
|
|
|
print(f"\n--- {setting['name']} Generation ---")
|
|
|
generated = generate_text(
|
|
|
model, tokenizer, prompt, max_length=80,
|
|
|
temperature=setting['temp'],
|
|
|
top_k=setting['top_k'],
|
|
|
top_p=setting['top_p']
|
|
|
)
|
|
|
|
|
|
generated_only = generated[len(prompt):].strip()
|
|
|
print(f"Output: {generated_only}")
|
|
|
|
|
|
def interactive_mode(model, tokenizer):
|
|
|
"""Interactive text generation"""
|
|
|
print(f"\n{'='*60}")
|
|
|
print("INTERACTIVE MODE - Enter prompts (or 'quit' to exit)")
|
|
|
print(f"{'='*60}")
|
|
|
|
|
|
while True:
|
|
|
try:
|
|
|
prompt = input("\nEnter your prompt: ").strip()
|
|
|
if prompt.lower() in ['quit', 'exit', 'q']:
|
|
|
break
|
|
|
|
|
|
if not prompt:
|
|
|
continue
|
|
|
|
|
|
|
|
|
try:
|
|
|
temp = float(input("Temperature (0.1-1.5, default 0.7): ") or "0.7")
|
|
|
top_k = int(input("Top-K (1-100, default 50): ") or "50")
|
|
|
top_p = float(input("Top-P (0.1-1.0, default 0.9): ") or "0.9")
|
|
|
max_len = int(input("Max length (10-200, default 100): ") or "100")
|
|
|
except ValueError:
|
|
|
print("Using default parameters...")
|
|
|
temp, top_k, top_p, max_len = 0.7, 50, 0.9, 100
|
|
|
|
|
|
print(f"\nGenerating...")
|
|
|
generated = generate_text(
|
|
|
model, tokenizer, prompt,
|
|
|
max_length=max_len, temperature=temp,
|
|
|
top_k=top_k, top_p=top_p
|
|
|
)
|
|
|
|
|
|
print(f"\nFull Output:\n{'-'*40}")
|
|
|
print(generated)
|
|
|
print(f"{'-'*40}")
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
break
|
|
|
|
|
|
print("\nExiting interactive mode...")
|
|
|
|
|
|
def main():
|
|
|
print("MAP-NEO Mini Text Generator")
|
|
|
print("=" * 50)
|
|
|
|
|
|
|
|
|
model, tokenizer = load_model()
|
|
|
|
|
|
if model is None or tokenizer is None:
|
|
|
print("Failed to load model. Exiting.")
|
|
|
return
|
|
|
|
|
|
|
|
|
test_prompts = [
|
|
|
"The future of artificial intelligence",
|
|
|
"In a world where technology",
|
|
|
"Scientists have discovered",
|
|
|
"The key to success is",
|
|
|
"Climate change is",
|
|
|
"The importance of education",
|
|
|
"Once upon a time, there was",
|
|
|
"To solve this problem, we need to"
|
|
|
]
|
|
|
|
|
|
print(f"\n{'='*60}")
|
|
|
print("BASIC GENERATION TEST")
|
|
|
print(f"{'='*60}")
|
|
|
|
|
|
|
|
|
for i, prompt in enumerate(test_prompts[:3], 1):
|
|
|
print(f"\n--- Test {i}/3 ---")
|
|
|
print(f"Prompt: {prompt}")
|
|
|
print("-" * 50)
|
|
|
|
|
|
generated = generate_text(
|
|
|
model, tokenizer, prompt,
|
|
|
max_length=80, temperature=0.7,
|
|
|
top_k=50, top_p=0.9
|
|
|
)
|
|
|
|
|
|
|
|
|
generated_only = generated[len(prompt):].strip()
|
|
|
print(f"Generated: {generated_only}")
|
|
|
|
|
|
|
|
|
compare_generation_settings(
|
|
|
model, tokenizer,
|
|
|
"The most important discovery in science was"
|
|
|
)
|
|
|
|
|
|
|
|
|
print(f"\n{'='*60}")
|
|
|
choice = input("Start interactive mode? (y/n): ").lower().strip()
|
|
|
if choice in ['y', 'yes']:
|
|
|
interactive_mode(model, tokenizer)
|
|
|
|
|
|
print("\nText generation complete!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|