Map-NEO / interactive_chat.py
Austin207's picture
Upload folder using huggingface_hub
a683148 verified
"""
interactive_chat.py - Fully Customizable MAP-NEO Mini Interactive Chat Interface
Features: Real-time parameter tuning, conversation memory, context management, multiple responses
"""
import torch
from transformers import AutoTokenizer
from model_neo import NeoMini, NeoMiniConfig
import os
import json
import time
from pathlib import Path
from datetime import datetime
import gc
class InteractiveChat:
def __init__(self, checkpoint_path="checkpoints/extended_context_model.pt"):
self.model = None
self.tokenizer = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.conversation_history = []
self.max_context_length = 16384
# Generation parameters (fully customizable)
self.params = {
'temperature': 0.7,
'top_k': 50,
'top_p': 0.9,
'repetition_penalty': 1.1,
'max_length': 150,
'do_sample': True,
'num_responses': 1
}
print("πŸš€ MAP-NEO Mini Interactive Chat Interface")
print("=" * 60)
self.load_model(checkpoint_path)
def clear_gpu_cache(self):
"""Clear GPU memory cache"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
def get_memory_usage(self):
"""Get current GPU memory usage"""
if not torch.cuda.is_available():
return "CPU only"
allocated = torch.cuda.memory_allocated(0) / 1024**3
cached = torch.cuda.memory_reserved(0) / 1024**3
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
return f"{allocated:.2f}GB/{total:.2f}GB (cached: {cached:.2f}GB)"
def load_model(self, checkpoint_path):
"""Load model and tokenizer"""
print(f"πŸ“‚ Loading model from {checkpoint_path}...")
if not os.path.exists(checkpoint_path):
print(f"❌ Checkpoint not found: {checkpoint_path}")
return False
try:
checkpoint = torch.load(checkpoint_path, map_location=self.device)
# Get context length from config
if 'config' in checkpoint:
self.max_context_length = checkpoint['config'].get('max_seq_len', 16384)
# Load model
config = NeoMiniConfig()
config.max_seq_len = self.max_context_length
self.model = NeoMini(config)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.eval()
self.model = self.model.to(self.device)
# Load tokenizer
tokenizer_path = "data/tokenizer"
if Path(tokenizer_path).exists():
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
else:
print("Using GPT-2 tokenizer as fallback...")
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print(f"βœ… Model loaded successfully!")
print(f"🧠 Parameters: {self.model.get_num_params():,}")
print(f"πŸ“ Context window: {self.max_context_length:,} tokens")
print(f"πŸ’Ύ Memory: {self.get_memory_usage()}")
return True
except Exception as e:
print(f"❌ Error loading model: {e}")
return False
def format_conversation_context(self):
"""Format conversation history for model input"""
if not self.conversation_history:
return ""
context = "The following is a conversation between a human and an AI assistant. The AI assistant is helpful, harmless, and honest.\n\n"
for exchange in self.conversation_history:
context += f"Human: {exchange['human']}\n"
context += f"AI: {exchange['ai']}\n\n"
return context
def generate_response(self, user_input, num_responses=None):
"""Generate AI response(s) to user input"""
if num_responses is None:
num_responses = self.params['num_responses']
# Build full context
context = self.format_conversation_context()
full_prompt = context + f"Human: {user_input}\nAI: "
# Check context length
input_ids = self.tokenizer.encode(full_prompt, return_tensors="pt").to(self.device)
prompt_length = input_ids.size(1)
print(f"πŸ“ Context: {prompt_length:,}/{self.max_context_length:,} tokens")
if prompt_length >= self.max_context_length:
print("⚠️ Context too long, trimming conversation history...")
self.trim_conversation_history()
context = self.format_conversation_context()
full_prompt = context + f"Human: {user_input}\nAI: "
input_ids = self.tokenizer.encode(full_prompt, return_tensors="pt").to(self.device)
prompt_length = input_ids.size(1)
# Generate response(s)
responses = []
for i in range(num_responses):
print(f"πŸ€– Generating response {i+1}/{num_responses}...")
with torch.no_grad():
generated = input_ids.clone()
max_new_tokens = min(self.params['max_length'], self.max_context_length - prompt_length)
for step in range(max_new_tokens):
logits = self.model(generated)
next_token_logits = logits[0, -1, :] / self.params['temperature']
# Apply repetition penalty
if self.params['repetition_penalty'] != 1.0:
for token_id in set(generated[0].tolist()):
if next_token_logits[token_id] < 0:
next_token_logits[token_id] *= self.params['repetition_penalty']
else:
next_token_logits[token_id] /= self.params['repetition_penalty']
# Top-k filtering
if self.params['top_k'] > 0:
top_k_logits, _ = torch.topk(next_token_logits, self.params['top_k'])
min_top_k = top_k_logits[-1]
next_token_logits[next_token_logits < min_top_k] = float("-inf")
# Top-p filtering
if self.params['top_p'] < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > self.params['top_p']
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
next_token_logits[indices_to_remove] = float("-inf")
# Sample next token
if self.params['do_sample']:
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1)
# Check stopping conditions
if next_token.item() == self.tokenizer.eos_token_id:
break
# Check for natural stopping points
if step > 10: # Only check after minimum generation
decoded = self.tokenizer.decode(generated[0][prompt_length:], skip_special_tokens=True)
if decoded.strip().endswith(('.', '!', '?', '\n\n')):
break
# Extract just the AI response
full_response = self.tokenizer.decode(generated[0], skip_special_tokens=True)
ai_response = full_response[len(full_prompt):].strip()
# Clean up response
if '\nHuman:' in ai_response:
ai_response = ai_response.split('\nHuman:')[0].strip()
responses.append(ai_response)
return responses
def trim_conversation_history(self):
"""Remove oldest conversation turns to fit context"""
while len(self.conversation_history) > 1:
self.conversation_history.pop(0)
context = self.format_conversation_context()
if len(self.tokenizer.encode(context)) < self.max_context_length // 2:
break
print(f"🧹 Trimmed conversation history to {len(self.conversation_history)} turns")
def update_parameters(self):
"""Interactive parameter adjustment"""
print("\nπŸŽ›οΈ Current Generation Parameters:")
for key, value in self.params.items():
print(f" {key}: {value}")
print("\nEnter new values (press Enter to keep current):")
# Temperature
temp_input = input(f"Temperature (0.1-2.0, current: {self.params['temperature']}): ").strip()
if temp_input:
try:
self.params['temperature'] = max(0.1, min(2.0, float(temp_input)))
except ValueError:
print("❌ Invalid temperature, keeping current value")
# Top-k
topk_input = input(f"Top-k (0-100, current: {self.params['top_k']}): ").strip()
if topk_input:
try:
self.params['top_k'] = max(0, min(100, int(topk_input)))
except ValueError:
print("❌ Invalid top-k, keeping current value")
# Top-p
topp_input = input(f"Top-p (0.1-1.0, current: {self.params['top_p']}): ").strip()
if topp_input:
try:
self.params['top_p'] = max(0.1, min(1.0, float(topp_input)))
except ValueError:
print("❌ Invalid top-p, keeping current value")
# Max length
maxlen_input = input(f"Max length (10-500, current: {self.params['max_length']}): ").strip()
if maxlen_input:
try:
self.params['max_length'] = max(10, min(500, int(maxlen_input)))
except ValueError:
print("❌ Invalid max length, keeping current value")
# Number of responses
num_resp = input(f"Number of responses (1-3, current: {self.params['num_responses']}): ").strip()
if num_resp:
try:
self.params['num_responses'] = max(1, min(3, int(num_resp)))
except ValueError:
print("❌ Invalid number, keeping current value")
print("βœ… Parameters updated!")
def save_conversation(self):
"""Save conversation to file"""
if not self.conversation_history:
print("❌ No conversation to save")
return
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"conversation_{timestamp}.json"
conversation_data = {
'timestamp': timestamp,
'model_info': {
'max_context': self.max_context_length,
'parameters': self.model.get_num_params()
},
'generation_params': self.params,
'conversation': self.conversation_history
}
with open(filename, 'w', encoding='utf-8') as f:
json.dump(conversation_data, f, indent=2, ensure_ascii=False)
print(f"πŸ’Ύ Conversation saved to {filename}")
def load_conversation(self, filename):
"""Load conversation from file"""
try:
with open(filename, 'r', encoding='utf-8') as f:
conversation_data = json.load(f)
self.conversation_history = conversation_data['conversation']
print(f"πŸ“‚ Loaded conversation with {len(self.conversation_history)} turns")
except Exception as e:
print(f"❌ Error loading conversation: {e}")
def show_help(self):
"""Show available commands"""
print("\nπŸ”§ Available Commands:")
print(" /help - Show this help message")
print(" /params - Adjust generation parameters")
print(" /clear - Clear conversation history")
print(" /save - Save current conversation")
print(" /load <file> - Load conversation from file")
print(" /memory - Show GPU memory usage")
print(" /context - Show current context usage")
print(" /multi <n> - Generate n responses to next input")
print(" /exit - Exit the chat")
print(" /quit - Exit the chat")
def run(self):
"""Main chat loop"""
if not self.model or not self.tokenizer:
print("❌ Model not loaded. Exiting.")
return
print(f"\nπŸ’¬ Chat started! Context window: {self.max_context_length:,} tokens")
print("Type /help for commands, /exit to quit")
print("-" * 60)
while True:
try:
# Get user input
user_input = input("\nπŸ‘€ You: ").strip()
if not user_input:
continue
# Handle commands
if user_input.startswith('/'):
command = user_input.lower()
if command in ['/exit', '/quit']:
print("πŸ‘‹ Goodbye!")
break
elif command == '/help':
self.show_help()
elif command == '/params':
self.update_parameters()
elif command == '/clear':
self.conversation_history = []
self.clear_gpu_cache()
print("🧹 Conversation history cleared")
elif command == '/save':
self.save_conversation()
elif command.startswith('/load '):
filename = command[6:].strip()
self.load_conversation(filename)
elif command == '/memory':
print(f"πŸ’Ύ GPU Memory: {self.get_memory_usage()}")
elif command == '/context':
context_length = len(self.tokenizer.encode(self.format_conversation_context()))
print(f"πŸ“ Current context: {context_length:,}/{self.max_context_length:,} tokens")
elif command.startswith('/multi '):
try:
num = int(command[7:].strip())
self.params['num_responses'] = max(1, min(3, num))
print(f"🎯 Next response will generate {self.params['num_responses']} options")
except ValueError:
print("❌ Invalid number format")
else:
print("❌ Unknown command. Type /help for available commands.")
continue
# Generate response(s)
start_time = time.time()
responses = self.generate_response(user_input)
generation_time = time.time() - start_time
# Display response(s)
if len(responses) == 1:
print(f"\nπŸ€– AI: {responses[0]}")
chosen_response = responses[0]
else:
print(f"\nπŸ€– AI generated {len(responses)} responses:")
for i, response in enumerate(responses, 1):
print(f"\n[{i}] {response}")
while True:
choice = input(f"\nChoose response (1-{len(responses)}, Enter for 1): ").strip()
if not choice:
choice = "1"
try:
choice_idx = int(choice) - 1
if 0 <= choice_idx < len(responses):
chosen_response = responses[choice_idx]
break
else:
print(f"❌ Invalid choice. Enter 1-{len(responses)}")
except ValueError:
print("❌ Invalid input. Enter a number.")
# Add to conversation history
self.conversation_history.append({
'human': user_input,
'ai': chosen_response,
'timestamp': datetime.now().isoformat(),
'generation_time': round(generation_time, 2)
})
# Reset num_responses if it was changed
if self.params['num_responses'] != 1:
self.params['num_responses'] = 1
print(f"⏱️ Generated in {generation_time:.2f}s | πŸ’Ύ {self.get_memory_usage()}")
except KeyboardInterrupt:
print("\n\nπŸ‘‹ Chat interrupted. Goodbye!")
break
except Exception as e:
print(f"\n❌ Error: {e}")
self.clear_gpu_cache()
def main():
# Allow custom checkpoint path
import sys
checkpoint_path = "checkpoints/extended_context_model.pt"
if len(sys.argv) > 1:
checkpoint_path = sys.argv[1]
chat = InteractiveChat(checkpoint_path)
chat.run()
if __name__ == "__main__":
main()