import gradio as gr import numpy as np # ----------------------------- # Utilities: RNS + CRT # ----------------------------- def pick_moduli(max_abs, primes=(257,263,269,271,277,281)): # choose K so 2^(K-1) > max_abs K = max(20, int(max_abs).bit_length()+1) MOD2 = 1 << K # ensure product > 2*max_abs (with plenty margin for toy sizes this is fine) M = MOD2 ps = [] for p in primes: ps.append(p); M *= p if M > 2*max_abs: break return K, MOD2, tuple(ps) def encode_rns(x, MOD2, primes): outs = [np.asarray(x, dtype=object) & (MOD2-1)] for p in primes: outs.append(np.asarray(x, dtype=object) % p) return np.stack(outs, axis=0) # [L, *shape] def batched_crt(res, MOD2, primes): # res: [L, *shape], L=1+len(primes), dtype=object (python ints) moduli = (MOD2,)+tuple(primes) L = res.shape[0] M = 1 for m in moduli: M *= m Mi = [M//m for m in moduli] inv = [pow(int(Mi[i] % moduli[i]), -1, moduli[i]) for i in range(L)] flat = res.reshape(L, -1) out = [] half = M//2 for k in range(flat.shape[1]): acc = 0 for i in range(L): acc = (acc + (flat[i,k] * inv[i] * Mi[i])) % M if acc > half: acc -= M out.append(int(acc)) return np.array(out, dtype=np.int64).reshape(res.shape[1:]) # ----------------------------- # CNN proof (tiny shapes) # ----------------------------- def im2col_nchw_int(x, kH, kW, stride=1, pad=0): N,C,H,W = x.shape Hp = H+2*pad; Wp = W+2*pad xp = np.zeros((N,C,Hp,Wp), dtype=np.int64) xp[:,:,pad:pad+H, pad:pad+W] = x Hout = (Hp - kH)//stride + 1 Wout = (Wp - kW)//stride + 1 cols = [] for i in range(Hout): for j in range(Wout): patch = xp[:,:,i*stride:i*stride+kH, j*stride:j*stride+kW] cols.append(patch.reshape(N,-1)) out = np.stack(cols, axis=1) # [N, Hout*Wout, C*kH*kW] out = out.reshape(N*Hout*Wout, -1) return out, Hout, Wout def conv_ref(X, W, kH, kW, stride=1, pad=0): N,C,H,W = X.shape Cout,Cin,KH,KW = W.shape Xcol,Hout,Wout = im2col_nchw_int(X,kH,kW,stride,pad) Ycol = Xcol @ W.reshape(Cout,-1).T Y = Ycol.reshape(N,Hout,Wout,Cout).transpose(0,3,1,2).copy() return Y def relu_ref(x): return np.where(x<0, 0, x) def poly_ref(x, s): return (x*x) >> s def linear_ref(X, W): return X @ W def relu_rns(Xr, MOD2, primes): # Xr: [L, *shape] L = Xr.shape[0] K = MOD2.bit_length()-1 x2 = Xr[0] mask = (x2 < (1<<(K-1))).astype(np.int64) # 1 if non-negative, else 0 out = [] out.append((x2 * mask) & (MOD2-1)) for i,p in enumerate(primes, start=1): out.append((Xr[i] * mask) % p) return np.stack(out, axis=0) def poly_rns(Xr, s, MOD2, primes): # q = floor(x^2 / 2^s) mask_s = (1<> s) & (MOD2-1) out.append(q2) # r from low s bits of x only low = Xr[0] & mask_s r = (low*low) & mask_s for i,p in enumerate(primes, start=1): y2p = (Xr[i]*Xr[i]) % p inv = pow(2, -s, p) qi = ((y2p - (r % p)) * inv) % p out.append(qi) return np.stack(out, axis=0) def matmul_rns(Ar, Br, MOD2, primes): # Ar: [L, M, K], Br: [L, K, N] -> [L, M, N] outs=[] outs.append((Ar[0] @ Br[0]) & (MOD2-1)) for i,p in enumerate(primes, start=1): outs.append((Ar[i] @ Br[i]) % p) return np.stack(outs, axis=0) def conv_rns(Xr, Wr, kH, kW, stride=1, pad=0, MOD2=0, primes=()): # reshape per-modulus and reuse im2col on ints L = Xr.shape[0] N,C,H,W = Xr[0].shape Cout = Wr[0].shape[0] Ymods=[] for i in range(L): Xi = Xr[i].astype(object) Wi = Wr[i].astype(object) Xcol,Hout,Wout = im2col_nchw_int(np.asarray(Xi, dtype=np.int64), kH,kW,stride,pad) Wcol = np.asarray(Wi.reshape(Cout,-1), dtype=np.int64) if i==0: MOD = MOD2 Ycol = (Xcol @ Wcol.T) & (MOD-1) else: p = primes[i-1] Ycol = (Xcol @ Wcol.T) % p Y = Ycol.reshape(N,Hout,Wout,Cout).transpose(0,3,1,2).copy() Ymods.append(Y.astype(object)) return np.stack(Ymods, axis=0) def cnn_proof(seed, xmax, wmax, activation, shift): # Tiny shapes to keep it snappy B,Cin,H,W = 2,1,8,8 Cout = 2 KH=KW=3 CLS=8 rng = np.random.default_rng(seed) X = rng.integers(-xmax, xmax+1, size=(B,Cin,H,W), dtype=np.int64) Wc = rng.integers(-wmax, wmax+1, size=(Cout,Cin,KH,KW), dtype=np.int64) Wfc = rng.integers(-wmax, wmax+1, size=(Cout*H*W, CLS), dtype=np.int64) # Reference (int64) Y = conv_ref(X,Wc,KH,KW,1,1) A = relu_ref(Y) if activation=="relu" else poly_ref(Y, shift) Z = linear_ref(A.reshape(B,-1), Wfc) # Bounds (very conservative) b1 = int(Cin*KH*KW * xmax * wmax) a1 = b1 if activation=="relu" else (b1*b1)>>shift b2 = int(Cout*KH*KW * a1 * wmax) a2 = b2 if activation=="relu" else (b2*b2)>>shift fc = int((Cout*H*W) * a2 * wmax) MAX_ABS = max(abs(b1),abs(a1),abs(b2),abs(a2),abs(fc), 1) K, MOD2, primes = pick_moduli(max_abs=MAX_ABS) # Encode Xr = encode_rns(X, MOD2, primes) Wcr = encode_rns(Wc, MOD2, primes) Wfr = encode_rns(Wfc, MOD2, primes) # RNS path Yr = conv_rns(Xr,Wcr,KH,KW,1,1,MOD2,primes) Ar = relu_rns(Yr, MOD2, primes) if activation=="relu" else poly_rns(Yr, shift, MOD2, primes) Arf = Ar.reshape(Ar.shape[0], B, -1) Zr = matmul_rns(Arf, Wfr, MOD2, primes) Zrec= batched_crt(Zr, MOD2, primes) ok = np.array_equal(Z, Zrec) arg = np.array_equal(Z.argmax(1), Zrec.argmax(1)) log = [] log.append("Stage-by-stage equality (exact):") log.append(f"conv : {'✅' if np.array_equal(batched_crt(Yr,MOD2,primes), Y) else '❌'}") log.append(f"act : {'✅' if np.array_equal(batched_crt(Ar,MOD2,primes), A) else '❌'}") log.append(f"fc : {'✅' if ok else '❌'}") log.append(f"Argmax match: {arg} (ref={Z.argmax(1).tolist()}, prime={Zrec.argmax(1).tolist()})") report = { "activation": activation, "shift": int(shift), "K": int(K), "primes": list(map(int,primes)), "exact": bool(ok), "argmax_exact": bool(arg), "ref_logits_sample": Z[0,:min(8,CLS)].tolist(), "rec_logits_sample": Zrec[0,:min(8,CLS)].tolist(), } return "\n".join(log), report # ----------------------------- # Attention (exact numerators) # ----------------------------- def attn_numerators_proof(seed): rng = np.random.default_rng(seed) T,d = 2,4 xmax = 15 t = 7 # scale shift Q = rng.integers(-xmax,xmax+1,size=(T,d),dtype=np.int64) K_ = rng.integers(-xmax,xmax+1,size=(T,d),dtype=np.int64) V = rng.integers(-xmax,xmax+1,size=(T,d),dtype=np.int64) Sraw = Q @ K_.T S = Sraw >> t # exact floor N = S @ V # numerators # conservative bound s_max = int((d * xmax * xmax) >> t) n_max = int(T * s_max * xmax) Kpow, MOD2, primes = pick_moduli(max_abs=max(abs(s_max),abs(n_max),1)) Qr = encode_rns(Q,MOD2,primes) Kr = encode_rns(K_,MOD2,primes) Vr = encode_rns(V,MOD2,primes) # Sraw_r and S_r Sraw_r = matmul_rns(Qr, Kr.transpose(0,2,1), MOD2, primes) # floor by right shift in 2^K channel + multiplicative inverse in odd primes # (here scaling is power-of-two so we can just shift in 2^K, and multiply by inv(2^t) after peeling low bits = 0) # since we already floored by pure shift in reference, it's safe to do: Sr = [] Sr.append( (Sraw_r[0] >> t) & (MOD2-1) ) invs = [pow(2,-t,p) for p in primes] for i,p in enumerate(primes, start=1): Sr.append( (Sraw_r[i] * invs[i-1]) % p ) Sr = np.stack(Sr, axis=0) Nr = matmul_rns(Sr, Vr, MOD2, primes) Srec = batched_crt(Sr, MOD2, primes) Nrec = batched_crt(Nr, MOD2, primes) okS = np.array_equal(S, Srec) okN = np.array_equal(N, Nrec) log = [] log.append("Attention (exact numerators):") log.append(f"S : {'✅' if okS else '❌'}") log.append(f"SV : {'✅' if okN else '❌'}") report = { "K": int(Kpow), "primes": list(map(int,primes)), "S_exact": bool(okS), "N_exact": bool(okN), "S_sample": S.tolist(), "Srec_sample": Srec.tolist(), "N_sample": N.tolist(), "Nrec_sample": Nrec.tolist(), } return "\n".join(log), report # ----------------------------- # Gradio UI # ----------------------------- with gr.Blocks(title="FieldSpace — Prime-Only Machine (RNS+CRT)") as demo: gr.Markdown("# FieldSpace — Prime-Only Machine (RNS+CRT) — Exactness Proofs") gr.Markdown("This demo proves exactness for a tiny CNN and a one-head attention numerator block using only residues mod $\\{2^K, p_i\\}$ and a single CRT.") with gr.Tab("CNN (Conv + ReLU/Poly + FC)"): with gr.Row(): seed = gr.Number(value=0, precision=0, label="Seed") xmax = gr.Number(value=31, precision=0, label="|X|max") wmax = gr.Number(value=31, precision=0, label="|W|max") activation = gr.Radio(choices=["relu","poly"], value="relu", label="Activation") shift = gr.Slider(1,12,value=7,step=1,label="SHIFT for poly (x^2 >> SHIFT)") run_btn = gr.Button("Run CNN Proof", variant="primary") out_text = gr.Textbox(label="Console", lines=10) out_json = gr.JSON(label="JSON Report") run_btn.click(fn=lambda a,b,c,d,e: cnn_proof(int(a),int(b),int(c),d,int(e)), inputs=[seed,xmax,wmax,activation,shift], outputs=[out_text,out_json]) with gr.Tab("Attention (Exact Numerators)"): aseed = gr.Number(value=0, precision=0, label="Seed") runA = gr.Button("Run Attention Proof", variant="secondary") AT = gr.Textbox(label="Console", lines=8) AJ = gr.JSON(label="JSON Report") runA.click(fn=lambda s: attn_numerators_proof(int(s)), inputs=[aseed], outputs=[AT,AJ]) gr.Markdown("**Use via API**") gr.Code("""from gradio_client import Client client = Client('https://huggingface.co/spaces/jackal79/fieldspace-prime-only') txt, rep = client.predict('/run_cnn_proof', 0, 31, 31, 'relu', 7)""") # Named routes for programmatic calls demo.load(None, None, None) demo.add_named_endpoint("/run_cnn_proof", cnn_proof, inputs=[ gr.Number(precision=0), gr.Number(precision=0), gr.Number(precision=0), gr.Textbox(), gr.Number(precision=0) ], outputs=[gr.Textbox(), gr.JSON()]) demo.add_named_endpoint("/run_attn_proof", attn_numerators_proof, inputs=[gr.Number(precision=0)], outputs=[gr.Textbox(), gr.JSON()]) if __name__ == "__main__": demo.queue().launch()