Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| import io | |
| import zipfile | |
| import tempfile | |
| from functools import lru_cache | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import matplotlib.pyplot as plt | |
| import imageio.v2 as imageio | |
| from mpl_toolkits.axes_grid1 import make_axes_locatable | |
| from einops import rearrange | |
| from models.geometric_deeponet.geometric_deeponet import GeometricDeepONetTime | |
| # ---------------- Config ---------------- | |
| REPO_ID = "arabeh/DeepONet-FlowBench-FPO" | |
| CKPTS = { | |
| "1": "checkpoints/time-dependent-deeponet_1in.ckpt", | |
| "4": "checkpoints/time-dependent-deeponet_4in.ckpt", | |
| "8": "checkpoints/time-dependent-deeponet_8in.ckpt", | |
| "16": "checkpoints/time-dependent-deeponet_16in.ckpt", | |
| } | |
| SAMPLES_DIR = Path("sample_cases") | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| TMP = Path(tempfile.gettempdir()) | |
| RANGES = { | |
| "u": (-2.0, 2.0), | |
| "v": (-1.0, 1.0), | |
| } | |
| def _tag() -> str: | |
| # unique per request (avoids filename collisions across sessions) | |
| return next(tempfile._get_candidate_names()) | |
| def _tmp(tag: str, name: str) -> str: | |
| out_dir = TMP / f"deeponet_fpo_{tag}" | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| return str(out_dir / name) | |
| # ---------------- Samples ---------------- | |
| def list_samples(): | |
| if not SAMPLES_DIR.is_dir(): | |
| return [] | |
| ids = [] | |
| for p in SAMPLES_DIR.glob("sample_*_input.npy"): | |
| # sample_{id}_input.npy | |
| sid = p.stem.split("_")[1] | |
| if sid.isdigit(): | |
| ids.append(sid) | |
| return sorted(set(ids), key=int) | |
| def load_sample(sample_id: str): | |
| sdf = np.load(SAMPLES_DIR / f"sample_{sample_id}_input.npy").astype(np.float32) # [1,H,W] | |
| y = np.load(SAMPLES_DIR / f"sample_{sample_id}_output.npy").astype(np.float32) # [T,2,H,W] | |
| return sdf, y | |
| # ---------------- Model ---------------- | |
| def load_model(history_s: int) -> GeometricDeepONetTime: | |
| ckpt_path = hf_hub_download(REPO_ID, CKPTS[str(history_s)]) | |
| model = GeometricDeepONetTime.load_from_checkpoint(ckpt_path, map_location=DEVICE) | |
| return model.eval().to(DEVICE) | |
| def static_tensors(hparams, sdf_np: np.ndarray): | |
| _, H, W = sdf_np.shape | |
| x = np.linspace(0.0, float(hparams.domain_length_x), W, dtype=np.float32) | |
| y = np.linspace(0.0, float(hparams.domain_length_y), H, dtype=np.float32) | |
| yv, xv = np.meshgrid(y, x, indexing="ij") | |
| coords = np.stack([xv, yv], axis=0)[None] # [1,2,H,W] | |
| sdf_t = torch.from_numpy(sdf_np)[None].to(DEVICE) # [1,1,H,W] | |
| coords_t = torch.from_numpy(coords).to(DEVICE) # [1,2,H,W] | |
| re_t = torch.zeros_like(sdf_t) # [1,1,H,W] | |
| return sdf_t, coords_t, re_t, H, W | |
| # ---------------- Rollout + metrics ---------------- | |
| def rollout(sample_id: str, history_s: str): | |
| s = int(history_s) | |
| model = load_model(s) | |
| sdf, y_true = load_sample(sample_id) | |
| T, C, H, W = y_true.shape | |
| if C != 2: | |
| raise ValueError(f"Expected 2 channels (u,v), got {C}") | |
| s = min(s, T - 1) # ensure s < T | |
| sdf_t, coords_t, re_t, _, _ = static_tensors(model.hparams, sdf) | |
| y_pred = np.zeros_like(y_true) | |
| y_pred[:s] = y_true[:s] | |
| history = y_true[:s].copy() # [s,2,H,W] | |
| for t in range(s, T): | |
| branch = rearrange(history, "nb c h w -> (nb c) h w")[None] # [1,s*2,H,W] | |
| branch_t = torch.from_numpy(branch).to(DEVICE) | |
| with torch.no_grad(): | |
| y_hat = model((branch_t, re_t, coords_t, sdf_t)) # [1,1,p,2] | |
| frame = y_hat[0, 0].view(H, W, 2).permute(2, 0, 1).cpu().numpy() # [2,H,W] | |
| y_pred[t] = frame | |
| history = frame[None] if s == 1 else np.concatenate([history[1:], frame[None]], axis=0) | |
| return y_true, y_pred, s | |
| def rollout_errors(y_true: np.ndarray, y_pred: np.ndarray, s: int): | |
| yt = y_true[s:] | |
| yp = y_pred[s:] | |
| diff = yp - yt | |
| ts = np.arange(s, y_true.shape[0]) | |
| def rel(comp: int): | |
| d = diff[:, comp].reshape(len(ts), -1) | |
| t = yt[:, comp].reshape(len(ts), -1) | |
| return np.linalg.norm(d, axis=1) / np.linalg.norm(t, axis=1) | |
| err_u = rel(0) | |
| err_v = rel(1) | |
| return ts, err_u, err_v, float(err_u.mean()), float(err_v.mean()) | |
| def pair_png(gt2d: np.ndarray, pred2d: np.ndarray, label: str, t: int) -> bytes: | |
| vmin, vmax = RANGES.get(label, (-1.0, 1.0)) # fallback if label changes | |
| fig, ax = plt.subplots(1, 2, figsize=(6.5, 2.6)) | |
| ax[0].imshow(gt2d, origin="lower", vmin=vmin, vmax=vmax) | |
| ax[0].set_title(f"{label} GT β t={t}") | |
| ax[0].axis("off") | |
| im2 = ax[1].imshow(pred2d, origin="lower", vmin=vmin, vmax=vmax) | |
| ax[1].set_title(f"{label} Pred β t={t}") | |
| ax[1].axis("off") | |
| # Colorbar height == ax[1] image height | |
| divider = make_axes_locatable(ax[1]) | |
| cax = divider.append_axes("right", size="5%", pad=0.05) | |
| fig.colorbar(im2, cax=cax) | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", bbox_inches="tight", dpi=110) | |
| plt.close(fig) | |
| return buf.getvalue() | |
| def write_gif(tag: str, y_true: np.ndarray, y_pred: np.ndarray, comp: int, label: str) -> str: | |
| path = _tmp(tag, f"{label}_rollout.gif") | |
| with imageio.get_writer(path, mode="I", duration=0.1, loop=0) as w: | |
| for t in range(y_true.shape[0]): | |
| png = pair_png(y_true[t, comp], y_pred[t, comp], label, t) | |
| w.append_data(imageio.imread(io.BytesIO(png))) | |
| return path | |
| def write_zip(tag: str, y_true: np.ndarray, y_pred: np.ndarray, comp: int, label: str) -> str: | |
| path = _tmp(tag, f"{label}_frames.zip") | |
| with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf: | |
| for t in range(y_true.shape[0]): | |
| zf.writestr(f"{label}_frame_{t:03d}.png", pair_png(y_true[t, comp], y_pred[t, comp], label, t)) | |
| return path | |
| def write_error_assets(tag: str, ts: np.ndarray, err_u: np.ndarray, err_v: np.ndarray): | |
| png = _tmp(tag, "relL2_vs_time.png") | |
| csv = _tmp(tag, "relL2_vs_time.csv") | |
| np.savetxt( | |
| csv, | |
| np.c_[ts, err_u, err_v], | |
| delimiter=",", | |
| header="timestep,rel_L2_u,rel_L2_v", | |
| comments="", | |
| ) | |
| fig, ax = plt.subplots(figsize=(5, 3)) | |
| ax.plot(ts, err_u, label="u") | |
| ax.plot(ts, err_v, label="v") | |
| ax.set_xlabel("Timestep") | |
| ax.set_ylabel("Relative L2") | |
| ax.set_title("Rollout rel. L2 vs time") | |
| ax.legend() | |
| ax.grid(True, alpha=0.3) | |
| fig.savefig(png, dpi=120, bbox_inches="tight") | |
| plt.close(fig) | |
| return png, csv | |
| # ---------------- Gradio callback ---------------- | |
| def predict_rollout(sample_id: str, history_s: str): | |
| tag = _tag() | |
| y_true, y_pred, s = rollout(sample_id, history_s) | |
| ts, err_u, err_v, avg_u, avg_v = rollout_errors(y_true, y_pred, s) | |
| u_gif = write_gif(tag, y_true, y_pred, 0, "u") | |
| v_gif = write_gif(tag, y_true, y_pred, 1, "v") | |
| u_zip = write_zip(tag, y_true, y_pred, 0, "u") | |
| v_zip = write_zip(tag, y_true, y_pred, 1, "v") | |
| err_png, csv = write_error_assets(tag, ts, err_u, err_v) | |
| metrics = ( | |
| f"Rollout relative L2 error (averaged over t β₯ {s}):\n" | |
| f" u: {avg_u:.3e}\n" | |
| f" v: {avg_v:.3e}" | |
| ) | |
| return (u_gif, u_gif, u_zip, v_gif, v_gif, v_zip, err_png, csv, metrics) | |
| # ---------------- UI builder ---------------- | |
| def build_demo(): | |
| sample_choices = list_samples() or ["0"] | |
| return gr.Interface( | |
| fn=predict_rollout, | |
| inputs=[ | |
| gr.Radio(sample_choices, value=sample_choices[0], label="Sample ID"), | |
| gr.Radio(["1", "4", "8", "16"], value="16", label="History length s"), | |
| ], | |
| outputs=[ | |
| gr.Image(type="filepath", label="u rollout (GIF)"), | |
| gr.File(label="Download u rollout (GIF)"), | |
| gr.File(label="Download all u frames (ZIP)"), | |
| gr.Image(type="filepath", label="v rollout (GIF)"), | |
| gr.File(label="Download v rollout (GIF)"), | |
| gr.File(label="Download all v frames (ZIP)"), | |
| gr.Image(type="filepath", label="Relative L2 vs time"), | |
| gr.File(label="Download L2 vs time (CSV)"), | |
| gr.Textbox(label="Summary metrics"), | |
| ], | |
| title="Time-Dependent DeepONet β FPO Rollout Demo", | |
| description=( | |
| "Auto-regressive 60-step rollout of u and v fields for a selected sample. " | |
| "Choose history length s (1, 4, 8, 16). Download videos/frames and relative error vs time (CSV)." | |
| ), | |
| ) | |
| if __name__ == "__main__": | |
| build_demo().launch() | |