|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
import math
|
|
|
import time
|
|
|
import json
|
|
|
from pathlib import Path
|
|
|
from dataclasses import dataclass
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
|
from accelerate import Accelerator
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
from model_neo import NeoMini, NeoMiniConfig
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class TrainingConfig:
|
|
|
"""Training configuration optimized for 8GB VRAM"""
|
|
|
|
|
|
data_path: str = "data/tokens/packed_1024.txt"
|
|
|
seq_length: int = 1024
|
|
|
|
|
|
|
|
|
model_config_path: Optional[str] = None
|
|
|
|
|
|
|
|
|
batch_size: int = 1
|
|
|
gradient_accumulation_steps: int = 32
|
|
|
max_steps: int = 150000
|
|
|
warmup_steps: int = 3750
|
|
|
|
|
|
|
|
|
resume_from_checkpoint: Optional[str] = "checkpoints/checkpoint_step_15000.pt"
|
|
|
|
|
|
|
|
|
learning_rate: float = 3e-4
|
|
|
weight_decay: float = 0.01
|
|
|
beta1: float = 0.9
|
|
|
beta2: float = 0.95
|
|
|
grad_clip: float = 1.0
|
|
|
|
|
|
|
|
|
mixed_precision: str = "bf16"
|
|
|
gradient_checkpointing: bool = True
|
|
|
|
|
|
|
|
|
log_interval: int = 10
|
|
|
eval_interval: int = 500
|
|
|
save_interval: int = 7500
|
|
|
output_dir: str = "checkpoints"
|
|
|
|
|
|
|
|
|
compile_model: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
class PackedDataset(Dataset):
|
|
|
"""Dataset for pre-tokenized and packed sequences"""
|
|
|
def __init__(self, data_path: str, seq_length: int = 1024):
|
|
|
self.data_path = Path(data_path)
|
|
|
self.seq_length = seq_length
|
|
|
|
|
|
|
|
|
print(f"Loading data from {data_path}...")
|
|
|
with open(self.data_path, 'r', encoding='utf-8') as f:
|
|
|
self.sequences = []
|
|
|
for line in f:
|
|
|
tokens = list(map(int, line.strip().split()))
|
|
|
if len(tokens) == seq_length:
|
|
|
self.sequences.append(tokens)
|
|
|
|
|
|
print(f"Loaded {len(self.sequences)} sequences of length {seq_length}")
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.sequences)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
tokens = self.sequences[idx]
|
|
|
|
|
|
input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
|
|
|
targets = torch.tensor(tokens[1:], dtype=torch.long)
|
|
|
return input_ids, targets
|
|
|
|
|
|
|
|
|
|
|
|
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.1):
|
|
|
"""Cosine learning rate schedule with warmup"""
|
|
|
def lr_lambda(current_step):
|
|
|
if current_step < num_warmup_steps:
|
|
|
return current_step / max(1, num_warmup_steps)
|
|
|
|
|
|
progress = (current_step - num_warmup_steps) / max(1, num_training_steps - num_warmup_steps)
|
|
|
return min_lr_ratio + (1 - min_lr_ratio) * 0.5 * (1 + math.cos(math.pi * progress))
|
|
|
|
|
|
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
|
|
|
|
|
|
|
|
|
|
|
def compute_loss(logits, targets):
|
|
|
"""Compute cross-entropy loss"""
|
|
|
|
|
|
logits_flat = logits.view(-1, logits.size(-1))
|
|
|
targets_flat = targets.view(-1)
|
|
|
|
|
|
loss = nn.functional.cross_entropy(logits_flat, targets_flat, ignore_index=-100)
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(model, optimizer, scheduler, step, loss, config, checkpoint_dir):
|
|
|
"""Save training checkpoint"""
|
|
|
checkpoint_dir = Path(checkpoint_dir)
|
|
|
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
checkpoint = {
|
|
|
'model_state_dict': model.state_dict(),
|
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
|
'scheduler_state_dict': scheduler.state_dict(),
|
|
|
'step': step,
|
|
|
'loss': loss,
|
|
|
'config': config.__dict__
|
|
|
}
|
|
|
|
|
|
|
|
|
checkpoint_path = checkpoint_dir / f"checkpoint_step_{step}.pt"
|
|
|
torch.save(checkpoint, checkpoint_path)
|
|
|
|
|
|
|
|
|
if hasattr(model, 'config'):
|
|
|
config_path = checkpoint_dir / "model_config.json"
|
|
|
with open(config_path, 'w') as f:
|
|
|
json.dump(model.config.to_dict(), f, indent=2)
|
|
|
|
|
|
print(f"Checkpoint saved: {checkpoint_path}")
|
|
|
return checkpoint_path
|
|
|
|
|
|
|
|
|
|
|
|
def load_checkpoint(checkpoint_path, model, optimizer, scheduler):
|
|
|
"""ADDED: Load training checkpoint and resume"""
|
|
|
print(f"Loading checkpoint from {checkpoint_path}...")
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
|
|
|
|
|
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
|
|
|
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
|
|
|
|
|
|
|
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
|
|
|
|
|
|
|
start_step = checkpoint['step']
|
|
|
last_loss = checkpoint['loss']
|
|
|
|
|
|
print(f"โ
Checkpoint loaded successfully!")
|
|
|
print(f" Resuming from step: {start_step}")
|
|
|
print(f" Last loss: {last_loss:.4f}")
|
|
|
|
|
|
return start_step, last_loss
|
|
|
|
|
|
|
|
|
|
|
|
def generate_sample(model, tokenizer, prompt="The future of AI", max_length=100, temperature=0.8):
|
|
|
"""Generate text sample for evaluation"""
|
|
|
model.eval()
|
|
|
device = next(model.parameters()).device
|
|
|
|
|
|
|
|
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for _ in range(max_length):
|
|
|
|
|
|
logits = model(input_ids)
|
|
|
next_token_logits = logits[0, -1, :] / temperature
|
|
|
|
|
|
|
|
|
probs = torch.softmax(next_token_logits, dim=-1)
|
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
|
|
|
|
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
|
|
|
|
|
|
|
|
|
if next_token.item() == tokenizer.eos_token_id:
|
|
|
break
|
|
|
|
|
|
|
|
|
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
|
|
model.train()
|
|
|
return generated_text
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
|
config = TrainingConfig()
|
|
|
|
|
|
|
|
|
accelerator = Accelerator(
|
|
|
mixed_precision=config.mixed_precision,
|
|
|
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
|
|
log_with="tensorboard",
|
|
|
project_dir=config.output_dir
|
|
|
)
|
|
|
|
|
|
|
|
|
output_dir = Path(config.output_dir)
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
print("Loading dataset...")
|
|
|
dataset = PackedDataset(config.data_path, config.seq_length)
|
|
|
dataloader = DataLoader(
|
|
|
dataset,
|
|
|
batch_size=config.batch_size,
|
|
|
shuffle=True,
|
|
|
pin_memory=True,
|
|
|
num_workers=0,
|
|
|
persistent_workers=False
|
|
|
)
|
|
|
|
|
|
|
|
|
print("Creating model...")
|
|
|
if config.model_config_path and Path(config.model_config_path).exists():
|
|
|
model = NeoMini.from_config(config.model_config_path)
|
|
|
else:
|
|
|
model_config = NeoMiniConfig()
|
|
|
model = NeoMini(model_config)
|
|
|
|
|
|
print(f"Model has {model.get_num_params():,} parameters")
|
|
|
|
|
|
|
|
|
if config.gradient_checkpointing:
|
|
|
model.gradient_checkpointing_enable = lambda: None
|
|
|
print("Gradient checkpointing enabled")
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(
|
|
|
model.parameters(),
|
|
|
lr=config.learning_rate,
|
|
|
betas=(config.beta1, config.beta2),
|
|
|
weight_decay=config.weight_decay
|
|
|
)
|
|
|
|
|
|
|
|
|
scheduler = get_cosine_schedule_with_warmup(
|
|
|
optimizer,
|
|
|
num_warmup_steps=config.warmup_steps,
|
|
|
num_training_steps=config.max_steps
|
|
|
)
|
|
|
|
|
|
|
|
|
model, optimizer, dataloader, scheduler = accelerator.prepare(
|
|
|
model, optimizer, dataloader, scheduler
|
|
|
)
|
|
|
|
|
|
|
|
|
start_step = 0
|
|
|
total_loss = 0
|
|
|
|
|
|
if config.resume_from_checkpoint and Path(config.resume_from_checkpoint).exists():
|
|
|
|
|
|
unwrapped_model = accelerator.unwrap_model(model)
|
|
|
start_step, last_loss = load_checkpoint(
|
|
|
config.resume_from_checkpoint,
|
|
|
unwrapped_model,
|
|
|
optimizer,
|
|
|
scheduler
|
|
|
)
|
|
|
total_loss = last_loss * start_step
|
|
|
print(f"๐ Resuming training from step {start_step}")
|
|
|
else:
|
|
|
print("๐ Starting fresh training")
|
|
|
|
|
|
|
|
|
print("Starting training...")
|
|
|
model.train()
|
|
|
|
|
|
log_loss = 0
|
|
|
|
|
|
|
|
|
dataloader_iter = iter(dataloader)
|
|
|
|
|
|
|
|
|
progress_bar = tqdm(range(start_step, config.max_steps), desc="Training")
|
|
|
|
|
|
for step in progress_bar:
|
|
|
|
|
|
try:
|
|
|
batch = next(dataloader_iter)
|
|
|
except StopIteration:
|
|
|
dataloader_iter = iter(dataloader)
|
|
|
batch = next(dataloader_iter)
|
|
|
|
|
|
input_ids, targets = batch
|
|
|
|
|
|
with accelerator.accumulate(model):
|
|
|
|
|
|
logits = model(input_ids)
|
|
|
loss = compute_loss(logits, targets)
|
|
|
|
|
|
|
|
|
accelerator.backward(loss)
|
|
|
|
|
|
|
|
|
if accelerator.sync_gradients:
|
|
|
accelerator.clip_grad_norm_(model.parameters(), config.grad_clip)
|
|
|
|
|
|
|
|
|
optimizer.step()
|
|
|
scheduler.step()
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
total_loss += loss.item()
|
|
|
log_loss += loss.item()
|
|
|
|
|
|
if step % config.log_interval == 0 and step > 0:
|
|
|
avg_loss = log_loss / config.log_interval
|
|
|
lr = scheduler.get_last_lr()[0]
|
|
|
|
|
|
progress_bar.set_postfix({
|
|
|
'loss': f'{avg_loss:.4f}',
|
|
|
'lr': f'{lr:.2e}',
|
|
|
'step': step
|
|
|
})
|
|
|
|
|
|
|
|
|
accelerator.log({
|
|
|
'train_loss': avg_loss,
|
|
|
'learning_rate': lr,
|
|
|
'step': step
|
|
|
}, step=step)
|
|
|
|
|
|
log_loss = 0
|
|
|
|
|
|
|
|
|
if step % config.save_interval == 0 and step > 0:
|
|
|
if accelerator.is_main_process:
|
|
|
|
|
|
unwrapped_model = accelerator.unwrap_model(model)
|
|
|
save_checkpoint(
|
|
|
unwrapped_model, optimizer, scheduler,
|
|
|
step, total_loss / (step + 1 - start_step), config, output_dir
|
|
|
)
|
|
|
|
|
|
|
|
|
if step >= config.max_steps:
|
|
|
break
|
|
|
|
|
|
|
|
|
if accelerator.is_main_process:
|
|
|
unwrapped_model = accelerator.unwrap_model(model)
|
|
|
final_checkpoint = save_checkpoint(
|
|
|
unwrapped_model, optimizer, scheduler,
|
|
|
step, total_loss / (step + 1 - start_step), config, output_dir
|
|
|
)
|
|
|
print(f"Training completed! Final checkpoint: {final_checkpoint}")
|
|
|
|
|
|
accelerator.end_training()
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|