Spaces:
Sleeping
Sleeping
Deploy FieldSpace app: CNN + Attention numerators (RNS+CRT)
Browse files- README.md +9 -21
- app.py +271 -227
- requirements.txt +1 -3
README.md
CHANGED
|
@@ -1,29 +1,17 @@
|
|
| 1 |
---
|
| 2 |
title: FieldSpace — Prime-Only Machine (RNS+CRT) — Exactness Proofs
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 4.44.0
|
| 8 |
-
|
|
|
|
| 9 |
license: mit
|
| 10 |
-
tags:
|
| 11 |
-
- number-theory
|
| 12 |
-
- residues
|
| 13 |
-
- crt
|
| 14 |
-
- integer-arithmetic
|
| 15 |
-
- exactness
|
| 16 |
-
- cnn
|
| 17 |
-
- gradio
|
| 18 |
---
|
| 19 |
|
| 20 |
-
This Space
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
-
```python
|
| 24 |
-
from gradio_client import Client
|
| 25 |
-
client = Client('https://huggingface.co/spaces/jackal79/fieldspace-prime-only')
|
| 26 |
-
txt, rep = client.predict('/run_cnn_proof', 0, 31, 31, 'relu', 7)
|
| 27 |
-
print(txt)
|
| 28 |
-
print(rep)
|
| 29 |
-
```
|
|
|
|
| 1 |
---
|
| 2 |
title: FieldSpace — Prime-Only Machine (RNS+CRT) — Exactness Proofs
|
| 3 |
+
emoji: 🧮
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 4.44.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
python: 3.10
|
| 10 |
license: mit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
+
This Space proves exact execution for:
|
| 14 |
+
- **CNN**: Conv → ReLU/Poly → FC
|
| 15 |
+
- **Attention (one head)**: exact numerators (QK>>t, then SV)
|
| 16 |
|
| 17 |
+
All ops run in residues mod `{2^K, p_i}` and reconstruct once via a batched CRT.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
|
@@ -1,254 +1,298 @@
|
|
| 1 |
|
| 2 |
-
import
|
|
|
|
| 3 |
|
| 4 |
-
#
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
#
|
| 9 |
-
|
| 10 |
-
N, C, H, W = x.shape
|
| 11 |
-
Hout = (H + 2*pad - kH)//stride + 1
|
| 12 |
-
Wout = (W + 2*pad - kW)//stride + 1
|
| 13 |
-
if pad:
|
| 14 |
-
x_pad = torch.zeros((N, C, H + 2*pad, W + 2*pad), dtype=x.dtype, device=x.device)
|
| 15 |
-
x_pad[:, :, pad:pad+H, pad:pad+W] = x
|
| 16 |
-
else:
|
| 17 |
-
x_pad = x
|
| 18 |
-
cols = []
|
| 19 |
-
for i in range(Hout):
|
| 20 |
-
for j in range(Wout):
|
| 21 |
-
patch = x_pad[:, :, i*stride:i*stride+kH, j*stride:j*stride+kW]
|
| 22 |
-
cols.append(patch.reshape(N, -1))
|
| 23 |
-
out = torch.stack(cols, dim=1) # [N, Hout*Wout, C*kH*kW]
|
| 24 |
-
out = out.reshape(N*Hout*Wout, -1) # [N*Hout*Wout, C*kH*kW]
|
| 25 |
-
return out, Hout, Wout
|
| 26 |
-
|
| 27 |
-
def conv_int64_im2col(X, W, kH, kW, stride=1, pad=0):
|
| 28 |
-
Xcol, Hout, Wout = im2col_nchw_int(X, kH, kW, stride=stride, pad=pad)
|
| 29 |
-
Ycol = Xcol @ W.reshape(W.shape[0], -1).t()
|
| 30 |
-
return Ycol.reshape(X.shape[0], Hout, Wout, W.shape[0]).permute(0,3,1,2).contiguous()
|
| 31 |
-
|
| 32 |
-
def poly_act_floor_int64(x, s): # exact q = floor(x^2 / 2^s)
|
| 33 |
-
return (x * x) >> s
|
| 34 |
-
|
| 35 |
-
def flatten_nchw(x): return x.reshape(x.shape[0], -1)
|
| 36 |
-
def linear_int64(X, W): return X @ W
|
| 37 |
-
|
| 38 |
-
# ===== RNS / CRT =====
|
| 39 |
-
def choose_moduli(MAX_ABS, base_primes=None, K_cap=62):
|
| 40 |
-
# minimal K so 2^(K-1) > MAX_ABS, cap at 62 for int64 safety
|
| 41 |
-
K = min(K_cap, max(2, MAX_ABS.bit_length() + 1))
|
| 42 |
MOD2 = 1 << K
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
M_total *= i
|
| 51 |
-
i = {283:293, 293:307, 307:311, 311:313, 313:317}.get(i, i+2)
|
| 52 |
-
return K, MOD2, primes
|
| 53 |
|
| 54 |
def encode_rns(x, MOD2, primes):
|
| 55 |
-
outs = []
|
| 56 |
-
mask = MOD2 - 1
|
| 57 |
-
x_list = x.reshape(-1).cpu().tolist()
|
| 58 |
-
mod2_res = [(v & mask) for v in x_list] # mod 2^K via mask (python int)
|
| 59 |
-
outs.append(torch.tensor(mod2_res, dtype=torch.int64).reshape(x.shape))
|
| 60 |
for p in primes:
|
| 61 |
-
outs.append((x % p)
|
| 62 |
-
return
|
| 63 |
|
| 64 |
-
def
|
|
|
|
|
|
|
|
|
|
| 65 |
M = 1
|
| 66 |
for m in moduli: M *= m
|
| 67 |
-
Mi = [M
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
flat = res.reshape(L, -1).cpu().tolist()
|
| 74 |
-
N = len(flat[0])
|
| 75 |
-
out = [0]*N
|
| 76 |
-
half = M // 2
|
| 77 |
-
for k in range(N):
|
| 78 |
acc = 0
|
| 79 |
for i in range(L):
|
| 80 |
-
acc = (acc + (flat[i
|
| 81 |
if acc > half: acc -= M
|
| 82 |
-
out
|
| 83 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
L = Xr.shape[0]
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
X_i = Xr[i]
|
| 92 |
-
W_i = Wr[i].reshape(Wr[i].shape[0], -1)
|
| 93 |
-
Xcol, Hout, Wout = im2col_nchw_int(X_i.to(torch.int64), kH, kW, stride=stride, pad=pad)
|
| 94 |
-
Ycol = (Xcol @ W_i.t()) % mod
|
| 95 |
-
Y = Ycol.reshape(X_i.shape[0], Hout, Wout, W_i.shape[0]).permute(0,3,1,2).contiguous()
|
| 96 |
-
Ys.append(Y.to(torch.int64))
|
| 97 |
-
return torch.stack(Ys, dim=0) # [L,B,Cout,H,W]
|
| 98 |
-
|
| 99 |
-
def relu_rns(Yr, MOD2, primes, K):
|
| 100 |
-
# Use sign from 2^K residue: negative iff top bit set.
|
| 101 |
-
L = Yr.shape[0]
|
| 102 |
-
y2k = Yr[0] & (MOD2 - 1)
|
| 103 |
-
neg = ((y2k >> (K-1)) & 1).to(torch.int64) # 1 if negative else 0
|
| 104 |
-
keep = 1 - neg # 0 or 1
|
| 105 |
out = []
|
| 106 |
-
out.append((
|
| 107 |
-
for i,
|
| 108 |
-
out.append((
|
| 109 |
-
return
|
| 110 |
|
| 111 |
-
def
|
| 112 |
-
|
|
|
|
|
|
|
| 113 |
# 2^K channel
|
| 114 |
-
|
| 115 |
-
q2
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
for
|
| 121 |
-
y2p = (
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
qi
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
L = Xr.shape[0]
|
| 130 |
-
|
|
|
|
|
|
|
| 131 |
for i in range(L):
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
b1 = bconv(Cin, KH, x_abs, w_abs)
|
| 162 |
-
a1 = (b1*b1) >> (shift if act_kind=="poly" else 0)
|
| 163 |
-
b2 = int(C1*KH*KH * a1 * w_abs)
|
| 164 |
-
a2 = (b2*b2) >> (shift if act_kind=="poly" else 0)
|
| 165 |
-
fc = int((C2*H*W) * a2 * w_abs)
|
| 166 |
-
MAX_ABS = max(b1,a1,b2,a2,fc)
|
| 167 |
-
|
| 168 |
-
K, MOD2, primes = choose_moduli(MAX_ABS)
|
| 169 |
-
moduli = [MOD2] + primes
|
| 170 |
-
M, Mi_np, inv_np, mods_np = precompute_crt(moduli)
|
| 171 |
-
|
| 172 |
-
# encode & run RNS
|
| 173 |
-
X_r = encode_rns(X, MOD2, primes)
|
| 174 |
-
W1_r = encode_rns(W1, MOD2, primes)
|
| 175 |
-
W2_r = encode_rns(W2, MOD2, primes)
|
| 176 |
-
Wfc_r = encode_rns(Wfc, MOD2, primes)
|
| 177 |
-
|
| 178 |
-
Y1_r = conv_rns(X_r, W1_r, KH, KW, STRIDE, PAD, mods_np)
|
| 179 |
-
A1_r = (relu_rns(Y1_r, MOD2, primes, K) if act_kind=="relu" else poly_act_floor_rns(Y1_r, shift, MOD2, primes))
|
| 180 |
-
Y2_r = conv_rns(A1_r, W2_r, KH, KW, STRIDE, PAD, mods_np)
|
| 181 |
-
A2_r = (relu_rns(Y2_r, MOD2, primes, K) if act_kind=="relu" else poly_act_floor_rns(Y2_r, shift, MOD2, primes))
|
| 182 |
-
t2 = time.time()
|
| 183 |
-
Z_r = linear_rns(A2_r.reshape(len(moduli), B, -1), Wfc_r, mods_np)
|
| 184 |
-
Z_rec= batched_crt(Z_r, M, Mi_np, inv_np, mods_np)
|
| 185 |
-
t3 = time.time()
|
| 186 |
-
rns_ms = (t3 - t2)*1000
|
| 187 |
-
|
| 188 |
-
# checks
|
| 189 |
-
flags = {
|
| 190 |
-
"conv1": bool(torch.equal(batched_crt(Y1_r, M, Mi_np, inv_np, mods_np), Y1_ref)),
|
| 191 |
-
"act1": bool(torch.equal(batched_crt(A1_r, M, Mi_np, inv_np, mods_np), A1_ref)),
|
| 192 |
-
"conv2": bool(torch.equal(batched_crt(Y2_r, M, Mi_np, inv_np, mods_np), Y2_ref)),
|
| 193 |
-
"act2": bool(torch.equal(batched_crt(A2_r, M, Mi_np, inv_np, mods_np), A2_ref)),
|
| 194 |
-
"final": bool(torch.equal(Z_rec, Z_ref)),
|
| 195 |
-
"argmax": bool(torch.equal(Z_rec.argmax(1), Z_ref.argmax(1))),
|
| 196 |
-
}
|
| 197 |
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
report = {
|
| 217 |
-
"
|
| 218 |
-
"argmax_ok": flags["argmax"],
|
| 219 |
-
"act": act_kind,
|
| 220 |
"shift": int(shift),
|
| 221 |
-
"seed": int(seed),
|
| 222 |
-
"x_abs": int(x_abs),
|
| 223 |
-
"w_abs": int(w_abs),
|
| 224 |
"K": int(K),
|
| 225 |
-
"primes": primes,
|
| 226 |
-
"
|
| 227 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
}
|
| 229 |
-
return "\n".join(
|
| 230 |
-
|
| 231 |
-
#
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
)
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
|
| 253 |
if __name__ == "__main__":
|
| 254 |
-
demo.launch()
|
|
|
|
| 1 |
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import numpy as np
|
| 4 |
|
| 5 |
+
# -----------------------------
|
| 6 |
+
# Utilities: RNS + CRT
|
| 7 |
+
# -----------------------------
|
| 8 |
+
def pick_moduli(max_abs, primes=(257,263,269,271,277,281)):
|
| 9 |
+
# choose K so 2^(K-1) > max_abs
|
| 10 |
+
K = max(20, int(max_abs).bit_length()+1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
MOD2 = 1 << K
|
| 12 |
+
# ensure product > 2*max_abs (with plenty margin for toy sizes this is fine)
|
| 13 |
+
M = MOD2
|
| 14 |
+
ps = []
|
| 15 |
+
for p in primes:
|
| 16 |
+
ps.append(p); M *= p
|
| 17 |
+
if M > 2*max_abs: break
|
| 18 |
+
return K, MOD2, tuple(ps)
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def encode_rns(x, MOD2, primes):
|
| 21 |
+
outs = [np.asarray(x, dtype=object) & (MOD2-1)]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
for p in primes:
|
| 23 |
+
outs.append(np.asarray(x, dtype=object) % p)
|
| 24 |
+
return np.stack(outs, axis=0) # [L, *shape]
|
| 25 |
|
| 26 |
+
def batched_crt(res, MOD2, primes):
|
| 27 |
+
# res: [L, *shape], L=1+len(primes), dtype=object (python ints)
|
| 28 |
+
moduli = (MOD2,)+tuple(primes)
|
| 29 |
+
L = res.shape[0]
|
| 30 |
M = 1
|
| 31 |
for m in moduli: M *= m
|
| 32 |
+
Mi = [M//m for m in moduli]
|
| 33 |
+
inv = [pow(int(Mi[i] % moduli[i]), -1, moduli[i]) for i in range(L)]
|
| 34 |
+
flat = res.reshape(L, -1)
|
| 35 |
+
out = []
|
| 36 |
+
half = M//2
|
| 37 |
+
for k in range(flat.shape[1]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
acc = 0
|
| 39 |
for i in range(L):
|
| 40 |
+
acc = (acc + (flat[i,k] * inv[i] * Mi[i])) % M
|
| 41 |
if acc > half: acc -= M
|
| 42 |
+
out.append(int(acc))
|
| 43 |
+
return np.array(out, dtype=np.int64).reshape(res.shape[1:])
|
| 44 |
+
|
| 45 |
+
# -----------------------------
|
| 46 |
+
# CNN proof (tiny shapes)
|
| 47 |
+
# -----------------------------
|
| 48 |
+
def im2col_nchw_int(x, kH, kW, stride=1, pad=0):
|
| 49 |
+
N,C,H,W = x.shape
|
| 50 |
+
Hp = H+2*pad; Wp = W+2*pad
|
| 51 |
+
xp = np.zeros((N,C,Hp,Wp), dtype=np.int64)
|
| 52 |
+
xp[:,:,pad:pad+H, pad:pad+W] = x
|
| 53 |
+
Hout = (Hp - kH)//stride + 1
|
| 54 |
+
Wout = (Wp - kW)//stride + 1
|
| 55 |
+
cols = []
|
| 56 |
+
for i in range(Hout):
|
| 57 |
+
for j in range(Wout):
|
| 58 |
+
patch = xp[:,:,i*stride:i*stride+kH, j*stride:j*stride+kW]
|
| 59 |
+
cols.append(patch.reshape(N,-1))
|
| 60 |
+
out = np.stack(cols, axis=1) # [N, Hout*Wout, C*kH*kW]
|
| 61 |
+
out = out.reshape(N*Hout*Wout, -1)
|
| 62 |
+
return out, Hout, Wout
|
| 63 |
|
| 64 |
+
def conv_ref(X, W, kH, kW, stride=1, pad=0):
|
| 65 |
+
N,C,H,W = X.shape
|
| 66 |
+
Cout,Cin,KH,KW = W.shape
|
| 67 |
+
Xcol,Hout,Wout = im2col_nchw_int(X,kH,kW,stride,pad)
|
| 68 |
+
Ycol = Xcol @ W.reshape(Cout,-1).T
|
| 69 |
+
Y = Ycol.reshape(N,Hout,Wout,Cout).transpose(0,3,1,2).copy()
|
| 70 |
+
return Y
|
| 71 |
+
|
| 72 |
+
def relu_ref(x): return np.where(x<0, 0, x)
|
| 73 |
+
def poly_ref(x, s): return (x*x) >> s
|
| 74 |
+
def linear_ref(X, W): return X @ W
|
| 75 |
+
|
| 76 |
+
def relu_rns(Xr, MOD2, primes):
|
| 77 |
+
# Xr: [L, *shape]
|
| 78 |
L = Xr.shape[0]
|
| 79 |
+
K = MOD2.bit_length()-1
|
| 80 |
+
x2 = Xr[0]
|
| 81 |
+
mask = (x2 < (1<<(K-1))).astype(np.int64) # 1 if non-negative, else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
out = []
|
| 83 |
+
out.append((x2 * mask) & (MOD2-1))
|
| 84 |
+
for i,p in enumerate(primes, start=1):
|
| 85 |
+
out.append((Xr[i] * mask) % p)
|
| 86 |
+
return np.stack(out, axis=0)
|
| 87 |
|
| 88 |
+
def poly_rns(Xr, s, MOD2, primes):
|
| 89 |
+
# q = floor(x^2 / 2^s)
|
| 90 |
+
mask_s = (1<<s)-1
|
| 91 |
+
out = []
|
| 92 |
# 2^K channel
|
| 93 |
+
y2 = (Xr[0]*Xr[0]) & (MOD2-1)
|
| 94 |
+
q2 = (y2 >> s) & (MOD2-1)
|
| 95 |
+
out.append(q2)
|
| 96 |
+
# r from low s bits of x only
|
| 97 |
+
low = Xr[0] & mask_s
|
| 98 |
+
r = (low*low) & mask_s
|
| 99 |
+
for i,p in enumerate(primes, start=1):
|
| 100 |
+
y2p = (Xr[i]*Xr[i]) % p
|
| 101 |
+
inv = pow(2, -s, p)
|
| 102 |
+
qi = ((y2p - (r % p)) * inv) % p
|
| 103 |
+
out.append(qi)
|
| 104 |
+
return np.stack(out, axis=0)
|
| 105 |
+
|
| 106 |
+
def matmul_rns(Ar, Br, MOD2, primes):
|
| 107 |
+
# Ar: [L, M, K], Br: [L, K, N] -> [L, M, N]
|
| 108 |
+
outs=[]
|
| 109 |
+
outs.append((Ar[0] @ Br[0]) & (MOD2-1))
|
| 110 |
+
for i,p in enumerate(primes, start=1):
|
| 111 |
+
outs.append((Ar[i] @ Br[i]) % p)
|
| 112 |
+
return np.stack(outs, axis=0)
|
| 113 |
+
|
| 114 |
+
def conv_rns(Xr, Wr, kH, kW, stride=1, pad=0, MOD2=0, primes=()):
|
| 115 |
+
# reshape per-modulus and reuse im2col on ints
|
| 116 |
L = Xr.shape[0]
|
| 117 |
+
N,C,H,W = Xr[0].shape
|
| 118 |
+
Cout = Wr[0].shape[0]
|
| 119 |
+
Ymods=[]
|
| 120 |
for i in range(L):
|
| 121 |
+
Xi = Xr[i].astype(object)
|
| 122 |
+
Wi = Wr[i].astype(object)
|
| 123 |
+
Xcol,Hout,Wout = im2col_nchw_int(np.asarray(Xi, dtype=np.int64), kH,kW,stride,pad)
|
| 124 |
+
Wcol = np.asarray(Wi.reshape(Cout,-1), dtype=np.int64)
|
| 125 |
+
if i==0:
|
| 126 |
+
MOD = MOD2
|
| 127 |
+
Ycol = (Xcol @ Wcol.T) & (MOD-1)
|
| 128 |
+
else:
|
| 129 |
+
p = primes[i-1]
|
| 130 |
+
Ycol = (Xcol @ Wcol.T) % p
|
| 131 |
+
Y = Ycol.reshape(N,Hout,Wout,Cout).transpose(0,3,1,2).copy()
|
| 132 |
+
Ymods.append(Y.astype(object))
|
| 133 |
+
return np.stack(Ymods, axis=0)
|
| 134 |
+
|
| 135 |
+
def cnn_proof(seed, xmax, wmax, activation, shift):
|
| 136 |
+
# Tiny shapes to keep it snappy
|
| 137 |
+
B,Cin,H,W = 2,1,8,8
|
| 138 |
+
Cout = 2
|
| 139 |
+
KH=KW=3
|
| 140 |
+
CLS=8
|
| 141 |
+
rng = np.random.default_rng(seed)
|
| 142 |
+
X = rng.integers(-xmax, xmax+1, size=(B,Cin,H,W), dtype=np.int64)
|
| 143 |
+
Wc = rng.integers(-wmax, wmax+1, size=(Cout,Cin,KH,KW), dtype=np.int64)
|
| 144 |
+
Wfc = rng.integers(-wmax, wmax+1, size=(Cout*H*W, CLS), dtype=np.int64)
|
| 145 |
+
|
| 146 |
+
# Reference (int64)
|
| 147 |
+
Y = conv_ref(X,Wc,KH,KW,1,1)
|
| 148 |
+
A = relu_ref(Y) if activation=="relu" else poly_ref(Y, shift)
|
| 149 |
+
Z = linear_ref(A.reshape(B,-1), Wfc)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
+
# Bounds (very conservative)
|
| 152 |
+
b1 = int(Cin*KH*KW * xmax * wmax)
|
| 153 |
+
a1 = b1 if activation=="relu" else (b1*b1)>>shift
|
| 154 |
+
b2 = int(Cout*KH*KW * a1 * wmax)
|
| 155 |
+
a2 = b2 if activation=="relu" else (b2*b2)>>shift
|
| 156 |
+
fc = int((Cout*H*W) * a2 * wmax)
|
| 157 |
+
MAX_ABS = max(abs(b1),abs(a1),abs(b2),abs(a2),abs(fc), 1)
|
| 158 |
+
|
| 159 |
+
K, MOD2, primes = pick_moduli(max_abs=MAX_ABS)
|
| 160 |
+
# Encode
|
| 161 |
+
Xr = encode_rns(X, MOD2, primes)
|
| 162 |
+
Wcr = encode_rns(Wc, MOD2, primes)
|
| 163 |
+
Wfr = encode_rns(Wfc, MOD2, primes)
|
| 164 |
+
|
| 165 |
+
# RNS path
|
| 166 |
+
Yr = conv_rns(Xr,Wcr,KH,KW,1,1,MOD2,primes)
|
| 167 |
+
Ar = relu_rns(Yr, MOD2, primes) if activation=="relu" else poly_rns(Yr, shift, MOD2, primes)
|
| 168 |
+
Arf = Ar.reshape(Ar.shape[0], B, -1)
|
| 169 |
+
Zr = matmul_rns(Arf, Wfr, MOD2, primes)
|
| 170 |
+
Zrec= batched_crt(Zr, MOD2, primes)
|
| 171 |
+
|
| 172 |
+
ok = np.array_equal(Z, Zrec)
|
| 173 |
+
arg = np.array_equal(Z.argmax(1), Zrec.argmax(1))
|
| 174 |
+
|
| 175 |
+
log = []
|
| 176 |
+
log.append("Stage-by-stage equality (exact):")
|
| 177 |
+
log.append(f"conv : {'✅' if np.array_equal(batched_crt(Yr,MOD2,primes), Y) else '❌'}")
|
| 178 |
+
log.append(f"act : {'✅' if np.array_equal(batched_crt(Ar,MOD2,primes), A) else '❌'}")
|
| 179 |
+
log.append(f"fc : {'✅' if ok else '❌'}")
|
| 180 |
+
log.append(f"Argmax match: {arg} (ref={Z.argmax(1).tolist()}, prime={Zrec.argmax(1).tolist()})")
|
| 181 |
|
| 182 |
report = {
|
| 183 |
+
"activation": activation,
|
|
|
|
|
|
|
| 184 |
"shift": int(shift),
|
|
|
|
|
|
|
|
|
|
| 185 |
"K": int(K),
|
| 186 |
+
"primes": list(map(int,primes)),
|
| 187 |
+
"exact": bool(ok),
|
| 188 |
+
"argmax_exact": bool(arg),
|
| 189 |
+
"ref_logits_sample": Z[0,:min(8,CLS)].tolist(),
|
| 190 |
+
"rec_logits_sample": Zrec[0,:min(8,CLS)].tolist(),
|
| 191 |
+
}
|
| 192 |
+
return "\n".join(log), report
|
| 193 |
+
|
| 194 |
+
# -----------------------------
|
| 195 |
+
# Attention (exact numerators)
|
| 196 |
+
# -----------------------------
|
| 197 |
+
def attn_numerators_proof(seed):
|
| 198 |
+
rng = np.random.default_rng(seed)
|
| 199 |
+
T,d = 2,4
|
| 200 |
+
xmax = 15
|
| 201 |
+
t = 7 # scale shift
|
| 202 |
+
|
| 203 |
+
Q = rng.integers(-xmax,xmax+1,size=(T,d),dtype=np.int64)
|
| 204 |
+
K_ = rng.integers(-xmax,xmax+1,size=(T,d),dtype=np.int64)
|
| 205 |
+
V = rng.integers(-xmax,xmax+1,size=(T,d),dtype=np.int64)
|
| 206 |
+
|
| 207 |
+
Sraw = Q @ K_.T
|
| 208 |
+
S = Sraw >> t # exact floor
|
| 209 |
+
N = S @ V # numerators
|
| 210 |
+
|
| 211 |
+
# conservative bound
|
| 212 |
+
s_max = int((d * xmax * xmax) >> t)
|
| 213 |
+
n_max = int(T * s_max * xmax)
|
| 214 |
+
Kpow, MOD2, primes = pick_moduli(max_abs=max(abs(s_max),abs(n_max),1))
|
| 215 |
+
|
| 216 |
+
Qr = encode_rns(Q,MOD2,primes)
|
| 217 |
+
Kr = encode_rns(K_,MOD2,primes)
|
| 218 |
+
Vr = encode_rns(V,MOD2,primes)
|
| 219 |
+
|
| 220 |
+
# Sraw_r and S_r
|
| 221 |
+
Sraw_r = matmul_rns(Qr, Kr.transpose(0,2,1), MOD2, primes)
|
| 222 |
+
# floor by right shift in 2^K channel + multiplicative inverse in odd primes
|
| 223 |
+
# (here scaling is power-of-two so we can just shift in 2^K, and multiply by inv(2^t) after peeling low bits = 0)
|
| 224 |
+
# since we already floored by pure shift in reference, it's safe to do:
|
| 225 |
+
Sr = []
|
| 226 |
+
Sr.append( (Sraw_r[0] >> t) & (MOD2-1) )
|
| 227 |
+
invs = [pow(2,-t,p) for p in primes]
|
| 228 |
+
for i,p in enumerate(primes, start=1):
|
| 229 |
+
Sr.append( (Sraw_r[i] * invs[i-1]) % p )
|
| 230 |
+
Sr = np.stack(Sr, axis=0)
|
| 231 |
+
|
| 232 |
+
Nr = matmul_rns(Sr, Vr, MOD2, primes)
|
| 233 |
+
Srec = batched_crt(Sr, MOD2, primes)
|
| 234 |
+
Nrec = batched_crt(Nr, MOD2, primes)
|
| 235 |
+
|
| 236 |
+
okS = np.array_equal(S, Srec)
|
| 237 |
+
okN = np.array_equal(N, Nrec)
|
| 238 |
+
|
| 239 |
+
log = []
|
| 240 |
+
log.append("Attention (exact numerators):")
|
| 241 |
+
log.append(f"S : {'✅' if okS else '❌'}")
|
| 242 |
+
log.append(f"SV : {'✅' if okN else '❌'}")
|
| 243 |
+
|
| 244 |
+
report = {
|
| 245 |
+
"K": int(Kpow),
|
| 246 |
+
"primes": list(map(int,primes)),
|
| 247 |
+
"S_exact": bool(okS),
|
| 248 |
+
"N_exact": bool(okN),
|
| 249 |
+
"S_sample": S.tolist(),
|
| 250 |
+
"Srec_sample": Srec.tolist(),
|
| 251 |
+
"N_sample": N.tolist(),
|
| 252 |
+
"Nrec_sample": Nrec.tolist(),
|
| 253 |
}
|
| 254 |
+
return "\n".join(log), report
|
| 255 |
+
|
| 256 |
+
# -----------------------------
|
| 257 |
+
# Gradio UI
|
| 258 |
+
# -----------------------------
|
| 259 |
+
with gr.Blocks(title="FieldSpace — Prime-Only Machine (RNS+CRT)") as demo:
|
| 260 |
+
gr.Markdown("# FieldSpace — Prime-Only Machine (RNS+CRT) — Exactness Proofs")
|
| 261 |
+
gr.Markdown("This demo proves exactness for a tiny CNN and a one-head attention numerator block using only residues mod $\\{2^K, p_i\\}$ and a single CRT.")
|
| 262 |
+
with gr.Tab("CNN (Conv + ReLU/Poly + FC)"):
|
| 263 |
+
with gr.Row():
|
| 264 |
+
seed = gr.Number(value=0, precision=0, label="Seed")
|
| 265 |
+
xmax = gr.Number(value=31, precision=0, label="|X|max")
|
| 266 |
+
wmax = gr.Number(value=31, precision=0, label="|W|max")
|
| 267 |
+
activation = gr.Radio(choices=["relu","poly"], value="relu", label="Activation")
|
| 268 |
+
shift = gr.Slider(1,12,value=7,step=1,label="SHIFT for poly (x^2 >> SHIFT)")
|
| 269 |
+
run_btn = gr.Button("Run CNN Proof", variant="primary")
|
| 270 |
+
out_text = gr.Textbox(label="Console", lines=10)
|
| 271 |
+
out_json = gr.JSON(label="JSON Report")
|
| 272 |
+
run_btn.click(fn=lambda a,b,c,d,e: cnn_proof(int(a),int(b),int(c),d,int(e)),
|
| 273 |
+
inputs=[seed,xmax,wmax,activation,shift],
|
| 274 |
+
outputs=[out_text,out_json])
|
| 275 |
+
with gr.Tab("Attention (Exact Numerators)"):
|
| 276 |
+
aseed = gr.Number(value=0, precision=0, label="Seed")
|
| 277 |
+
runA = gr.Button("Run Attention Proof", variant="secondary")
|
| 278 |
+
AT = gr.Textbox(label="Console", lines=8)
|
| 279 |
+
AJ = gr.JSON(label="JSON Report")
|
| 280 |
+
runA.click(fn=lambda s: attn_numerators_proof(int(s)),
|
| 281 |
+
inputs=[aseed], outputs=[AT,AJ])
|
| 282 |
+
gr.Markdown("**Use via API**")
|
| 283 |
+
gr.Code("""from gradio_client import Client
|
| 284 |
+
client = Client('https://huggingface.co/spaces/jackal79/fieldspace-prime-only')
|
| 285 |
+
txt, rep = client.predict('/run_cnn_proof', 0, 31, 31, 'relu', 7)""")
|
| 286 |
+
|
| 287 |
+
# Named routes for programmatic calls
|
| 288 |
+
demo.load(None, None, None)
|
| 289 |
+
demo.add_named_endpoint("/run_cnn_proof", cnn_proof, inputs=[
|
| 290 |
+
gr.Number(precision=0), gr.Number(precision=0), gr.Number(precision=0),
|
| 291 |
+
gr.Textbox(), gr.Number(precision=0)
|
| 292 |
+
], outputs=[gr.Textbox(), gr.JSON()])
|
| 293 |
+
demo.add_named_endpoint("/run_attn_proof", attn_numerators_proof,
|
| 294 |
+
inputs=[gr.Number(precision=0)],
|
| 295 |
+
outputs=[gr.Textbox(), gr.JSON()])
|
| 296 |
|
| 297 |
if __name__ == "__main__":
|
| 298 |
+
demo.queue().launch()
|
requirements.txt
CHANGED
|
@@ -1,4 +1,2 @@
|
|
| 1 |
-
gradio>=4.44
|
| 2 |
numpy>=1.24
|
| 3 |
-
torch==2.3.1
|
| 4 |
-
--extra-index-url https://download.pytorch.org/whl/cpu
|
|
|
|
| 1 |
+
gradio>=4.44.0
|
| 2 |
numpy>=1.24
|
|
|
|
|
|