Map-NEO / extend_context.py
Austin207's picture
Upload folder using huggingface_hub
a683148 verified
raw
history blame
2.14 kB
# extend_context.py - Extend MAP-NEO Mini context window to 4096 tokens
from model_neo import NeoMiniConfig, NeoMini
import torch
def extend_model_context(checkpoint_path="checkpoints/checkpoint_step_149999.pt",
new_max_len=16384):
"""Extend model's context window from 2048 to 4096 tokens"""
print(f"Extending context window to {new_max_len} tokens...")
# Load original config and model
config = NeoMiniConfig()
config.max_seq_len = new_max_len # Extend context window
# Create new model with extended context
extended_model = NeoMini(config)
# Load original weights
checkpoint = torch.load(checkpoint_path, map_location='cpu')
original_state = checkpoint['model_state_dict']
# Transfer weights (position embeddings will be interpolated)
extended_state = extended_model.state_dict()
for key in original_state:
if key in extended_state:
if 'pos' in key and extended_state[key].shape != original_state[key].shape:
# Interpolate position embeddings for longer context
print(f"Interpolating position embeddings: {key}")
old_pos_emb = original_state[key]
new_pos_emb = torch.nn.functional.interpolate(
old_pos_emb.unsqueeze(0).unsqueeze(0),
size=(new_max_len, old_pos_emb.shape[-1]),
mode='linear'
).squeeze(0).squeeze(0)
extended_state[key] = new_pos_emb
else:
extended_state[key] = original_state[key]
extended_model.load_state_dict(extended_state)
# Save extended model
extended_checkpoint = {
'model_state_dict': extended_model.state_dict(),
'config': config.to_dict()
}
output_path = "checkpoints/extended_context_model.pt"
torch.save(extended_checkpoint, output_path)
print(f"Extended model saved to {output_path}")
return extended_model, config
if __name__ == "__main__":
extend_model_context()