import json import re from collections import defaultdict infer_result_path = '/root/autodl-tmp/output_7B_GRPO/v28-20250722-002940/checkpoint-870/infer_result/53_HH.jsonl' test_path = '/root/autodl-tmp/ms-swift/all_audio_test_50.jsonl' output_path = 'inference_comparison_result.json' def extract_overall_score(response_text): match = re.search(r'(\d+)', response_text) if match: return int(match.group(1)) return None def main(): # 读取infer_result文件,建立audio到score的映射 infer_audio2score = {} with open(infer_result_path, 'r', encoding='utf-8') as f: for line in f: data = json.loads(line) score = extract_overall_score(data['response']) audios = tuple(data.get('audios', [])) infer_audio2score[audios] = { 'score': score, 'raw_response': data['response'] } # 读取test文件,建立audio到solution的映射 test_audio2solution = {} with open(test_path, 'r', encoding='utf-8') as f: for line in f: data = json.loads(line) solution = data['solution'] audios = tuple(data.get('audios', [])) test_audio2solution[audios] = solution # 统计和收集错误样本 & 所有推理结果 stats_per_class = defaultdict(lambda: {'correct': 0, 'incorrect': 0}) incorrect_samples_solution1 = [] all_results = [] total = 0 correct = 0 for audios, solution in test_audio2solution.items(): infer_entry = infer_audio2score.get(audios, None) infer_score = infer_entry['score'] if infer_entry else None raw_response = infer_entry['raw_response'] if infer_entry else None match = infer_score == solution # 收集所有结果 all_results.append({ 'audios': audios, 'gt_solution': solution, 'predicted_score': infer_score, 'match': match, 'response': raw_response }) if match: correct += 1 stats_per_class[solution]['correct'] += 1 else: stats_per_class[solution]['incorrect'] += 1 if solution == 1: incorrect_samples_solution1.append({ 'audios': audios, 'gt_solution': solution, 'predicted_score': infer_score, 'response': raw_response }) total += 1 # 总体准确率 print(f'\nOverall Accuracy: {correct}/{total} = {correct/total:.2%}\n') # 每类准确率 print("Per-Class Accuracy:") for solution, stats in sorted(stats_per_class.items()): total_class = stats['correct'] + stats['incorrect'] accuracy = stats['correct'] / total_class if total_class > 0 else 0.0 print(f'Class {solution}: Correct={stats["correct"]}, Incorrect={stats["incorrect"]}, Accuracy={accuracy:.2%}') # 列出 solution=1 且预测错误的样本 print("\nIncorrect Samples for solution = 1:") for sample in incorrect_samples_solution1: print(json.dumps(sample, indent=2, ensure_ascii=False)) # 写入所有结果到 JSON 文件 with open(output_path, 'w', encoding='utf-8') as f: json.dump(all_results, f, indent=2, ensure_ascii=False) print(f"\nAll inference comparison results saved to: {output_path}") if __name__ == '__main__': main()