|
|
|
|
|
from model_neo import NeoMiniConfig, NeoMini
|
|
|
import torch
|
|
|
|
|
|
def extend_model_context(checkpoint_path="checkpoints/checkpoint_step_149999.pt",
|
|
|
new_max_len=16384):
|
|
|
"""Extend model's context window from 2048 to 4096 tokens"""
|
|
|
|
|
|
print(f"Extending context window to {new_max_len} tokens...")
|
|
|
|
|
|
|
|
|
config = NeoMiniConfig()
|
|
|
config.max_seq_len = new_max_len
|
|
|
|
|
|
|
|
|
extended_model = NeoMini(config)
|
|
|
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
|
original_state = checkpoint['model_state_dict']
|
|
|
|
|
|
|
|
|
extended_state = extended_model.state_dict()
|
|
|
|
|
|
for key in original_state:
|
|
|
if key in extended_state:
|
|
|
if 'pos' in key and extended_state[key].shape != original_state[key].shape:
|
|
|
|
|
|
print(f"Interpolating position embeddings: {key}")
|
|
|
old_pos_emb = original_state[key]
|
|
|
new_pos_emb = torch.nn.functional.interpolate(
|
|
|
old_pos_emb.unsqueeze(0).unsqueeze(0),
|
|
|
size=(new_max_len, old_pos_emb.shape[-1]),
|
|
|
mode='linear'
|
|
|
).squeeze(0).squeeze(0)
|
|
|
extended_state[key] = new_pos_emb
|
|
|
else:
|
|
|
extended_state[key] = original_state[key]
|
|
|
|
|
|
extended_model.load_state_dict(extended_state)
|
|
|
|
|
|
|
|
|
extended_checkpoint = {
|
|
|
'model_state_dict': extended_model.state_dict(),
|
|
|
'config': config.to_dict()
|
|
|
}
|
|
|
|
|
|
output_path = "checkpoints/extended_context_model.pt"
|
|
|
torch.save(extended_checkpoint, output_path)
|
|
|
print(f"Extended model saved to {output_path}")
|
|
|
|
|
|
return extended_model, config
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
extend_model_context()
|
|
|
|