Spaces:
Running
on
Zero
Running
on
Zero
| ''' | |
| Adapted from | |
| https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/train.py | |
| ''' | |
| import os | |
| import sys | |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..')) | |
| from typing import Callable, Iterable, Iterator | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.distributed import ReduceOp | |
| from SAE.dataset_iterator import ActivationsDataloader | |
| from SAE.sae import SparseAutoencoder, unit_norm_decoder_, unit_norm_decoder_grad_adjustment_ | |
| from SAE.sae_utils import SAETrainingConfig, Config | |
| from types import SimpleNamespace | |
| from typing import Optional, List | |
| import json | |
| import tqdm | |
| def weighted_average(points: torch.Tensor, weights: torch.Tensor): | |
| weights = weights / weights.sum() | |
| return (points * weights.view(-1, 1)).sum(dim=0) | |
| def geometric_median_objective( | |
| median: torch.Tensor, points: torch.Tensor, weights: torch.Tensor | |
| ) -> torch.Tensor: | |
| norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore | |
| return (norms * weights).sum() | |
| def compute_geometric_median( | |
| points: torch.Tensor, | |
| weights: Optional[torch.Tensor] = None, | |
| eps: float = 1e-6, | |
| maxiter: int = 100, | |
| ftol: float = 1e-20, | |
| do_log: bool = False, | |
| ): | |
| """ | |
| :param points: ``torch.Tensor`` of shape ``(n, d)`` | |
| :param weights: Optional ``torch.Tensor`` of shape :math:``(n,)``. | |
| :param eps: Smallest allowed value of denominator, to avoid divide by zero. | |
| Equivalently, this is a smoothing parameter. Default 1e-6. | |
| :param maxiter: Maximum number of Weiszfeld iterations. Default 100 | |
| :param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20. | |
| :param do_log: If true will return a log of function values encountered through the course of the algorithm | |
| :return: SimpleNamespace object with fields | |
| - `median`: estimate of the geometric median, which is a ``torch.Tensor`` object of shape :math:``(d,)`` | |
| - `termination`: string explaining how the algorithm terminated. | |
| - `logs`: function values encountered through the course of the algorithm in a list (None if do_log is false). | |
| """ | |
| with torch.no_grad(): | |
| if weights is None: | |
| weights = torch.ones((points.shape[0],), device=points.device) | |
| # initialize median estimate at mean | |
| new_weights = weights | |
| median = weighted_average(points, weights) | |
| objective_value = geometric_median_objective(median, points, weights) | |
| if do_log: | |
| logs = [objective_value] | |
| else: | |
| logs = None | |
| # Weiszfeld iterations | |
| early_termination = False | |
| pbar = tqdm.tqdm(range(maxiter)) | |
| for _ in pbar: | |
| prev_obj_value = objective_value | |
| norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore | |
| new_weights = weights / torch.clamp(norms, min=eps) | |
| median = weighted_average(points, new_weights) | |
| objective_value = geometric_median_objective(median, points, weights) | |
| if logs is not None: | |
| logs.append(objective_value) | |
| if abs(prev_obj_value - objective_value) <= ftol * objective_value: | |
| early_termination = True | |
| break | |
| pbar.set_description(f"Objective value: {objective_value:.4f}") | |
| median = weighted_average(points, new_weights) # allow autodiff to track it | |
| return SimpleNamespace( | |
| median=median, | |
| new_weights=new_weights, | |
| termination=( | |
| "function value converged within tolerance" | |
| if early_termination | |
| else "maximum iterations reached" | |
| ), | |
| logs=logs, | |
| ) | |
| def maybe_transpose(x): | |
| return x.T if not x.is_contiguous() and x.T.is_contiguous() else x | |
| import wandb | |
| RANK = 0 | |
| class Logger: | |
| def __init__(self, sae_name, **kws): | |
| self.vals = {} | |
| self.enabled = (RANK == 0) and not kws.pop("dummy", False) | |
| self.sae_name = sae_name | |
| def logkv(self, k, v): | |
| if self.enabled: | |
| self.vals[f'{self.sae_name}/{k}'] = v.detach() if isinstance(v, torch.Tensor) else v | |
| return v | |
| def dumpkvs(self, step): | |
| if self.enabled: | |
| wandb.log(self.vals, step=step) | |
| self.vals = {} | |
| class FeaturesStats: | |
| def __init__(self, dim, logger): | |
| self.dim = dim | |
| self.logger = logger | |
| self.reinit() | |
| def reinit(self): | |
| self.n_activated = torch.zeros(self.dim, dtype=torch.long, device="cuda") | |
| self.n = 0 | |
| def update(self, inds): | |
| self.n += inds.shape[0] | |
| inds = inds.flatten().detach() | |
| self.n_activated.scatter_add_(0, inds, torch.ones_like(inds)) | |
| def log(self): | |
| self.logger.logkv('activated', (self.n_activated / self.n + 1e-9).log10().cpu().numpy()) | |
| def training_loop_( | |
| aes, | |
| train_acts_iter, | |
| loss_fn, | |
| log_interval, | |
| save_interval, | |
| loggers, | |
| sae_cfgs, | |
| ): | |
| sae_packs = [] | |
| for ae, cfg, logger in zip(aes, sae_cfgs, loggers): | |
| pbar = tqdm.tqdm(unit=" steps", desc="Training Loss: ") | |
| fstats = FeaturesStats(ae.n_dirs, logger) | |
| opt = torch.optim.Adam(ae.parameters(), lr=cfg.lr, eps=cfg.eps, fused=True) | |
| sae_packs.append((ae, cfg, logger, pbar, fstats, opt)) | |
| for i, flat_acts_train_batch in enumerate(train_acts_iter): | |
| flat_acts_train_batch = flat_acts_train_batch.cuda() | |
| for ae, cfg, logger, pbar, fstats, opt in sae_packs: | |
| recons, info = ae(flat_acts_train_batch) | |
| loss = loss_fn(ae, cfg, flat_acts_train_batch, recons, info, logger) | |
| fstats.update(info['inds']) | |
| bs = flat_acts_train_batch.shape[0] | |
| logger.logkv('not-activated 1e4', (ae.stats_last_nonzero > 1e4 / bs).mean(dtype=float).item()) | |
| logger.logkv('not-activated 1e6', (ae.stats_last_nonzero > 1e6 / bs).mean(dtype=float).item()) | |
| logger.logkv('not-activated 1e7', (ae.stats_last_nonzero > 1e7 / bs).mean(dtype=float).item()) | |
| logger.logkv('explained variance', explained_variance(recons, flat_acts_train_batch)) | |
| logger.logkv('l2_div', (torch.linalg.norm(recons, dim=1) / torch.linalg.norm(flat_acts_train_batch, dim=1)).mean()) | |
| if (i + 1) % log_interval == 0: | |
| fstats.log() | |
| fstats.reinit() | |
| if (i + 1) % save_interval == 0: | |
| ae.save_to_disk(f"{cfg.save_path}/{i + 1}") | |
| loss.backward() | |
| unit_norm_decoder_(ae) | |
| unit_norm_decoder_grad_adjustment_(ae) | |
| opt.step() | |
| opt.zero_grad() | |
| logger.dumpkvs(i) | |
| pbar.set_description(f"Training Loss {loss.item():.4f}") | |
| pbar.update(1) | |
| for ae, cfg, logger, pbar, fstats, opt in sae_packs: | |
| pbar.close() | |
| ae.save_to_disk(f"{cfg.save_path}/final") | |
| def init_from_data_(ae, stats_acts_sample): | |
| ae.pre_bias.data = ( | |
| compute_geometric_median(stats_acts_sample[:32768].float().cpu()).median.cuda().float() | |
| ) | |
| def mse(recons, x): | |
| # return ((recons - x) ** 2).sum(dim=-1).mean() | |
| return ((recons - x) ** 2).mean() | |
| def normalized_mse(recon: torch.Tensor, xs: torch.Tensor) -> torch.Tensor: | |
| # only used for auxk | |
| xs_mu = xs.mean(dim=0) | |
| loss = mse(recon, xs) / mse( | |
| xs_mu[None, :].broadcast_to(xs.shape), xs | |
| ) | |
| return loss | |
| def explained_variance(recons, x): | |
| # Compute the variance of the difference | |
| diff = x - recons | |
| diff_var = torch.var(diff, dim=0, unbiased=False) | |
| # Compute the variance of the original tensor | |
| x_var = torch.var(x, dim=0, unbiased=False) | |
| # Avoid division by zero | |
| explained_var = 1 - diff_var / (x_var + 1e-8) | |
| return explained_var.mean() | |
| def main(): | |
| cfg = Config(json.load(open('SAE/config.json'))) | |
| dataloader = ActivationsDataloader(cfg.paths_to_latents, cfg.block_name, cfg.bs) | |
| acts_iter = dataloader.iterate() | |
| stats_acts_sample = torch.cat([ | |
| next(acts_iter).cpu() for _ in range(10) | |
| ], dim=0) | |
| aes = [ | |
| SparseAutoencoder( | |
| n_dirs_local=sae.n_dirs, | |
| d_model=sae.d_model, | |
| k=sae.k, | |
| auxk=sae.auxk, | |
| dead_steps_threshold=sae.dead_toks_threshold // cfg.bs, | |
| ).cuda() | |
| for sae in cfg.saes | |
| ] | |
| for ae in aes: | |
| init_from_data_(ae, stats_acts_sample) | |
| mse_scale = ( | |
| 1 / ((stats_acts_sample.float().mean(dim=0) - stats_acts_sample.float()) ** 2).mean() | |
| ) | |
| mse_scale = mse_scale.item() | |
| del stats_acts_sample | |
| wandb.init( | |
| project=cfg.wandb_project, | |
| name=cfg.wandb_name, | |
| ) | |
| loggers = [Logger( | |
| sae_name=cfg_sae.sae_name, | |
| dummy=False, | |
| ) for cfg_sae in cfg.saes] | |
| training_loop_( | |
| aes, | |
| acts_iter, | |
| lambda ae, cfg_sae, flat_acts_train_batch, recons, info, logger: ( | |
| # MSE | |
| logger.logkv("train_recons", mse_scale * mse(recons, flat_acts_train_batch)) | |
| # AuxK | |
| + logger.logkv( | |
| "train_maxk_recons", | |
| cfg_sae.auxk_coef | |
| * normalized_mse( | |
| ae.decode_sparse( | |
| info["auxk_inds"], | |
| info["auxk_vals"], | |
| ), | |
| flat_acts_train_batch - recons.detach() + ae.pre_bias.detach(), | |
| ).nan_to_num(0), | |
| ) | |
| ), | |
| sae_cfgs = cfg.saes, | |
| loggers=loggers, | |
| log_interval=cfg.log_interval, | |
| save_interval=cfg.save_interval, | |
| ) | |
| if __name__ == "__main__": | |
| main() |