| from transformers import AutoTokenizer, TextGenerationPipeline | |
| from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig | |
| import logging | |
| pretrained_model_dir: str = "models/WizardLM-7B-Uncensored" | |
| quantized_model_dir: str = "./" | |
| config: dict = dict( | |
| quantize_config=dict(bits=8, desc_act=True, true_sequential=True, model_file_base_name='WizardLM-7B-Uncensored'), | |
| use_safetensors=True | |
| ) | |
| logging.basicConfig( | |
| format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) | |
| examples: list[dict[str, list[int]]] = [tokenizer("It was a cold night")] | |
| model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, BaseQuantizeConfig(**config['quantize_config'])) | |
| model.quantize(examples) | |
| model.save_quantized(quantized_model_dir, use_safetensors=config['use_safetensors']) | |