Peiran commited on
Commit
88f2a10
·
1 Parent(s): b25a877

Pairing improvements: filter already-evaluated pairs from /data, round-robin schedule across test_ids, alternate A/B order per pair; ensure submit maps scores to correct model columns and auto-advance

Browse files
Files changed (1) hide show
  1. app.py +93 -6
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import csv
2
  import itertools
 
3
  import json
4
  import os
5
  import uuid
@@ -15,6 +16,7 @@ except Exception: # optional dependency at runtime
15
 
16
 
17
  BASE_DIR = os.path.dirname(__file__)
 
18
  # Persistent local storage inside HF Spaces
19
  PERSIST_DIR = os.environ.get("PERSIST_DIR", "/data")
20
  TASK_CONFIG = {
@@ -35,6 +37,11 @@ def _csv_path_for_task(task_name: str, filename: str) -> str:
35
  return os.path.join(BASE_DIR, folder, filename)
36
 
37
 
 
 
 
 
 
38
  def _resolve_image_path(path: str) -> str:
39
  return path if os.path.isabs(path) else os.path.join(BASE_DIR, path)
40
 
@@ -87,12 +94,72 @@ def _build_image_pairs(rows: List[Dict[str, str]], task_name: str) -> List[Dict[
87
  return pairs
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def load_task(task_name: str):
91
  if not task_name:
92
  raise gr.Error("请先选择任务。")
93
 
94
  rows = _load_task_rows(task_name)
95
  pairs = _build_image_pairs(rows, task_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  if not pairs:
97
  raise gr.Error("没有找到可评测的图片对,请检查数据文件。")
98
 
@@ -204,13 +271,16 @@ def on_task_change(task_name: str, _state_pairs: List[Dict[str, str]]):
204
  header = _format_pair_header(pair)
205
  # Defaults for A and B (8 sliders total)
206
  default_scores = [3, 3, 3, 3, 3, 3, 3, 3]
 
 
 
207
  return (
208
  pairs,
209
  gr.update(value=0, minimum=0, maximum=len(pairs) - 1, visible=(len(pairs) > 1)),
210
  gr.update(value=header),
211
  _resolve_image_path(pair["org_img"]),
212
- _resolve_image_path(pair["model1_path"]),
213
- _resolve_image_path(pair["model2_path"]),
214
  *default_scores,
215
  gr.update(value=f"共 {len(pairs)} 个待评测的图片对。"),
216
  )
@@ -223,12 +293,14 @@ def on_pair_navigate(index: int, pairs: List[Dict[str, str]]):
223
  index = max(0, min(index, len(pairs) - 1))
224
  pair = pairs[index]
225
  header = _format_pair_header(pair)
 
 
226
  return (
227
  gr.update(value=index),
228
  gr.update(value=header),
229
  _resolve_image_path(pair["org_img"]),
230
- _resolve_image_path(pair["model1_path"]),
231
- _resolve_image_path(pair["model2_path"]),
232
  3, 3, 3, 3, # A
233
  3, 3, 3, 3, # B
234
  )
@@ -266,6 +338,19 @@ def on_submit(
266
  "model2_semantic_functional_alignment_score": int(b_semantic_score),
267
  "model2_overall_photorealism_score": int(b_overall_score),
268
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  row = _build_eval_row(pair, score_map)
270
  ok_local = _append_local_persist_csv(task_name, row)
271
  ok_hub, hub_msg = _upload_eval_record_to_dataset(task_name, row)
@@ -278,12 +363,14 @@ def on_submit(
278
  if next_index != index:
279
  pair = pairs[next_index]
280
  header = _format_pair_header(pair)
 
 
281
  return (
282
  gr.update(value=next_index),
283
  gr.update(value=header),
284
  _resolve_image_path(pair["org_img"]),
285
- _resolve_image_path(pair["model1_path"]),
286
- _resolve_image_path(pair["model2_path"]),
287
  3, 3, 3, 3,
288
  3, 3, 3, 3,
289
  gr.update(value=info + f" 自动跳转到下一组({next_index + 1}/{len(pairs)})。"),
 
1
  import csv
2
  import itertools
3
+ import random
4
  import json
5
  import os
6
  import uuid
 
16
 
17
 
18
  BASE_DIR = os.path.dirname(__file__)
19
+ PERSIST_DIR = os.environ.get("PERSIST_DIR", "/data")
20
  # Persistent local storage inside HF Spaces
21
  PERSIST_DIR = os.environ.get("PERSIST_DIR", "/data")
22
  TASK_CONFIG = {
 
37
  return os.path.join(BASE_DIR, folder, filename)
38
 
39
 
40
+ def _persist_csv_path_for_task(task_name: str) -> str:
41
+ folder = TASK_CONFIG[task_name]["folder"]
42
+ return os.path.join(PERSIST_DIR, folder, "evaluation_results.csv")
43
+
44
+
45
  def _resolve_image_path(path: str) -> str:
46
  return path if os.path.isabs(path) else os.path.join(BASE_DIR, path)
47
 
 
94
  return pairs
95
 
96
 
97
+ def _read_existing_eval_keys(task_name: str) -> set:
98
+ """Read already-evaluated pair keys from persistent CSV, return a set of keys.
99
+ Key is (test_id, frozenset({model1_name, model2_name}), org_img) to ignore A/B order.
100
+ """
101
+ keys = set()
102
+ csv_path = _persist_csv_path_for_task(task_name)
103
+ if not os.path.exists(csv_path):
104
+ return keys
105
+ try:
106
+ with open(csv_path, newline="", encoding="utf-8") as f:
107
+ reader = csv.DictReader(f)
108
+ for r in reader:
109
+ tid = str(r.get("test_id", "")).strip()
110
+ m1 = str(r.get("model1_name", "")).strip()
111
+ m2 = str(r.get("model2_name", "")).strip()
112
+ org = str(r.get("org_img", "")).strip()
113
+ if tid and m1 and m2 and org:
114
+ keys.add((tid, frozenset({m1, m2}), org))
115
+ except Exception:
116
+ pass
117
+ return keys
118
+
119
+
120
+ def _schedule_round_robin_by_test_id(pairs: List[Dict[str, str]], seed: int | None = None) -> List[Dict[str, str]]:
121
+ """Interleave pairs across test_ids for balanced coverage; shuffle within each group.
122
+ """
123
+ groups: Dict[str, List[Dict[str, str]]] = {}
124
+ for p in pairs:
125
+ groups.setdefault(p["test_id"], []).append(p)
126
+ rnd = random.Random(seed)
127
+ for lst in groups.values():
128
+ rnd.shuffle(lst)
129
+ # round-robin drain
130
+ ordered: List[Dict[str, str]] = []
131
+ while True:
132
+ progressed = False
133
+ for tid in sorted(groups.keys(), key=lambda x: (int(x) if x.isdigit() else x)):
134
+ if groups[tid]:
135
+ ordered.append(groups[tid].pop())
136
+ progressed = True
137
+ if not progressed:
138
+ break
139
+ return ordered
140
+
141
+
142
  def load_task(task_name: str):
143
  if not task_name:
144
  raise gr.Error("请先选择任务。")
145
 
146
  rows = _load_task_rows(task_name)
147
  pairs = _build_image_pairs(rows, task_name)
148
+ # Filter out already evaluated pairs from persistent CSV
149
+ done_keys = _read_existing_eval_keys(task_name)
150
+ def key_of(p: Dict[str, str]):
151
+ return (p["test_id"], frozenset({p["model1_name"], p["model2_name"]}), p["org_img"])
152
+ pairs = [p for p in pairs if key_of(p) not in done_keys]
153
+
154
+ # Balanced schedule across test_ids with a stable randomization
155
+ seed_env = os.environ.get("SCHEDULE_SEED")
156
+ seed = int(seed_env) if seed_env and seed_env.isdigit() else None
157
+ pairs = _schedule_round_robin_by_test_id(pairs, seed=seed)
158
+
159
+ # Assign A/B order to counteract position bias: alternate after scheduling
160
+ for idx, p in enumerate(pairs):
161
+ p["swap"] = bool(idx % 2) # True -> A=B's image; False -> A=A's image
162
+
163
  if not pairs:
164
  raise gr.Error("没有找到可评测的图片对,请检查数据文件。")
165
 
 
271
  header = _format_pair_header(pair)
272
  # Defaults for A and B (8 sliders total)
273
  default_scores = [3, 3, 3, 3, 3, 3, 3, 3]
274
+ # Pick display order according to swap flag
275
+ a_path = pair["model2_path"] if pair.get("swap") else pair["model1_path"]
276
+ b_path = pair["model1_path"] if pair.get("swap") else pair["model2_path"]
277
  return (
278
  pairs,
279
  gr.update(value=0, minimum=0, maximum=len(pairs) - 1, visible=(len(pairs) > 1)),
280
  gr.update(value=header),
281
  _resolve_image_path(pair["org_img"]),
282
+ _resolve_image_path(a_path),
283
+ _resolve_image_path(b_path),
284
  *default_scores,
285
  gr.update(value=f"共 {len(pairs)} 个待评测的图片对。"),
286
  )
 
293
  index = max(0, min(index, len(pairs) - 1))
294
  pair = pairs[index]
295
  header = _format_pair_header(pair)
296
+ a_path = pair["model2_path"] if pair.get("swap") else pair["model1_path"]
297
+ b_path = pair["model1_path"] if pair.get("swap") else pair["model2_path"]
298
  return (
299
  gr.update(value=index),
300
  gr.update(value=header),
301
  _resolve_image_path(pair["org_img"]),
302
+ _resolve_image_path(a_path),
303
+ _resolve_image_path(b_path),
304
  3, 3, 3, 3, # A
305
  3, 3, 3, 3, # B
306
  )
 
338
  "model2_semantic_functional_alignment_score": int(b_semantic_score),
339
  "model2_overall_photorealism_score": int(b_overall_score),
340
  }
341
+ # Map A/B scores to the correct model columns depending on swap
342
+ if pair.get("swap"):
343
+ # UI A == model2, UI B == model1
344
+ score_map = {
345
+ "model1_physical_interaction_fidelity_score": int(b_physical_score),
346
+ "model1_optical_effect_accuracy_score": int(b_optical_score),
347
+ "model1_semantic_functional_alignment_score": int(b_semantic_score),
348
+ "model1_overall_photorealism_score": int(b_overall_score),
349
+ "model2_physical_interaction_fidelity_score": int(a_physical_score),
350
+ "model2_optical_effect_accuracy_score": int(a_optical_score),
351
+ "model2_semantic_functional_alignment_score": int(a_semantic_score),
352
+ "model2_overall_photorealism_score": int(a_overall_score),
353
+ }
354
  row = _build_eval_row(pair, score_map)
355
  ok_local = _append_local_persist_csv(task_name, row)
356
  ok_hub, hub_msg = _upload_eval_record_to_dataset(task_name, row)
 
363
  if next_index != index:
364
  pair = pairs[next_index]
365
  header = _format_pair_header(pair)
366
+ a_path = pair["model2_path"] if pair.get("swap") else pair["model1_path"]
367
+ b_path = pair["model1_path"] if pair.get("swap") else pair["model2_path"]
368
  return (
369
  gr.update(value=next_index),
370
  gr.update(value=header),
371
  _resolve_image_path(pair["org_img"]),
372
+ _resolve_image_path(a_path),
373
+ _resolve_image_path(b_path),
374
  3, 3, 3, 3,
375
  3, 3, 3, 3,
376
  gr.update(value=info + f" 自动跳转到下一组({next_index + 1}/{len(pairs)})。"),