"""finetune_fixed_clean.py - Fixed fine-tuning with better parameters""" import os import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from transformers import AutoTokenizer, get_cosine_schedule_with_warmup from datasets import load_dataset from model_neo import NeoMini, NeoMiniConfig import argparse from tqdm import tqdm import json class CleanConversationDataset(Dataset): def __init__(self, data_path, tokenizer, max_length=1024): # Reduced length self.tokenizer = tokenizer self.max_length = max_length # Load data self.data = load_dataset('json', data_files=data_path)['train'] print(f"Loaded {len(self.data)} examples") # Filter for quality self.filtered_data = [] for example in self.data: instruction = example.get('instruction', '').strip() output = example.get('output', '').strip() # Quality filters if (len(instruction) > 10 and len(output) > 10 and len(instruction) < 500 and len(output) < 500 and not any(url in output.lower() for url in ['http', 'www', '.com'])): self.filtered_data.append(example) print(f"Filtered to {len(self.filtered_data)} high-quality examples") def __len__(self): return len(self.filtered_data) def __getitem__(self, idx): example = self.filtered_data[idx] instruction = example.get('instruction', '').strip() input_text = example.get('input', '').strip() output = example.get('output', '').strip() # Simple format if input_text: prompt = f"Human: {instruction}\nInput: {input_text}\nAssistant:" else: prompt = f"Human: {instruction}\nAssistant:" # Create full sequence: prompt + response + EOS full_text = f"{prompt} {output}{self.tokenizer.eos_token}" # Tokenize tokens = self.tokenizer( full_text, truncation=True, max_length=self.max_length, padding='max_length', return_tensors='pt' ) input_ids = tokens['input_ids'].squeeze() attention_mask = tokens['attention_mask'].squeeze() # Create labels - mask prompt tokens prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)['input_ids'] labels = input_ids.clone() labels[:len(prompt_tokens)] = -100 # Mask prompt return { 'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels } class SimpleTrainer: def __init__(self, model, tokenizer, dataset, args): self.model = model self.tokenizer = tokenizer self.dataset = dataset self.args = args self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model = self.model.to(self.device) # Much lower learning rate self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=args.lr, weight_decay=0.01, betas=(0.9, 0.95) ) self.dataloader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=0 ) # Scheduler total_steps = len(self.dataloader) * args.epochs self.scheduler = get_cosine_schedule_with_warmup( self.optimizer, num_warmup_steps=50, num_training_steps=total_steps ) print(f"Training setup: {total_steps} total steps") def train(self): print("\n๐ŸŽฏ Starting CLEAN fine-tuning...") print("="*50) self.model.train() total_loss = 0 step = 0 for epoch in range(self.args.epochs): print(f"\n๐Ÿ“š Epoch {epoch + 1}/{self.args.epochs}") epoch_loss = 0 for batch in tqdm(self.dataloader, desc=f"Epoch {epoch + 1}"): # Move to device input_ids = batch['input_ids'].to(self.device) attention_mask = batch['attention_mask'].to(self.device) labels = batch['labels'].to(self.device) # Forward pass outputs = self.model(input_ids) # Calculate loss manually shift_logits = outputs[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss(ignore_index=-100) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # Backward pass loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) # Gradient clipping self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() # Track loss current_loss = loss.item() total_loss += current_loss epoch_loss += current_loss step += 1 # Save checkpoint if step % 200 == 0: self.save_checkpoint(step, current_loss) avg_loss = epoch_loss / len(self.dataloader) print(f"Epoch {epoch + 1} completed - Average loss: {avg_loss:.4f}") # Early stopping if loss is very low if avg_loss < 1.0: print("โœ… Loss converged, stopping early") break self.save_final_model() print(f"โœ… Training completed! Final average loss: {total_loss/step:.4f}") def save_checkpoint(self, step, loss): os.makedirs(self.args.output_dir, exist_ok=True) torch.save({ 'model_state_dict': self.model.state_dict(), 'step': step, 'loss': loss }, f"{self.args.output_dir}/checkpoint_step_{step}.pt") def save_final_model(self): os.makedirs(self.args.output_dir, exist_ok=True) # Save model torch.save({ 'model_state_dict': self.model.state_dict(), 'config': vars(self.model.config) }, f"{self.args.output_dir}/clean_conversational_model.pt") # Save tokenizer self.tokenizer.save_pretrained(self.args.output_dir) print(f"โœ… Clean model saved to {self.args.output_dir}") def main(): parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, default='data/conversation_final/conversation_train.jsonl') parser.add_argument('--output_dir', type=str, default='clean_conversational_neo') parser.add_argument('--epochs', type=int, default=2) # Fewer epochs parser.add_argument('--batch_size', type=int, default=2) # Slightly larger parser.add_argument('--lr', type=float, default=1e-5) # Much lower LR parser.add_argument('--max_length', type=int, default=1024) # Shorter sequences args = parser.parse_args() print("๐Ÿงน MAP-NEO Mini CLEAN Conversational Fine-Tuning") print("="*60) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained('data/tokenizer') if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load model (use original, not extended for stability) print("Loading original model for stability...") checkpoint = torch.load('checkpoints/checkpoint_step_99999.pt', map_location='cpu') config = NeoMiniConfig() config.max_seq_len = 2048 # Standard context model = NeoMini(config) model.load_state_dict(checkpoint['model_state_dict']) print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") # Load clean dataset dataset = CleanConversationDataset(args.dataset, tokenizer, args.max_length) # Train trainer = SimpleTrainer(model, tokenizer, dataset, args) trainer.train() if __name__ == '__main__': main()