Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| class TimedHook: | |
| def __init__(self, hook_fn, total_steps, apply_at_steps=None): | |
| self.hook_fn = hook_fn | |
| self.total_steps = total_steps | |
| self.apply_at_steps = apply_at_steps | |
| self.current_step = 0 | |
| def identity(self, module, input, output): | |
| return output | |
| def __call__(self, module, input, output): | |
| if self.apply_at_steps is not None: | |
| if self.current_step in self.apply_at_steps: | |
| self.__increment() | |
| return self.hook_fn(module, input, output) | |
| else: | |
| self.__increment() | |
| return self.identity(module, input, output) | |
| return self.identity(module, input, output) | |
| def __increment(self): | |
| if self.current_step < self.total_steps: | |
| self.current_step += 1 | |
| else: | |
| self.current_step = 0 | |
| def add_feature(sae, feature_idx, value, module, input, output): | |
| diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device) | |
| activated = sae.encode(diff) | |
| mask = torch.zeros_like(activated, device=diff.device) | |
| mask[..., feature_idx] = value | |
| to_add = mask @ sae.decoder.weight.T | |
| return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),) | |
| def add_feature_on_area_base(sae, feature_idx, activation_map, module, input, output): | |
| return add_feature_on_area_base_both(sae, feature_idx, activation_map, module, input, output) | |
| def add_feature_on_area_base_both(sae, feature_idx, activation_map, module, input, output): | |
| # add the feature to cond and subtract from uncond | |
| # this assumes diff.shape[0] == 2 | |
| diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device) | |
| activated = sae.encode(diff) | |
| mask = torch.zeros_like(activated, device=diff.device) | |
| if len(activation_map) == 2: | |
| activation_map = activation_map.unsqueeze(0) | |
| mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device) | |
| to_add = mask @ sae.decoder.weight.T | |
| to_add = to_add.chunk(2) | |
| output[0][0] -= to_add[0].permute(0, 3, 1, 2).to(output[0].device)[0] | |
| output[0][1] += to_add[1].permute(0, 3, 1, 2).to(output[0].device)[0] | |
| return output | |
| def add_feature_on_area_base_cond(sae, feature_idx, activation_map, module, input, output): | |
| # add the feature to cond | |
| # this assumes diff.shape[0] == 2 | |
| diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device) | |
| diff_uncond, diff_cond = diff.chunk(2) | |
| activated = sae.encode(diff_cond) | |
| mask = torch.zeros_like(activated, device=diff_cond.device) | |
| if len(activation_map) == 2: | |
| activation_map = activation_map.unsqueeze(0) | |
| mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device) | |
| to_add = mask @ sae.decoder.weight.T | |
| output[0][1] += to_add.permute(0, 3, 1, 2).to(output[0].device)[0] | |
| return output | |
| def replace_with_feature_base(sae, feature_idx, value, module, input, output): | |
| # this assumes diff.shape[0] == 2 | |
| diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device) | |
| diff_uncond, diff_cond = diff.chunk(2) | |
| activated = sae.encode(diff_cond) | |
| mask = torch.zeros_like(activated, device=diff_cond.device) | |
| mask[..., feature_idx] = value | |
| to_add = mask @ sae.decoder.weight.T | |
| input[0][1] += to_add.permute(0, 3, 1, 2).to(output[0].device)[0] | |
| return input | |
| def add_feature_on_area_turbo(sae, feature_idx, activation_map, module, input, output): | |
| diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device) | |
| activated = sae.encode(diff) | |
| mask = torch.zeros_like(activated, device=diff.device) | |
| if len(activation_map) == 2: | |
| activation_map = activation_map.unsqueeze(0) | |
| mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device) | |
| to_add = mask @ sae.decoder.weight.T | |
| return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),) | |
| def replace_with_feature_turbo(sae, feature_idx, value, module, input, output): | |
| diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device) | |
| activated = sae.encode(diff) | |
| mask = torch.zeros_like(activated, device=diff.device) | |
| mask[..., feature_idx] = value | |
| to_add = mask @ sae.decoder.weight.T | |
| return (input[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),) | |
| def reconstruct_sae_hook(sae, module, input, output): | |
| diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device) | |
| activated = sae.encode(diff) | |
| reconstructed = sae.decoder(activated) + sae.pre_bias | |
| return (input[0] + reconstructed.permute(0, 3, 1, 2).to(output[0].device),) | |
| def ablate_block(module, input, output): | |
| return input |