| import torch | |
| import torch.nn as nn | |
| from .gptq import * | |
| from .modelutils import * | |
| from .quant import * | |
| from transformers import BloomForCausalLM as LM | |
| class SakuraForCausalLM(LM): | |
| def __init__(self,*args,**kwargs): | |
| def noop(*args, **kwargs): | |
| pass | |
| torch.nn.init.kaiming_uniform_ = noop | |
| torch.nn.init.uniform_ = noop | |
| torch.nn.init.normal_ = noop | |
| torch.set_default_dtype(torch.half) | |
| transformers.modeling_utils._init_weights = False | |
| torch.set_default_dtype(torch.half) | |
| super().__init__(*args,**kwargs) | |
| torch.set_default_dtype(torch.float) | |
| self.eval() | |
| layers = find_layers(self) | |
| for name in ['lm_head']: | |
| if name in layers: | |
| del layers[name] | |
| make_quant(self, layers, 8, 128) |