# MAP-NEO Mini Training Script # Optimized for 8GB VRAM with mixed precision, gradient accumulation, and checkpointing import os import math import time import json from pathlib import Path from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torch.optim.lr_scheduler import CosineAnnealingLR from accelerate import Accelerator from tqdm import tqdm from model_neo import NeoMini, NeoMiniConfig @dataclass class TrainingConfig: """Training configuration optimized for 8GB VRAM""" # Data data_path: str = "data/tokens/packed_1024.txt" seq_length: int = 1024 # Model model_config_path: Optional[str] = None # Training batch_size: int = 1 # CHANGED: Back to 1 for speed (was 2) gradient_accumulation_steps: int = 32 # CHANGED: Back to 32 (was 16) max_steps: int = 150000 warmup_steps: int = 3750 # Resume training resume_from_checkpoint: Optional[str] = "checkpoints/checkpoint_step_15000.pt" # ADDED: Resume from your checkpoint # Optimization learning_rate: float = 3e-4 weight_decay: float = 0.01 beta1: float = 0.9 beta2: float = 0.95 grad_clip: float = 1.0 # Memory optimization mixed_precision: str = "bf16" # Use bfloat16 for RTX 5070 gradient_checkpointing: bool = True # Logging and checkpointing log_interval: int = 10 eval_interval: int = 500 save_interval: int = 7500 output_dir: str = "checkpoints" # Hardware compile_model: bool = False # Disable compilation for now (can cause issues on Windows) class PackedDataset(Dataset): """Dataset for pre-tokenized and packed sequences""" def __init__(self, data_path: str, seq_length: int = 1024): self.data_path = Path(data_path) self.seq_length = seq_length # Load all sequences into memory (for small datasets) print(f"Loading data from {data_path}...") with open(self.data_path, 'r', encoding='utf-8') as f: self.sequences = [] for line in f: tokens = list(map(int, line.strip().split())) if len(tokens) == seq_length: self.sequences.append(tokens) print(f"Loaded {len(self.sequences)} sequences of length {seq_length}") def __len__(self): return len(self.sequences) def __getitem__(self, idx): tokens = self.sequences[idx] # Input: tokens[:-1], Target: tokens[1:] input_ids = torch.tensor(tokens[:-1], dtype=torch.long) targets = torch.tensor(tokens[1:], dtype=torch.long) return input_ids, targets def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.1): """Cosine learning rate schedule with warmup""" def lr_lambda(current_step): if current_step < num_warmup_steps: return current_step / max(1, num_warmup_steps) progress = (current_step - num_warmup_steps) / max(1, num_training_steps - num_warmup_steps) return min_lr_ratio + (1 - min_lr_ratio) * 0.5 * (1 + math.cos(math.pi * progress)) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) def compute_loss(logits, targets): """Compute cross-entropy loss""" # Flatten for loss computation logits_flat = logits.view(-1, logits.size(-1)) targets_flat = targets.view(-1) loss = nn.functional.cross_entropy(logits_flat, targets_flat, ignore_index=-100) return loss def save_checkpoint(model, optimizer, scheduler, step, loss, config, checkpoint_dir): """Save training checkpoint""" checkpoint_dir = Path(checkpoint_dir) checkpoint_dir.mkdir(parents=True, exist_ok=True) checkpoint = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'step': step, 'loss': loss, 'config': config.__dict__ } # Save checkpoint checkpoint_path = checkpoint_dir / f"checkpoint_step_{step}.pt" torch.save(checkpoint, checkpoint_path) # Save model config if hasattr(model, 'config'): config_path = checkpoint_dir / "model_config.json" with open(config_path, 'w') as f: json.dump(model.config.to_dict(), f, indent=2) print(f"Checkpoint saved: {checkpoint_path}") return checkpoint_path def load_checkpoint(checkpoint_path, model, optimizer, scheduler): """ADDED: Load training checkpoint and resume""" print(f"Loading checkpoint from {checkpoint_path}...") checkpoint = torch.load(checkpoint_path, map_location='cpu') # Load model state model.load_state_dict(checkpoint['model_state_dict']) # Load optimizer state optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # Load scheduler state scheduler.load_state_dict(checkpoint['scheduler_state_dict']) # Get training progress start_step = checkpoint['step'] last_loss = checkpoint['loss'] print(f"✅ Checkpoint loaded successfully!") print(f" Resuming from step: {start_step}") print(f" Last loss: {last_loss:.4f}") return start_step, last_loss def generate_sample(model, tokenizer, prompt="The future of AI", max_length=100, temperature=0.8): """Generate text sample for evaluation""" model.eval() device = next(model.parameters()).device # Encode prompt input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) with torch.no_grad(): for _ in range(max_length): # Forward pass logits = model(input_ids) next_token_logits = logits[0, -1, :] / temperature # Sample next token probs = torch.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # Append to sequence input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) # Check for EOS or max length if next_token.item() == tokenizer.eos_token_id: break # Decode generated text generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True) model.train() return generated_text def main(): # Initialize training config config = TrainingConfig() # Setup accelerator for mixed precision and optimization accelerator = Accelerator( mixed_precision=config.mixed_precision, gradient_accumulation_steps=config.gradient_accumulation_steps, log_with="tensorboard", project_dir=config.output_dir ) # Create output directory output_dir = Path(config.output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Load dataset print("Loading dataset...") dataset = PackedDataset(config.data_path, config.seq_length) dataloader = DataLoader( dataset, batch_size=config.batch_size, shuffle=True, pin_memory=True, num_workers=0, # CHANGED: Back to 0 for stability (was 2) persistent_workers=False # CHANGED: Disabled for stability (was True) ) # Create model print("Creating model...") if config.model_config_path and Path(config.model_config_path).exists(): model = NeoMini.from_config(config.model_config_path) else: model_config = NeoMiniConfig() model = NeoMini(model_config) print(f"Model has {model.get_num_params():,} parameters") # Enable gradient checkpointing for memory savings if config.gradient_checkpointing: model.gradient_checkpointing_enable = lambda: None # Placeholder print("Gradient checkpointing enabled") # Create optimizer optimizer = torch.optim.AdamW( model.parameters(), lr=config.learning_rate, betas=(config.beta1, config.beta2), weight_decay=config.weight_decay ) # Create scheduler scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=config.warmup_steps, num_training_steps=config.max_steps ) # Prepare with accelerator model, optimizer, dataloader, scheduler = accelerator.prepare( model, optimizer, dataloader, scheduler ) # ADDED: Resume from checkpoint if specified start_step = 0 total_loss = 0 if config.resume_from_checkpoint and Path(config.resume_from_checkpoint).exists(): # Unwrap model for loading (since accelerator wraps it) unwrapped_model = accelerator.unwrap_model(model) start_step, last_loss = load_checkpoint( config.resume_from_checkpoint, unwrapped_model, optimizer, scheduler ) total_loss = last_loss * start_step # Approximate total loss print(f"🚀 Resuming training from step {start_step}") else: print("🚀 Starting fresh training") # Training loop print("Starting training...") model.train() log_loss = 0 # Create infinite dataloader dataloader_iter = iter(dataloader) # MODIFIED: Start progress bar from start_step progress_bar = tqdm(range(start_step, config.max_steps), desc="Training") for step in progress_bar: # Get batch try: batch = next(dataloader_iter) except StopIteration: dataloader_iter = iter(dataloader) batch = next(dataloader_iter) input_ids, targets = batch with accelerator.accumulate(model): # Forward pass logits = model(input_ids) loss = compute_loss(logits, targets) # Backward pass accelerator.backward(loss) # Gradient clipping if accelerator.sync_gradients: accelerator.clip_grad_norm_(model.parameters(), config.grad_clip) # Optimizer step optimizer.step() scheduler.step() optimizer.zero_grad() # Logging total_loss += loss.item() log_loss += loss.item() if step % config.log_interval == 0 and step > 0: avg_loss = log_loss / config.log_interval lr = scheduler.get_last_lr()[0] progress_bar.set_postfix({ 'loss': f'{avg_loss:.4f}', 'lr': f'{lr:.2e}', 'step': step }) # Log to accelerator (tensorboard) accelerator.log({ 'train_loss': avg_loss, 'learning_rate': lr, 'step': step }, step=step) log_loss = 0 # Checkpointing if step % config.save_interval == 0 and step > 0: if accelerator.is_main_process: # Unwrap model for saving unwrapped_model = accelerator.unwrap_model(model) save_checkpoint( unwrapped_model, optimizer, scheduler, step, total_loss / (step + 1 - start_step), config, output_dir # MODIFIED: Adjusted loss calculation ) # Early stopping check if step >= config.max_steps: break # Final checkpoint if accelerator.is_main_process: unwrapped_model = accelerator.unwrap_model(model) final_checkpoint = save_checkpoint( unwrapped_model, optimizer, scheduler, step, total_loss / (step + 1 - start_step), config, output_dir # MODIFIED: Adjusted loss calculation ) print(f"Training completed! Final checkpoint: {final_checkpoint}") accelerator.end_training() if __name__ == "__main__": main()