anonymous-author-129 commited on
Commit
d85a156
·
verified ·
1 Parent(s): 17a1a2a

Update utils/hooks.py

Browse files
Files changed (1) hide show
  1. utils/hooks.py +79 -5
utils/hooks.py CHANGED
@@ -1,5 +1,31 @@
1
  import torch
2
- import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  @torch.no_grad()
5
  def add_feature(sae, feature_idx, value, module, input, output):
@@ -10,9 +36,14 @@ def add_feature(sae, feature_idx, value, module, input, output):
10
  to_add = mask @ sae.decoder.weight.T
11
  return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
12
 
 
 
 
13
 
14
  @torch.no_grad()
15
- def add_feature_on_area(sae, feature_idx, activation_map, module, input, output):
 
 
16
  diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
17
  activated = sae.encode(diff)
18
  mask = torch.zeros_like(activated, device=diff.device)
@@ -20,11 +51,54 @@ def add_feature_on_area(sae, feature_idx, activation_map, module, input, output)
20
  activation_map = activation_map.unsqueeze(0)
21
  mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device)
22
  to_add = mask @ sae.decoder.weight.T
23
- return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
 
 
 
24
 
25
 
26
  @torch.no_grad()
27
- def replace_with_feature(sae, feature_idx, value, module, input, output):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
29
  activated = sae.encode(diff)
30
  mask = torch.zeros_like(activated, device=diff.device)
@@ -43,4 +117,4 @@ def reconstruct_sae_hook(sae, module, input, output):
43
 
44
  @torch.no_grad()
45
  def ablate_block(module, input, output):
46
- return input
 
1
  import torch
2
+
3
+ class TimedHook:
4
+ def __init__(self, hook_fn, total_steps, apply_at_steps=None):
5
+ self.hook_fn = hook_fn
6
+ self.total_steps = total_steps
7
+ self.apply_at_steps = apply_at_steps
8
+ self.current_step = 0
9
+
10
+ def identity(self, module, input, output):
11
+ return output
12
+
13
+ def __call__(self, module, input, output):
14
+ if self.apply_at_steps is not None:
15
+ if self.current_step in self.apply_at_steps:
16
+ self.__increment()
17
+ return self.hook_fn(module, input, output)
18
+ else:
19
+ self.__increment()
20
+ return self.identity(module, input, output)
21
+
22
+ return self.identity(module, input, output)
23
+
24
+ def __increment(self):
25
+ if self.current_step < self.total_steps:
26
+ self.current_step += 1
27
+ else:
28
+ self.current_step = 0
29
 
30
  @torch.no_grad()
31
  def add_feature(sae, feature_idx, value, module, input, output):
 
36
  to_add = mask @ sae.decoder.weight.T
37
  return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
38
 
39
+ @torch.no_grad()
40
+ def add_feature_on_area_base(sae, feature_idx, activation_map, module, input, output):
41
+ return add_feature_on_area_base_both(sae, feature_idx, activation_map, module, input, output)
42
 
43
  @torch.no_grad()
44
+ def add_feature_on_area_base_both(sae, feature_idx, activation_map, module, input, output):
45
+ # add the feature to cond and subtract from uncond
46
+ # this assumes diff.shape[0] == 2
47
  diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
48
  activated = sae.encode(diff)
49
  mask = torch.zeros_like(activated, device=diff.device)
 
51
  activation_map = activation_map.unsqueeze(0)
52
  mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device)
53
  to_add = mask @ sae.decoder.weight.T
54
+ to_add = to_add.chunk(2)
55
+ output[0][0] -= to_add[0].permute(0, 3, 1, 2).to(output[0].device)[0]
56
+ output[0][1] += to_add[1].permute(0, 3, 1, 2).to(output[0].device)[0]
57
+ return output
58
 
59
 
60
  @torch.no_grad()
61
+ def add_feature_on_area_base_cond(sae, feature_idx, activation_map, module, input, output):
62
+ # add the feature to cond
63
+ # this assumes diff.shape[0] == 2
64
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
65
+ diff_uncond, diff_cond = diff.chunk(2)
66
+ activated = sae.encode(diff_cond)
67
+ mask = torch.zeros_like(activated, device=diff_cond.device)
68
+ if len(activation_map) == 2:
69
+ activation_map = activation_map.unsqueeze(0)
70
+ mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device)
71
+ to_add = mask @ sae.decoder.weight.T
72
+ output[0][1] += to_add.permute(0, 3, 1, 2).to(output[0].device)[0]
73
+ return output
74
+
75
+
76
+ @torch.no_grad()
77
+ def replace_with_feature_base(sae, feature_idx, value, module, input, output):
78
+ # this assumes diff.shape[0] == 2
79
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
80
+ diff_uncond, diff_cond = diff.chunk(2)
81
+ activated = sae.encode(diff_cond)
82
+ mask = torch.zeros_like(activated, device=diff_cond.device)
83
+ mask[..., feature_idx] = value
84
+ to_add = mask @ sae.decoder.weight.T
85
+ input[0][1] += to_add.permute(0, 3, 1, 2).to(output[0].device)[0]
86
+ return input
87
+
88
+
89
+ @torch.no_grad()
90
+ def add_feature_on_area_turbo(sae, feature_idx, activation_map, module, input, output):
91
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
92
+ activated = sae.encode(diff)
93
+ mask = torch.zeros_like(activated, device=diff.device)
94
+ if len(activation_map) == 2:
95
+ activation_map = activation_map.unsqueeze(0)
96
+ mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device)
97
+ to_add = mask @ sae.decoder.weight.T
98
+ return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
99
+
100
+ @torch.no_grad()
101
+ def replace_with_feature_turbo(sae, feature_idx, value, module, input, output):
102
  diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
103
  activated = sae.encode(diff)
104
  mask = torch.zeros_like(activated, device=diff.device)
 
117
 
118
  @torch.no_grad()
119
  def ablate_block(module, input, output):
120
+ return input