Map-NEO / finetune_neo.py
Austin207's picture
Upload folder using huggingface_hub
a683148 verified
"""finetune_fixed_clean.py - Fixed fine-tuning with better parameters"""
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
from datasets import load_dataset
from model_neo import NeoMini, NeoMiniConfig
import argparse
from tqdm import tqdm
import json
class CleanConversationDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=1024): # Reduced length
self.tokenizer = tokenizer
self.max_length = max_length
# Load data
self.data = load_dataset('json', data_files=data_path)['train']
print(f"Loaded {len(self.data)} examples")
# Filter for quality
self.filtered_data = []
for example in self.data:
instruction = example.get('instruction', '').strip()
output = example.get('output', '').strip()
# Quality filters
if (len(instruction) > 10 and len(output) > 10 and
len(instruction) < 500 and len(output) < 500 and
not any(url in output.lower() for url in ['http', 'www', '.com'])):
self.filtered_data.append(example)
print(f"Filtered to {len(self.filtered_data)} high-quality examples")
def __len__(self):
return len(self.filtered_data)
def __getitem__(self, idx):
example = self.filtered_data[idx]
instruction = example.get('instruction', '').strip()
input_text = example.get('input', '').strip()
output = example.get('output', '').strip()
# Simple format
if input_text:
prompt = f"Human: {instruction}\nInput: {input_text}\nAssistant:"
else:
prompt = f"Human: {instruction}\nAssistant:"
# Create full sequence: prompt + response + EOS
full_text = f"{prompt} {output}{self.tokenizer.eos_token}"
# Tokenize
tokens = self.tokenizer(
full_text,
truncation=True,
max_length=self.max_length,
padding='max_length',
return_tensors='pt'
)
input_ids = tokens['input_ids'].squeeze()
attention_mask = tokens['attention_mask'].squeeze()
# Create labels - mask prompt tokens
prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)['input_ids']
labels = input_ids.clone()
labels[:len(prompt_tokens)] = -100 # Mask prompt
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels
}
class SimpleTrainer:
def __init__(self, model, tokenizer, dataset, args):
self.model = model
self.tokenizer = tokenizer
self.dataset = dataset
self.args = args
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = self.model.to(self.device)
# Much lower learning rate
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=args.lr,
weight_decay=0.01,
betas=(0.9, 0.95)
)
self.dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=0
)
# Scheduler
total_steps = len(self.dataloader) * args.epochs
self.scheduler = get_cosine_schedule_with_warmup(
self.optimizer,
num_warmup_steps=50,
num_training_steps=total_steps
)
print(f"Training setup: {total_steps} total steps")
def train(self):
print("\n🎯 Starting CLEAN fine-tuning...")
print("="*50)
self.model.train()
total_loss = 0
step = 0
for epoch in range(self.args.epochs):
print(f"\n📚 Epoch {epoch + 1}/{self.args.epochs}")
epoch_loss = 0
for batch in tqdm(self.dataloader, desc=f"Epoch {epoch + 1}"):
# Move to device
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
labels = batch['labels'].to(self.device)
# Forward pass
outputs = self.model(input_ids)
# Calculate loss manually
shift_logits = outputs[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
# Backward pass
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) # Gradient clipping
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
# Track loss
current_loss = loss.item()
total_loss += current_loss
epoch_loss += current_loss
step += 1
# Save checkpoint
if step % 200 == 0:
self.save_checkpoint(step, current_loss)
avg_loss = epoch_loss / len(self.dataloader)
print(f"Epoch {epoch + 1} completed - Average loss: {avg_loss:.4f}")
# Early stopping if loss is very low
if avg_loss < 1.0:
print("✅ Loss converged, stopping early")
break
self.save_final_model()
print(f"✅ Training completed! Final average loss: {total_loss/step:.4f}")
def save_checkpoint(self, step, loss):
os.makedirs(self.args.output_dir, exist_ok=True)
torch.save({
'model_state_dict': self.model.state_dict(),
'step': step,
'loss': loss
}, f"{self.args.output_dir}/checkpoint_step_{step}.pt")
def save_final_model(self):
os.makedirs(self.args.output_dir, exist_ok=True)
# Save model
torch.save({
'model_state_dict': self.model.state_dict(),
'config': vars(self.model.config)
}, f"{self.args.output_dir}/clean_conversational_model.pt")
# Save tokenizer
self.tokenizer.save_pretrained(self.args.output_dir)
print(f"✅ Clean model saved to {self.args.output_dir}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='data/conversation_final/conversation_train.jsonl')
parser.add_argument('--output_dir', type=str, default='clean_conversational_neo')
parser.add_argument('--epochs', type=int, default=2) # Fewer epochs
parser.add_argument('--batch_size', type=int, default=2) # Slightly larger
parser.add_argument('--lr', type=float, default=1e-5) # Much lower LR
parser.add_argument('--max_length', type=int, default=1024) # Shorter sequences
args = parser.parse_args()
print("🧹 MAP-NEO Mini CLEAN Conversational Fine-Tuning")
print("="*60)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained('data/tokenizer')
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load model (use original, not extended for stability)
print("Loading original model for stability...")
checkpoint = torch.load('checkpoints/checkpoint_step_99999.pt', map_location='cpu')
config = NeoMiniConfig()
config.max_seq_len = 2048 # Standard context
model = NeoMini(config)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# Load clean dataset
dataset = CleanConversationDataset(args.dataset, tokenizer, args.max_length)
# Train
trainer = SimpleTrainer(model, tokenizer, dataset, args)
trainer.train()
if __name__ == '__main__':
main()