hbXNov commited on
Commit
6c25a84
·
verified ·
1 Parent(s): 10c0670

Create eval_qwen.py

Browse files
Files changed (1) hide show
  1. eval_qwen.py +205 -0
eval_qwen.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from dataclasses import dataclass
4
+ from typing import List, Dict, Optional, Union
5
+
6
+ import torch
7
+ from datasets import load_dataset
8
+ import json
9
+ from tqdm import tqdm
10
+ from PIL import Image
11
+ import requests
12
+ from io import BytesIO
13
+ import argparse
14
+ from pathlib import Path
15
+ from enum import Enum
16
+
17
+ # Import custom modules
18
+ from data import (
19
+ DatasetType,
20
+ DatasetConfig,
21
+ get_dataset_config,
22
+ get_formatted_instruction,
23
+ process_response,
24
+ save_descriptions,
25
+ load_image_dataset,
26
+ get_processed_response
27
+ )
28
+ from torch.utils.data import Dataset, DataLoader, DistributedSampler
29
+ import torch.distributed as dist
30
+ from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
31
+ from vllm import LLM, SamplingParams
32
+
33
+
34
+ import io
35
+ import base64
36
+ from PIL import Image
37
+
38
+ # Configure logging
39
+ logging.basicConfig(
40
+ level=logging.INFO,
41
+ format='%(asctime)s - %(levelname)s - %(message)s',
42
+ handlers=[
43
+ logging.FileHandler('evaluation.log'),
44
+ logging.StreamHandler()
45
+ ]
46
+ )
47
+ logger = logging.getLogger(__name__)
48
+
49
+ INSTRUCTION = "\n\nYour final answer MUST BE put in \\boxed{}."
50
+
51
+ def pil_to_base64(image_pil, format="PNG"):
52
+ buffered = io.BytesIO()
53
+ image_pil.save(buffered, format=format)
54
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
55
+ return img_str
56
+
57
+ def base64_to_pil(base64_string):
58
+ img_data = base64.b64decode(base64_string)
59
+ image_pil = Image.open(io.BytesIO(img_data))
60
+ return image_pil
61
+
62
+ class InstanceDataset(Dataset):
63
+
64
+ def __init__(self, data):
65
+ self.data = data
66
+
67
+ def __len__(self):
68
+ return len(self.data)
69
+
70
+ def __getitem__(self, index):
71
+ item = self.data[index]
72
+ for k in item:
73
+ if k == 'options' or k == 'choices':
74
+ if item[k] == None:
75
+ item[k] = ""
76
+ else:
77
+ item[k] = str(item[k])
78
+ if 'image_url' in item:
79
+ image_url = item['image_url']
80
+ image_str = pil_to_base64(image_url)
81
+ item['image_url'] = image_str
82
+ instance = {'index': index, 'item': item}
83
+ return instance
84
+
85
+ def main():
86
+ parser = argparse.ArgumentParser(description='Evaluate model on various math datasets')
87
+ parser.add_argument('--dataset', type=str, choices=['mathvista', 'mathverse', 'mathvision', 'mathvision-mini', 'hallusionbench', 'mmmu-pro-vision', 'we-math', 'math500', 'gpqa', 'dynamath', 'logicvista'],
88
+ default='mathvista', help='Dataset to evaluate on')
89
+ parser.add_argument('--model_path', type=str, help='Path to the model', default="Qwen/Qwen3-VL-2B-Instruct")
90
+ parser.add_argument('--name', type=str, help='model save name', default="plm")
91
+ parser.add_argument('--bsz', type=int, help='batch size', default=2)
92
+
93
+ args = parser.parse_args()
94
+
95
+ # device = int(os.environ['LOCAL_RANK'])
96
+ # torch.cuda.set_device(f'cuda:{device}')
97
+
98
+ # Configuration
99
+ dataset_type = DatasetType(args.dataset)
100
+ dataset_config = get_dataset_config(dataset_type)
101
+
102
+ output_folder = f"./outputs/{dataset_type.value}_{args.name}"
103
+ os.makedirs(output_folder, exist_ok=True)
104
+
105
+ MODEL_PATH = args.model_path
106
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
107
+ vlm = LLM(MODEL_PATH, limit_mm_per_prompt={"image": 1}, tensor_parallel_size=torch.cuda.device_count())
108
+ sampling_params = SamplingParams(max_tokens=2048, temperature=0.7, top_p=0.8, top_k=20, repetition_penalty=1.0, presence_penalty=1.5)
109
+
110
+ # Load dataset
111
+ logger.info(f"Loading dataset {dataset_config.name}")
112
+ data = load_image_dataset(dataset_config)
113
+
114
+ # dist.init_process_group()
115
+ dataset = InstanceDataset(data)
116
+ # sampler = DistributedSampler(dataset, shuffle=False)
117
+ dataloader = DataLoader(dataset, batch_size=args.bsz)
118
+
119
+ # Load model
120
+ # local_rank = int(os.environ['LOCAL_RANK'])
121
+ # logger.info(f"Loaded model {args.model_path} | local rank: {local_rank}")
122
+
123
+ for batch in tqdm(dataloader):
124
+
125
+ indices = batch['index']
126
+
127
+ run_input_instances = []
128
+ run_indices = []
129
+ run_processed_responses = []
130
+ run_items = []
131
+ run_formatted_instructions = []
132
+
133
+ for j in range(len(indices)):
134
+ index = indices[j].item()
135
+ output_file = os.path.join(output_folder, f'{index}.json')
136
+ global_item = batch['item']
137
+ if not os.path.exists(output_file):
138
+ item = {}
139
+ for k in global_item:
140
+ item[k] = global_item[k][j]
141
+
142
+ for k in item:
143
+ if len(item[k]) > 0:
144
+ if k == 'choices' or k == 'options':
145
+ # print(f'item[k]: {item[k]}')
146
+ try:
147
+ item[k] = eval(item[k])
148
+ except:
149
+ item[k] = item[k]
150
+ if k == 'image_url':
151
+ item['image_url'] = base64_to_pil(item['image_url'])
152
+
153
+ formatted_instruction = get_formatted_instruction(dataset_type, item)
154
+ formatted_instruction = formatted_instruction + INSTRUCTION
155
+
156
+ if 'image_url' in item:
157
+ message = [{"role": "user", "content": [{"type": "image", "image": ""}, {"type": "text", "text": formatted_instruction}]}]
158
+ else:
159
+ message = [{"role": "user", "content": [{"type": "text", "text": formatted_instruction}]}]
160
+
161
+ text = processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
162
+ if 'image_url' in item:
163
+ input_instance = {'prompt': text, 'multi_modal_data': {'image': item['image_url']}}
164
+ else:
165
+ input_instance = {'prompt': text}
166
+
167
+ # print(f'input_instance: {input_instance}')
168
+
169
+ run_input_instances.append(input_instance)
170
+ run_indices.append(index)
171
+
172
+ processed_response = get_processed_response(dataset_type, item)
173
+ # print(f'response: {item["response"]} | processed_response: {processed_response} | choices: {item["choices"]} | ')
174
+ run_processed_responses.append(processed_response)
175
+ run_items.append(item)
176
+ run_formatted_instructions.append(formatted_instruction)
177
+
178
+ outputs = vlm.generate(run_input_instances, sampling_params=sampling_params)
179
+
180
+ for j in range(len(run_indices)):
181
+ answer = outputs[j].outputs[0].text
182
+ processed_response = run_processed_responses[j]
183
+ item = run_items[j]
184
+ formatted_instruction = run_formatted_instructions[j]
185
+
186
+ if 'image_url' in item:
187
+ del item['image_url']
188
+
189
+ description = {
190
+ 'index': j,
191
+ 'item': json.dumps(item),
192
+ 'formatted_instruction': formatted_instruction,
193
+ 'processed_response': processed_response,
194
+ 'answer': answer
195
+ }
196
+
197
+ with open(output_file, 'w') as f:
198
+ json.dump(description, f, indent = 4)
199
+
200
+ if __name__ == "__main__":
201
+ main()
202
+
203
+ #
204
+ # VLLM_WORKER_MULTIPROC_METHOD=spawn VLLM_DISABLE_COMPILE_CACHE=1 CUDA_VISIBLE_DEVICES=3,4,5,6 python eval_qwen_multi_vllm.py --dataset mathvista --name qwen3_vl_2b_instruct_vllm
205
+ #