import os import json import torch from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoModel import chromadb import re # ======== 사용자 설정 ======== # base_model = "/home/sooh5090/axolotl/output/spell-finetune2/checkpoint-139" test_file = "../data/json/korean_language_rag_V1.0_test.json" output_file = "../output/test_predictions_2.json" max_new_tokens = 256 # ======== GPU 디바이스 분리 ======== # device_llm = torch.device("cuda:3" if torch.cuda.is_available() else "cpu") # LLM device_rag = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") # 임베딩 # ======== 1. 모델 로드 ======== # print("🔄 모델 로드 중...") tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) tokenizer.pad_token = "<|end_of_text|>" tokenizer.eos_token = "<|eot_id|>" tokenizer.bos_token = "<|begin_of_text|>" model = AutoModelForCausalLM.from_pretrained( base_model, device_map={"": device_llm}, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, trust_remote_code=True ) model.eval() print("✅ FP16 모델 로드 완료 (GPU 3번)") # ======== 2. RAG 검색기 설정 (GPU 2번) ======== # embed_model_id = "dragonkue/snowflake-arctic-embed-l-v2.0-ko" embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_id) embed_model = AutoModel.from_pretrained(embed_model_id).to(device_rag).eval() client = chromadb.PersistentClient(path="../grammar_db") collection = client.get_collection(name="korean_grammar_rules", embedding_function=None) # ======== 3. 임베딩 함수 ======== # def embed_query(text, chunk_size=512): tokens = embed_tokenizer(text, add_special_tokens=False)["input_ids"] chunks = [tokens[i:i + chunk_size] for i in range(0, len(tokens), chunk_size)] embeddings = [] for chunk in chunks: inputs = torch.tensor([embed_tokenizer.build_inputs_with_special_tokens(chunk)]).to(device_rag) with torch.no_grad(): output = embed_model(input_ids=inputs).last_hidden_state valid_token_count = (inputs != embed_tokenizer.pad_token_id).sum(dim=1, keepdim=True) chunk_emb = output.sum(dim=1) / valid_token_count embeddings.append(chunk_emb.cpu()) return torch.stack(embeddings).mean(dim=0).squeeze(0).tolist() # ======== 4. 2개 규범 반환 ======== # def retrieve_context(query_text, top_k=3): query_vec = embed_query(query_text) results = collection.query(query_embeddings=[query_vec], n_results=top_k * 4) # 넉넉하게 받아옴 docs = results["documents"][0] metas = results["metadatas"][0] seen_rules = set() contexts = [] for doc, meta in zip(docs, metas): rule_text = meta.get("rule", "").strip() if rule_text in seen_rules: continue # 동일 규범(rule) 중복 제거 seen_rules.add(rule_text) # 참고용: 제목 + 해당 규범 내용 (doc는 안 써도 됨) context = f"[{meta['title']}]\n{rule_text}" contexts.append(context) if len(contexts) == top_k: break return "\n\n".join(contexts) # ======== 4-1. 질문 전처리 함수 ======== # def extract_query_from_question(q_type, question): if q_type == "선택형": matches = re.findall(r"\{([^}]+)\}", question) parens_matches = re.findall(r"\(([^\)]+)\)", question) tokens = [] if matches: tokens.extend(re.split(r"[ /]", matches[0])) for m in parens_matches: if m.strip(): tokens.extend(re.split(r"[ /]", m.strip())) return " ".join(tokens) if tokens else question else: # 교정형 dash_lines = [line.strip("― ").strip() for line in question.splitlines() if line.strip().startswith("―")] if dash_lines: return dash_lines[0] quote_matches = re.findall(r"\"([^\"]+)\"", question) if quote_matches: return quote_matches[0] parens_matches = re.findall(r"\(([^\)]+)\)", question) if parens_matches: return parens_matches[0] return question # ======== 5. Instruction 템플릿 ======== # INSTRUCTION_TEMPLATES = { "교정형": """당신은 한국어 어문 규범(맞춤법, 띄어쓰기, 표준어, 문장부호, 외래어 표기법 등)에 따라 문장을 교정하고 그 이유를 설명하는 AI입니다. [문제 유형: 교정형] - 문제와 함께 관련 규범이 주어질 수 있으나, 규범의 타당성을 스스로 판단하여 정답을 도출해야 합니다. - 제시된 문장은 반드시 어문 규범에 어긋난 표현을 포함하고 있습니다. - 틀린 표현을 정확히 찾아 수정하되, 그 외 표현은 변경하지 마십시오. - 답변은 반드시 다음 형식을 따라야 합니다: "{수정문}이 옳다. {이유}" [예시] 문제: 다음 문장에서 어문 규범에 부합하지 않는 부분을 찾아 고치고, 그렇게 고친 이유를 설명하세요. "어서 쾌차하시길 바래요." 정답: "어서 쾌차하시길 바라요."가 옳다. 동사 '바라다'에 어미 '-아요'가 결합한 형태이므로 '바라요'로 써야 한다. '바래요'는 비표준어다.""", "선택형": """당신은 한국어 어문 규범(맞춤법, 띄어쓰기, 표준어, 문장부호, 외래어 표기법 등)에 따라 문장에서 올바른 표현을 선택하고 그 이유를 설명하는 AI입니다. [문제 유형: 선택형] - 문제와 함께 관련 규범이 주어질 수 있으나, 규범의 타당성을 스스로 판단하여 정답을 도출해야 합니다. - 보기 중 단 하나의 정답이 있으며, 반드시 어문 규범에 따라 하나의 표현만 선택해야 합니다. - 선택한 표현 이외의 문장 구성은 수정하지 마십시오. - 답변은 반드시 다음 형식을 따라야 합니다: "{정답문}이 옳다. {이유}" [예시] 문제: "가축을 기를 때에는 {먹이량/먹이양}을 조절해 주어야 한다." 가운데 올바른 것을 선택하고, 그 이유를 설명하세요. 정답: "가축을 기를 때에는 먹이양을 조절해 주어야 한다."가 옳다. '먹이'는 고유어이므로, 한자어 '量'은 두음 법칙에 따라 '양'으로 표기해야 한다.""" } # ======== 6. 테스트 데이터 로드 ======== # with open(test_file, "r", encoding="utf-8") as f: test_data = json.load(f) # ======== 7. 예측 ======== # predictions = [] for sample in tqdm(test_data, desc="🔍 Test 예측 중"): q_type = sample.get("input", {}).get("question_type") question = sample.get("input", {}).get("question", "").strip() search_query = extract_query_from_question(q_type, question) retrieved = retrieve_context(search_query) instruction = INSTRUCTION_TEMPLATES.get(q_type, INSTRUCTION_TEMPLATES["교정형"]) input_text = f"[참고 규범]\n{retrieved}\n=====\n문제: {question}\n정답:" # ======== 카나나 프롬프트 ======== # prompt = ( "<|begin_of_text|>\n" f"[|system|]{instruction}<|eot_id|>\n" f"[|user|]{input_text}<|eot_id|>\n" "[|assistant|]" ) # LLM (GPU 3번) inputs = tokenizer(prompt, return_tensors="pt").to(device_llm) inputs.pop("token_type_ids", None) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.0, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id ) decoded = tokenizer.decode(outputs[0], skip_special_tokens=False) prediction = decoded.split("[|assistant|]")[-1].split(tokenizer.eos_token)[0].strip() print("\n=============================") print(f"📝 질문: {question}") print(f"🔍 검색 쿼리: {search_query}") print(f"📚 검색 컨텍스트:\n{retrieved}") print(f"🤖 모델 답변: {prediction}") print("=============================\n") predictions.append({ "id": sample.get("id", ""), "input": sample.get("input", {}), "output": {"answer": prediction} }) # ======== 8. 결과 저장 ======== # os.makedirs(os.path.dirname(output_file), exist_ok=True) with open(output_file, "w", encoding="utf-8") as f: json.dump(predictions, f, ensure_ascii=False, indent=2) print(f"\n📄 테스트 결과 저장 완료: {output_file}")