|
|
"""
|
|
|
interactive_chat.py - Fully Customizable MAP-NEO Mini Interactive Chat Interface
|
|
|
Features: Real-time parameter tuning, conversation memory, context management, multiple responses
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
from transformers import AutoTokenizer
|
|
|
from model_neo import NeoMini, NeoMiniConfig
|
|
|
import os
|
|
|
import json
|
|
|
import time
|
|
|
from pathlib import Path
|
|
|
from datetime import datetime
|
|
|
import gc
|
|
|
|
|
|
class InteractiveChat:
|
|
|
def __init__(self, checkpoint_path="checkpoints/extended_context_model.pt"):
|
|
|
self.model = None
|
|
|
self.tokenizer = None
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
self.conversation_history = []
|
|
|
self.max_context_length = 16384
|
|
|
|
|
|
|
|
|
self.params = {
|
|
|
'temperature': 0.7,
|
|
|
'top_k': 50,
|
|
|
'top_p': 0.9,
|
|
|
'repetition_penalty': 1.1,
|
|
|
'max_length': 150,
|
|
|
'do_sample': True,
|
|
|
'num_responses': 1
|
|
|
}
|
|
|
|
|
|
print("π MAP-NEO Mini Interactive Chat Interface")
|
|
|
print("=" * 60)
|
|
|
self.load_model(checkpoint_path)
|
|
|
|
|
|
def clear_gpu_cache(self):
|
|
|
"""Clear GPU memory cache"""
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.empty_cache()
|
|
|
torch.cuda.synchronize()
|
|
|
gc.collect()
|
|
|
|
|
|
def get_memory_usage(self):
|
|
|
"""Get current GPU memory usage"""
|
|
|
if not torch.cuda.is_available():
|
|
|
return "CPU only"
|
|
|
|
|
|
allocated = torch.cuda.memory_allocated(0) / 1024**3
|
|
|
cached = torch.cuda.memory_reserved(0) / 1024**3
|
|
|
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
|
|
return f"{allocated:.2f}GB/{total:.2f}GB (cached: {cached:.2f}GB)"
|
|
|
|
|
|
def load_model(self, checkpoint_path):
|
|
|
"""Load model and tokenizer"""
|
|
|
print(f"π Loading model from {checkpoint_path}...")
|
|
|
|
|
|
if not os.path.exists(checkpoint_path):
|
|
|
print(f"β Checkpoint not found: {checkpoint_path}")
|
|
|
return False
|
|
|
|
|
|
try:
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
|
|
|
|
|
|
|
if 'config' in checkpoint:
|
|
|
self.max_context_length = checkpoint['config'].get('max_seq_len', 16384)
|
|
|
|
|
|
|
|
|
config = NeoMiniConfig()
|
|
|
config.max_seq_len = self.max_context_length
|
|
|
self.model = NeoMini(config)
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
self.model.eval()
|
|
|
self.model = self.model.to(self.device)
|
|
|
|
|
|
|
|
|
tokenizer_path = "data/tokenizer"
|
|
|
if Path(tokenizer_path).exists():
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
|
|
else:
|
|
|
print("Using GPT-2 tokenizer as fallback...")
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
|
if self.tokenizer.pad_token is None:
|
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
|
|
print(f"β
Model loaded successfully!")
|
|
|
print(f"π§ Parameters: {self.model.get_num_params():,}")
|
|
|
print(f"π Context window: {self.max_context_length:,} tokens")
|
|
|
print(f"πΎ Memory: {self.get_memory_usage()}")
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"β Error loading model: {e}")
|
|
|
return False
|
|
|
|
|
|
def format_conversation_context(self):
|
|
|
"""Format conversation history for model input"""
|
|
|
if not self.conversation_history:
|
|
|
return ""
|
|
|
|
|
|
context = "The following is a conversation between a human and an AI assistant. The AI assistant is helpful, harmless, and honest.\n\n"
|
|
|
|
|
|
for exchange in self.conversation_history:
|
|
|
context += f"Human: {exchange['human']}\n"
|
|
|
context += f"AI: {exchange['ai']}\n\n"
|
|
|
|
|
|
return context
|
|
|
|
|
|
def generate_response(self, user_input, num_responses=None):
|
|
|
"""Generate AI response(s) to user input"""
|
|
|
if num_responses is None:
|
|
|
num_responses = self.params['num_responses']
|
|
|
|
|
|
|
|
|
context = self.format_conversation_context()
|
|
|
full_prompt = context + f"Human: {user_input}\nAI: "
|
|
|
|
|
|
|
|
|
input_ids = self.tokenizer.encode(full_prompt, return_tensors="pt").to(self.device)
|
|
|
prompt_length = input_ids.size(1)
|
|
|
|
|
|
print(f"π Context: {prompt_length:,}/{self.max_context_length:,} tokens")
|
|
|
|
|
|
if prompt_length >= self.max_context_length:
|
|
|
print("β οΈ Context too long, trimming conversation history...")
|
|
|
self.trim_conversation_history()
|
|
|
context = self.format_conversation_context()
|
|
|
full_prompt = context + f"Human: {user_input}\nAI: "
|
|
|
input_ids = self.tokenizer.encode(full_prompt, return_tensors="pt").to(self.device)
|
|
|
prompt_length = input_ids.size(1)
|
|
|
|
|
|
|
|
|
responses = []
|
|
|
for i in range(num_responses):
|
|
|
print(f"π€ Generating response {i+1}/{num_responses}...")
|
|
|
|
|
|
with torch.no_grad():
|
|
|
generated = input_ids.clone()
|
|
|
max_new_tokens = min(self.params['max_length'], self.max_context_length - prompt_length)
|
|
|
|
|
|
for step in range(max_new_tokens):
|
|
|
logits = self.model(generated)
|
|
|
next_token_logits = logits[0, -1, :] / self.params['temperature']
|
|
|
|
|
|
|
|
|
if self.params['repetition_penalty'] != 1.0:
|
|
|
for token_id in set(generated[0].tolist()):
|
|
|
if next_token_logits[token_id] < 0:
|
|
|
next_token_logits[token_id] *= self.params['repetition_penalty']
|
|
|
else:
|
|
|
next_token_logits[token_id] /= self.params['repetition_penalty']
|
|
|
|
|
|
|
|
|
if self.params['top_k'] > 0:
|
|
|
top_k_logits, _ = torch.topk(next_token_logits, self.params['top_k'])
|
|
|
min_top_k = top_k_logits[-1]
|
|
|
next_token_logits[next_token_logits < min_top_k] = float("-inf")
|
|
|
|
|
|
|
|
|
if self.params['top_p'] < 1.0:
|
|
|
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
|
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
sorted_indices_to_remove = cumulative_probs > self.params['top_p']
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
|
sorted_indices_to_remove[..., 0] = 0
|
|
|
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
|
|
next_token_logits[indices_to_remove] = float("-inf")
|
|
|
|
|
|
|
|
|
if self.params['do_sample']:
|
|
|
probs = torch.softmax(next_token_logits, dim=-1)
|
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
|
else:
|
|
|
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
|
|
|
|
|
generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1)
|
|
|
|
|
|
|
|
|
if next_token.item() == self.tokenizer.eos_token_id:
|
|
|
break
|
|
|
|
|
|
|
|
|
if step > 10:
|
|
|
decoded = self.tokenizer.decode(generated[0][prompt_length:], skip_special_tokens=True)
|
|
|
if decoded.strip().endswith(('.', '!', '?', '\n\n')):
|
|
|
break
|
|
|
|
|
|
|
|
|
full_response = self.tokenizer.decode(generated[0], skip_special_tokens=True)
|
|
|
ai_response = full_response[len(full_prompt):].strip()
|
|
|
|
|
|
|
|
|
if '\nHuman:' in ai_response:
|
|
|
ai_response = ai_response.split('\nHuman:')[0].strip()
|
|
|
|
|
|
responses.append(ai_response)
|
|
|
|
|
|
return responses
|
|
|
|
|
|
def trim_conversation_history(self):
|
|
|
"""Remove oldest conversation turns to fit context"""
|
|
|
while len(self.conversation_history) > 1:
|
|
|
self.conversation_history.pop(0)
|
|
|
context = self.format_conversation_context()
|
|
|
if len(self.tokenizer.encode(context)) < self.max_context_length // 2:
|
|
|
break
|
|
|
print(f"π§Ή Trimmed conversation history to {len(self.conversation_history)} turns")
|
|
|
|
|
|
def update_parameters(self):
|
|
|
"""Interactive parameter adjustment"""
|
|
|
print("\nποΈ Current Generation Parameters:")
|
|
|
for key, value in self.params.items():
|
|
|
print(f" {key}: {value}")
|
|
|
|
|
|
print("\nEnter new values (press Enter to keep current):")
|
|
|
|
|
|
|
|
|
temp_input = input(f"Temperature (0.1-2.0, current: {self.params['temperature']}): ").strip()
|
|
|
if temp_input:
|
|
|
try:
|
|
|
self.params['temperature'] = max(0.1, min(2.0, float(temp_input)))
|
|
|
except ValueError:
|
|
|
print("β Invalid temperature, keeping current value")
|
|
|
|
|
|
|
|
|
topk_input = input(f"Top-k (0-100, current: {self.params['top_k']}): ").strip()
|
|
|
if topk_input:
|
|
|
try:
|
|
|
self.params['top_k'] = max(0, min(100, int(topk_input)))
|
|
|
except ValueError:
|
|
|
print("β Invalid top-k, keeping current value")
|
|
|
|
|
|
|
|
|
topp_input = input(f"Top-p (0.1-1.0, current: {self.params['top_p']}): ").strip()
|
|
|
if topp_input:
|
|
|
try:
|
|
|
self.params['top_p'] = max(0.1, min(1.0, float(topp_input)))
|
|
|
except ValueError:
|
|
|
print("β Invalid top-p, keeping current value")
|
|
|
|
|
|
|
|
|
maxlen_input = input(f"Max length (10-500, current: {self.params['max_length']}): ").strip()
|
|
|
if maxlen_input:
|
|
|
try:
|
|
|
self.params['max_length'] = max(10, min(500, int(maxlen_input)))
|
|
|
except ValueError:
|
|
|
print("β Invalid max length, keeping current value")
|
|
|
|
|
|
|
|
|
num_resp = input(f"Number of responses (1-3, current: {self.params['num_responses']}): ").strip()
|
|
|
if num_resp:
|
|
|
try:
|
|
|
self.params['num_responses'] = max(1, min(3, int(num_resp)))
|
|
|
except ValueError:
|
|
|
print("β Invalid number, keeping current value")
|
|
|
|
|
|
print("β
Parameters updated!")
|
|
|
|
|
|
def save_conversation(self):
|
|
|
"""Save conversation to file"""
|
|
|
if not self.conversation_history:
|
|
|
print("β No conversation to save")
|
|
|
return
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
filename = f"conversation_{timestamp}.json"
|
|
|
|
|
|
conversation_data = {
|
|
|
'timestamp': timestamp,
|
|
|
'model_info': {
|
|
|
'max_context': self.max_context_length,
|
|
|
'parameters': self.model.get_num_params()
|
|
|
},
|
|
|
'generation_params': self.params,
|
|
|
'conversation': self.conversation_history
|
|
|
}
|
|
|
|
|
|
with open(filename, 'w', encoding='utf-8') as f:
|
|
|
json.dump(conversation_data, f, indent=2, ensure_ascii=False)
|
|
|
|
|
|
print(f"πΎ Conversation saved to {filename}")
|
|
|
|
|
|
def load_conversation(self, filename):
|
|
|
"""Load conversation from file"""
|
|
|
try:
|
|
|
with open(filename, 'r', encoding='utf-8') as f:
|
|
|
conversation_data = json.load(f)
|
|
|
|
|
|
self.conversation_history = conversation_data['conversation']
|
|
|
print(f"π Loaded conversation with {len(self.conversation_history)} turns")
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"β Error loading conversation: {e}")
|
|
|
|
|
|
def show_help(self):
|
|
|
"""Show available commands"""
|
|
|
print("\nπ§ Available Commands:")
|
|
|
print(" /help - Show this help message")
|
|
|
print(" /params - Adjust generation parameters")
|
|
|
print(" /clear - Clear conversation history")
|
|
|
print(" /save - Save current conversation")
|
|
|
print(" /load <file> - Load conversation from file")
|
|
|
print(" /memory - Show GPU memory usage")
|
|
|
print(" /context - Show current context usage")
|
|
|
print(" /multi <n> - Generate n responses to next input")
|
|
|
print(" /exit - Exit the chat")
|
|
|
print(" /quit - Exit the chat")
|
|
|
|
|
|
def run(self):
|
|
|
"""Main chat loop"""
|
|
|
if not self.model or not self.tokenizer:
|
|
|
print("β Model not loaded. Exiting.")
|
|
|
return
|
|
|
|
|
|
print(f"\n㪠Chat started! Context window: {self.max_context_length:,} tokens")
|
|
|
print("Type /help for commands, /exit to quit")
|
|
|
print("-" * 60)
|
|
|
|
|
|
while True:
|
|
|
try:
|
|
|
|
|
|
user_input = input("\nπ€ You: ").strip()
|
|
|
|
|
|
if not user_input:
|
|
|
continue
|
|
|
|
|
|
|
|
|
if user_input.startswith('/'):
|
|
|
command = user_input.lower()
|
|
|
|
|
|
if command in ['/exit', '/quit']:
|
|
|
print("π Goodbye!")
|
|
|
break
|
|
|
elif command == '/help':
|
|
|
self.show_help()
|
|
|
elif command == '/params':
|
|
|
self.update_parameters()
|
|
|
elif command == '/clear':
|
|
|
self.conversation_history = []
|
|
|
self.clear_gpu_cache()
|
|
|
print("π§Ή Conversation history cleared")
|
|
|
elif command == '/save':
|
|
|
self.save_conversation()
|
|
|
elif command.startswith('/load '):
|
|
|
filename = command[6:].strip()
|
|
|
self.load_conversation(filename)
|
|
|
elif command == '/memory':
|
|
|
print(f"πΎ GPU Memory: {self.get_memory_usage()}")
|
|
|
elif command == '/context':
|
|
|
context_length = len(self.tokenizer.encode(self.format_conversation_context()))
|
|
|
print(f"π Current context: {context_length:,}/{self.max_context_length:,} tokens")
|
|
|
elif command.startswith('/multi '):
|
|
|
try:
|
|
|
num = int(command[7:].strip())
|
|
|
self.params['num_responses'] = max(1, min(3, num))
|
|
|
print(f"π― Next response will generate {self.params['num_responses']} options")
|
|
|
except ValueError:
|
|
|
print("β Invalid number format")
|
|
|
else:
|
|
|
print("β Unknown command. Type /help for available commands.")
|
|
|
continue
|
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
responses = self.generate_response(user_input)
|
|
|
generation_time = time.time() - start_time
|
|
|
|
|
|
|
|
|
if len(responses) == 1:
|
|
|
print(f"\nπ€ AI: {responses[0]}")
|
|
|
chosen_response = responses[0]
|
|
|
else:
|
|
|
print(f"\nπ€ AI generated {len(responses)} responses:")
|
|
|
for i, response in enumerate(responses, 1):
|
|
|
print(f"\n[{i}] {response}")
|
|
|
|
|
|
while True:
|
|
|
choice = input(f"\nChoose response (1-{len(responses)}, Enter for 1): ").strip()
|
|
|
if not choice:
|
|
|
choice = "1"
|
|
|
try:
|
|
|
choice_idx = int(choice) - 1
|
|
|
if 0 <= choice_idx < len(responses):
|
|
|
chosen_response = responses[choice_idx]
|
|
|
break
|
|
|
else:
|
|
|
print(f"β Invalid choice. Enter 1-{len(responses)}")
|
|
|
except ValueError:
|
|
|
print("β Invalid input. Enter a number.")
|
|
|
|
|
|
|
|
|
self.conversation_history.append({
|
|
|
'human': user_input,
|
|
|
'ai': chosen_response,
|
|
|
'timestamp': datetime.now().isoformat(),
|
|
|
'generation_time': round(generation_time, 2)
|
|
|
})
|
|
|
|
|
|
|
|
|
if self.params['num_responses'] != 1:
|
|
|
self.params['num_responses'] = 1
|
|
|
|
|
|
print(f"β±οΈ Generated in {generation_time:.2f}s | πΎ {self.get_memory_usage()}")
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
print("\n\nπ Chat interrupted. Goodbye!")
|
|
|
break
|
|
|
except Exception as e:
|
|
|
print(f"\nβ Error: {e}")
|
|
|
self.clear_gpu_cache()
|
|
|
|
|
|
def main():
|
|
|
|
|
|
import sys
|
|
|
checkpoint_path = "checkpoints/extended_context_model.pt"
|
|
|
if len(sys.argv) > 1:
|
|
|
checkpoint_path = sys.argv[1]
|
|
|
|
|
|
chat = InteractiveChat(checkpoint_path)
|
|
|
chat.run()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|