|
|
"""finetune_fixed_clean.py - Fixed fine-tuning with better parameters"""
|
|
|
|
|
|
import os
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
|
|
|
from datasets import load_dataset
|
|
|
from model_neo import NeoMini, NeoMiniConfig
|
|
|
import argparse
|
|
|
from tqdm import tqdm
|
|
|
import json
|
|
|
|
|
|
class CleanConversationDataset(Dataset):
|
|
|
def __init__(self, data_path, tokenizer, max_length=1024):
|
|
|
self.tokenizer = tokenizer
|
|
|
self.max_length = max_length
|
|
|
|
|
|
|
|
|
self.data = load_dataset('json', data_files=data_path)['train']
|
|
|
print(f"Loaded {len(self.data)} examples")
|
|
|
|
|
|
|
|
|
self.filtered_data = []
|
|
|
for example in self.data:
|
|
|
instruction = example.get('instruction', '').strip()
|
|
|
output = example.get('output', '').strip()
|
|
|
|
|
|
|
|
|
if (len(instruction) > 10 and len(output) > 10 and
|
|
|
len(instruction) < 500 and len(output) < 500 and
|
|
|
not any(url in output.lower() for url in ['http', 'www', '.com'])):
|
|
|
self.filtered_data.append(example)
|
|
|
|
|
|
print(f"Filtered to {len(self.filtered_data)} high-quality examples")
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.filtered_data)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
example = self.filtered_data[idx]
|
|
|
instruction = example.get('instruction', '').strip()
|
|
|
input_text = example.get('input', '').strip()
|
|
|
output = example.get('output', '').strip()
|
|
|
|
|
|
|
|
|
if input_text:
|
|
|
prompt = f"Human: {instruction}\nInput: {input_text}\nAssistant:"
|
|
|
else:
|
|
|
prompt = f"Human: {instruction}\nAssistant:"
|
|
|
|
|
|
|
|
|
full_text = f"{prompt} {output}{self.tokenizer.eos_token}"
|
|
|
|
|
|
|
|
|
tokens = self.tokenizer(
|
|
|
full_text,
|
|
|
truncation=True,
|
|
|
max_length=self.max_length,
|
|
|
padding='max_length',
|
|
|
return_tensors='pt'
|
|
|
)
|
|
|
|
|
|
input_ids = tokens['input_ids'].squeeze()
|
|
|
attention_mask = tokens['attention_mask'].squeeze()
|
|
|
|
|
|
|
|
|
prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)['input_ids']
|
|
|
labels = input_ids.clone()
|
|
|
labels[:len(prompt_tokens)] = -100
|
|
|
|
|
|
return {
|
|
|
'input_ids': input_ids,
|
|
|
'attention_mask': attention_mask,
|
|
|
'labels': labels
|
|
|
}
|
|
|
|
|
|
class SimpleTrainer:
|
|
|
def __init__(self, model, tokenizer, dataset, args):
|
|
|
self.model = model
|
|
|
self.tokenizer = tokenizer
|
|
|
self.dataset = dataset
|
|
|
self.args = args
|
|
|
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
self.model = self.model.to(self.device)
|
|
|
|
|
|
|
|
|
self.optimizer = torch.optim.AdamW(
|
|
|
self.model.parameters(),
|
|
|
lr=args.lr,
|
|
|
weight_decay=0.01,
|
|
|
betas=(0.9, 0.95)
|
|
|
)
|
|
|
|
|
|
self.dataloader = DataLoader(
|
|
|
dataset,
|
|
|
batch_size=args.batch_size,
|
|
|
shuffle=True,
|
|
|
num_workers=0
|
|
|
)
|
|
|
|
|
|
|
|
|
total_steps = len(self.dataloader) * args.epochs
|
|
|
self.scheduler = get_cosine_schedule_with_warmup(
|
|
|
self.optimizer,
|
|
|
num_warmup_steps=50,
|
|
|
num_training_steps=total_steps
|
|
|
)
|
|
|
|
|
|
print(f"Training setup: {total_steps} total steps")
|
|
|
|
|
|
def train(self):
|
|
|
print("\n🎯 Starting CLEAN fine-tuning...")
|
|
|
print("="*50)
|
|
|
|
|
|
self.model.train()
|
|
|
total_loss = 0
|
|
|
step = 0
|
|
|
|
|
|
for epoch in range(self.args.epochs):
|
|
|
print(f"\n📚 Epoch {epoch + 1}/{self.args.epochs}")
|
|
|
epoch_loss = 0
|
|
|
|
|
|
for batch in tqdm(self.dataloader, desc=f"Epoch {epoch + 1}"):
|
|
|
|
|
|
input_ids = batch['input_ids'].to(self.device)
|
|
|
attention_mask = batch['attention_mask'].to(self.device)
|
|
|
labels = batch['labels'].to(self.device)
|
|
|
|
|
|
|
|
|
outputs = self.model(input_ids)
|
|
|
|
|
|
|
|
|
shift_logits = outputs[..., :-1, :].contiguous()
|
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
|
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
|
|
|
|
|
self.optimizer.step()
|
|
|
self.scheduler.step()
|
|
|
self.optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
current_loss = loss.item()
|
|
|
total_loss += current_loss
|
|
|
epoch_loss += current_loss
|
|
|
step += 1
|
|
|
|
|
|
|
|
|
if step % 200 == 0:
|
|
|
self.save_checkpoint(step, current_loss)
|
|
|
|
|
|
avg_loss = epoch_loss / len(self.dataloader)
|
|
|
print(f"Epoch {epoch + 1} completed - Average loss: {avg_loss:.4f}")
|
|
|
|
|
|
|
|
|
if avg_loss < 1.0:
|
|
|
print("✅ Loss converged, stopping early")
|
|
|
break
|
|
|
|
|
|
self.save_final_model()
|
|
|
print(f"✅ Training completed! Final average loss: {total_loss/step:.4f}")
|
|
|
|
|
|
def save_checkpoint(self, step, loss):
|
|
|
os.makedirs(self.args.output_dir, exist_ok=True)
|
|
|
torch.save({
|
|
|
'model_state_dict': self.model.state_dict(),
|
|
|
'step': step,
|
|
|
'loss': loss
|
|
|
}, f"{self.args.output_dir}/checkpoint_step_{step}.pt")
|
|
|
|
|
|
def save_final_model(self):
|
|
|
os.makedirs(self.args.output_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
torch.save({
|
|
|
'model_state_dict': self.model.state_dict(),
|
|
|
'config': vars(self.model.config)
|
|
|
}, f"{self.args.output_dir}/clean_conversational_model.pt")
|
|
|
|
|
|
|
|
|
self.tokenizer.save_pretrained(self.args.output_dir)
|
|
|
|
|
|
print(f"✅ Clean model saved to {self.args.output_dir}")
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument('--dataset', type=str, default='data/conversation_final/conversation_train.jsonl')
|
|
|
parser.add_argument('--output_dir', type=str, default='clean_conversational_neo')
|
|
|
parser.add_argument('--epochs', type=int, default=2)
|
|
|
parser.add_argument('--batch_size', type=int, default=2)
|
|
|
parser.add_argument('--lr', type=float, default=1e-5)
|
|
|
parser.add_argument('--max_length', type=int, default=1024)
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
print("🧹 MAP-NEO Mini CLEAN Conversational Fine-Tuning")
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained('data/tokenizer')
|
|
|
if tokenizer.pad_token is None:
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
|
|
print("Loading original model for stability...")
|
|
|
checkpoint = torch.load('checkpoints/checkpoint_step_99999.pt', map_location='cpu')
|
|
|
config = NeoMiniConfig()
|
|
|
config.max_seq_len = 2048
|
|
|
model = NeoMini(config)
|
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
|
|
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
|
|
|
|
|
|
|
dataset = CleanConversationDataset(args.dataset, tokenizer, args.max_length)
|
|
|
|
|
|
|
|
|
trainer = SimpleTrainer(model, tokenizer, dataset, args)
|
|
|
trainer.train()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
main()
|
|
|
|