|
|
"""advanced_generate.py - Advanced text generation with instruction prompts, context window info, and GPU monitoring"""
|
|
|
|
|
|
import torch
|
|
|
from transformers import AutoTokenizer
|
|
|
from model_neo import NeoMini, NeoMiniConfig
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
import gc
|
|
|
|
|
|
def clear_gpu_cache():
|
|
|
"""Clear GPU memory cache to free up VRAM"""
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.empty_cache()
|
|
|
torch.cuda.synchronize()
|
|
|
print("๐งน GPU cache cleared")
|
|
|
|
|
|
def force_garbage_collection():
|
|
|
"""Force garbage collection and clear caches"""
|
|
|
gc.collect()
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.empty_cache()
|
|
|
torch.cuda.synchronize()
|
|
|
print("๐๏ธ Garbage collection and cache clearing completed")
|
|
|
|
|
|
def reset_gpu():
|
|
|
"""Quick GPU reset function for interactive use"""
|
|
|
force_garbage_collection()
|
|
|
print(f"๐ GPU reset: {get_gpu_memory_info()}")
|
|
|
|
|
|
def get_gpu_memory_info():
|
|
|
"""Get GPU memory usage information"""
|
|
|
if not torch.cuda.is_available():
|
|
|
return "CUDA not available"
|
|
|
|
|
|
try:
|
|
|
|
|
|
allocated = torch.cuda.memory_allocated(0) / 1024**3
|
|
|
cached = torch.cuda.memory_reserved(0) / 1024**3
|
|
|
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
|
|
|
|
|
return f"GPU Memory: {allocated:.2f}GB allocated, {cached:.2f}GB cached, {total:.2f}GB total"
|
|
|
except Exception as e:
|
|
|
return f"Could not get GPU memory info: {e}"
|
|
|
|
|
|
def load_model(checkpoint_path="checkpoints/extended_context_model.pt"):
|
|
|
print(f"Loading model from {checkpoint_path}...")
|
|
|
|
|
|
|
|
|
print("๐งน Clearing cache before model loading...")
|
|
|
force_garbage_collection()
|
|
|
|
|
|
if not os.path.exists(checkpoint_path):
|
|
|
print(f"โ Checkpoint {checkpoint_path} not found.")
|
|
|
return None, None, None
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location="cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
if 'config' in checkpoint:
|
|
|
max_seq_len = checkpoint['config'].get('max_seq_len', 2048)
|
|
|
else:
|
|
|
max_seq_len = 2048
|
|
|
|
|
|
config = NeoMiniConfig()
|
|
|
config.max_seq_len = max_seq_len
|
|
|
|
|
|
model = NeoMini(config)
|
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
model.eval()
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
model = model.to(device)
|
|
|
|
|
|
tokenizer_path = "data/tokenizer"
|
|
|
if Path(tokenizer_path).exists():
|
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
|
|
else:
|
|
|
print("Tokenizer path not found, fallback to GPT-2 tokenizer.")
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
|
if tokenizer.pad_token is None:
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
print(f"โ
Model loaded on {device}")
|
|
|
print(f"๐ Tokenizer vocab size: {tokenizer.vocab_size:,}")
|
|
|
print(f"๐ง Model parameters: {model.get_num_params():,}")
|
|
|
print(f"๐ Max context window: {max_seq_len:,} tokens")
|
|
|
|
|
|
|
|
|
clear_gpu_cache()
|
|
|
print(f"๐พ After model load: {get_gpu_memory_info()}")
|
|
|
|
|
|
return model, tokenizer, max_seq_len
|
|
|
|
|
|
def generate_text(model, tokenizer, max_context_length, prompt, max_length=100,
|
|
|
temperature=0.4, top_k=20, top_p=0.8, repetition_penalty=1.2):
|
|
|
device = next(model.parameters()).device
|
|
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
|
|
|
|
|
|
|
|
if input_ids.size(1) > 1000:
|
|
|
clear_gpu_cache()
|
|
|
|
|
|
|
|
|
prompt_length = input_ids.size(1)
|
|
|
print(f"๐ Prompt length: {prompt_length:,} tokens")
|
|
|
|
|
|
if prompt_length >= max_context_length:
|
|
|
print(f"โ ๏ธ Warning: Prompt ({prompt_length}) exceeds max context ({max_context_length})")
|
|
|
return "Error: Prompt too long for context window"
|
|
|
|
|
|
|
|
|
available_tokens = max_context_length - prompt_length
|
|
|
if max_length > available_tokens:
|
|
|
print(f"โ ๏ธ Adjusting max_length from {max_length} to {available_tokens} (context limit)")
|
|
|
max_length = available_tokens
|
|
|
|
|
|
print(f"๐ฏ Generating max {max_length} tokens (temp={temperature}, top_k={top_k}, top_p={top_p}, rep_penalty={repetition_penalty})")
|
|
|
print(f"๐พ Before generation: {get_gpu_memory_info()}")
|
|
|
|
|
|
with torch.no_grad():
|
|
|
generated = input_ids
|
|
|
tokens_generated = 0
|
|
|
|
|
|
for step in range(max_length):
|
|
|
|
|
|
if step % 100 == 0 and step > 0:
|
|
|
current_length = generated.size(1)
|
|
|
print(f" ๐ Step {step}: {current_length:,}/{max_context_length:,} tokens")
|
|
|
if current_length > 2000:
|
|
|
clear_gpu_cache()
|
|
|
print(f" {get_gpu_memory_info()}")
|
|
|
|
|
|
logits = model(generated)
|
|
|
next_token_logits = logits[0, -1, :] / temperature
|
|
|
|
|
|
|
|
|
if repetition_penalty != 1.0:
|
|
|
for token_id in set(generated[0].tolist()):
|
|
|
if next_token_logits[token_id] < 0:
|
|
|
next_token_logits[token_id] *= repetition_penalty
|
|
|
else:
|
|
|
next_token_logits[token_id] /= repetition_penalty
|
|
|
|
|
|
|
|
|
if top_k > 0:
|
|
|
top_k_logits, _ = torch.topk(next_token_logits, top_k)
|
|
|
min_top_k = top_k_logits[-1]
|
|
|
next_token_logits[next_token_logits < min_top_k] = float("-inf")
|
|
|
|
|
|
|
|
|
if top_p < 1.0:
|
|
|
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
|
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
|
sorted_indices_to_remove[..., 0] = 0
|
|
|
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
|
|
next_token_logits[indices_to_remove] = float("-inf")
|
|
|
|
|
|
probs = torch.softmax(next_token_logits, dim=-1)
|
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
|
generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1)
|
|
|
tokens_generated += 1
|
|
|
|
|
|
|
|
|
if next_token.item() == tokenizer.eos_token_id:
|
|
|
print(f"๐ Stopped at EOS token (generated {tokens_generated} tokens)")
|
|
|
break
|
|
|
|
|
|
if generated.size(1) >= max_context_length:
|
|
|
print(f"๐ Stopped at max context length {max_context_length:,} (generated {tokens_generated} tokens)")
|
|
|
break
|
|
|
|
|
|
final_length = generated.size(1)
|
|
|
print(f"โ
Generation complete: {final_length:,} total tokens ({tokens_generated} new tokens)")
|
|
|
|
|
|
|
|
|
if final_length > 2000:
|
|
|
clear_gpu_cache()
|
|
|
|
|
|
print(f"๐พ Final: {get_gpu_memory_info()}")
|
|
|
|
|
|
return tokenizer.decode(generated[0], skip_special_tokens=True)
|
|
|
|
|
|
def test_context_window_limits(model, tokenizer, max_context_length):
|
|
|
"""Test how much context the model can actually handle"""
|
|
|
print(f"\n๐งช Testing Context Window Limits (Max: {max_context_length:,} tokens)")
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
base_text = "This is a test of the context window. " * 20
|
|
|
|
|
|
|
|
|
for multiplier in [1, 5, 10, 20, 50, 100, 150]:
|
|
|
|
|
|
print(f"\n๐งน Clearing GPU cache before test {multiplier}...")
|
|
|
force_garbage_collection()
|
|
|
|
|
|
test_prompt = base_text * multiplier
|
|
|
token_count = len(tokenizer.encode(test_prompt))
|
|
|
|
|
|
print(f"\n๐ Test prompt length: {token_count:,} tokens")
|
|
|
|
|
|
if token_count > max_context_length:
|
|
|
print(f"โ ๏ธ Exceeds context limit ({max_context_length:,}), skipping...")
|
|
|
continue
|
|
|
|
|
|
print(f"๐พ Before generation: {get_gpu_memory_info()}")
|
|
|
|
|
|
try:
|
|
|
result = generate_text(model, tokenizer, max_context_length,
|
|
|
test_prompt + " In conclusion,", max_length=50, temperature=0.7)
|
|
|
print(f"โ
Success at {token_count:,} tokens")
|
|
|
print(f"๐พ After generation: {get_gpu_memory_info()}")
|
|
|
|
|
|
|
|
|
clear_gpu_cache()
|
|
|
print(f"๐พ After cache clear: {get_gpu_memory_info()}")
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"โ Failed at {token_count:,} tokens: {e}")
|
|
|
print("๐งน Cleaning up after failure...")
|
|
|
force_garbage_collection()
|
|
|
break
|
|
|
|
|
|
def test_instruction_prompts(model, tokenizer, max_context_length):
|
|
|
print(f"\n๐ฏ Testing Instruction Following")
|
|
|
print("="*60)
|
|
|
|
|
|
prompts = [
|
|
|
"Complete this sentence in a helpful way: The weather today is",
|
|
|
"Write a short explanation: Why is exercise important?",
|
|
|
"Answer in 2-3 sentences: What is artificial intelligence?",
|
|
|
"Continue this story logically: The scientist walked into the lab and saw"
|
|
|
]
|
|
|
|
|
|
for idx, prompt in enumerate(prompts, 1):
|
|
|
print(f"\n--- Instruction Prompt {idx} ---")
|
|
|
print(f"Prompt: {prompt}")
|
|
|
|
|
|
|
|
|
if idx > 1:
|
|
|
clear_gpu_cache()
|
|
|
|
|
|
output = generate_text(model, tokenizer, max_context_length, prompt, max_length=100)
|
|
|
print(f"Output: {output}")
|
|
|
|
|
|
def test_long_context(model, tokenizer, max_context_length):
|
|
|
print(f"\n๐ฌ Testing Long Context Conversation")
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
clear_gpu_cache()
|
|
|
|
|
|
prompt = """The following is a conversation between a human and an AI assistant. The AI assistant is helpful, harmless, and honest.
|
|
|
Human: Hello, who are you?
|
|
|
AI: I am a large language model trained to assist you.
|
|
|
Human: What can you do for me?
|
|
|
AI: """
|
|
|
|
|
|
output = generate_text(model, tokenizer, max_context_length, prompt, max_length=200)
|
|
|
print(f"Output: {output}")
|
|
|
|
|
|
def main():
|
|
|
print("๐ MAP-NEO Mini Advanced Text Generation with Context & VRAM Monitoring")
|
|
|
print("="*80)
|
|
|
|
|
|
|
|
|
print("๐งน Initial system cleanup...")
|
|
|
force_garbage_collection()
|
|
|
|
|
|
|
|
|
model, tokenizer, max_context_length = load_model()
|
|
|
if model is None or tokenizer is None:
|
|
|
print("โ Failed to load model or tokenizer.")
|
|
|
return
|
|
|
|
|
|
print(f"\n๐ฅ Model ready! Context window: {max_context_length:,} tokens")
|
|
|
|
|
|
|
|
|
print("\n" + "="*40 + " TESTS " + "="*40)
|
|
|
|
|
|
|
|
|
test_instruction_prompts(model, tokenizer, max_context_length)
|
|
|
force_garbage_collection()
|
|
|
print(f"๐พ After instructions: {get_gpu_memory_info()}")
|
|
|
|
|
|
|
|
|
test_long_context(model, tokenizer, max_context_length)
|
|
|
force_garbage_collection()
|
|
|
print(f"๐พ After long context: {get_gpu_memory_info()}")
|
|
|
|
|
|
|
|
|
test_context_window_limits(model, tokenizer, max_context_length)
|
|
|
|
|
|
print(f"\n๐ All tests complete!")
|
|
|
print(f"๐พ Final GPU state: {get_gpu_memory_info()}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|