# app.py, single file Space # FieldSpace, Prime-Only Machine, mod 2^K exactness proofs import sys, subprocess, time, math, json def _ensure(package_spec: str): try: __import__(package_spec.split("==")[0].split(">=")[0].split("<")[0]) return except Exception: subprocess.check_call([sys.executable, "-m", "pip", "install", package_spec]) # Self install, single file repo, no external requirements for spec in ["gradio==4.44.1", "numpy>=1.26", "torch>=2.2,<2.4"]: _ensure(spec) import numpy as np import torch import gradio as gr torch.set_grad_enabled(False) torch.set_num_threads(1) DEVICE = "cpu" # ----------------------------- # Utilities (mod 2^K) # ----------------------------- def seed_all(s: int): return torch.Generator(device=DEVICE).manual_seed(int(s)) def choose_K(max_abs: int, min_bits: int = 24) -> int: need = max(min_bits, int(max_abs).bit_length() + 1) return min(62, max(need, min_bits)) def mod2_encode(x: torch.Tensor, MOD2: int) -> torch.Tensor: return (x & (MOD2 - 1)).to(torch.int64) def mod2_center(v: torch.Tensor, MOD2: int) -> torch.Tensor: K = MOD2.bit_length() - 1 half = 1 << (K - 1) out = v.clone() out[out > half] -= MOD2 return out # ----------------------------- # CNN (Conv -> Act -> Conv -> Act -> FC) # ----------------------------- def im2col_nchw_int(x, kH, kW, stride=1, pad=0): N, C, H, W = x.shape Hout = (H + 2*pad - kH)//stride + 1 Wout = (W + 2*pad - kW)//stride + 1 if pad: x_pad = torch.zeros((N, C, H + 2*pad, W + 2*pad), dtype=x.dtype, device=x.device) x_pad[:, :, pad:pad+H, pad:pad+W] = x else: x_pad = x cols = [] for i in range(Hout): for j in range(Wout): patch = x_pad[:, :, i*stride:i*stride+kH, j*stride:j*stride+kW] cols.append(patch.reshape(N, -1)) out = torch.stack(cols, dim=1) # [N, Hout*Wout, C*kH*kW] out = out.reshape(N*Hout*Wout, -1) # [N*Hout*Wout, C*kH*kW] return out, Hout, Wout def conv_int64_im2col(X, W, kH, kW, stride=1, pad=0): Xcol, Hout, Wout = im2col_nchw_int(X, kH, kW, stride=stride, pad=pad) Ycol = Xcol @ W.reshape(W.shape[0], -1).t() return Ycol.reshape(X.shape[0], Hout, Wout, W.shape[0]).permute(0,3,1,2).contiguous() def run_cnn_proof(seed:int=0, Xmax:int=31, Wmax:int=31, act_kind:str="relu", poly_shift:int=7): g = seed_all(seed) B,Cin,H,W = 4,1,16,16 C1,C2 = 8,8 kH,kW = 3,3 STRIDE,PAD = 1,1 CLS = 10 X = torch.randint(-Xmax, Xmax+1, (B,Cin,H,W), dtype=torch.int64, generator=g, device=DEVICE) W1 = torch.randint(-Wmax, Wmax+1, (C1,Cin,kH,kW), dtype=torch.int64, generator=g, device=DEVICE) W2 = torch.randint(-Wmax, Wmax+1, (C2,C1,kH,kW), dtype=torch.int64, generator=g, device=DEVICE) Wf = torch.randint(-Wmax, Wmax+1, (C2*H*W, CLS), dtype=torch.int64, generator=g, device=DEVICE) t0 = time.time() Z1 = conv_int64_im2col(X, W1, kH,kW, STRIDE, PAD) if act_kind=="relu": A1 = Z1.clamp_min_(0) else: A1 = ((Z1*Z1) >> int(poly_shift)) Z2 = conv_int64_im2col(A1, W2, kH,kW, STRIDE, PAD) if act_kind=="relu": A2 = Z2.clamp_min_(0) else: A2 = ((Z2*Z2) >> int(poly_shift)) Y = (A2.reshape(B, -1) @ Wf) t1 = time.time() max_abs = int(max(Z1.abs().max().item(), A1.abs().max().item(), Z2.abs().max().item(), A2.abs().max().item(), Y.abs().max().item())) K = choose_K(max_abs, min_bits=24) MOD2 = 1 << K Z1m = mod2_encode(conv_int64_im2col(mod2_encode(X,MOD2), mod2_encode(W1,MOD2), kH,kW, STRIDE, PAD), MOD2) Z1i = mod2_center(Z1m, MOD2) A1i = Z1i.clamp_min_(0) if act_kind=="relu" else ((Z1i*Z1i) >> int(poly_shift)) Z2m = mod2_encode(conv_int64_im2col(mod2_encode(A1i,MOD2), mod2_encode(W2,MOD2), kH,kW, STRIDE, PAD), MOD2) Z2i = mod2_center(Z2m, MOD2) A2i = Z2i.clamp_min_(0) if act_kind=="relu" else ((Z2i*Z2i) >> int(poly_shift)) Ym = mod2_encode((mod2_encode(A2i.reshape(B,-1),MOD2) @ mod2_encode(Wf,MOD2)), MOD2) Yi = mod2_center(Ym, MOD2) ok_all = bool(torch.equal(Y, Yi)) arg_ok = bool(torch.equal(Y.argmax(1), Yi.argmax(1))) txt = [] txt.append(f"K={K} (2^(K-1)={1<<(K-1):,} > max_abs={max_abs:,}), act={act_kind} ({'ReLU' if act_kind=='relu' else f'poly SHIFT={poly_shift}'})") txt.append("Stage by stage equality, exact:") txt.append(f"Z1: {'OK' if torch.equal(Z1,Z1i) else 'NO'} | A1: {'OK' if torch.equal(A1,A1i) else 'NO'} | " f"Z2: {'OK' if torch.equal(Z2,Z2i) else 'NO'} | Y: {'OK' if torch.equal(Y,Yi) else 'NO'}") txt.append(f"Argmax match: {arg_ok} (ref={Y.argmax(1).tolist()}, mod2={Yi.argmax(1).tolist()})") txt.append(f"Timing ms: ref≈{(t1-t0)*1000:.2f}") return "\n".join(txt), { "ok_all": ok_all, "argmax_ok": arg_ok, "K": K, "act": act_kind, "poly_shift": int(poly_shift), "ref_top1": Y.argmax(1).tolist(), "mod2_top1": Yi.argmax(1).tolist() } # ----------------------------- # 1 head Attention # ----------------------------- def split_heads(Z, B,T,C,H): Dh = C//H return Z.reshape(B,T,H,Dh).permute(0,2,1,3).contiguous() def merge_heads(Zh, B,T,C,H): return Zh.permute(0,2,1,3).reshape(B,T,C).contiguous() def pick_shift_for_cap(smax:int, cap:int=8) -> int: if smax <= cap: return 0 s = 0 while (smax >> s) > cap: s += 1 return s def run_attn_proof(seed:int=0, Xmax:int=7, Wmax:int=7, B:int=1, T:int=8, C:int=16, H:int=1): assert C % H == 0 Dh = C//H g = seed_all(seed) X = torch.randint(0, Xmax+1, (B,T,C), dtype=torch.int64, generator=g, device=DEVICE) Wq = torch.randint(0, Wmax+1, (C,C), dtype=torch.int64, generator=g, device=DEVICE) Wk = torch.randint(0, Wmax+1, (C,C), dtype=torch.int64, generator=g, device=DEVICE) Wv = torch.randint(0, Wmax+1, (C,C), dtype=torch.int64, generator=g, device=DEVICE) Wo = torch.randint(0, Wmax+1, (C,C), dtype=torch.int64, generator=g, device=DEVICE) t0 = time.time() Q = X @ Wq; K = X @ Wk; V = X @ Wv Qh, Kh, Vh = split_heads(Q,B,T,C,H), split_heads(K,B,T,C,H), split_heads(V,B,T,C,H) Sraw = torch.einsum("bhtd,bhTd->bhtT", Qh, Kh) SHIFT = pick_shift_for_cap(int(Sraw.max().item()), cap=8) S = (Sraw >> SHIFT).clamp_min_(0) E = (torch.ones_like(S) << S) Den = E.sum(-1, keepdim=True) Num = torch.einsum("bhtT,bhTd->bhtd", E, Vh) Oh = Num // Den O = merge_heads(Oh,B,T,C,H) Y = O @ Wo t1 = time.time() max_abs = int(max(Q.abs().max().item(), K.abs().max().item(), V.abs().max().item(), Sraw.abs().max().item(), S.abs().max().item(), Num.abs().max().item(), O.abs().max().item(), Y.abs().max().item())) Kbits = choose_K(max_abs, min_bits=24) MOD2 = 1 << Kbits Qm = mod2_encode(X,MOD2) @ mod2_encode(Wq,MOD2); Qi = mod2_center(Qm,MOD2) Km = mod2_encode(X,MOD2) @ mod2_encode(Wk,MOD2); Ki = mod2_center(Km,MOD2) Vm = mod2_encode(X,MOD2) @ mod2_encode(Wv,MOD2); Vi = mod2_center(Vm,MOD2) assert torch.equal(Qi,Q) and torch.equal(Ki,K) and torch.equal(Vi,V) Qh, Kh, Vh = split_heads(Qi,B,T,C,H), split_heads(Ki,B,T,C,H), split_heads(Vi,B,T,C,H) BH = B*H Sraw_mod = torch.empty((BH, T, T), dtype=torch.int64) Q2, K2 = Qh.reshape(BH,T,C//H), Kh.reshape(BH,T,C//H) for b in range(BH): Sraw_mod[b] = (mod2_encode(Q2[b],MOD2) @ mod2_encode(K2[b].t(),MOD2)) & (MOD2-1) Sraw_i = mod2_center(Sraw_mod.reshape(B,H,T,T), MOD2) if SHIFT == 0: S_i = Sraw_i else: S2 = mod2_encode(Sraw_i, MOD2) q2 = (S2 >> SHIFT) & (MOD2 - 1) S_i = mod2_center(q2, MOD2) S_i = S_i.clamp_min_(0) E_mod = (torch.ones_like(S_i) << S_i) & (MOD2 - 1) Den_m = E_mod.sum(-1) Den_i = mod2_center(Den_m, MOD2) Vh_m = mod2_encode(Vh, MOD2) Num_m = torch.empty((B,H,T, C//H), dtype=torch.int64) for b in range(B): for h in range(H): Num_m[b,h] = (E_mod[b,h] @ Vh_m[b,h]) & (MOD2 - 1) Num_i = mod2_center(Num_m, MOD2) Oh = Num_i // Den_i.unsqueeze(-1) Oi = merge_heads(Oh,B,T,C,H) Ym = (mod2_encode(Oi,MOD2) @ mod2_encode(Wo,MOD2)) & (MOD2 - 1) Yi = mod2_center(Ym, MOD2) ok_all = bool(torch.equal(Y, Yi)) arg_ok = bool(torch.equal(Y.argmax(-1), Yi.argmax(-1))) lines = [] lines.append(f"Auto SHIFT={SHIFT} cap=8") lines.append(f"K={Kbits} 2^(K-1)={1<<(Kbits-1):,} > max_abs={max_abs:,}") lines.append("Stage by stage equality, exact:") lines.append("Q:{0} K:{1} V:{2} S:{3} Y:{4}".format( 'OK' if torch.equal(Q,Qi) else 'NO', 'OK' if torch.equal(K,Ki) else 'NO', 'OK' if torch.equal(V,Vi) else 'NO', 'OK' if torch.equal(S,S_i) else 'NO', 'OK' if torch.equal(Y,Yi) else 'NO' )) lines.append(f"Argmax match: {arg_ok} ref={Y.argmax(-1).tolist()} mod2={Yi.argmax(-1).tolist()}") lines.append(f"Timing ms: ref≈{(t1-t0)*1000:.2f}") report = { "ok_all": ok_all, "argmax_ok": arg_ok, "K": Kbits, "ref_top1": Y.argmax(-1).tolist(), "mod2_top1": Yi.argmax(-1).tolist() } return "\n".join(lines), report # ----------------------------- # UI # ----------------------------- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# FieldSpace, Prime Only Machine mod $2^K$, Exactness Proofs, single file") gr.Markdown("CNN and 1 head attention, all int64, exact mod 2^K reconstruction.") with gr.Tab("CNN Proof"): with gr.Row(): seed = gr.Number(value=0, label="Seed", precision=0) xmax = gr.Number(value=31, label="|X|max", precision=0) wmax = gr.Number(value=31, label="|W|max", precision=0) act_sel = gr.Radio(choices=["relu","poly"], value="relu", label="Activation") shift_poly= gr.Slider(1, 12, value=7, step=1, label="SHIFT for poly x^2 >> SHIFT") run_btn = gr.Button("Run CNN Proof", variant="primary") out_txt = gr.Textbox(label="Console", lines=12) out_json = gr.JSON(label="JSON Report") def _run_cnn(s, xM, wM, act, sh): return run_cnn_proof(int(s), int(xM), int(wM), act, int(sh)) run_btn.click(_run_cnn, [seed,xmax,wmax,act_sel,shift_poly], [out_txt,out_json], api_name="run_cnn_proof") with gr.Tab("Attention Proof 1 head"): with gr.Row(): seedA = gr.Number(value=0, label="Seed", precision=0) XmaxA = gr.Number(value=7, label="X in [0, Xmax]", precision=0) WmaxA = gr.Number(value=7, label="W in [0, Wmax]", precision=0) with gr.Row(): BA = gr.Number(value=1, label="Batch B", precision=0) TA = gr.Number(value=8, label="Seq T", precision=0) CA = gr.Number(value=16, label="Width C", precision=0) HA = gr.Number(value=1, label="Heads H divides C", precision=0) runA = gr.Button("Run Attention Proof", variant="primary") outA_t = gr.Textbox(label="Console", lines=12) outA_j = gr.JSON(label="JSON Report") def _run_attn(s, xM, wM, B,T,C,H): return run_attn_proof(int(s), int(xM), int(wM), int(B), int(T), int(C), int(H)) runA.click(_run_attn, [seedA,XmaxA,WmaxA,BA,TA,CA,HA], [outA_t,outA_j], api_name="run_attn_proof") if __name__ == "__main__": demo.launch()