jackal79's picture
Single file Space, app.py self installs deps, CNN and 1 head attention proofs
e13fb8e verified
# app.py, single file Space
# FieldSpace, Prime-Only Machine, mod 2^K exactness proofs
import sys, subprocess, time, math, json
def _ensure(package_spec: str):
try:
__import__(package_spec.split("==")[0].split(">=")[0].split("<")[0])
return
except Exception:
subprocess.check_call([sys.executable, "-m", "pip", "install", package_spec])
# Self install, single file repo, no external requirements
for spec in ["gradio==4.44.1", "numpy>=1.26", "torch>=2.2,<2.4"]:
_ensure(spec)
import numpy as np
import torch
import gradio as gr
torch.set_grad_enabled(False)
torch.set_num_threads(1)
DEVICE = "cpu"
# -----------------------------
# Utilities (mod 2^K)
# -----------------------------
def seed_all(s: int):
return torch.Generator(device=DEVICE).manual_seed(int(s))
def choose_K(max_abs: int, min_bits: int = 24) -> int:
need = max(min_bits, int(max_abs).bit_length() + 1)
return min(62, max(need, min_bits))
def mod2_encode(x: torch.Tensor, MOD2: int) -> torch.Tensor:
return (x & (MOD2 - 1)).to(torch.int64)
def mod2_center(v: torch.Tensor, MOD2: int) -> torch.Tensor:
K = MOD2.bit_length() - 1
half = 1 << (K - 1)
out = v.clone()
out[out > half] -= MOD2
return out
# -----------------------------
# CNN (Conv -> Act -> Conv -> Act -> FC)
# -----------------------------
def im2col_nchw_int(x, kH, kW, stride=1, pad=0):
N, C, H, W = x.shape
Hout = (H + 2*pad - kH)//stride + 1
Wout = (W + 2*pad - kW)//stride + 1
if pad:
x_pad = torch.zeros((N, C, H + 2*pad, W + 2*pad), dtype=x.dtype, device=x.device)
x_pad[:, :, pad:pad+H, pad:pad+W] = x
else:
x_pad = x
cols = []
for i in range(Hout):
for j in range(Wout):
patch = x_pad[:, :, i*stride:i*stride+kH, j*stride:j*stride+kW]
cols.append(patch.reshape(N, -1))
out = torch.stack(cols, dim=1) # [N, Hout*Wout, C*kH*kW]
out = out.reshape(N*Hout*Wout, -1) # [N*Hout*Wout, C*kH*kW]
return out, Hout, Wout
def conv_int64_im2col(X, W, kH, kW, stride=1, pad=0):
Xcol, Hout, Wout = im2col_nchw_int(X, kH, kW, stride=stride, pad=pad)
Ycol = Xcol @ W.reshape(W.shape[0], -1).t()
return Ycol.reshape(X.shape[0], Hout, Wout, W.shape[0]).permute(0,3,1,2).contiguous()
def run_cnn_proof(seed:int=0, Xmax:int=31, Wmax:int=31, act_kind:str="relu", poly_shift:int=7):
g = seed_all(seed)
B,Cin,H,W = 4,1,16,16
C1,C2 = 8,8
kH,kW = 3,3
STRIDE,PAD = 1,1
CLS = 10
X = torch.randint(-Xmax, Xmax+1, (B,Cin,H,W), dtype=torch.int64, generator=g, device=DEVICE)
W1 = torch.randint(-Wmax, Wmax+1, (C1,Cin,kH,kW), dtype=torch.int64, generator=g, device=DEVICE)
W2 = torch.randint(-Wmax, Wmax+1, (C2,C1,kH,kW), dtype=torch.int64, generator=g, device=DEVICE)
Wf = torch.randint(-Wmax, Wmax+1, (C2*H*W, CLS), dtype=torch.int64, generator=g, device=DEVICE)
t0 = time.time()
Z1 = conv_int64_im2col(X, W1, kH,kW, STRIDE, PAD)
if act_kind=="relu":
A1 = Z1.clamp_min_(0)
else:
A1 = ((Z1*Z1) >> int(poly_shift))
Z2 = conv_int64_im2col(A1, W2, kH,kW, STRIDE, PAD)
if act_kind=="relu":
A2 = Z2.clamp_min_(0)
else:
A2 = ((Z2*Z2) >> int(poly_shift))
Y = (A2.reshape(B, -1) @ Wf)
t1 = time.time()
max_abs = int(max(Z1.abs().max().item(), A1.abs().max().item(),
Z2.abs().max().item(), A2.abs().max().item(), Y.abs().max().item()))
K = choose_K(max_abs, min_bits=24)
MOD2 = 1 << K
Z1m = mod2_encode(conv_int64_im2col(mod2_encode(X,MOD2), mod2_encode(W1,MOD2), kH,kW, STRIDE, PAD), MOD2)
Z1i = mod2_center(Z1m, MOD2)
A1i = Z1i.clamp_min_(0) if act_kind=="relu" else ((Z1i*Z1i) >> int(poly_shift))
Z2m = mod2_encode(conv_int64_im2col(mod2_encode(A1i,MOD2), mod2_encode(W2,MOD2), kH,kW, STRIDE, PAD), MOD2)
Z2i = mod2_center(Z2m, MOD2)
A2i = Z2i.clamp_min_(0) if act_kind=="relu" else ((Z2i*Z2i) >> int(poly_shift))
Ym = mod2_encode((mod2_encode(A2i.reshape(B,-1),MOD2) @ mod2_encode(Wf,MOD2)), MOD2)
Yi = mod2_center(Ym, MOD2)
ok_all = bool(torch.equal(Y, Yi))
arg_ok = bool(torch.equal(Y.argmax(1), Yi.argmax(1)))
txt = []
txt.append(f"K={K} (2^(K-1)={1<<(K-1):,} > max_abs={max_abs:,}), act={act_kind} ({'ReLU' if act_kind=='relu' else f'poly SHIFT={poly_shift}'})")
txt.append("Stage by stage equality, exact:")
txt.append(f"Z1: {'OK' if torch.equal(Z1,Z1i) else 'NO'} | A1: {'OK' if torch.equal(A1,A1i) else 'NO'} | "
f"Z2: {'OK' if torch.equal(Z2,Z2i) else 'NO'} | Y: {'OK' if torch.equal(Y,Yi) else 'NO'}")
txt.append(f"Argmax match: {arg_ok} (ref={Y.argmax(1).tolist()}, mod2={Yi.argmax(1).tolist()})")
txt.append(f"Timing ms: ref≈{(t1-t0)*1000:.2f}")
return "\n".join(txt), {
"ok_all": ok_all, "argmax_ok": arg_ok, "K": K,
"act": act_kind, "poly_shift": int(poly_shift),
"ref_top1": Y.argmax(1).tolist(), "mod2_top1": Yi.argmax(1).tolist()
}
# -----------------------------
# 1 head Attention
# -----------------------------
def split_heads(Z, B,T,C,H):
Dh = C//H
return Z.reshape(B,T,H,Dh).permute(0,2,1,3).contiguous()
def merge_heads(Zh, B,T,C,H):
return Zh.permute(0,2,1,3).reshape(B,T,C).contiguous()
def pick_shift_for_cap(smax:int, cap:int=8) -> int:
if smax <= cap: return 0
s = 0
while (smax >> s) > cap: s += 1
return s
def run_attn_proof(seed:int=0, Xmax:int=7, Wmax:int=7, B:int=1, T:int=8, C:int=16, H:int=1):
assert C % H == 0
Dh = C//H
g = seed_all(seed)
X = torch.randint(0, Xmax+1, (B,T,C), dtype=torch.int64, generator=g, device=DEVICE)
Wq = torch.randint(0, Wmax+1, (C,C), dtype=torch.int64, generator=g, device=DEVICE)
Wk = torch.randint(0, Wmax+1, (C,C), dtype=torch.int64, generator=g, device=DEVICE)
Wv = torch.randint(0, Wmax+1, (C,C), dtype=torch.int64, generator=g, device=DEVICE)
Wo = torch.randint(0, Wmax+1, (C,C), dtype=torch.int64, generator=g, device=DEVICE)
t0 = time.time()
Q = X @ Wq; K = X @ Wk; V = X @ Wv
Qh, Kh, Vh = split_heads(Q,B,T,C,H), split_heads(K,B,T,C,H), split_heads(V,B,T,C,H)
Sraw = torch.einsum("bhtd,bhTd->bhtT", Qh, Kh)
SHIFT = pick_shift_for_cap(int(Sraw.max().item()), cap=8)
S = (Sraw >> SHIFT).clamp_min_(0)
E = (torch.ones_like(S) << S)
Den = E.sum(-1, keepdim=True)
Num = torch.einsum("bhtT,bhTd->bhtd", E, Vh)
Oh = Num // Den
O = merge_heads(Oh,B,T,C,H)
Y = O @ Wo
t1 = time.time()
max_abs = int(max(Q.abs().max().item(), K.abs().max().item(), V.abs().max().item(),
Sraw.abs().max().item(), S.abs().max().item(),
Num.abs().max().item(), O.abs().max().item(), Y.abs().max().item()))
Kbits = choose_K(max_abs, min_bits=24)
MOD2 = 1 << Kbits
Qm = mod2_encode(X,MOD2) @ mod2_encode(Wq,MOD2); Qi = mod2_center(Qm,MOD2)
Km = mod2_encode(X,MOD2) @ mod2_encode(Wk,MOD2); Ki = mod2_center(Km,MOD2)
Vm = mod2_encode(X,MOD2) @ mod2_encode(Wv,MOD2); Vi = mod2_center(Vm,MOD2)
assert torch.equal(Qi,Q) and torch.equal(Ki,K) and torch.equal(Vi,V)
Qh, Kh, Vh = split_heads(Qi,B,T,C,H), split_heads(Ki,B,T,C,H), split_heads(Vi,B,T,C,H)
BH = B*H
Sraw_mod = torch.empty((BH, T, T), dtype=torch.int64)
Q2, K2 = Qh.reshape(BH,T,C//H), Kh.reshape(BH,T,C//H)
for b in range(BH):
Sraw_mod[b] = (mod2_encode(Q2[b],MOD2) @ mod2_encode(K2[b].t(),MOD2)) & (MOD2-1)
Sraw_i = mod2_center(Sraw_mod.reshape(B,H,T,T), MOD2)
if SHIFT == 0:
S_i = Sraw_i
else:
S2 = mod2_encode(Sraw_i, MOD2)
q2 = (S2 >> SHIFT) & (MOD2 - 1)
S_i = mod2_center(q2, MOD2)
S_i = S_i.clamp_min_(0)
E_mod = (torch.ones_like(S_i) << S_i) & (MOD2 - 1)
Den_m = E_mod.sum(-1)
Den_i = mod2_center(Den_m, MOD2)
Vh_m = mod2_encode(Vh, MOD2)
Num_m = torch.empty((B,H,T, C//H), dtype=torch.int64)
for b in range(B):
for h in range(H):
Num_m[b,h] = (E_mod[b,h] @ Vh_m[b,h]) & (MOD2 - 1)
Num_i = mod2_center(Num_m, MOD2)
Oh = Num_i // Den_i.unsqueeze(-1)
Oi = merge_heads(Oh,B,T,C,H)
Ym = (mod2_encode(Oi,MOD2) @ mod2_encode(Wo,MOD2)) & (MOD2 - 1)
Yi = mod2_center(Ym, MOD2)
ok_all = bool(torch.equal(Y, Yi))
arg_ok = bool(torch.equal(Y.argmax(-1), Yi.argmax(-1)))
lines = []
lines.append(f"Auto SHIFT={SHIFT} cap=8")
lines.append(f"K={Kbits} 2^(K-1)={1<<(Kbits-1):,} > max_abs={max_abs:,}")
lines.append("Stage by stage equality, exact:")
lines.append("Q:{0} K:{1} V:{2} S:{3} Y:{4}".format(
'OK' if torch.equal(Q,Qi) else 'NO',
'OK' if torch.equal(K,Ki) else 'NO',
'OK' if torch.equal(V,Vi) else 'NO',
'OK' if torch.equal(S,S_i) else 'NO',
'OK' if torch.equal(Y,Yi) else 'NO'
))
lines.append(f"Argmax match: {arg_ok} ref={Y.argmax(-1).tolist()} mod2={Yi.argmax(-1).tolist()}")
lines.append(f"Timing ms: ref≈{(t1-t0)*1000:.2f}")
report = {
"ok_all": ok_all, "argmax_ok": arg_ok, "K": Kbits,
"ref_top1": Y.argmax(-1).tolist(), "mod2_top1": Yi.argmax(-1).tolist()
}
return "\n".join(lines), report
# -----------------------------
# UI
# -----------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# FieldSpace, Prime Only Machine mod $2^K$, Exactness Proofs, single file")
gr.Markdown("CNN and 1 head attention, all int64, exact mod 2^K reconstruction.")
with gr.Tab("CNN Proof"):
with gr.Row():
seed = gr.Number(value=0, label="Seed", precision=0)
xmax = gr.Number(value=31, label="|X|max", precision=0)
wmax = gr.Number(value=31, label="|W|max", precision=0)
act_sel = gr.Radio(choices=["relu","poly"], value="relu", label="Activation")
shift_poly= gr.Slider(1, 12, value=7, step=1, label="SHIFT for poly x^2 >> SHIFT")
run_btn = gr.Button("Run CNN Proof", variant="primary")
out_txt = gr.Textbox(label="Console", lines=12)
out_json = gr.JSON(label="JSON Report")
def _run_cnn(s, xM, wM, act, sh): return run_cnn_proof(int(s), int(xM), int(wM), act, int(sh))
run_btn.click(_run_cnn, [seed,xmax,wmax,act_sel,shift_poly], [out_txt,out_json], api_name="run_cnn_proof")
with gr.Tab("Attention Proof 1 head"):
with gr.Row():
seedA = gr.Number(value=0, label="Seed", precision=0)
XmaxA = gr.Number(value=7, label="X in [0, Xmax]", precision=0)
WmaxA = gr.Number(value=7, label="W in [0, Wmax]", precision=0)
with gr.Row():
BA = gr.Number(value=1, label="Batch B", precision=0)
TA = gr.Number(value=8, label="Seq T", precision=0)
CA = gr.Number(value=16, label="Width C", precision=0)
HA = gr.Number(value=1, label="Heads H divides C", precision=0)
runA = gr.Button("Run Attention Proof", variant="primary")
outA_t = gr.Textbox(label="Console", lines=12)
outA_j = gr.JSON(label="JSON Report")
def _run_attn(s, xM, wM, B,T,C,H): return run_attn_proof(int(s), int(xM), int(wM), int(B), int(T), int(C), int(H))
runA.click(_run_attn, [seedA,XmaxA,WmaxA,BA,TA,CA,HA], [outA_t,outA_j], api_name="run_attn_proof")
if __name__ == "__main__":
demo.launch()