"""advanced_generate.py - Advanced text generation with instruction prompts, context window info, and GPU monitoring""" import torch from transformers import AutoTokenizer from model_neo import NeoMini, NeoMiniConfig import os from pathlib import Path import gc def clear_gpu_cache(): """Clear GPU memory cache to free up VRAM""" if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() print("๐Ÿงน GPU cache cleared") def force_garbage_collection(): """Force garbage collection and clear caches""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() print("๐Ÿ—‘๏ธ Garbage collection and cache clearing completed") def reset_gpu(): """Quick GPU reset function for interactive use""" force_garbage_collection() print(f"๐Ÿ”„ GPU reset: {get_gpu_memory_info()}") def get_gpu_memory_info(): """Get GPU memory usage information""" if not torch.cuda.is_available(): return "CUDA not available" try: # Get current GPU memory usage allocated = torch.cuda.memory_allocated(0) / 1024**3 # Convert to GB cached = torch.cuda.memory_reserved(0) / 1024**3 total = torch.cuda.get_device_properties(0).total_memory / 1024**3 return f"GPU Memory: {allocated:.2f}GB allocated, {cached:.2f}GB cached, {total:.2f}GB total" except Exception as e: return f"Could not get GPU memory info: {e}" def load_model(checkpoint_path="checkpoints/extended_context_model.pt"): print(f"Loading model from {checkpoint_path}...") # Clear cache before loading print("๐Ÿงน Clearing cache before model loading...") force_garbage_collection() if not os.path.exists(checkpoint_path): print(f"โŒ Checkpoint {checkpoint_path} not found.") return None, None, None checkpoint = torch.load(checkpoint_path, map_location="cuda" if torch.cuda.is_available() else "cpu") # Get config from checkpoint or use default if 'config' in checkpoint: max_seq_len = checkpoint['config'].get('max_seq_len', 2048) else: max_seq_len = 2048 # fallback config = NeoMiniConfig() config.max_seq_len = max_seq_len model = NeoMini(config) model.load_state_dict(checkpoint['model_state_dict']) model.eval() device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) tokenizer_path = "data/tokenizer" if Path(tokenizer_path).exists(): tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) else: print("Tokenizer path not found, fallback to GPT-2 tokenizer.") tokenizer = AutoTokenizer.from_pretrained("gpt2") if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print(f"โœ… Model loaded on {device}") print(f"๐Ÿ“Š Tokenizer vocab size: {tokenizer.vocab_size:,}") print(f"๐Ÿง  Model parameters: {model.get_num_params():,}") print(f"๐Ÿ“ Max context window: {max_seq_len:,} tokens") # Clear cache after model loading clear_gpu_cache() print(f"๐Ÿ’พ After model load: {get_gpu_memory_info()}") return model, tokenizer, max_seq_len def generate_text(model, tokenizer, max_context_length, prompt, max_length=100, temperature=0.4, top_k=20, top_p=0.8, repetition_penalty=1.2): device = next(model.parameters()).device input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) # Clear cache for long contexts if input_ids.size(1) > 1000: clear_gpu_cache() # Check initial prompt length prompt_length = input_ids.size(1) print(f"๐Ÿ“ Prompt length: {prompt_length:,} tokens") if prompt_length >= max_context_length: print(f"โš ๏ธ Warning: Prompt ({prompt_length}) exceeds max context ({max_context_length})") return "Error: Prompt too long for context window" # Adjust max_length if needed available_tokens = max_context_length - prompt_length if max_length > available_tokens: print(f"โš ๏ธ Adjusting max_length from {max_length} to {available_tokens} (context limit)") max_length = available_tokens print(f"๐ŸŽฏ Generating max {max_length} tokens (temp={temperature}, top_k={top_k}, top_p={top_p}, rep_penalty={repetition_penalty})") print(f"๐Ÿ’พ Before generation: {get_gpu_memory_info()}") with torch.no_grad(): generated = input_ids tokens_generated = 0 for step in range(max_length): # Check memory and clear cache periodically for long generations if step % 100 == 0 and step > 0: current_length = generated.size(1) print(f" ๐Ÿ“Š Step {step}: {current_length:,}/{max_context_length:,} tokens") if current_length > 2000: # Clear cache for very long contexts clear_gpu_cache() print(f" {get_gpu_memory_info()}") logits = model(generated) next_token_logits = logits[0, -1, :] / temperature # Repetition penalty if repetition_penalty != 1.0: for token_id in set(generated[0].tolist()): if next_token_logits[token_id] < 0: next_token_logits[token_id] *= repetition_penalty else: next_token_logits[token_id] /= repetition_penalty # Top-k filtering if top_k > 0: top_k_logits, _ = torch.topk(next_token_logits, top_k) min_top_k = top_k_logits[-1] next_token_logits[next_token_logits < min_top_k] = float("-inf") # Top-p filtering if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] next_token_logits[indices_to_remove] = float("-inf") probs = torch.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1) tokens_generated += 1 # Check stopping conditions if next_token.item() == tokenizer.eos_token_id: print(f"๐Ÿ›‘ Stopped at EOS token (generated {tokens_generated} tokens)") break if generated.size(1) >= max_context_length: print(f"๐Ÿ›‘ Stopped at max context length {max_context_length:,} (generated {tokens_generated} tokens)") break final_length = generated.size(1) print(f"โœ… Generation complete: {final_length:,} total tokens ({tokens_generated} new tokens)") # Clear cache after long generations if final_length > 2000: clear_gpu_cache() print(f"๐Ÿ’พ Final: {get_gpu_memory_info()}") return tokenizer.decode(generated[0], skip_special_tokens=True) def test_context_window_limits(model, tokenizer, max_context_length): """Test how much context the model can actually handle""" print(f"\n๐Ÿงช Testing Context Window Limits (Max: {max_context_length:,} tokens)") print("="*60) # Create a long repetitive prompt to test limits base_text = "This is a test of the context window. " * 20 # ~140 tokens per repeat # Extended multipliers for better testing for multiplier in [1, 5, 10, 20, 50, 100, 150]: # Clear cache before each test print(f"\n๐Ÿงน Clearing GPU cache before test {multiplier}...") force_garbage_collection() test_prompt = base_text * multiplier token_count = len(tokenizer.encode(test_prompt)) print(f"\n๐Ÿ“ Test prompt length: {token_count:,} tokens") if token_count > max_context_length: print(f"โš ๏ธ Exceeds context limit ({max_context_length:,}), skipping...") continue print(f"๐Ÿ’พ Before generation: {get_gpu_memory_info()}") try: result = generate_text(model, tokenizer, max_context_length, test_prompt + " In conclusion,", max_length=50, temperature=0.7) print(f"โœ… Success at {token_count:,} tokens") print(f"๐Ÿ’พ After generation: {get_gpu_memory_info()}") # Clear cache after successful test clear_gpu_cache() print(f"๐Ÿ’พ After cache clear: {get_gpu_memory_info()}") except Exception as e: print(f"โŒ Failed at {token_count:,} tokens: {e}") print("๐Ÿงน Cleaning up after failure...") force_garbage_collection() break def test_instruction_prompts(model, tokenizer, max_context_length): print(f"\n๐ŸŽฏ Testing Instruction Following") print("="*60) prompts = [ "Complete this sentence in a helpful way: The weather today is", "Write a short explanation: Why is exercise important?", "Answer in 2-3 sentences: What is artificial intelligence?", "Continue this story logically: The scientist walked into the lab and saw" ] for idx, prompt in enumerate(prompts, 1): print(f"\n--- Instruction Prompt {idx} ---") print(f"Prompt: {prompt}") # Clear cache before each instruction test if idx > 1: # Not needed for first test clear_gpu_cache() output = generate_text(model, tokenizer, max_context_length, prompt, max_length=100) print(f"Output: {output}") def test_long_context(model, tokenizer, max_context_length): print(f"\n๐Ÿ’ฌ Testing Long Context Conversation") print("="*60) # Clear cache before long context test clear_gpu_cache() prompt = """The following is a conversation between a human and an AI assistant. The AI assistant is helpful, harmless, and honest. Human: Hello, who are you? AI: I am a large language model trained to assist you. Human: What can you do for me? AI: """ output = generate_text(model, tokenizer, max_context_length, prompt, max_length=200) print(f"Output: {output}") def main(): print("๐Ÿš€ MAP-NEO Mini Advanced Text Generation with Context & VRAM Monitoring") print("="*80) # Force clear at startup print("๐Ÿงน Initial system cleanup...") force_garbage_collection() # Load model and get context info model, tokenizer, max_context_length = load_model() if model is None or tokenizer is None: print("โŒ Failed to load model or tokenizer.") return print(f"\n๐Ÿ”ฅ Model ready! Context window: {max_context_length:,} tokens") # Run tests with cache management print("\n" + "="*40 + " TESTS " + "="*40) # Test 1: Instructions test_instruction_prompts(model, tokenizer, max_context_length) force_garbage_collection() print(f"๐Ÿ’พ After instructions: {get_gpu_memory_info()}") # Test 2: Long context test_long_context(model, tokenizer, max_context_length) force_garbage_collection() print(f"๐Ÿ’พ After long context: {get_gpu_memory_info()}") # Test 3: Context limits (most memory intensive) test_context_window_limits(model, tokenizer, max_context_length) print(f"\n๐ŸŽ‰ All tests complete!") print(f"๐Ÿ’พ Final GPU state: {get_gpu_memory_info()}") if __name__ == "__main__": main()