DeepONet-FPO-demo / app_v2.py
arabeh's picture
minor edits
c7a4de3
raw
history blame
7.39 kB
from pathlib import Path
import io
import zipfile
import tempfile
from functools import lru_cache
import numpy as np
import torch
import gradio as gr
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt
import imageio.v2 as imageio
from mpl_toolkits.axes_grid1 import make_axes_locatable
from einops import rearrange
from models.geometric_deeponet.geometric_deeponet import GeometricDeepONetTime
# ---------------- Config ----------------
REPO_ID = "arabeh/DeepONet-FlowBench-FPO"
CKPTS = {
"1": "checkpoints/time-dependent-deeponet_1in.ckpt",
"4": "checkpoints/time-dependent-deeponet_4in.ckpt",
"8": "checkpoints/time-dependent-deeponet_8in.ckpt",
"16": "checkpoints/time-dependent-deeponet_16in.ckpt",
}
# v2 samples live here (only 16 GT timesteps per sample)
SAMPLES_DIR = Path("sample_cases") / "few_timesteps"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TMP = Path(tempfile.gettempdir())
RANGES = {
"u": (-2.0, 2.0),
"v": (-1.0, 1.0),
}
def _tag() -> str:
return next(tempfile._get_candidate_names())
def _tmp(tag: str, name: str) -> str:
out_dir = TMP / f"deeponet_fpo_{tag}"
out_dir.mkdir(parents=True, exist_ok=True)
return str(out_dir / name)
# ---------------- Samples ----------------
def list_samples():
if not SAMPLES_DIR.is_dir():
return []
ids = []
for p in SAMPLES_DIR.glob("sample_*_input.npy"):
sid = p.stem.split("_")[1]
if sid.isdigit():
ids.append(sid)
return sorted(set(ids), key=int)
def load_sample(sample_id: str):
sdf = np.load(SAMPLES_DIR / f"sample_{sample_id}_input.npy").astype(np.float32) # [1,H,W]
y16 = np.load(SAMPLES_DIR / f"sample_{sample_id}_output.npy").astype(np.float32) # [16,2,H,W]
return sdf, y16
# ---------------- Model ----------------
@lru_cache(maxsize=4)
def load_model(history_s: int) -> GeometricDeepONetTime:
ckpt_path = hf_hub_download(REPO_ID, CKPTS[str(history_s)])
model = GeometricDeepONetTime.load_from_checkpoint(ckpt_path, map_location=DEVICE)
return model.eval().to(DEVICE)
def static_tensors(hparams, sdf_np: np.ndarray):
_, H, W = sdf_np.shape
x = np.linspace(0.0, float(hparams.domain_length_x), W, dtype=np.float32)
y = np.linspace(0.0, float(hparams.domain_length_y), H, dtype=np.float32)
yv, xv = np.meshgrid(y, x, indexing="ij")
coords = np.stack([xv, yv], axis=0)[None] # [1,2,H,W]
sdf_t = torch.from_numpy(sdf_np)[None].to(DEVICE) # [1,1,H,W]
coords_t = torch.from_numpy(coords).to(DEVICE) # [1,2,H,W]
re_t = torch.zeros_like(sdf_t) # [1,1,H,W]
return sdf_t, coords_t, re_t, H, W
# ---------------- Rollout ----------------
def rollout_pred(sample_id: str, history_s: str, n_steps: int):
s = int(history_s)
n_steps = int(n_steps)
if n_steps <= 0:
raise ValueError("Number of rollout steps must be a positive integer.")
if n_steps < s:
n_steps = s # must have at least s frames to seed
model = load_model(s)
sdf, y16 = load_sample(sample_id)
# Expect [16,2,H,W] (or more), but we ONLY use first s to seed the model.
if y16.ndim != 4 or y16.shape[1] != 2:
raise ValueError(f"Expected y shape [T,2,H,W], got {y16.shape}")
if y16.shape[0] < s:
raise ValueError(f"Sample only has {y16.shape[0]} timesteps, but checkpoint needs s={s}.")
_, _, H, W = y16.shape
sdf_t, coords_t, re_t, _, _ = static_tensors(model.hparams, sdf)
seed = y16[:s].copy() # [s,2,H,W] (GT seed only)
y_out = np.zeros((n_steps, 2, H, W), dtype=np.float32)
y_out[:s] = seed
history = seed.copy()
for t in range(s, n_steps):
branch = rearrange(history, "nb c h w -> (nb c) h w")[None] # [1,s*2,H,W]
branch_t = torch.from_numpy(branch).to(DEVICE)
with torch.no_grad():
y_hat = model((branch_t, re_t, coords_t, sdf_t)) # [1,1,p,2]
frame = y_hat[0, 0].view(H, W, 2).permute(2, 0, 1).cpu().numpy().astype(np.float32) # [2,H,W]
y_out[t] = frame
history = frame[None] if s == 1 else np.concatenate([history[1:], frame[None]], axis=0)
return y_out, s
# ---------------- Rendering (prediction-only) ----------------
def single_png(field2d: np.ndarray, label: str, t: int) -> bytes:
vmin, vmax = RANGES.get(label, (-1.0, 1.0))
fig, ax = plt.subplots(1, 1, figsize=(3.4, 2.8))
im = ax.imshow(field2d, origin="lower", vmin=vmin, vmax=vmax)
ax.set_title(f"{label} – t={t}")
ax.axis("off")
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="6%", pad=0.05)
fig.colorbar(im, cax=cax)
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", dpi=120)
plt.close(fig)
return buf.getvalue()
def write_gif(tag: str, y: np.ndarray, comp: int, label: str) -> str:
path = _tmp(tag, f"{label}_rollout.gif")
with imageio.get_writer(path, mode="I", duration=0.1, loop=0) as w:
for t in range(y.shape[0]):
png = single_png(y[t, comp], label, t)
w.append_data(imageio.imread(io.BytesIO(png)))
return path
def write_zip(tag: str, y: np.ndarray, comp: int, label: str) -> str:
path = _tmp(tag, f"{label}_frames.zip")
with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
for t in range(y.shape[0]):
zf.writestr(f"{label}_frame_{t:03d}.png", single_png(y[t, comp], label, t))
return path
# ---------------- Gradio callback ----------------
def run_v2(sample_id: str, history_s: str, n_steps: int):
tag = _tag()
y, s = rollout_pred(sample_id, history_s, n_steps)
u_gif = write_gif(tag, y, comp=0, label="u")
v_gif = write_gif(tag, y, comp=1, label="v")
u_zip = write_zip(tag, y, comp=0, label="u")
v_zip = write_zip(tag, y, comp=1, label="v")
summary = (
f"Seeded with s={s} timesteps from {SAMPLES_DIR}.\n"
f"Generated rollout length N={y.shape[0]} (frames labeled seed for t<s, pred for t≥s)."
)
return (
u_gif, u_gif, u_zip,
v_gif, v_gif, v_zip,
summary,
)
# ---------------- UI builder ----------------
def build_demo():
sample_choices = list_samples() or ["0"]
history_choices = ["1", "4", "8", "16"]
return gr.Interface(
fn=run_v2,
inputs=[
gr.Radio(sample_choices, value=sample_choices[0], label="Sample ID"),
gr.Radio(history_choices, value="16", label="History length s (checkpoint)"),
gr.Number(value=60, precision=0, label="Rollout steps N (total frames)"),
],
outputs=[
gr.Image(type="filepath", label="u rollout (GIF)"),
gr.File(label="Download u rollout (GIF)"),
gr.File(label="Download all u frames (ZIP)"),
gr.Image(type="filepath", label="v rollout (GIF)"),
gr.File(label="Download v rollout (GIF)"),
gr.File(label="Download all v frames (ZIP)"),
gr.Textbox(label="Run summary"),
],
title="Time-Dependent DeepONet – FPO Rollout Demo",
description=(
"Auto-regressive rollout of u and v fields for a selected sample. "
"Choose history length s (1, 4, 8, 16). Download videos/frames."
),
)
if __name__ == "__main__":
build_demo().launch()