Spaces:
Sleeping
Sleeping
| # app.py, single file Space | |
| # FieldSpace, Prime-Only Machine, mod 2^K exactness proofs | |
| import sys, subprocess, time, math, json | |
| def _ensure(package_spec: str): | |
| try: | |
| __import__(package_spec.split("==")[0].split(">=")[0].split("<")[0]) | |
| return | |
| except Exception: | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", package_spec]) | |
| # Self install, single file repo, no external requirements | |
| for spec in ["gradio==4.44.1", "numpy>=1.26", "torch>=2.2,<2.4"]: | |
| _ensure(spec) | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| torch.set_grad_enabled(False) | |
| torch.set_num_threads(1) | |
| DEVICE = "cpu" | |
| # ----------------------------- | |
| # Utilities (mod 2^K) | |
| # ----------------------------- | |
| def seed_all(s: int): | |
| return torch.Generator(device=DEVICE).manual_seed(int(s)) | |
| def choose_K(max_abs: int, min_bits: int = 24) -> int: | |
| need = max(min_bits, int(max_abs).bit_length() + 1) | |
| return min(62, max(need, min_bits)) | |
| def mod2_encode(x: torch.Tensor, MOD2: int) -> torch.Tensor: | |
| return (x & (MOD2 - 1)).to(torch.int64) | |
| def mod2_center(v: torch.Tensor, MOD2: int) -> torch.Tensor: | |
| K = MOD2.bit_length() - 1 | |
| half = 1 << (K - 1) | |
| out = v.clone() | |
| out[out > half] -= MOD2 | |
| return out | |
| # ----------------------------- | |
| # CNN (Conv -> Act -> Conv -> Act -> FC) | |
| # ----------------------------- | |
| def im2col_nchw_int(x, kH, kW, stride=1, pad=0): | |
| N, C, H, W = x.shape | |
| Hout = (H + 2*pad - kH)//stride + 1 | |
| Wout = (W + 2*pad - kW)//stride + 1 | |
| if pad: | |
| x_pad = torch.zeros((N, C, H + 2*pad, W + 2*pad), dtype=x.dtype, device=x.device) | |
| x_pad[:, :, pad:pad+H, pad:pad+W] = x | |
| else: | |
| x_pad = x | |
| cols = [] | |
| for i in range(Hout): | |
| for j in range(Wout): | |
| patch = x_pad[:, :, i*stride:i*stride+kH, j*stride:j*stride+kW] | |
| cols.append(patch.reshape(N, -1)) | |
| out = torch.stack(cols, dim=1) # [N, Hout*Wout, C*kH*kW] | |
| out = out.reshape(N*Hout*Wout, -1) # [N*Hout*Wout, C*kH*kW] | |
| return out, Hout, Wout | |
| def conv_int64_im2col(X, W, kH, kW, stride=1, pad=0): | |
| Xcol, Hout, Wout = im2col_nchw_int(X, kH, kW, stride=stride, pad=pad) | |
| Ycol = Xcol @ W.reshape(W.shape[0], -1).t() | |
| return Ycol.reshape(X.shape[0], Hout, Wout, W.shape[0]).permute(0,3,1,2).contiguous() | |
| def run_cnn_proof(seed:int=0, Xmax:int=31, Wmax:int=31, act_kind:str="relu", poly_shift:int=7): | |
| g = seed_all(seed) | |
| B,Cin,H,W = 4,1,16,16 | |
| C1,C2 = 8,8 | |
| kH,kW = 3,3 | |
| STRIDE,PAD = 1,1 | |
| CLS = 10 | |
| X = torch.randint(-Xmax, Xmax+1, (B,Cin,H,W), dtype=torch.int64, generator=g, device=DEVICE) | |
| W1 = torch.randint(-Wmax, Wmax+1, (C1,Cin,kH,kW), dtype=torch.int64, generator=g, device=DEVICE) | |
| W2 = torch.randint(-Wmax, Wmax+1, (C2,C1,kH,kW), dtype=torch.int64, generator=g, device=DEVICE) | |
| Wf = torch.randint(-Wmax, Wmax+1, (C2*H*W, CLS), dtype=torch.int64, generator=g, device=DEVICE) | |
| t0 = time.time() | |
| Z1 = conv_int64_im2col(X, W1, kH,kW, STRIDE, PAD) | |
| if act_kind=="relu": | |
| A1 = Z1.clamp_min_(0) | |
| else: | |
| A1 = ((Z1*Z1) >> int(poly_shift)) | |
| Z2 = conv_int64_im2col(A1, W2, kH,kW, STRIDE, PAD) | |
| if act_kind=="relu": | |
| A2 = Z2.clamp_min_(0) | |
| else: | |
| A2 = ((Z2*Z2) >> int(poly_shift)) | |
| Y = (A2.reshape(B, -1) @ Wf) | |
| t1 = time.time() | |
| max_abs = int(max(Z1.abs().max().item(), A1.abs().max().item(), | |
| Z2.abs().max().item(), A2.abs().max().item(), Y.abs().max().item())) | |
| K = choose_K(max_abs, min_bits=24) | |
| MOD2 = 1 << K | |
| Z1m = mod2_encode(conv_int64_im2col(mod2_encode(X,MOD2), mod2_encode(W1,MOD2), kH,kW, STRIDE, PAD), MOD2) | |
| Z1i = mod2_center(Z1m, MOD2) | |
| A1i = Z1i.clamp_min_(0) if act_kind=="relu" else ((Z1i*Z1i) >> int(poly_shift)) | |
| Z2m = mod2_encode(conv_int64_im2col(mod2_encode(A1i,MOD2), mod2_encode(W2,MOD2), kH,kW, STRIDE, PAD), MOD2) | |
| Z2i = mod2_center(Z2m, MOD2) | |
| A2i = Z2i.clamp_min_(0) if act_kind=="relu" else ((Z2i*Z2i) >> int(poly_shift)) | |
| Ym = mod2_encode((mod2_encode(A2i.reshape(B,-1),MOD2) @ mod2_encode(Wf,MOD2)), MOD2) | |
| Yi = mod2_center(Ym, MOD2) | |
| ok_all = bool(torch.equal(Y, Yi)) | |
| arg_ok = bool(torch.equal(Y.argmax(1), Yi.argmax(1))) | |
| txt = [] | |
| txt.append(f"K={K} (2^(K-1)={1<<(K-1):,} > max_abs={max_abs:,}), act={act_kind} ({'ReLU' if act_kind=='relu' else f'poly SHIFT={poly_shift}'})") | |
| txt.append("Stage by stage equality, exact:") | |
| txt.append(f"Z1: {'OK' if torch.equal(Z1,Z1i) else 'NO'} | A1: {'OK' if torch.equal(A1,A1i) else 'NO'} | " | |
| f"Z2: {'OK' if torch.equal(Z2,Z2i) else 'NO'} | Y: {'OK' if torch.equal(Y,Yi) else 'NO'}") | |
| txt.append(f"Argmax match: {arg_ok} (ref={Y.argmax(1).tolist()}, mod2={Yi.argmax(1).tolist()})") | |
| txt.append(f"Timing ms: ref≈{(t1-t0)*1000:.2f}") | |
| return "\n".join(txt), { | |
| "ok_all": ok_all, "argmax_ok": arg_ok, "K": K, | |
| "act": act_kind, "poly_shift": int(poly_shift), | |
| "ref_top1": Y.argmax(1).tolist(), "mod2_top1": Yi.argmax(1).tolist() | |
| } | |
| # ----------------------------- | |
| # 1 head Attention | |
| # ----------------------------- | |
| def split_heads(Z, B,T,C,H): | |
| Dh = C//H | |
| return Z.reshape(B,T,H,Dh).permute(0,2,1,3).contiguous() | |
| def merge_heads(Zh, B,T,C,H): | |
| return Zh.permute(0,2,1,3).reshape(B,T,C).contiguous() | |
| def pick_shift_for_cap(smax:int, cap:int=8) -> int: | |
| if smax <= cap: return 0 | |
| s = 0 | |
| while (smax >> s) > cap: s += 1 | |
| return s | |
| def run_attn_proof(seed:int=0, Xmax:int=7, Wmax:int=7, B:int=1, T:int=8, C:int=16, H:int=1): | |
| assert C % H == 0 | |
| Dh = C//H | |
| g = seed_all(seed) | |
| X = torch.randint(0, Xmax+1, (B,T,C), dtype=torch.int64, generator=g, device=DEVICE) | |
| Wq = torch.randint(0, Wmax+1, (C,C), dtype=torch.int64, generator=g, device=DEVICE) | |
| Wk = torch.randint(0, Wmax+1, (C,C), dtype=torch.int64, generator=g, device=DEVICE) | |
| Wv = torch.randint(0, Wmax+1, (C,C), dtype=torch.int64, generator=g, device=DEVICE) | |
| Wo = torch.randint(0, Wmax+1, (C,C), dtype=torch.int64, generator=g, device=DEVICE) | |
| t0 = time.time() | |
| Q = X @ Wq; K = X @ Wk; V = X @ Wv | |
| Qh, Kh, Vh = split_heads(Q,B,T,C,H), split_heads(K,B,T,C,H), split_heads(V,B,T,C,H) | |
| Sraw = torch.einsum("bhtd,bhTd->bhtT", Qh, Kh) | |
| SHIFT = pick_shift_for_cap(int(Sraw.max().item()), cap=8) | |
| S = (Sraw >> SHIFT).clamp_min_(0) | |
| E = (torch.ones_like(S) << S) | |
| Den = E.sum(-1, keepdim=True) | |
| Num = torch.einsum("bhtT,bhTd->bhtd", E, Vh) | |
| Oh = Num // Den | |
| O = merge_heads(Oh,B,T,C,H) | |
| Y = O @ Wo | |
| t1 = time.time() | |
| max_abs = int(max(Q.abs().max().item(), K.abs().max().item(), V.abs().max().item(), | |
| Sraw.abs().max().item(), S.abs().max().item(), | |
| Num.abs().max().item(), O.abs().max().item(), Y.abs().max().item())) | |
| Kbits = choose_K(max_abs, min_bits=24) | |
| MOD2 = 1 << Kbits | |
| Qm = mod2_encode(X,MOD2) @ mod2_encode(Wq,MOD2); Qi = mod2_center(Qm,MOD2) | |
| Km = mod2_encode(X,MOD2) @ mod2_encode(Wk,MOD2); Ki = mod2_center(Km,MOD2) | |
| Vm = mod2_encode(X,MOD2) @ mod2_encode(Wv,MOD2); Vi = mod2_center(Vm,MOD2) | |
| assert torch.equal(Qi,Q) and torch.equal(Ki,K) and torch.equal(Vi,V) | |
| Qh, Kh, Vh = split_heads(Qi,B,T,C,H), split_heads(Ki,B,T,C,H), split_heads(Vi,B,T,C,H) | |
| BH = B*H | |
| Sraw_mod = torch.empty((BH, T, T), dtype=torch.int64) | |
| Q2, K2 = Qh.reshape(BH,T,C//H), Kh.reshape(BH,T,C//H) | |
| for b in range(BH): | |
| Sraw_mod[b] = (mod2_encode(Q2[b],MOD2) @ mod2_encode(K2[b].t(),MOD2)) & (MOD2-1) | |
| Sraw_i = mod2_center(Sraw_mod.reshape(B,H,T,T), MOD2) | |
| if SHIFT == 0: | |
| S_i = Sraw_i | |
| else: | |
| S2 = mod2_encode(Sraw_i, MOD2) | |
| q2 = (S2 >> SHIFT) & (MOD2 - 1) | |
| S_i = mod2_center(q2, MOD2) | |
| S_i = S_i.clamp_min_(0) | |
| E_mod = (torch.ones_like(S_i) << S_i) & (MOD2 - 1) | |
| Den_m = E_mod.sum(-1) | |
| Den_i = mod2_center(Den_m, MOD2) | |
| Vh_m = mod2_encode(Vh, MOD2) | |
| Num_m = torch.empty((B,H,T, C//H), dtype=torch.int64) | |
| for b in range(B): | |
| for h in range(H): | |
| Num_m[b,h] = (E_mod[b,h] @ Vh_m[b,h]) & (MOD2 - 1) | |
| Num_i = mod2_center(Num_m, MOD2) | |
| Oh = Num_i // Den_i.unsqueeze(-1) | |
| Oi = merge_heads(Oh,B,T,C,H) | |
| Ym = (mod2_encode(Oi,MOD2) @ mod2_encode(Wo,MOD2)) & (MOD2 - 1) | |
| Yi = mod2_center(Ym, MOD2) | |
| ok_all = bool(torch.equal(Y, Yi)) | |
| arg_ok = bool(torch.equal(Y.argmax(-1), Yi.argmax(-1))) | |
| lines = [] | |
| lines.append(f"Auto SHIFT={SHIFT} cap=8") | |
| lines.append(f"K={Kbits} 2^(K-1)={1<<(Kbits-1):,} > max_abs={max_abs:,}") | |
| lines.append("Stage by stage equality, exact:") | |
| lines.append("Q:{0} K:{1} V:{2} S:{3} Y:{4}".format( | |
| 'OK' if torch.equal(Q,Qi) else 'NO', | |
| 'OK' if torch.equal(K,Ki) else 'NO', | |
| 'OK' if torch.equal(V,Vi) else 'NO', | |
| 'OK' if torch.equal(S,S_i) else 'NO', | |
| 'OK' if torch.equal(Y,Yi) else 'NO' | |
| )) | |
| lines.append(f"Argmax match: {arg_ok} ref={Y.argmax(-1).tolist()} mod2={Yi.argmax(-1).tolist()}") | |
| lines.append(f"Timing ms: ref≈{(t1-t0)*1000:.2f}") | |
| report = { | |
| "ok_all": ok_all, "argmax_ok": arg_ok, "K": Kbits, | |
| "ref_top1": Y.argmax(-1).tolist(), "mod2_top1": Yi.argmax(-1).tolist() | |
| } | |
| return "\n".join(lines), report | |
| # ----------------------------- | |
| # UI | |
| # ----------------------------- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# FieldSpace, Prime Only Machine mod $2^K$, Exactness Proofs, single file") | |
| gr.Markdown("CNN and 1 head attention, all int64, exact mod 2^K reconstruction.") | |
| with gr.Tab("CNN Proof"): | |
| with gr.Row(): | |
| seed = gr.Number(value=0, label="Seed", precision=0) | |
| xmax = gr.Number(value=31, label="|X|max", precision=0) | |
| wmax = gr.Number(value=31, label="|W|max", precision=0) | |
| act_sel = gr.Radio(choices=["relu","poly"], value="relu", label="Activation") | |
| shift_poly= gr.Slider(1, 12, value=7, step=1, label="SHIFT for poly x^2 >> SHIFT") | |
| run_btn = gr.Button("Run CNN Proof", variant="primary") | |
| out_txt = gr.Textbox(label="Console", lines=12) | |
| out_json = gr.JSON(label="JSON Report") | |
| def _run_cnn(s, xM, wM, act, sh): return run_cnn_proof(int(s), int(xM), int(wM), act, int(sh)) | |
| run_btn.click(_run_cnn, [seed,xmax,wmax,act_sel,shift_poly], [out_txt,out_json], api_name="run_cnn_proof") | |
| with gr.Tab("Attention Proof 1 head"): | |
| with gr.Row(): | |
| seedA = gr.Number(value=0, label="Seed", precision=0) | |
| XmaxA = gr.Number(value=7, label="X in [0, Xmax]", precision=0) | |
| WmaxA = gr.Number(value=7, label="W in [0, Wmax]", precision=0) | |
| with gr.Row(): | |
| BA = gr.Number(value=1, label="Batch B", precision=0) | |
| TA = gr.Number(value=8, label="Seq T", precision=0) | |
| CA = gr.Number(value=16, label="Width C", precision=0) | |
| HA = gr.Number(value=1, label="Heads H divides C", precision=0) | |
| runA = gr.Button("Run Attention Proof", variant="primary") | |
| outA_t = gr.Textbox(label="Console", lines=12) | |
| outA_j = gr.JSON(label="JSON Report") | |
| def _run_attn(s, xM, wM, B,T,C,H): return run_attn_proof(int(s), int(xM), int(wM), int(B), int(T), int(C), int(H)) | |
| runA.click(_run_attn, [seedA,XmaxA,WmaxA,BA,TA,CA,HA], [outA_t,outA_j], api_name="run_attn_proof") | |
| if __name__ == "__main__": | |
| demo.launch() | |