# coding=utf-8 import os import torch import shutil import argparse from transformers import AutoModelForCausalLM from peft import PeftModel def main(args): base_model = AutoModelForCausalLM.from_pretrained(args.base_model_path, torch_dtype=torch.bfloat16) model = PeftModel.from_pretrained(base_model, args.peft_model_path) print("\n>>> Base Model + PEFT before merging:\n", model) model = model.merge_and_unload() print("\n>>> Base Model + PEFT after merging:\n", model) print("\n>>> Save model into {}".format(args.save_dir)) model.save_pretrained(args.save_dir, safe_serialization=True) print("\n>>> Copy tokenization files...") shutil.copyfile(os.path.join(args.peft_model_path, "special_tokens_map.json"), os.path.join(args.save_dir, "special_tokens_map.json")) shutil.copyfile(os.path.join(args.peft_model_path, "tokenizer.model"), os.path.join(args.save_dir, "tokenizer.model")) shutil.copyfile(os.path.join(args.peft_model_path, "tokenizer_config.json"), os.path.join(args.save_dir, "tokenizer_config.json")) # Loading test # merged_model = AutoModelForCausalLM.from_pretrained(args.save_dir) # print(">>> Merged Model:\n", merged_model) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Merge PEFT models") # parser.add_argument("--base_model_path", default="/home/shuai/pretrained/meta-llama/Llama-2-7b-hf", type=str) # parser.add_argument("--base_model_path", default="/home/shuai/pretrained/HuggingFaceH4/zephyr-7b-alpha", type=str) # parser.add_argument("--base_model_path", default="/home/shuai/pretrained/mistralai/Mistral-7B-v0.1", type=str) # parser.add_argument("--peft_model_path", default="/home/shuai/output/rlhf/sim_conf_sft_zephyr_007", type=str) parser.add_argument("--base_model_path", default="/home/shuai/pretrained/meta-llama/Llama-2-13b-hf", type=str) parser.add_argument("--peft_model_path", default="/home/shuai/output/rlhf/sim_conf_sft_llama_13b_001", type=str) args = parser.parse_args() args.save_dir = args.peft_model_path + "_merged" print(vars(args)) main(args)