Map-NEO / advanced_generate.py
Austin207's picture
Upload folder using huggingface_hub
a683148 verified
"""advanced_generate.py - Advanced text generation with instruction prompts, context window info, and GPU monitoring"""
import torch
from transformers import AutoTokenizer
from model_neo import NeoMini, NeoMiniConfig
import os
from pathlib import Path
import gc
def clear_gpu_cache():
"""Clear GPU memory cache to free up VRAM"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
print("๐Ÿงน GPU cache cleared")
def force_garbage_collection():
"""Force garbage collection and clear caches"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
print("๐Ÿ—‘๏ธ Garbage collection and cache clearing completed")
def reset_gpu():
"""Quick GPU reset function for interactive use"""
force_garbage_collection()
print(f"๐Ÿ”„ GPU reset: {get_gpu_memory_info()}")
def get_gpu_memory_info():
"""Get GPU memory usage information"""
if not torch.cuda.is_available():
return "CUDA not available"
try:
# Get current GPU memory usage
allocated = torch.cuda.memory_allocated(0) / 1024**3 # Convert to GB
cached = torch.cuda.memory_reserved(0) / 1024**3
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
return f"GPU Memory: {allocated:.2f}GB allocated, {cached:.2f}GB cached, {total:.2f}GB total"
except Exception as e:
return f"Could not get GPU memory info: {e}"
def load_model(checkpoint_path="checkpoints/extended_context_model.pt"):
print(f"Loading model from {checkpoint_path}...")
# Clear cache before loading
print("๐Ÿงน Clearing cache before model loading...")
force_garbage_collection()
if not os.path.exists(checkpoint_path):
print(f"โŒ Checkpoint {checkpoint_path} not found.")
return None, None, None
checkpoint = torch.load(checkpoint_path, map_location="cuda" if torch.cuda.is_available() else "cpu")
# Get config from checkpoint or use default
if 'config' in checkpoint:
max_seq_len = checkpoint['config'].get('max_seq_len', 2048)
else:
max_seq_len = 2048 # fallback
config = NeoMiniConfig()
config.max_seq_len = max_seq_len
model = NeoMini(config)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
tokenizer_path = "data/tokenizer"
if Path(tokenizer_path).exists():
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
else:
print("Tokenizer path not found, fallback to GPT-2 tokenizer.")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"โœ… Model loaded on {device}")
print(f"๐Ÿ“Š Tokenizer vocab size: {tokenizer.vocab_size:,}")
print(f"๐Ÿง  Model parameters: {model.get_num_params():,}")
print(f"๐Ÿ“ Max context window: {max_seq_len:,} tokens")
# Clear cache after model loading
clear_gpu_cache()
print(f"๐Ÿ’พ After model load: {get_gpu_memory_info()}")
return model, tokenizer, max_seq_len
def generate_text(model, tokenizer, max_context_length, prompt, max_length=100,
temperature=0.4, top_k=20, top_p=0.8, repetition_penalty=1.2):
device = next(model.parameters()).device
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
# Clear cache for long contexts
if input_ids.size(1) > 1000:
clear_gpu_cache()
# Check initial prompt length
prompt_length = input_ids.size(1)
print(f"๐Ÿ“ Prompt length: {prompt_length:,} tokens")
if prompt_length >= max_context_length:
print(f"โš ๏ธ Warning: Prompt ({prompt_length}) exceeds max context ({max_context_length})")
return "Error: Prompt too long for context window"
# Adjust max_length if needed
available_tokens = max_context_length - prompt_length
if max_length > available_tokens:
print(f"โš ๏ธ Adjusting max_length from {max_length} to {available_tokens} (context limit)")
max_length = available_tokens
print(f"๐ŸŽฏ Generating max {max_length} tokens (temp={temperature}, top_k={top_k}, top_p={top_p}, rep_penalty={repetition_penalty})")
print(f"๐Ÿ’พ Before generation: {get_gpu_memory_info()}")
with torch.no_grad():
generated = input_ids
tokens_generated = 0
for step in range(max_length):
# Check memory and clear cache periodically for long generations
if step % 100 == 0 and step > 0:
current_length = generated.size(1)
print(f" ๐Ÿ“Š Step {step}: {current_length:,}/{max_context_length:,} tokens")
if current_length > 2000: # Clear cache for very long contexts
clear_gpu_cache()
print(f" {get_gpu_memory_info()}")
logits = model(generated)
next_token_logits = logits[0, -1, :] / temperature
# Repetition penalty
if repetition_penalty != 1.0:
for token_id in set(generated[0].tolist()):
if next_token_logits[token_id] < 0:
next_token_logits[token_id] *= repetition_penalty
else:
next_token_logits[token_id] /= repetition_penalty
# Top-k filtering
if top_k > 0:
top_k_logits, _ = torch.topk(next_token_logits, top_k)
min_top_k = top_k_logits[-1]
next_token_logits[next_token_logits < min_top_k] = float("-inf")
# Top-p filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
next_token_logits[indices_to_remove] = float("-inf")
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1)
tokens_generated += 1
# Check stopping conditions
if next_token.item() == tokenizer.eos_token_id:
print(f"๐Ÿ›‘ Stopped at EOS token (generated {tokens_generated} tokens)")
break
if generated.size(1) >= max_context_length:
print(f"๐Ÿ›‘ Stopped at max context length {max_context_length:,} (generated {tokens_generated} tokens)")
break
final_length = generated.size(1)
print(f"โœ… Generation complete: {final_length:,} total tokens ({tokens_generated} new tokens)")
# Clear cache after long generations
if final_length > 2000:
clear_gpu_cache()
print(f"๐Ÿ’พ Final: {get_gpu_memory_info()}")
return tokenizer.decode(generated[0], skip_special_tokens=True)
def test_context_window_limits(model, tokenizer, max_context_length):
"""Test how much context the model can actually handle"""
print(f"\n๐Ÿงช Testing Context Window Limits (Max: {max_context_length:,} tokens)")
print("="*60)
# Create a long repetitive prompt to test limits
base_text = "This is a test of the context window. " * 20 # ~140 tokens per repeat
# Extended multipliers for better testing
for multiplier in [1, 5, 10, 20, 50, 100, 150]:
# Clear cache before each test
print(f"\n๐Ÿงน Clearing GPU cache before test {multiplier}...")
force_garbage_collection()
test_prompt = base_text * multiplier
token_count = len(tokenizer.encode(test_prompt))
print(f"\n๐Ÿ“ Test prompt length: {token_count:,} tokens")
if token_count > max_context_length:
print(f"โš ๏ธ Exceeds context limit ({max_context_length:,}), skipping...")
continue
print(f"๐Ÿ’พ Before generation: {get_gpu_memory_info()}")
try:
result = generate_text(model, tokenizer, max_context_length,
test_prompt + " In conclusion,", max_length=50, temperature=0.7)
print(f"โœ… Success at {token_count:,} tokens")
print(f"๐Ÿ’พ After generation: {get_gpu_memory_info()}")
# Clear cache after successful test
clear_gpu_cache()
print(f"๐Ÿ’พ After cache clear: {get_gpu_memory_info()}")
except Exception as e:
print(f"โŒ Failed at {token_count:,} tokens: {e}")
print("๐Ÿงน Cleaning up after failure...")
force_garbage_collection()
break
def test_instruction_prompts(model, tokenizer, max_context_length):
print(f"\n๐ŸŽฏ Testing Instruction Following")
print("="*60)
prompts = [
"Complete this sentence in a helpful way: The weather today is",
"Write a short explanation: Why is exercise important?",
"Answer in 2-3 sentences: What is artificial intelligence?",
"Continue this story logically: The scientist walked into the lab and saw"
]
for idx, prompt in enumerate(prompts, 1):
print(f"\n--- Instruction Prompt {idx} ---")
print(f"Prompt: {prompt}")
# Clear cache before each instruction test
if idx > 1: # Not needed for first test
clear_gpu_cache()
output = generate_text(model, tokenizer, max_context_length, prompt, max_length=100)
print(f"Output: {output}")
def test_long_context(model, tokenizer, max_context_length):
print(f"\n๐Ÿ’ฌ Testing Long Context Conversation")
print("="*60)
# Clear cache before long context test
clear_gpu_cache()
prompt = """The following is a conversation between a human and an AI assistant. The AI assistant is helpful, harmless, and honest.
Human: Hello, who are you?
AI: I am a large language model trained to assist you.
Human: What can you do for me?
AI: """
output = generate_text(model, tokenizer, max_context_length, prompt, max_length=200)
print(f"Output: {output}")
def main():
print("๐Ÿš€ MAP-NEO Mini Advanced Text Generation with Context & VRAM Monitoring")
print("="*80)
# Force clear at startup
print("๐Ÿงน Initial system cleanup...")
force_garbage_collection()
# Load model and get context info
model, tokenizer, max_context_length = load_model()
if model is None or tokenizer is None:
print("โŒ Failed to load model or tokenizer.")
return
print(f"\n๐Ÿ”ฅ Model ready! Context window: {max_context_length:,} tokens")
# Run tests with cache management
print("\n" + "="*40 + " TESTS " + "="*40)
# Test 1: Instructions
test_instruction_prompts(model, tokenizer, max_context_length)
force_garbage_collection()
print(f"๐Ÿ’พ After instructions: {get_gpu_memory_info()}")
# Test 2: Long context
test_long_context(model, tokenizer, max_context_length)
force_garbage_collection()
print(f"๐Ÿ’พ After long context: {get_gpu_memory_info()}")
# Test 3: Context limits (most memory intensive)
test_context_window_limits(model, tokenizer, max_context_length)
print(f"\n๐ŸŽ‰ All tests complete!")
print(f"๐Ÿ’พ Final GPU state: {get_gpu_memory_info()}")
if __name__ == "__main__":
main()