# extend_context.py - Extend MAP-NEO Mini context window to 4096 tokens from model_neo import NeoMiniConfig, NeoMini import torch def extend_model_context(checkpoint_path="checkpoints/checkpoint_step_149999.pt", new_max_len=16384): """Extend model's context window from 2048 to 4096 tokens""" print(f"Extending context window to {new_max_len} tokens...") # Load original config and model config = NeoMiniConfig() config.max_seq_len = new_max_len # Extend context window # Create new model with extended context extended_model = NeoMini(config) # Load original weights checkpoint = torch.load(checkpoint_path, map_location='cpu') original_state = checkpoint['model_state_dict'] # Transfer weights (position embeddings will be interpolated) extended_state = extended_model.state_dict() for key in original_state: if key in extended_state: if 'pos' in key and extended_state[key].shape != original_state[key].shape: # Interpolate position embeddings for longer context print(f"Interpolating position embeddings: {key}") old_pos_emb = original_state[key] new_pos_emb = torch.nn.functional.interpolate( old_pos_emb.unsqueeze(0).unsqueeze(0), size=(new_max_len, old_pos_emb.shape[-1]), mode='linear' ).squeeze(0).squeeze(0) extended_state[key] = new_pos_emb else: extended_state[key] = original_state[key] extended_model.load_state_dict(extended_state) # Save extended model extended_checkpoint = { 'model_state_dict': extended_model.state_dict(), 'config': config.to_dict() } output_path = "checkpoints/extended_context_model.pt" torch.save(extended_checkpoint, output_path) print(f"Extended model saved to {output_path}") return extended_model, config if __name__ == "__main__": extend_model_context()