|
|
|
|
|
""" |
|
|
Auto-apply creativity ITI wrapper for LLaMA 3.1 8B |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import pickle |
|
|
import json |
|
|
from pathlib import Path |
|
|
from transformers import LlamaForCausalLM |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
class CreativityITILlamaForCausalLM(LlamaForCausalLM): |
|
|
"""LLaMA with automatic creativity ITI application""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
try: |
|
|
|
|
|
model_name = getattr(config, "_name_or_path", "") |
|
|
|
|
|
|
|
|
print(f"Loading Creativity ITI components...") |
|
|
|
|
|
top_heads_path = hf_hub_download( |
|
|
repo_id=model_name, |
|
|
filename="iti_top_heads.pkl", |
|
|
repo_type="model" |
|
|
) |
|
|
|
|
|
directions_path = hf_hub_download( |
|
|
repo_id=model_name, |
|
|
filename="iti_directions.pkl", |
|
|
repo_type="model" |
|
|
) |
|
|
|
|
|
config_path = hf_hub_download( |
|
|
repo_id=model_name, |
|
|
filename="iti_config.json", |
|
|
repo_type="model" |
|
|
) |
|
|
|
|
|
|
|
|
with open(top_heads_path, 'rb') as f: |
|
|
self.top_heads = pickle.load(f) |
|
|
|
|
|
with open(directions_path, 'rb') as f: |
|
|
self.directions = pickle.load(f) |
|
|
|
|
|
with open(config_path, 'r') as f: |
|
|
iti_config = json.load(f) |
|
|
self.alpha = iti_config['alpha'] |
|
|
|
|
|
|
|
|
self.num_heads = config.num_attention_heads |
|
|
self.head_dim = config.hidden_size // self.num_heads |
|
|
|
|
|
|
|
|
self._register_iti_hooks() |
|
|
print(f"✓ Creativity ITI active: α={self.alpha}, {len(self.top_heads)} heads") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Warning: Could not load ITI: {e}") |
|
|
self.top_heads = [] |
|
|
self.directions = {} |
|
|
self.alpha = 0 |
|
|
|
|
|
def _register_iti_hooks(self): |
|
|
"""Register ITI intervention hooks""" |
|
|
if not self.top_heads: |
|
|
return |
|
|
|
|
|
heads_by_layer = {} |
|
|
for head_info in self.top_heads: |
|
|
layer = head_info['layer'] |
|
|
head = head_info['head'] |
|
|
if layer not in heads_by_layer: |
|
|
heads_by_layer[layer] = [] |
|
|
heads_by_layer[layer].append(head) |
|
|
|
|
|
for layer_idx, head_indices in heads_by_layer.items(): |
|
|
def make_hook(layer_idx, head_indices): |
|
|
def hook_fn(module, input, output): |
|
|
if isinstance(output, tuple): |
|
|
hidden_states = output[0] |
|
|
else: |
|
|
hidden_states = output |
|
|
|
|
|
batch_size, seq_len, hidden_size = hidden_states.shape |
|
|
|
|
|
hidden_reshaped = hidden_states.view( |
|
|
batch_size, seq_len, self.num_heads, self.head_dim |
|
|
) |
|
|
|
|
|
for head_idx in head_indices: |
|
|
if (layer_idx, head_idx) in self.directions: |
|
|
direction = torch.tensor( |
|
|
self.directions[(layer_idx, head_idx)], |
|
|
dtype=hidden_reshaped.dtype, |
|
|
device=hidden_reshaped.device |
|
|
) |
|
|
|
|
|
hidden_reshaped[:, -1, head_idx, :] += self.alpha * direction |
|
|
|
|
|
hidden_states = hidden_reshaped.view(batch_size, seq_len, hidden_size) |
|
|
|
|
|
if isinstance(output, tuple): |
|
|
return (hidden_states,) + output[1:] |
|
|
else: |
|
|
return hidden_states |
|
|
|
|
|
return hook_fn |
|
|
|
|
|
hook = make_hook(layer_idx, head_indices) |
|
|
self.model.layers[layer_idx].self_attn.o_proj.register_forward_hook(hook) |
|
|
|