File size: 12,437 Bytes
a683148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
# MAP-NEO Mini Training Script
# Optimized for 8GB VRAM with mixed precision, gradient accumulation, and checkpointing


import os
import math
import time
import json
from pathlib import Path
from dataclasses import dataclass
from typing import Optional


import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from accelerate import Accelerator
from tqdm import tqdm


from model_neo import NeoMini, NeoMiniConfig



@dataclass
class TrainingConfig:
    """Training configuration optimized for 8GB VRAM"""
    # Data
    data_path: str = "data/tokens/packed_1024.txt"
    seq_length: int = 1024
    
    # Model
    model_config_path: Optional[str] = None
    
    # Training
    batch_size: int = 1  # CHANGED: Back to 1 for speed (was 2)
    gradient_accumulation_steps: int = 32  # CHANGED: Back to 32 (was 16)
    max_steps: int = 150000
    warmup_steps: int = 3750
    
    # Resume training
    resume_from_checkpoint: Optional[str] = "checkpoints/checkpoint_step_15000.pt"  # ADDED: Resume from your checkpoint
    
    # Optimization
    learning_rate: float = 3e-4
    weight_decay: float = 0.01
    beta1: float = 0.9
    beta2: float = 0.95
    grad_clip: float = 1.0
    
    # Memory optimization
    mixed_precision: str = "bf16"  # Use bfloat16 for RTX 5070
    gradient_checkpointing: bool = True
    
    # Logging and checkpointing
    log_interval: int = 10
    eval_interval: int = 500
    save_interval: int = 7500
    output_dir: str = "checkpoints"
    
    # Hardware
    compile_model: bool = False  # Disable compilation for now (can cause issues on Windows)



class PackedDataset(Dataset):
    """Dataset for pre-tokenized and packed sequences"""
    def __init__(self, data_path: str, seq_length: int = 1024):
        self.data_path = Path(data_path)
        self.seq_length = seq_length
        
        # Load all sequences into memory (for small datasets)
        print(f"Loading data from {data_path}...")
        with open(self.data_path, 'r', encoding='utf-8') as f:
            self.sequences = []
            for line in f:
                tokens = list(map(int, line.strip().split()))
                if len(tokens) == seq_length:
                    self.sequences.append(tokens)
        
        print(f"Loaded {len(self.sequences)} sequences of length {seq_length}")
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        tokens = self.sequences[idx]
        # Input: tokens[:-1], Target: tokens[1:]
        input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
        targets = torch.tensor(tokens[1:], dtype=torch.long)
        return input_ids, targets



def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.1):
    """Cosine learning rate schedule with warmup"""
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return current_step / max(1, num_warmup_steps)
        
        progress = (current_step - num_warmup_steps) / max(1, num_training_steps - num_warmup_steps)
        return min_lr_ratio + (1 - min_lr_ratio) * 0.5 * (1 + math.cos(math.pi * progress))
    
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)



def compute_loss(logits, targets):
    """Compute cross-entropy loss"""
    # Flatten for loss computation
    logits_flat = logits.view(-1, logits.size(-1))
    targets_flat = targets.view(-1)
    
    loss = nn.functional.cross_entropy(logits_flat, targets_flat, ignore_index=-100)
    return loss



def save_checkpoint(model, optimizer, scheduler, step, loss, config, checkpoint_dir):
    """Save training checkpoint"""
    checkpoint_dir = Path(checkpoint_dir)
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'step': step,
        'loss': loss,
        'config': config.__dict__
    }
    
    # Save checkpoint
    checkpoint_path = checkpoint_dir / f"checkpoint_step_{step}.pt"
    torch.save(checkpoint, checkpoint_path)
    
    # Save model config
    if hasattr(model, 'config'):
        config_path = checkpoint_dir / "model_config.json"
        with open(config_path, 'w') as f:
            json.dump(model.config.to_dict(), f, indent=2)
    
    print(f"Checkpoint saved: {checkpoint_path}")
    return checkpoint_path



def load_checkpoint(checkpoint_path, model, optimizer, scheduler):
    """ADDED: Load training checkpoint and resume"""
    print(f"Loading checkpoint from {checkpoint_path}...")
    
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    # Load model state
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Load optimizer state
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    # Load scheduler state
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # Get training progress
    start_step = checkpoint['step']
    last_loss = checkpoint['loss']
    
    print(f"โœ… Checkpoint loaded successfully!")
    print(f"   Resuming from step: {start_step}")
    print(f"   Last loss: {last_loss:.4f}")
    
    return start_step, last_loss



def generate_sample(model, tokenizer, prompt="The future of AI", max_length=100, temperature=0.8):
    """Generate text sample for evaluation"""
    model.eval()
    device = next(model.parameters()).device
    
    # Encode prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        for _ in range(max_length):
            # Forward pass
            logits = model(input_ids)
            next_token_logits = logits[0, -1, :] / temperature
            
            # Sample next token
            probs = torch.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Append to sequence
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
            
            # Check for EOS or max length
            if next_token.item() == tokenizer.eos_token_id:
                break
    
    # Decode generated text
    generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    model.train()
    return generated_text



