# modeling_creativity_iti.py """ Auto-apply creativity ITI wrapper for LLaMA 3.1 8B """ import torch import pickle import json from pathlib import Path from transformers import LlamaForCausalLM from huggingface_hub import hf_hub_download class CreativityITILlamaForCausalLM(LlamaForCausalLM): """LLaMA with automatic creativity ITI application""" def __init__(self, config): super().__init__(config) try: # Get model name from config model_name = getattr(config, "_name_or_path", "") # Download ITI files print(f"Loading Creativity ITI components...") top_heads_path = hf_hub_download( repo_id=model_name, filename="iti_top_heads.pkl", repo_type="model" ) directions_path = hf_hub_download( repo_id=model_name, filename="iti_directions.pkl", repo_type="model" ) config_path = hf_hub_download( repo_id=model_name, filename="iti_config.json", repo_type="model" ) # Load files with open(top_heads_path, 'rb') as f: self.top_heads = pickle.load(f) with open(directions_path, 'rb') as f: self.directions = pickle.load(f) with open(config_path, 'r') as f: iti_config = json.load(f) self.alpha = iti_config['alpha'] # Model dimensions self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // self.num_heads # Register hooks self._register_iti_hooks() print(f"✓ Creativity ITI active: α={self.alpha}, {len(self.top_heads)} heads") except Exception as e: print(f"Warning: Could not load ITI: {e}") self.top_heads = [] self.directions = {} self.alpha = 0 def _register_iti_hooks(self): """Register ITI intervention hooks""" if not self.top_heads: return heads_by_layer = {} for head_info in self.top_heads: layer = head_info['layer'] head = head_info['head'] if layer not in heads_by_layer: heads_by_layer[layer] = [] heads_by_layer[layer].append(head) for layer_idx, head_indices in heads_by_layer.items(): def make_hook(layer_idx, head_indices): def hook_fn(module, input, output): if isinstance(output, tuple): hidden_states = output[0] else: hidden_states = output batch_size, seq_len, hidden_size = hidden_states.shape hidden_reshaped = hidden_states.view( batch_size, seq_len, self.num_heads, self.head_dim ) for head_idx in head_indices: if (layer_idx, head_idx) in self.directions: direction = torch.tensor( self.directions[(layer_idx, head_idx)], dtype=hidden_reshaped.dtype, device=hidden_reshaped.device ) hidden_reshaped[:, -1, head_idx, :] += self.alpha * direction hidden_states = hidden_reshaped.view(batch_size, seq_len, hidden_size) if isinstance(output, tuple): return (hidden_states,) + output[1:] else: return hidden_states return hook_fn hook = make_hook(layer_idx, head_indices) self.model.layers[layer_idx].self_attn.o_proj.register_forward_hook(hook)