llama-31-8b-creativity-it-40-percent / modeling_creativity_iti.py
syed-aliredha's picture
Upload LLaMA 3.1 8B with Creativity ITI
acd3736 verified
raw
history blame
4.25 kB
# modeling_creativity_iti.py
"""
Auto-apply creativity ITI wrapper for LLaMA 3.1 8B
"""
import torch
import pickle
import json
from pathlib import Path
from transformers import LlamaForCausalLM
from huggingface_hub import hf_hub_download
class CreativityITILlamaForCausalLM(LlamaForCausalLM):
"""LLaMA with automatic creativity ITI application"""
def __init__(self, config):
super().__init__(config)
try:
# Get model name from config
model_name = getattr(config, "_name_or_path", "")
# Download ITI files
print(f"Loading Creativity ITI components...")
top_heads_path = hf_hub_download(
repo_id=model_name,
filename="iti_top_heads.pkl",
repo_type="model"
)
directions_path = hf_hub_download(
repo_id=model_name,
filename="iti_directions.pkl",
repo_type="model"
)
config_path = hf_hub_download(
repo_id=model_name,
filename="iti_config.json",
repo_type="model"
)
# Load files
with open(top_heads_path, 'rb') as f:
self.top_heads = pickle.load(f)
with open(directions_path, 'rb') as f:
self.directions = pickle.load(f)
with open(config_path, 'r') as f:
iti_config = json.load(f)
self.alpha = iti_config['alpha']
# Model dimensions
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // self.num_heads
# Register hooks
self._register_iti_hooks()
print(f"✓ Creativity ITI active: α={self.alpha}, {len(self.top_heads)} heads")
except Exception as e:
print(f"Warning: Could not load ITI: {e}")
self.top_heads = []
self.directions = {}
self.alpha = 0
def _register_iti_hooks(self):
"""Register ITI intervention hooks"""
if not self.top_heads:
return
heads_by_layer = {}
for head_info in self.top_heads:
layer = head_info['layer']
head = head_info['head']
if layer not in heads_by_layer:
heads_by_layer[layer] = []
heads_by_layer[layer].append(head)
for layer_idx, head_indices in heads_by_layer.items():
def make_hook(layer_idx, head_indices):
def hook_fn(module, input, output):
if isinstance(output, tuple):
hidden_states = output[0]
else:
hidden_states = output
batch_size, seq_len, hidden_size = hidden_states.shape
hidden_reshaped = hidden_states.view(
batch_size, seq_len, self.num_heads, self.head_dim
)
for head_idx in head_indices:
if (layer_idx, head_idx) in self.directions:
direction = torch.tensor(
self.directions[(layer_idx, head_idx)],
dtype=hidden_reshaped.dtype,
device=hidden_reshaped.device
)
hidden_reshaped[:, -1, head_idx, :] += self.alpha * direction
hidden_states = hidden_reshaped.view(batch_size, seq_len, hidden_size)
if isinstance(output, tuple):
return (hidden_states,) + output[1:]
else:
return hidden_states
return hook_fn
hook = make_hook(layer_idx, head_indices)
self.model.layers[layer_idx].self_attn.o_proj.register_forward_hook(hook)