def main():
    # Initialize training config
    config = TrainingConfig()
    
    # Setup accelerator for mixed precision and optimization
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=config.output_dir
    )
    
    # Create output directory
    output_dir = Path(config.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Load dataset
    print("Loading dataset...")
    dataset = PackedDataset(config.data_path, config.seq_length)
    dataloader = DataLoader(
        dataset, 
        batch_size=config.batch_size, 
        shuffle=True,
        pin_memory=True,
        num_workers=0,  # CHANGED: Back to 0 for stability (was 2)
        persistent_workers=False  # CHANGED: Disabled for stability (was True)
    )
    
    # Create model
    print("Creating model...")
    if config.model_config_path and Path(config.model_config_path).exists():
        model = NeoMini.from_config(config.model_config_path)
    else:
        model_config = NeoMiniConfig()
        model = NeoMini(model_config)
    
    print(f"Model has {model.get_num_params():,} parameters")
    
    # Enable gradient checkpointing for memory savings
    if config.gradient_checkpointing:
        model.gradient_checkpointing_enable = lambda: None  # Placeholder
        print("Gradient checkpointing enabled")
    
    # Create optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        betas=(config.beta1, config.beta2),
        weight_decay=config.weight_decay
    )
    
    # Create scheduler
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=config.warmup_steps,
        num_training_steps=config.max_steps
    )
    
    # Prepare with accelerator
    model, optimizer, dataloader, scheduler = accelerator.prepare(
        model, optimizer, dataloader, scheduler
    )
    
    # ADDED: Resume from checkpoint if specified
    start_step = 0
    total_loss = 0
    
    if config.resume_from_checkpoint and Path(config.resume_from_checkpoint).exists():
        # Unwrap model for loading (since accelerator wraps it)
        unwrapped_model = accelerator.unwrap_model(model)
        start_step, last_loss = load_checkpoint(
            config.resume_from_checkpoint, 
            unwrapped_model, 
            optimizer, 
            scheduler
        )
        total_loss = last_loss * start_step  # Approximate total loss
        print(f"๐Ÿš€ Resuming training from step {start_step}")
    else:
        print("๐Ÿš€ Starting fresh training")
    
    # Training loop
    print("Starting training...")
    model.train()
    
    log_loss = 0
    
    # Create infinite dataloader
    dataloader_iter = iter(dataloader)
    
    # MODIFIED: Start progress bar from start_step
    progress_bar = tqdm(range(start_step, config.max_steps), desc="Training")
    
    for step in progress_bar:
        # Get batch
        try:
            batch = next(dataloader_iter)
        except StopIteration:
            dataloader_iter = iter(dataloader)
            batch = next(dataloader_iter)
        
        input_ids, targets = batch
        
        with accelerator.accumulate(model):
            # Forward pass
            logits = model(input_ids)
            loss = compute_loss(logits, targets)
            
            # Backward pass
            accelerator.backward(loss)
            
            # Gradient clipping
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(model.parameters(), config.grad_clip)
            
            # Optimizer step
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        # Logging
        total_loss += loss.item()
        log_loss += loss.item()
        
        if step % config.log_interval == 0 and step > 0:
            avg_loss = log_loss / config.log_interval
            lr = scheduler.get_last_lr()[0]
            
            progress_bar.set_postfix({
                'loss': f'{avg_loss:.4f}',
                'lr': f'{lr:.2e}',
                'step': step
            })
            
            # Log to accelerator (tensorboard)
            accelerator.log({
                'train_loss': avg_loss,
                'learning_rate': lr,
                'step': step
            }, step=step)
            
            log_loss = 0
        
        # Checkpointing
        if step % config.save_interval == 0 and step > 0:
            if accelerator.is_main_process:
                # Unwrap model for saving
                unwrapped_model = accelerator.unwrap_model(model)
                save_checkpoint(
                    unwrapped_model, optimizer, scheduler, 
                    step, total_loss / (step + 1 - start_step), config, output_dir  # MODIFIED: Adjusted loss calculation
                )
        
        # Early stopping check
        if step >= config.max_steps:
            break
    
    # Final checkpoint
    if accelerator.is_main_process:
        unwrapped_model = accelerator.unwrap_model(model)
        final_checkpoint = save_checkpoint(
            unwrapped_model, optimizer, scheduler, 
            step, total_loss / (step + 1 - start_step), config, output_dir  # MODIFIED: Adjusted loss calculation
        )
        print(f"Training completed! Final checkpoint: {final_checkpoint}")
    
    accelerator.end_training()



if __name__ == "__main__":
    main()