#!/usr/bin/env python3 """ Enhanced Foundational Model Implementation with Instruct, Dense/MoE, Interactive Chat, ONNX Export, and Enhanced Logging & Evaluation This file implements a foundational model inspired by GPT-4, Claude-3, Llama-2, and others. New features include: - Instruct mode: an extra dense projection for instruction tuning. - Option for standard dense feedforward or a Mixture-of-Experts (MoE) variant. - Compatibility with GPT-2 tokenization. - Interactive chat mode for multi-turn conversation. - Enhanced training logging with gradient accumulation and checkpointing. - ONNX model export. - Evaluation on a validation dataset (perplexity). - Hidden state visualization (if matplotlib is available). Usage examples (see CLI help for details). """ import os import math import time import random import argparse import urllib.request import json import csv import logging import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm # ============================================================================= # TOKENIZER SETUP (using Hugging Face GPT2TokenizerFast) # ============================================================================= try: from transformers import GPT2TokenizerFast tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") TOKENIZER_NAME = "gpt2" except ImportError: print("transformers not installed. Falling back to tiktoken or simple tokenizer.") try: import tiktoken tokenizer = tiktoken.get_encoding("gpt2") TOKENIZER_NAME = "gpt2" except ImportError: print("tiktoken not installed. Using fallback simple tokenizer.") import numpy as np class SimpleTokenizer: def __init__(self): self.name = "simple" def encode(self, text): arr = np.frombuffer(text.encode('utf-8'), dtype=np.uint8) return arr.tolist() def decode(self, tokens): return bytes(tokens).decode('utf-8', errors='ignore') tokenizer = SimpleTokenizer() TOKENIZER_NAME = "simple" # ============================================================================= # UTILITY FUNCTIONS # ============================================================================= def set_seed(seed: int = 42): random.seed(seed) torch.manual_seed(seed) print(f"Seed set to {seed}") # ----------------------------------------------------------------------------- # ROTARY POSITIONAL EMBEDDINGS # ----------------------------------------------------------------------------- def apply_rotary_pos_emb(q, k): """ Applies rotary positional embeddings (RoPE) to queries and keys. q, k: Tensors of shape (B, n_head, T, head_dim) """ T = q.size(-2) dim = q.size(-1) device = q.device inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).float() / dim)) positions = torch.arange(T, device=device).float() sinusoid_inp = torch.einsum("i,j->ij", positions, inv_freq) # (T, dim/2) sin = torch.sin(sinusoid_inp).unsqueeze(0).unsqueeze(0) # (1, 1, T, dim/2) cos = torch.cos(sinusoid_inp).unsqueeze(0).unsqueeze(0) # (1, 1, T, dim/2) q1, q2 = q[..., :dim//2], q[..., dim//2:] k1, k2 = k[..., :dim//2], k[..., dim//2:] q_rot = torch.cat((q1 * cos - q2 * sin, q2 * cos + q1 * sin), dim=-1) k_rot = torch.cat((k1 * cos - k2 * sin, k2 * cos + k1 * sin), dim=-1) return q_rot, k_rot # ============================================================================= # CONFIGURATION CLASS # ============================================================================= class GPTConfig: def __init__(self, variant='7M', instruct=False, use_moe=False, num_experts=4): self.vocab_size = 50257 self.dropout = 0.1 self.grad_clip = 1.0 # New flags: self.instruct = instruct # Extra dense projection for instruct mode. self.use_moe = use_moe # Use MoE feedforward network. self.num_experts = num_experts # Number of experts in MoE. # Model size parameters by variant: if variant == '7M': self.n_layer = 4; self.n_head = 4; self.n_embd = 128 self.block_size = 128; self.batch_size = 16 self.learning_rate = 3e-4; self.weight_decay = 0.1 self.max_iters = 1000; self.log_interval = 50; self.eval_interval = 200 elif variant == '120M': self.n_layer = 12; self.n_head = 12; self.n_embd = 768 self.block_size = 1024; self.batch_size = 32 self.learning_rate = 6e-4; self.weight_decay = 0.1 self.max_iters = 2000; self.log_interval = 50; self.eval_interval = 200 elif variant == '300M': self.n_layer = 16; self.n_head = 16; self.n_embd = 768 self.block_size = 512; self.batch_size = 64 self.learning_rate = 2e-4; self.weight_decay = 0.1 self.max_iters = 3000; self.log_interval = 100; self.eval_interval = 500 elif variant == '500M': self.n_layer = 24; self.n_head = 16; self.n_embd = 1024 self.block_size = 512; self.batch_size = 64 self.learning_rate = 1.5e-4; self.weight_decay = 0.1 self.max_iters = 4000; self.log_interval = 100; self.eval_interval = 500 elif variant == '700M': self.n_layer = 28; self.n_head = 16; self.n_embd = 1280 self.block_size = 512; self.batch_size = 128 self.learning_rate = 1e-4; self.weight_decay = 0.1 self.max_iters = 5000; self.log_interval = 100; self.eval_interval = 500 elif variant == '1B': self.n_layer = 32; self.n_head = 16; self.n_embd = 1536 self.block_size = 1024; self.batch_size = 128 self.learning_rate = 1e-4; self.weight_decay = 0.1 self.max_iters = 6000; self.log_interval = 100; self.eval_interval = 500 else: print(f"Variant {variant} not recognized. Defaulting to 7M.") self.n_layer = 4; self.n_head = 4; self.n_embd = 128 self.block_size = 128; self.batch_size = 16 self.learning_rate = 3e-4; self.weight_decay = 0.1 self.max_iters = 1000; self.log_interval = 50; self.eval_interval = 200 # Additional training features: self.grad_accum_steps = 1 # Default: no accumulation. self.warmup_iters = 100 # For dynamic LR scheduler. self.print_config(variant) def print_config(self, variant): print("=====================================") print(f"Initializing GPT model with variant: {variant}") print(f"Layers: {self.n_layer}, Heads: {self.n_head}, Embedding Dim: {self.n_embd}") print(f"Block Size: {self.block_size}, Batch Size: {self.batch_size}") print(f"Learning Rate: {self.learning_rate}, Max Iters: {self.max_iters}") if self.instruct: print("Instruct mode: ENABLED") if self.use_moe: print(f"Using MoE in MLP with {self.num_experts} experts") print(f"Gradient Accumulation Steps: {self.grad_accum_steps}") print("=====================================") # ============================================================================= # MODEL ARCHITECTURE (with Rotary Embeddings, SwiGLU, Instruct & MoE options) # ============================================================================= class GPT(nn.Module): def __init__(self, config): super(GPT, self).__init__() self.config = config self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) # No absolute positional embeddings; using RoPE instead. self.drop = nn.Dropout(config.dropout) self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(config.n_embd) # Extra dense layer for instruct mode: if config.instruct: self.instruct_dense = nn.Linear(config.n_embd, config.n_embd) self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): nn.init.normal_(module.weight, mean=0.0, std=0.02) if isinstance(module, nn.Linear) and module.bias is not None: nn.init.zeros_(module.bias) def forward(self, inputs, return_hidden=False, instruct_mode=False): x, targets = inputs # x: (B, T) B, T = x.size() x = self.tok_emb(x) # (B, T, n_embd) x = self.drop(x) hidden_states = [] if return_hidden else None for block in self.blocks: x = block(x) if return_hidden: hidden_states.append(x) x = self.ln_f(x) if self.config.instruct and instruct_mode: x = self.instruct_dense(x) logits = self.head(x) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) if return_hidden: return logits, loss, hidden_states else: return logits, loss class Block(nn.Module): def __init__(self, config): super(Block, self).__init__() self.ln1 = nn.LayerNorm(config.n_embd) self.attn = CausalSelfAttention(config) self.ln2 = nn.LayerNorm(config.n_embd) if config.use_moe: self.mlp = MoEMLP(config) else: self.mlp = MLP(config) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class CausalSelfAttention(nn.Module): def __init__(self, config): super(CausalSelfAttention, self).__init__() assert config.n_embd % config.n_head == 0, "n_embd must be divisible by n_head" self.n_head = config.n_head self.head_dim = config.n_embd // config.n_head self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) self.c_proj = nn.Linear(config.n_embd, config.n_embd) self.dropout = nn.Dropout(config.dropout) self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size)) def forward(self, x): B, T, C = x.size() qkv = self.c_attn(x) # (B, T, 3*C) q, k, v = qkv.split(C, dim=2) q = q.view(B, T, self.n_head, self.head_dim).transpose(1,2) k = k.view(B, T, self.n_head, self.head_dim).transpose(1,2) v = v.view(B, T, self.n_head, self.head_dim).transpose(1,2) q, k = apply_rotary_pos_emb(q, k) att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf')) att = F.softmax(att, dim=-1) att = self.dropout(att) y = att @ v y = y.transpose(1,2).contiguous().view(B, T, C) y = self.c_proj(y) y = self.dropout(y) return y # ------------------------------- # Standard Dense MLP (SwiGLU) # ------------------------------- class MLP(nn.Module): def __init__(self, config): super(MLP, self).__init__() self.fc1 = nn.Linear(config.n_embd, config.n_embd * 2) self.fc2 = nn.Linear(config.n_embd, config.n_embd) self.dropout = nn.Dropout(config.dropout) def forward(self, x): x_proj = self.fc1(x) x1, x2 = x_proj.chunk(2, dim=-1) x = F.silu(x1) * x2 x = self.fc2(x) x = self.dropout(x) return x # ------------------------------- # MoE MLP Components # ------------------------------- class ExpertMLP(nn.Module): def __init__(self, config): super(ExpertMLP, self).__init__() self.fc1 = nn.Linear(config.n_embd, config.n_embd * 2) self.fc2 = nn.Linear(config.n_embd, config.n_embd) self.dropout = nn.Dropout(config.dropout) def forward(self, x): x_proj = self.fc1(x) x1, x2 = x_proj.chunk(2, dim=-1) x = F.silu(x1) * x2 x = self.fc2(x) x = self.dropout(x) return x class MoEMLP(nn.Module): def __init__(self, config): super(MoEMLP, self).__init__() self.num_experts = config.num_experts self.experts = nn.ModuleList([ExpertMLP(config) for _ in range(self.num_experts)]) self.gate = nn.Linear(config.n_embd, self.num_experts) self.dropout = nn.Dropout(config.dropout) def forward(self, x): gate_scores = self.gate(x) # (B, T, num_experts) gate_probs = F.softmax(gate_scores, dim=-1) # Softmax over experts expert_outputs = [expert(x) for expert in self.experts] # List of (B, T, n_embd) expert_outputs = torch.stack(expert_outputs, dim=2) # (B, T, num_experts, n_embd) gate_probs = gate_probs.unsqueeze(-1) # (B, T, num_experts, 1) output = torch.sum(gate_probs * expert_outputs, dim=2) output = self.dropout(output) return output # ============================================================================= # DATA PROCESSING CLASSES # ============================================================================= class DataProcessorLocal: """Tokenize a local .txt file or URL.""" def __init__(self, dataset_name, txt_file): self.dataset_name = dataset_name self.txt_file = txt_file self.data_dir = f"_data_{dataset_name}_" if not os.path.exists(self.data_dir): os.makedirs(self.data_dir) print(f"Created directory: {self.data_dir}") else: print(f"Directory already exists: {self.data_dir}") def process(self): if self.txt_file.startswith("http://") or self.txt_file.startswith("https://"): print(f"Downloading dataset from: {self.txt_file}") with urllib.request.urlopen(self.txt_file) as response: text = response.read().decode('utf-8') else: print(f"Processing dataset from file: {self.txt_file}") with open(self.txt_file, "r", encoding="utf-8") as f: text = f.read() print("Tokenizing text...") tokens = tokenizer.encode(text) print(f"Token count: {len(tokens)}") data_tensor = torch.tensor(tokens, dtype=torch.long) bin_file = os.path.join(self.data_dir, "data.bin") torch.save(data_tensor, bin_file) print(f"Saved tokenized data to {bin_file}") return data_tensor # ============================================================================= # TRAINING CLASSES (with Gradient Accumulation & Enhanced Logging) # ============================================================================= class Trainer: def __init__(self, model, config, train_data, device, instruct_mode=False): self.device = device self.model = model.to(device) self.config = config self.train_data = train_data self.optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay) # Use a cosine scheduler with warmup: self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=config.max_iters, eta_min=1e-6) self.iter_num = 0 self.accum_steps = config.grad_accum_steps self.instruct_mode = instruct_mode # Use instruct mode during training # Set up logging to file logging.basicConfig(filename='training.log', level=logging.INFO, format='%(asctime)s:%(levelname)s:%(message)s') self.metric_log = [] def get_batch(self, data): batch_size = self.config.batch_size block_size = self.config.block_size ix = torch.randint(0, len(data) - block_size - 1, (batch_size,)) x = torch.stack([data[i:i+block_size] for i in ix]).to(self.device) y = torch.stack([data[i+1:i+block_size+1] for i in ix]).to(self.device) return x, y def evaluate(self, data_subset=None): self.model.eval() if data_subset is None: data_subset = self.train_data x, y = self.get_batch(data_subset) with torch.no_grad(): _, loss = self.model((x, y), instruct_mode=self.instruct_mode) perplexity = torch.exp(loss) print(f"[Evaluation] Iter {self.iter_num}: Loss = {loss.item():.4f}, Perplexity = {perplexity.item():.2f}") self.model.train() return loss.item(), perplexity.item() def train(self): print("Starting training...") start_time = time.time() accum_loss = 0.0 for i in tqdm(range(1, self.config.max_iters + 1), desc="Training", unit="iter"): self.iter_num = i x, y = self.get_batch(self.train_data) self.model.train() self.optimizer.zero_grad() # Pass instruct_mode flag to the model logits, loss = self.model((x, y), instruct_mode=self.instruct_mode) loss = loss / self.accum_steps loss.backward() accum_loss += loss.item() if i % self.accum_steps == 0: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip) self.optimizer.step() self.scheduler.step() if i % self.config.log_interval == 0: elapsed = time.time() - start_time current_lr = self.optimizer.param_groups[0]['lr'] log_msg = f"Iter {i}/{self.config.max_iters}: Avg Loss = {accum_loss:.4f}, LR = {current_lr:.2e} (Elapsed: {elapsed:.2f}s)" print(log_msg) logging.info(log_msg) self.metric_log.append([i, accum_loss, current_lr]) accum_loss = 0.0 if i % self.config.eval_interval == 0: self.evaluate() # Save periodic checkpoint every 500 iterations if i % 500 == 0: ckpt_path = f"checkpoint_iter_{i}.pt" torch.save(self.model.state_dict(), ckpt_path) print(f"Checkpoint saved at {ckpt_path}") print("Training completed.") # Save training metrics to CSV with open("training_metrics.csv", "w", newline="") as csvfile: writer = csv.writer(csvfile) writer.writerow(["Iteration", "Loss", "LearningRate"]) writer.writerows(self.metric_log) print("Training metrics saved to training_metrics.csv") def save_checkpoint(self, filename): torch.save(self.model.state_dict(), filename) print(f"Checkpoint saved to {filename}") class Distiller: def __init__(self, teacher_model, student_model, config, train_data, device, temperature=2.0, alpha=0.5): self.device = device self.teacher = teacher_model.to(device) self.student = student_model.to(device) self.config = config self.train_data = train_data self.teacher.eval() # Freeze teacher. self.optimizer = torch.optim.AdamW(self.student.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay) self.temperature = temperature self.alpha = alpha self.iter_num = 0 def get_batch(self, data): batch_size = self.config.batch_size block_size = self.config.block_size ix = torch.randint(0, len(data) - block_size - 1, (batch_size,)) x = torch.stack([data[i:i+block_size] for i in ix]).to(self.device) y = torch.stack([data[i+1:i+block_size+1] for i in ix]).to(self.device) return x, y def train(self): self.student.train() print("Starting distillation training...") start_time = time.time() for i in tqdm(range(1, self.config.max_iters + 1), desc="Distillation", unit="iter"): self.iter_num = i x, y = self.get_batch(self.train_data) self.optimizer.zero_grad() with torch.no_grad(): teacher_out = self.teacher(x) teacher_logits = teacher_out.logits if hasattr(teacher_out, "logits") else teacher_out student_logits, ce_loss = self.student((x, y)) T = self.temperature teacher_probs = F.softmax(teacher_logits / T, dim=-1) student_log_probs = F.log_softmax(student_logits / T, dim=-1) distill_loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean') * (T * T) total_loss = self.alpha * distill_loss + (1 - self.alpha) * ce_loss total_loss.backward() torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.config.grad_clip) self.optimizer.step() if i % self.config.log_interval == 0: elapsed = time.time() - start_time print(f"Iter {i}/{self.config.max_iters}: Total Loss = {total_loss.item():.4f}, CE = {ce_loss.item():.4f}, KL = {distill_loss.item():.4f} (Elapsed: {elapsed:.2f}s)") print("Distillation training completed.") # ============================================================================= # SAMPLING FUNCTIONS # ============================================================================= def sample(model, prompt, config, length=100, temperature=1.0, top_k=None, instruct_mode=False): model.eval() device = next(model.parameters()).device tokens = tokenizer.encode(prompt) tokens = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0) for _ in range(length): idx = tokens[:, -config.block_size:] if tokens.size(1) > config.block_size else tokens logits, _ = model((idx, None), instruct_mode=instruct_mode) logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(logits, top_k) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) tokens = torch.cat([tokens, next_token], dim=1) return tokenizer.decode(tokens[0].tolist()) def sample_with_hidden(model, prompt, config, length=100, temperature=1.0, top_k=None, instruct_mode=False): model.eval() device = next(model.parameters()).device tokens = tokenizer.encode(prompt) tokens = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0) hidden_states_last = None for _ in range(length): idx = tokens[:, -config.block_size:] if tokens.size(1) > config.block_size else tokens logits, _, hidden_states = model((idx, None), return_hidden=True, instruct_mode=instruct_mode) hidden_states_last = hidden_states logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(logits, top_k) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) tokens = torch.cat([tokens, next_token], dim=1) return tokenizer.decode(tokens[0].tolist()), hidden_states_last # ============================================================================= # EVALUATION FUNCTION (Perplexity on Validation Set) # ============================================================================= def evaluate_perplexity(model, config, val_data, device): model.eval() total_loss = 0.0 count = 0 with torch.no_grad(): for i in range(0, len(val_data) - config.block_size, config.batch_size): x = torch.stack([val_data[j:j+config.block_size] for j in range(i, min(i+config.batch_size, len(val_data)-config.block_size))]).to(device) y = torch.stack([val_data[j+1:j+config.block_size+1] for j in range(i, min(i+config.batch_size, len(val_data)-config.block_size))]).to(device) _, loss = model((x, y)) total_loss += loss.item() * x.size(0) count += x.size(0) avg_loss = total_loss / count perplexity = math.exp(avg_loss) print(f"Validation Loss: {avg_loss:.4f}, Perplexity: {perplexity:.2f}") return avg_loss, perplexity # ============================================================================= # ONNX EXPORT FUNCTION # ============================================================================= def export_to_onnx(model, config, filename="model.onnx"): dummy_input = torch.randint(0, config.vocab_size, (1, config.block_size)) torch.onnx.export(model, (dummy_input, None), filename, input_names=["input"], output_names=["output"], dynamic_axes={"input": {1: "seq_len"}, "output": {1: "seq_len"}}) print(f"Model exported to ONNX format at {filename}") # ============================================================================= # INTERACTIVE CHAT MODE # ============================================================================= def interactive_chat(model, config, device, instruct_mode=False): model.eval() print("Entering interactive chat mode. Type 'exit' to quit.") conversation_history = "" while True: user_input = input("User: ") if user_input.strip().lower() == "exit": break conversation_history += f"User: {user_input}\nAssistant: " response = sample(model, conversation_history, config, length=100, temperature=1.0, instruct_mode=instruct_mode) print("Assistant:", response) conversation_history += response + "\n" # ============================================================================= # HIDDEN STATE VISUALIZATION (requires matplotlib) # ============================================================================= def visualize_hidden_states(hidden_states): try: import matplotlib.pyplot as plt except ImportError: print("matplotlib not installed. Skipping hidden state visualization.") return norms = [hs.norm().item() for hs in hidden_states] plt.figure(figsize=(10, 5)) plt.plot(range(len(norms)), norms, marker='o') plt.xlabel("Block Index") plt.ylabel("Hidden State Norm") plt.title("Hidden State Norm per Block") plt.grid(True) plt.show() # ============================================================================= # COMMAND-LINE INTERFACE (CLI) # ============================================================================= def main(): set_seed(42) parser = argparse.ArgumentParser(description="Enhanced Foundational Model CLI with Extra Features") subparsers = parser.add_subparsers(dest="command") # Tokenize from local file/URL parser_tokenize = subparsers.add_parser("tokenize", help="Tokenize a .txt file/URL and save as .bin") parser_tokenize.add_argument("--dataset", type=str, required=True, help="Dataset name (e.g., openwebtext)") parser_tokenize.add_argument("--txt", type=str, required=True, help="Path or URL to .txt file") # HF Tokenize (from HuggingFace dataset) parser_hf_tokenize = subparsers.add_parser("hf_tokenize", help="Tokenize a HuggingFace dataset and save as .bin") parser_hf_tokenize.add_argument("--hf_dataset", type=str, required=True, help="HF dataset name") parser_hf_tokenize.add_argument("--hf_config", type=str, default=None, help="(Optional) HF config name") parser_hf_tokenize.add_argument("--hf_split", type=str, default="train", help="Dataset split") parser_hf_tokenize.add_argument("--hf_text_column", type=str, default="text", help="Text column name") parser_hf_tokenize.add_argument("--dataset", type=str, required=True, help="Local dataset name (folder name)") parser_hf_tokenize.add_argument("--streaming", action="store_true", help="Stream dataset from HF instead of downloading") # Train command parser_train = subparsers.add_parser("train", help="Train the GPT model on tokenized data") parser_train.add_argument("--variant", type=str, default="120M", help="Model variant (e.g., 120M)") parser_train.add_argument("--data", type=str, required=True, help="Path to tokenized .bin file") parser_train.add_argument("--instruct", action="store_true", help="Enable instruct mode") parser_train.add_argument("--moe", action="store_true", help="Enable MoE in MLP layers") parser_train.add_argument("--num_experts", type=int, default=4, help="Number of experts (if using MoE)") parser_train.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation steps") # Sample command parser_sample = subparsers.add_parser("sample", help="Generate text using a trained GPT model") parser_sample.add_argument("--variant", type=str, default="120M", help="Model variant for sampling") parser_sample.add_argument("--data", type=str, required=True, help="Path to tokenized .bin file (for reference)") parser_sample.add_argument("--prompt", type=str, required=True, help="Input prompt") parser_sample.add_argument("--length", type=int, default=100, help="Number of tokens to generate") parser_sample.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature") parser_sample.add_argument("--top_k", type=int, default=None, help="Top-k sampling parameter") parser_sample.add_argument("--model_checkpoint", type=str, default=None, help="Path to model checkpoint") parser_sample.add_argument("--instruct_mode", action="store_true", help="Apply instruct dense layer during sampling") # Analyze command (with hidden state visualization) parser_analyze = subparsers.add_parser("analyze", help="Generate text and analyze hidden state norms") parser_analyze.add_argument("--variant", type=str, default="120M", help="Model variant") parser_analyze.add_argument("--model_checkpoint", type=str, required=True, help="Path to model checkpoint") parser_analyze.add_argument("--prompt", type=str, required=True, help="Input prompt") parser_analyze.add_argument("--length", type=int, default=100, help="Number of tokens") parser_analyze.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature") parser_analyze.add_argument("--top_k", type=int, default=None, help="Top-k sampling parameter") parser_analyze.add_argument("--plot", action="store_true", help="Plot hidden state norms if matplotlib is available") # Export to safetensors parser_export = subparsers.add_parser("export", help="Export a trained model and tokenizer config") parser_export.add_argument("--variant", type=str, default="120M", help="Model variant") parser_export.add_argument("--model_path", type=str, required=True, help="Path to trained .pt model") parser_export.add_argument("--output_dir", type=str, default=None, help="Output directory") # ONNX export command parser_export_onnx = subparsers.add_parser("export_onnx", help="Export the model to ONNX format") parser_export_onnx.add_argument("--variant", type=str, default="120M", help="Model variant") parser_export_onnx.add_argument("--model_checkpoint", type=str, required=True, help="Path to model checkpoint") parser_export_onnx.add_argument("--output_file", type=str, default="model.onnx", help="ONNX output filename") # Distill command parser_distill = subparsers.add_parser("distill", help="Knowledge distillation from teacher to student") parser_distill.add_argument("--teacher_model_path", type=str, required=True, help='Teacher model checkpoint path or "hf"') parser_distill.add_argument("--teacher_variant", type=str, default="120M", help="Teacher variant") parser_distill.add_argument("--student_variant", type=str, default="120M", help="Student variant") parser_distill.add_argument("--data", type=str, required=True, help="Path to tokenized .bin file") parser_distill.add_argument("--temperature", type=float, default=2.0, help="Distillation temperature") parser_distill.add_argument("--alpha", type=float, default=0.5, help="Weight for distillation loss") parser_distill.add_argument("--sample_prompt", type=str, default=None, help="Prompt for sample generation") parser_distill.add_argument("--sample_length", type=int, default=100, help="Tokens to generate") parser_distill.add_argument("--sample_temperature", type=float, default=1.0, help="Sampling temperature") # Finetune command parser_finetune = subparsers.add_parser("finetune", help="Fine-tune on local conversation data") parser_finetune.add_argument("--variant", type=str, default="120M", help="Model variant") parser_finetune.add_argument("--model_checkpoint", type=str, required=True, help="Path to model checkpoint") parser_finetune.add_argument("--data", type=str, required=True, help="Path to fine-tune tokenized .bin file") parser_finetune.add_argument("--finetune_iters", type=int, default=1000, help="Number of iterations") parser_finetune.add_argument("--log_interval", type=int, default=50, help="Log interval") parser_finetune.add_argument("--eval_interval", type=int, default=200, help="Evaluation interval") parser_finetune.add_argument("--prompt", type=str, default=None, help="Prompt after fine-tuning") parser_finetune.add_argument("--sample_length", type=int, default=100, help="Tokens to generate") parser_finetune.add_argument("--sample_temperature", type=float, default=1.0, help="Sampling temperature") # HF Finetune command parser_hf_finetune = subparsers.add_parser("hf_finetune", help="Fine-tune on a HF conversational dataset") parser_hf_finetune.add_argument("--variant", type=str, default="120M", help="Model variant") parser_hf_finetune.add_argument("--model_checkpoint", type=str, required=True, help="Path to model checkpoint") parser_hf_finetune.add_argument("--hf_dataset", type=str, required=True, help="HF dataset name") parser_hf_finetune.add_argument("--hf_config", type=str, default=None, help="(Optional) HF config name") parser_hf_finetune.add_argument("--hf_split", type=str, default="train", help="Dataset split") parser_hf_finetune.add_argument("--hf_text_column", type=str, default="text", help="Text column") parser_hf_finetune.add_argument("--finetune_iters", type=int, default=1000, help="Iterations") parser_hf_finetune.add_argument("--log_interval", type=int, default=50, help="Log interval") parser_hf_finetune.add_argument("--eval_interval", type=int, default=200, help="Evaluation interval") parser_hf_finetune.add_argument("--prompt", type=str, default=None, help="Prompt after fine-tuning") parser_hf_finetune.add_argument("--sample_length", type=int, default=100, help="Tokens to generate") parser_hf_finetune.add_argument("--sample_temperature", type=float, default=1.0, help="Sampling temperature") # Instruct Finetune command (for conversational instruct tuning) parser_instruct_finetune = subparsers.add_parser("instruct_finetune", help="Fine-tune on instruct-style/conversational data from a HF dataset") parser_instruct_finetune.add_argument("--variant", type=str, default="120M", help="Model variant") parser_instruct_finetune.add_argument("--model_checkpoint", type=str, required=True, help="Path to model checkpoint") parser_instruct_finetune.add_argument("--hf_dataset", type=str, required=True, help="HF dataset name") parser_instruct_finetune.add_argument("--hf_config", type=str, default=None, help="(Optional) HF config name") parser_instruct_finetune.add_argument("--hf_split", type=str, default="train", help="Dataset split") # Use a comma-separated list for text columns. # For conversation data, you might use "user,assistant" parser_instruct_finetune.add_argument("--hf_text_column", type=str, default="system_prompt,question,response", help="Comma-separated text columns") parser_instruct_finetune.add_argument("--finetune_iters", type=int, default=1000, help="Iterations") parser_instruct_finetune.add_argument("--log_interval", type=int, default=50, help="Log interval") parser_instruct_finetune.add_argument("--eval_interval", type=int, default=200, help="Evaluation interval") parser_instruct_finetune.add_argument("--prompt", type=str, default=None, help="Prompt after fine-tuning") parser_instruct_finetune.add_argument("--sample_length", type=int, default=100, help="Tokens to generate") parser_instruct_finetune.add_argument("--sample_temperature", type=float, default=1.0, help="Sampling temperature") # Benchmark command parser_bench = subparsers.add_parser("bench", help="Run benchmark evaluations and print results") parser_bench.add_argument("--variant", type=str, default="120M", help="Model variant") parser_bench.add_argument("--model_checkpoint", type=str, required=True, help="Path to model checkpoint") # Instance info command parser_instance = subparsers.add_parser("instance", help="Print recommended Azure VM information") # Compare command parser_compare = subparsers.add_parser("compare", help="Compare our model with external HF models on benchmark prompts") parser_compare.add_argument("--our_variant", type=str, default="120M", help="Our model variant") parser_compare.add_argument("--our_model_checkpoint", type=str, required=True, help="Path to our model checkpoint") parser_compare.add_argument("--llama", type=str, default="meta-llama/Llama-2-1b-hf", help="HF model ID for Llama 2 1B") parser_compare.add_argument("--gpt2", type=str, default="gpt2-xl", help="HF model ID for GPT2 1B") parser_compare.add_argument("--smol", type=str, default="hf-smolLM2-1b", help="HF model ID for HF SmolLM2 1B") parser_compare.add_argument("--qwen", type=str, default="qwen-2-1b", help="HF model ID for Qwen 2 1B") parser_compare.add_argument("--sample_length", type=int, default=100, help="Max tokens for generation") parser_compare.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature") # Evaluate command (perplexity on validation dataset) parser_evaluate = subparsers.add_parser("evaluate", help="Evaluate model perplexity on a validation dataset") parser_evaluate.add_argument("--variant", type=str, default="120M", help="Model variant") parser_evaluate.add_argument("--model_checkpoint", type=str, required=True, help="Path to model checkpoint") parser_evaluate.add_argument("--val_data", type=str, required=True, help="Path to validation tokenized .bin file") # Interactive Chat command parser_chat = subparsers.add_parser("chat", help="Enter interactive chat mode with the model") parser_chat.add_argument("--variant", type=str, default="120M", help="Model variant") parser_chat.add_argument("--model_checkpoint", type=str, required=True, help="Path to model checkpoint") parser_chat.add_argument("--instruct_mode", action="store_true", help="Apply instruct dense layer during chat") args = parser.parse_args() # Force CPU usage (or change to GPU if available) device = torch.device("cpu") torch.set_num_threads(64) print("Using device:", device) # Dispatch commands if args.command == "tokenize": processor = DataProcessorLocal(args.dataset, args.txt) processor.process() elif args.command == "hf_tokenize": try: from datasets import load_dataset except ImportError: print("Please install the datasets library: pip install datasets") return if args.streaming: ds = load_dataset(args.hf_dataset, args.hf_config, split=args.hf_split, streaming=True) if args.hf_config else load_dataset(args.hf_dataset, split=args.hf_split, streaming=True) print("Streaming dataset from HuggingFace...") def process_example(ex): txt = ex[args.hf_text_column] return " ".join(txt) if isinstance(txt, list) else txt all_text = " ".join(process_example(example) for example in ds) else: ds = load_dataset(args.hf_dataset, args.hf_config, split=args.hf_split) if args.hf_config else load_dataset(args.hf_dataset, split=args.hf_split) print(f"Loaded {len(ds)} examples from {args.hf_dataset}.") def process_example(ex): txt = ex[args.hf_text_column] return " ".join(txt) if isinstance(txt, list) else txt all_text = " ".join(process_example(example) for example in ds) tokens = tokenizer.encode(all_text) print(f"Token count: {len(tokens)}") data_tensor = torch.tensor(tokens, dtype=torch.long) data_dir = f"_data_{args.hf_dataset.replace('/', '_')}_" os.makedirs(data_dir, exist_ok=True) bin_file = os.path.join(data_dir, "data.bin") torch.save(data_tensor, bin_file) print(f"Saved tokenized data to {bin_file}") elif args.command == "train": if not os.path.exists(args.data): print(f"Data file {args.data} not found.") return train_data = torch.load(args.data) config = GPTConfig(args.variant, instruct=args.instruct, use_moe=args.moe, num_experts=args.num_experts) config.grad_accum_steps = args.grad_accum model = GPT(config) trainer = Trainer(model, config, train_data, device=device, instruct_mode=args.instruct) trainer.train() model_path = f"{args.variant}_model.pt" torch.save(model.state_dict(), model_path) print(f"Model saved as: {model_path}") elif args.command == "sample": config = GPTConfig(args.variant) model = GPT(config) checkpoint = args.model_checkpoint if args.model_checkpoint else f"{args.variant}_model.pt" if not os.path.exists(checkpoint): print(f"Checkpoint {checkpoint} not found.") return model.load_state_dict(torch.load(checkpoint, map_location=device)) model.to(device) generated = sample(model, args.prompt, config, length=args.length, temperature=args.temperature, top_k=args.top_k, instruct_mode=args.instruct_mode) print("=====================================") print("Generated Text:") print(generated) print("=====================================") elif args.command == "analyze": config = GPTConfig(args.variant) model = GPT(config) if not os.path.exists(args.model_checkpoint): print(f"Checkpoint {args.model_checkpoint} not found.") return model.load_state_dict(torch.load(args.model_checkpoint, map_location=device)) model.to(device) generated, hidden_states = sample_with_hidden(model, args.prompt, config, length=args.length, temperature=args.temperature, top_k=args.top_k) print("=====================================") print("Generated Text:") print(generated) print("\nHidden State Analysis (Mean Norm per Block):") for i, hs in enumerate(hidden_states): norm_val = hs.norm().item() print(f"Block {i}: Norm = {norm_val:.4f}") if args.plot: visualize_hidden_states(hidden_states) print("=====================================") elif args.command == "export": try: from safetensors.torch import save_file as safe_save_file except ImportError: print("Please install safetensors: pip install safetensors") return out_dir = args.output_dir if args.output_dir else f"exported_{args.variant}" os.makedirs(out_dir, exist_ok=True) if not os.path.exists(args.model_path): print(f"Model file {args.model_path} not found.") return state_dict = torch.load(args.model_path, map_location=device) export_path = os.path.join(out_dir, "model.safetensors") safe_save_file(state_dict, export_path) print(f"Model exported to {export_path}") tokenizer_config = {"tokenizer": TOKENIZER_NAME, "vocab_size": 50257 if TOKENIZER_NAME=="gpt2" else None} with open(os.path.join(out_dir, "tokenizer.json"), "w", encoding="utf-8") as f: json.dump(tokenizer_config, f, ensure_ascii=False, indent=2) print("Tokenizer configuration saved.") elif args.command == "export_onnx": config = GPTConfig(args.variant) model = GPT(config) if not os.path.exists(args.model_checkpoint): print(f"Model checkpoint {args.model_checkpoint} not found.") return model.load_state_dict(torch.load(args.model_checkpoint, map_location=device)) model.to(device) export_to_onnx(model, config, filename=args.output_file) elif args.command == "distill": if args.teacher_model_path.lower() == "hf": try: from transformers import GPT2LMHeadModel teacher_model = GPT2LMHeadModel.from_pretrained("gpt2") print("Loaded teacher model from HF.") except ImportError: print("Please install transformers: pip install transformers") return else: teacher_config = GPTConfig(args.teacher_variant) teacher_model = GPT(teacher_config) if not os.path.exists(args.teacher_model_path): print(f"Teacher model {args.teacher_model_path} not found.") return teacher_model.load_state_dict(torch.load(args.teacher_model_path, map_location=device)) teacher_model.eval() print("Loaded teacher model from local checkpoint.") student_config = GPTConfig(args.student_variant) student_model = GPT(student_config) if not os.path.exists(args.data): print(f"Data file {args.data} not found.") return train_data = torch.load(args.data) distiller = Distiller(teacher_model, student_model, student_config, train_data, device=device, temperature=args.temperature, alpha=args.alpha) distiller.train() model_path = f"distilled_{args.student_variant}_model.pt" torch.save(student_model.state_dict(), model_path) print(f"Distilled model saved as: {model_path}") if args.sample_prompt: sample_text = sample(student_model, args.sample_prompt, student_config, length=args.sample_length, temperature=args.sample_temperature) print("=====================================") print("Sample Generated Text:") print(sample_text) print("=====================================") elif args.command == "finetune": if not os.path.exists(args.data): print(f"Fine-tune data file {args.data} not found.") return conversation_data = torch.load(args.data) config = GPTConfig(args.variant) config.max_iters = args.finetune_iters config.log_interval = args.log_interval config.eval_interval = args.eval_interval model = GPT(config) if not os.path.exists(args.model_checkpoint): print(f"Model checkpoint {args.model_checkpoint} not found.") return model.load_state_dict(torch.load(args.model_checkpoint, map_location=device)) print(f"Loaded model from {args.model_checkpoint}.") trainer = Trainer(model, config, conversation_data, device=device) trainer.train() finetuned_path = f"finetuned_{args.variant}_model.pt" torch.save(model.state_dict(), finetuned_path) print(f"Finetuned model saved as: {finetuned_path}") if args.prompt: sample_text = sample(model, args.prompt, config, length=args.sample_length, temperature=args.sample_temperature) print("=====================================") print("Sample Generated Text:") print(sample_text) print("=====================================") elif args.command == "hf_finetune": try: from datasets import load_dataset except ImportError: print("Please install datasets: pip install datasets") return ds = load_dataset(args.hf_dataset, args.hf_config, split=args.hf_split) if args.hf_config else load_dataset(args.hf_dataset, split=args.hf_split) print(f"Loaded {len(ds)} examples from {args.hf_dataset}.") def process_example(ex): txt = ex[args.hf_text_column] return " ".join(txt) if isinstance(txt, list) else txt all_text = " ".join(process_example(example) for example in ds) tokens = tokenizer.encode(all_text) print(f"Token count: {len(tokens)}") conversation_data = torch.tensor(tokens, dtype=torch.long) data_dir = f"_data_{args.hf_dataset.replace('/', '_')}_" os.makedirs(data_dir, exist_ok=True) bin_file = os.path.join(data_dir, "data.bin") torch.save(conversation_data, bin_file) print(f"Saved tokenized dataset to {bin_file}") config = GPTConfig(args.variant) config.max_iters = args.finetune_iters config.log_interval = args.log_interval config.eval_interval = args.eval_interval model = GPT(config) if not os.path.exists(args.model_checkpoint): print(f"Model checkpoint {args.model_checkpoint} not found.") return model.load_state_dict(torch.load(args.model_checkpoint, map_location=device)) print(f"Loaded model from {args.model_checkpoint}.") trainer = Trainer(model, config, conversation_data, device=device) trainer.train() finetuned_path = f"hf_finetuned_{args.variant}_model.pt" torch.save(model.state_dict(), finetuned_path) print(f"Finetuned model saved as: {finetuned_path}") if args.prompt: sample_text = sample(model, args.prompt, config, length=args.sample_length, temperature=args.sample_temperature) print("=====================================") print("Sample Generated Text:") print(sample_text) print("=====================================") elif args.command == "instruct_finetune": try: from datasets import load_dataset except ImportError: print("Please install the datasets library: pip install datasets") return ds = load_dataset(args.hf_dataset, args.hf_config, split=args.hf_split) if args.hf_config else load_dataset(args.hf_dataset, split=args.hf_split) print(f"Loaded {len(ds)} examples from {args.hf_dataset}.") def process_example(ex): cols = [col.strip() for col in args.hf_text_column.split(",")] # If conversation fields (e.g., user and assistant) are provided: if len(cols) == 2 and cols[0].lower() == "user" and cols[1].lower() == "assistant": return f"User: {ex.get(cols[0], '')}\nAssistant: {ex.get(cols[1], '')}" elif len(cols) == 3: sys_prompt = str(ex.get(cols[0], "")) question = str(ex.get(cols[1], "")) response = str(ex.get(cols[2], "")) return f"System: {sys_prompt}\nInstruction: {question}\nResponse: {response}" elif len(cols) == 2: instruction = str(ex.get(cols[0], "")) response = str(ex.get(cols[1], "")) return f"Instruction: {instruction}\nResponse: {response}" else: texts = [str(ex.get(col, "")) for col in cols] return " ".join(texts) all_text = " ".join(process_example(example) for example in ds) tokens = tokenizer.encode(all_text) print(f"Token count: {len(tokens)}") instruct_data = torch.tensor(tokens, dtype=torch.long) data_dir = f"_data_{args.hf_dataset.replace('/', '_')}_instruct_" os.makedirs(data_dir, exist_ok=True) bin_file = os.path.join(data_dir, "data.bin") torch.save(instruct_data, bin_file) print(f"Saved tokenized instruct data to {bin_file}") # Enable instruct mode in the config config = GPTConfig(args.variant, instruct=True) config.max_iters = args.finetune_iters config.log_interval = args.log_interval config.eval_interval = args.eval_interval model = GPT(config) if not os.path.exists(args.model_checkpoint): print(f"Model checkpoint {args.model_checkpoint} not found.") return model.load_state_dict(torch.load(args.model_checkpoint, map_location=device)) print(f"Loaded model from {args.model_checkpoint}.") # Pass instruct_mode=True to the Trainer for instruct-tuning trainer = Trainer(model, config, instruct_data, device=device, instruct_mode=True) trainer.train() finetuned_path = f"instruct_finetuned_{args.variant}_model.pt" torch.save(model.state_dict(), finetuned_path) print(f"Instruct fine-tuned model saved as: {finetuned_path}") if args.prompt: sample_text = sample(model, args.prompt, config, length=args.sample_length, temperature=args.sample_temperature, instruct_mode=True) print("=====================================") print("Sample Generated Text:") print(sample_text) print("=====================================") elif args.command == "bench": config = GPTConfig(args.variant) model = GPT(config) if not os.path.exists(args.model_checkpoint): print(f"Checkpoint {args.model_checkpoint} not found.") return model.load_state_dict(torch.load(args.model_checkpoint, map_location=device)) model.to(device) print("Loaded model for benchmarking.") # Run a series of placeholder benchmarks: prompts = [ "ARC Benchmark: What is the main purpose of photosynthesis?", "MMLU Benchmark: Who was the first President of the United States?", "HumanEval Benchmark: Write a Python function that calculates the factorial of a number." ] for prompt in prompts: print("=====================================") print("Prompt:") print(prompt) output = sample(model, prompt, config, length=100, temperature=1.0) print("Output:") print(output) print("=====================================") elif args.command == "instance": print("=====================================") print("Recommended Azure VM:") print("Standard_E64ds_v4 CPU instance: 64 cores, 504GB RAM, 2400GB storage.") print("Monitor training times as CPU training is slower than GPU.") print("=====================================") elif args.command == "compare": config = GPTConfig(args.our_variant) model = GPT(config) if not os.path.exists(args.our_model_checkpoint): print(f"Our model checkpoint {args.our_model_checkpoint} not found.") return model.load_state_dict(torch.load(args.our_model_checkpoint, map_location=device)) model.to(device) try: from transformers import pipeline except ImportError: print("Please install transformers: pip install transformers") return hf_pipelines = { "Llama2 1B": pipeline("text-generation", model=args.llama, device=0 if torch.cuda.is_available() else -1), "GPT2 1B": pipeline("text-generation", model=args.gpt2, device=0 if torch.cuda.is_available() else -1), "HF SmolLM2 1B": pipeline("text-generation", model=args.smol, device=0 if torch.cuda.is_available() else -1), "Qwen 2 1B": pipeline("text-generation", model=args.qwen, device=0 if torch.cuda.is_available() else -1) } prompts = [ "ARC Benchmark: What is the main purpose of photosynthesis?", "MMLU Benchmark: Who was the first President of the United States?", "HumanEval Benchmark: Write a Python function that calculates the factorial of a number." ] for prompt in prompts: print("=====================================") print("Prompt:") print(prompt) print("\nOur Model's Output:") our_output = sample(model, prompt, config, length=args.sample_length, temperature=args.temperature) print(our_output) for model_name, pipe in hf_pipelines.items(): print(f"\n{model_name} Output:") result = pipe(prompt, max_length=args.sample_length, temperature=args.temperature, num_return_sequences=1) print(result[0]["generated_text"]) print("=====================================") elif args.command == "evaluate": config = GPTConfig(args.variant) model = GPT(config) if not os.path.exists(args.model_checkpoint): print(f"Checkpoint {args.model_checkpoint} not found.") return model.load_state_dict(torch.load(args.model_checkpoint, map_location=device)) model.to(device) if not os.path.exists(args.val_data): print(f"Validation data file {args.val_data} not found.") return val_data = torch.load(args.val_data) evaluate_perplexity(model, config, val_data, device) elif args.command == "chat": config = GPTConfig(args.variant) model = GPT(config) if not os.path.exists(args.model_checkpoint): print(f"Model checkpoint {args.model_checkpoint} not found.") return model.load_state_dict(torch.load(args.model_checkpoint, map_location=device)) model.to(device) interactive_chat(model, config, device, instruct_mode=args.instruct_mode) else: parser.print_help() if __name__ == "__main__": main() ############################################################################### # EXTRA COMMENTS # ############################################################################### # In this updated version, we have modified both the training and fine-tuning # routines to support instruct-tuning for conversations. Specifically: # - The Trainer class now accepts an "instruct_mode" flag and passes it to the model. # - The "instruct_finetune" command creates a GPT configuration with instruct=True # and processes data to include explicit role markers when using conversation data # (e.g., "User:" and "Assistant:"). # # Enjoy experimenting with and extending your enhanced foundational model! ###############################################################################