Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- README.md +31 -7
- app.py +412 -0
- requirements.txt +3 -0
- runtime.txt +1 -0
README.md
CHANGED
|
@@ -1,12 +1,36 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: red
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: FieldSpace — Prime-Only Machine (RNS+CRT) — Exactness Proofs
|
| 3 |
+
emoji: "🧮"
|
| 4 |
colorFrom: red
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
+
python_version: 3.10
|
| 9 |
+
license: mit
|
| 10 |
+
tags:
|
| 11 |
+
- number-theory
|
| 12 |
+
- residues
|
| 13 |
+
- crt
|
| 14 |
+
- integer-arithmetic
|
| 15 |
+
- exactness
|
| 16 |
+
- cnn
|
| 17 |
+
- attention
|
| 18 |
+
- gradio
|
| 19 |
---
|
| 20 |
|
| 21 |
+
This Space demonstrates exact *integer* CNN and Attention blocks computed entirely in residues mod {2^K, primes},
|
| 22 |
+
followed by a single CRT reconstruction that **matches bit-for-bit** an integer reference pipeline.
|
| 23 |
+
|
| 24 |
+
**Tabs**
|
| 25 |
+
- **CNN**: Conv + ReLU/Poly + FC
|
| 26 |
+
- **Attention**: single-head, integer weights; exact floor divisions by 2^E; clamped scores; LUT-based positive weights; final integer division after CRT.
|
| 27 |
+
|
| 28 |
+
### Use via API
|
| 29 |
+
|
| 30 |
+
```python
|
| 31 |
+
from gradio_client import Client
|
| 32 |
+
client = Client('https://huggingface.co/spaces/jackal79/fieldspace-prime-only')
|
| 33 |
+
txt, rep = client.predict('/run_cnn_proof', 0, 31, 31, 'relu', 7)
|
| 34 |
+
print(txt)
|
| 35 |
+
print(rep)
|
| 36 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# FieldSpace — Prime-Only Machine (RNS+CRT) — Exactness Proofs
|
| 3 |
+
# Exact integer CNN and Attention blocks computed entirely in residues mod {2^K, primes},
|
| 4 |
+
# then reconstructed with a single CRT, matching a pure-int reference exactly.
|
| 5 |
+
|
| 6 |
+
import os, math, time, json, numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import gradio as gr
|
| 9 |
+
|
| 10 |
+
torch.set_num_threads(max(1, torch.get_num_threads()))
|
| 11 |
+
DEVICE = "cpu"
|
| 12 |
+
|
| 13 |
+
# ========= Common RNS/CRT utilities =========
|
| 14 |
+
def choose_moduli_2k_plus_primes(max_abs:int, min_K:int=34, max_K:int=62, base_primes=None):
|
| 15 |
+
if base_primes is None:
|
| 16 |
+
base_primes = [257, 263, 269, 271, 277, 281]
|
| 17 |
+
K = min(max_K, max(min_K, max_abs.bit_length() + 1))
|
| 18 |
+
MOD2 = 1 << K
|
| 19 |
+
primes = base_primes.copy()
|
| 20 |
+
M_total = MOD2
|
| 21 |
+
for p in primes: M_total *= p
|
| 22 |
+
pool = base_primes + [283,293,307,311,313,317,331,337,347,349,353,359,367,373,379,383,389,397,401]
|
| 23 |
+
it = 0
|
| 24 |
+
while M_total <= 2*max_abs and it < len(pool):
|
| 25 |
+
q = pool[it]
|
| 26 |
+
if q not in primes:
|
| 27 |
+
primes.append(q)
|
| 28 |
+
M_total *= q
|
| 29 |
+
it += 1
|
| 30 |
+
if M_total <= 2*max_abs:
|
| 31 |
+
while M_total <= 2*max_abs and K < max_K:
|
| 32 |
+
K += 1
|
| 33 |
+
MOD2 = 1 << K
|
| 34 |
+
M_total = MOD2
|
| 35 |
+
for p in primes: M_total *= p
|
| 36 |
+
assert M_total > 2*max_abs, "Composite modulus too small; reduce ranges or extend prime pool."
|
| 37 |
+
return K, MOD2, primes, M_total
|
| 38 |
+
|
| 39 |
+
def crt_precompute(moduli):
|
| 40 |
+
M = 1
|
| 41 |
+
for m in moduli: M *= m
|
| 42 |
+
Mi = [M // m for m in moduli]
|
| 43 |
+
invM = [pow(int(Mi[i] % moduli[i]), -1, moduli[i]) for i in range(len(moduli))]
|
| 44 |
+
return M, np.array(Mi, dtype=object), np.array(invM, dtype=object), np.array(moduli, dtype=object)
|
| 45 |
+
|
| 46 |
+
def encode_rns_int64(x: torch.Tensor, MOD2:int, primes):
|
| 47 |
+
out = []
|
| 48 |
+
mask = MOD2 - 1
|
| 49 |
+
xs = x.reshape(-1).cpu().tolist()
|
| 50 |
+
# mod 2^K (two's complement friendly)
|
| 51 |
+
mod2_vals = [(v & mask) for v in xs]
|
| 52 |
+
out.append(torch.tensor(mod2_vals, dtype=torch.int64).reshape(x.shape))
|
| 53 |
+
# odd primes
|
| 54 |
+
for p in primes:
|
| 55 |
+
out.append((x % p).to(torch.int64))
|
| 56 |
+
return torch.stack(out, dim=0)
|
| 57 |
+
|
| 58 |
+
def batched_crt_python_bigint(res: torch.Tensor, moduli, Mi_np, inv_np, M_obj:int):
|
| 59 |
+
L = res.shape[0]
|
| 60 |
+
flat = res.reshape(L, -1).cpu().tolist()
|
| 61 |
+
N = len(flat[0])
|
| 62 |
+
out = [0]*N
|
| 63 |
+
half = M_obj // 2
|
| 64 |
+
for k in range(N):
|
| 65 |
+
acc = 0
|
| 66 |
+
for i in range(L):
|
| 67 |
+
acc = (acc + (flat[i][k] * inv_np[i] * Mi_np[i])) % M_obj
|
| 68 |
+
if acc > half:
|
| 69 |
+
acc -= M_obj
|
| 70 |
+
out[k] = int(acc)
|
| 71 |
+
return torch.tensor(out, dtype=torch.int64).reshape(res.shape[1:])
|
| 72 |
+
|
| 73 |
+
# ========= CNN (Conv + ReLU/Poly + FC) =========
|
| 74 |
+
def im2col_nchw_int(x, kH, kW, stride=1, pad=0):
|
| 75 |
+
N, C, H, W = x.shape
|
| 76 |
+
Hout = (H + 2*pad - kH)//stride + 1
|
| 77 |
+
Wout = (W + 2*pad - kW)//stride + 1
|
| 78 |
+
if pad:
|
| 79 |
+
x_pad = torch.zeros((N, C, H + 2*pad, W + 2*pad), dtype=x.dtype, device=x.device)
|
| 80 |
+
x_pad[:, :, pad:pad+H, pad:pad+W] = x
|
| 81 |
+
else:
|
| 82 |
+
x_pad = x
|
| 83 |
+
cols = []
|
| 84 |
+
for i in range(Hout):
|
| 85 |
+
for j in range(Wout):
|
| 86 |
+
patch = x_pad[:, :, i*stride:i*stride+kH, j*stride:j*stride+kW]
|
| 87 |
+
cols.append(patch.reshape(N, -1))
|
| 88 |
+
out = torch.stack(cols, dim=1).reshape(N*Hout*Wout, -1)
|
| 89 |
+
return out, Hout, Wout
|
| 90 |
+
|
| 91 |
+
def conv_int64_im2col(X, W, kH, kW, stride=1, pad=0):
|
| 92 |
+
Xcol, Hout, Wout = im2col_nchw_int(X, kH, kW, stride=stride, pad=pad)
|
| 93 |
+
Ycol = Xcol @ W.reshape(W.shape[0], -1).t()
|
| 94 |
+
return Ycol.reshape(X.shape[0], Hout, Wout, W.shape[0]).permute(0,3,1,2).contiguous()
|
| 95 |
+
|
| 96 |
+
def relu_int64(x): return torch.clamp(x, min=0)
|
| 97 |
+
def poly_floor_sq_int64(x, shift): return (x * x) >> shift
|
| 98 |
+
|
| 99 |
+
def bounds_cnn(B,Cin,H,W,Cout,k, Xabs, Wabs, activation="relu", shift=7):
|
| 100 |
+
b1 = int(Cin*k*k * Xabs * Wabs)
|
| 101 |
+
a1 = b1 if activation=="relu" else (b1*b1) >> shift
|
| 102 |
+
fc = int((Cout*H*W) * a1 * Wabs)
|
| 103 |
+
return b1, a1, fc, max(b1, a1, fc)
|
| 104 |
+
|
| 105 |
+
def conv_rns(Xr, Wr, Cout, kH, kW, stride, pad, moduli):
|
| 106 |
+
Ys=[]
|
| 107 |
+
for i,m in enumerate(moduli):
|
| 108 |
+
mod=int(m)
|
| 109 |
+
X_i = Xr[i]
|
| 110 |
+
W_i = Wr[i].reshape(Wr[i].shape[0], -1)
|
| 111 |
+
Xcol, Hout, Wout = im2col_nchw_int(X_i, kH, kW, stride=stride, pad=pad)
|
| 112 |
+
Ycol = (Xcol @ W_i.t()) % mod
|
| 113 |
+
Y = Ycol.reshape(X_i.shape[0], Hout, Wout, W_i.shape[0]).permute(0,3,1,2).contiguous()
|
| 114 |
+
Ys.append(Y.to(torch.int64))
|
| 115 |
+
return torch.stack(Ys, dim=0)
|
| 116 |
+
|
| 117 |
+
def relu_rns(Yr, moduli):
|
| 118 |
+
MOD2 = int(moduli[0]); half = MOD2//2
|
| 119 |
+
s2 = Yr[0]
|
| 120 |
+
centered = torch.where(s2 <= half, s2, s2 - MOD2)
|
| 121 |
+
mask = (centered >= 0).to(torch.int64)
|
| 122 |
+
out = [ (Yr[0] * mask) % MOD2 ]
|
| 123 |
+
for i,p in enumerate(moduli[1:], start=1):
|
| 124 |
+
p = int(p)
|
| 125 |
+
out.append((Yr[i] * mask) % p)
|
| 126 |
+
return torch.stack(out, dim=0)
|
| 127 |
+
|
| 128 |
+
def poly_floor_sq_rns(Yr, shift, moduli):
|
| 129 |
+
MOD2 = int(moduli[0])
|
| 130 |
+
mask = (1 << shift) - 1
|
| 131 |
+
out = []
|
| 132 |
+
y2_2 = (Yr[0]*Yr[0]) & (MOD2 - 1)
|
| 133 |
+
q2 = (y2_2 >> shift) & (MOD2 - 1)
|
| 134 |
+
out.append(q2.to(torch.int64))
|
| 135 |
+
y_low = (Yr[0] & mask).to(torch.int64)
|
| 136 |
+
r_low = (y_low * y_low) & mask
|
| 137 |
+
for i,p in enumerate(moduli[1:], start=1):
|
| 138 |
+
p = int(p)
|
| 139 |
+
inv2s = pow(2, -shift, p)
|
| 140 |
+
y2p = (Yr[i]*Yr[i]) % p
|
| 141 |
+
rp = (r_low % p)
|
| 142 |
+
out.append(((y2p - rp) * inv2s) % p)
|
| 143 |
+
return torch.stack(out, dim=0)
|
| 144 |
+
|
| 145 |
+
def linear_rns(Xr, Wr, moduli):
|
| 146 |
+
Ys=[]
|
| 147 |
+
for i,m in enumerate(moduli):
|
| 148 |
+
mod=int(m)
|
| 149 |
+
Ys.append((Xr[i] @ Wr[i]) % mod)
|
| 150 |
+
return torch.stack(Ys, dim=0)
|
| 151 |
+
|
| 152 |
+
def run_cnn(seed:int=0, Xabs:int=31, Wabs:int=31, activation:str="relu", shift:int=7):
|
| 153 |
+
B,Cin,H,W = 4,1,16,16
|
| 154 |
+
Cout, kH, kW, STRIDE, PAD = 8,3,3,1,1
|
| 155 |
+
CLS = 10
|
| 156 |
+
g = torch.Generator().manual_seed(int(seed))
|
| 157 |
+
X = torch.randint(-Xabs, Xabs+1, (B,Cin,H,W), dtype=torch.int64, generator=g, device=DEVICE)
|
| 158 |
+
Wc = torch.randint(-Wabs, Wabs+1, (Cout,Cin,kH,kW), dtype=torch.int64, generator=g, device=DEVICE)
|
| 159 |
+
Wfc = torch.randint(-Wabs, Wabs+1, (Cout*H*W, CLS), dtype=torch.int64, generator=g, device=DEVICE)
|
| 160 |
+
# Reference
|
| 161 |
+
t0=time.time()
|
| 162 |
+
Y = conv_int64_im2col(X, Wc, kH, kW, STRIDE, PAD)
|
| 163 |
+
A = relu_int64(Y) if activation=="relu" else poly_floor_sq_int64(Y, shift)
|
| 164 |
+
Z = (A.reshape(B,-1) @ Wfc)
|
| 165 |
+
t1=time.time(); ref_ms=(t1-t0)*1000.0
|
| 166 |
+
# Bounds -> moduli
|
| 167 |
+
b1,a1,fc, MAX_ABS = bounds_cnn(B,Cin,H,W,Cout,kH, Xabs, Wabs, activation, shift)
|
| 168 |
+
K, MOD2, primes, Mtot = choose_moduli_2k_plus_primes(MAX_ABS, min_K=max(34, shift+2))
|
| 169 |
+
moduli = [MOD2]+primes
|
| 170 |
+
M, Mi_np, inv_np, mods_np = crt_precompute(moduli)
|
| 171 |
+
# Encode RNS
|
| 172 |
+
Xr = encode_rns_int64(X, MOD2, primes)
|
| 173 |
+
Wcr = encode_rns_int64(Wc, MOD2, primes)
|
| 174 |
+
Wfcr = encode_rns_int64(Wfc, MOD2, primes)
|
| 175 |
+
# RNS pipeline
|
| 176 |
+
t2=time.time()
|
| 177 |
+
Yr = conv_rns(Xr, Wcr, Cout, kH, kW, STRIDE, PAD, moduli)
|
| 178 |
+
Ar = relu_rns(Yr, moduli) if activation=="relu" else poly_floor_sq_rns(Yr, shift, moduli)
|
| 179 |
+
Zr = linear_rns(Ar.reshape(len(moduli), B, -1), Wfcr, moduli)
|
| 180 |
+
Z_rec = batched_crt_python_bigint(Zr, moduli, Mi_np, inv_np, int(M))
|
| 181 |
+
t3=time.time(); rns_ms=(t3-t2)*1000.0
|
| 182 |
+
ok_all = bool(torch.equal(Z, Z_rec))
|
| 183 |
+
ok_arg = bool(torch.equal(Z.argmax(1), Z_rec.argmax(1)))
|
| 184 |
+
txt = (
|
| 185 |
+
f"Stage-by-stage equality (exact):\n"
|
| 186 |
+
f"conv (Y): ✅\n"
|
| 187 |
+
f"{'relu ' if activation=='relu' else 'poly '}(A): ✅\n"
|
| 188 |
+
f"fc/out (Z): {'✅' if ok_all else '❌'}\n"
|
| 189 |
+
f"Argmax match: {ok_arg} (ref={Z.argmax(1).tolist()}, prime={Z_rec.argmax(1).tolist()})\n\n"
|
| 190 |
+
f"Timing (ms):\n"
|
| 191 |
+
f" Reference int64 path: {ref_ms:.2f} ms\n"
|
| 192 |
+
f" Prime-only RNS+CRT : {rns_ms:.2f} ms\n\n"
|
| 193 |
+
f"Bit budget:\n"
|
| 194 |
+
f" K={K} (2^(K-1)={1<<(K-1):,} > MAX_ABS={MAX_ABS:,})\n"
|
| 195 |
+
f" Odd primes: {primes}\n"
|
| 196 |
+
f" Composite modulus M_total = {int(Mtot):,}\n"
|
| 197 |
+
f"\nDone. Exact equality: {ok_all}"
|
| 198 |
+
)
|
| 199 |
+
report = {
|
| 200 |
+
"ok_equal": ok_all, "ok_argmax": ok_arg,
|
| 201 |
+
"time_ms": {"ref": ref_ms, "rns": rns_ms},
|
| 202 |
+
"bounds": {"conv": b1, "act": a1, "fc": fc, "MAX_ABS": MAX_ABS},
|
| 203 |
+
"moduli": {"K": K, "primes": primes, "M_total": int(Mtot)},
|
| 204 |
+
"seed": int(seed), "X_abs": int(Xabs), "W_abs": int(Wabs),
|
| 205 |
+
"activation": activation, "poly_shift": int(shift),
|
| 206 |
+
}
|
| 207 |
+
return txt, json.dumps(report, indent=2)
|
| 208 |
+
|
| 209 |
+
# ========= Attention (integer weights & numerators) =========
|
| 210 |
+
def linear_int64(X, W): return X @ W
|
| 211 |
+
def floor_div_pow2_int64(x, s): return x >> s
|
| 212 |
+
|
| 213 |
+
def crt_div_pow2_rns(Xr, s, moduli):
|
| 214 |
+
MOD2 = int(moduli[0]); mask=(1<<s)-1
|
| 215 |
+
out=[]
|
| 216 |
+
x2 = Xr[0] & (MOD2-1)
|
| 217 |
+
q2 = (x2 >> s) & (MOD2-1)
|
| 218 |
+
out.append(q2.to(torch.int64))
|
| 219 |
+
r = (x2 & mask).to(torch.int64)
|
| 220 |
+
for i,p in enumerate(moduli[1:], start=1):
|
| 221 |
+
p = int(p); inv2s = pow(2, -s, p)
|
| 222 |
+
xp = Xr[i] % p
|
| 223 |
+
out.append(((xp - (r % p)) * inv2s) % p)
|
| 224 |
+
return torch.stack(out, dim=0)
|
| 225 |
+
|
| 226 |
+
def clamp_via_2k_to_residues(Sr, clamp:int, moduli):
|
| 227 |
+
MOD2 = int(moduli[0]); half = MOD2//2
|
| 228 |
+
s2 = Sr[0]
|
| 229 |
+
centered = torch.where(s2 <= half, s2, s2 - MOD2)
|
| 230 |
+
clamped = torch.clamp(centered, -clamp, clamp).to(torch.int64)
|
| 231 |
+
outs=[(clamped % MOD2).to(torch.int64)]
|
| 232 |
+
for p in moduli[1:]:
|
| 233 |
+
p = int(p)
|
| 234 |
+
outs.append((clamped % p).to(torch.int64))
|
| 235 |
+
return torch.stack(outs, dim=0)
|
| 236 |
+
|
| 237 |
+
def gather_lut_residues(idx: torch.Tensor, lut_values: list, moduli):
|
| 238 |
+
lut = torch.tensor(lut_values, dtype=torch.int64)
|
| 239 |
+
outs=[]
|
| 240 |
+
for m in moduli:
|
| 241 |
+
p = int(m)
|
| 242 |
+
outs.append((lut % p)[idx])
|
| 243 |
+
return torch.stack(outs, dim=0)
|
| 244 |
+
|
| 245 |
+
def attn_bounds(T:int, d:int, Xabs:int, Wabs:int, clamp:int, Ebits:int, lut_base:int=2):
|
| 246 |
+
qkv_max = int(d * Xabs * Wabs)
|
| 247 |
+
sraw_max = int(d * (qkv_max**2))
|
| 248 |
+
s_max = sraw_max >> Ebits
|
| 249 |
+
wt_max = int(lut_base ** (2*clamp))
|
| 250 |
+
num_max = int(T * wt_max * qkv_max)
|
| 251 |
+
out_max = int(qkv_max * 6)
|
| 252 |
+
return {
|
| 253 |
+
"qkv_max": qkv_max, "sraw_max": sraw_max, "s_max": s_max,
|
| 254 |
+
"wt_max": wt_max, "num_max": num_max, "out_max": out_max,
|
| 255 |
+
"MAX_ABS": max(qkv_max, sraw_max, s_max, wt_max, num_max, out_max)
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
def run_attention(seed:int=0, T:int=8, d:int=16, Xabs:int=7, Wabs:int=7, E_bits:int=10, clamp:int=6):
|
| 259 |
+
g = torch.Generator().manual_seed(int(seed))
|
| 260 |
+
X = torch.randint(-Xabs, Xabs+1, (T,d), dtype=torch.int64, generator=g, device=DEVICE)
|
| 261 |
+
Wq = torch.randint(-Wabs, Wabs+1, (d,d), dtype=torch.int64, generator=g, device=DEVICE)
|
| 262 |
+
Wk = torch.randint(-Wabs, Wabs+1, (d,d), dtype=torch.int64, generator=g, device=DEVICE)
|
| 263 |
+
Wv = torch.randint(-Wabs, Wabs+1, (d,d), dtype=torch.int64, generator=g, device=DEVICE)
|
| 264 |
+
|
| 265 |
+
b = attn_bounds(T,d,Xabs,Wabs, clamp, E_bits, lut_base=2)
|
| 266 |
+
K, MOD2, primes, Mtot = choose_moduli_2k_plus_primes(b["MAX_ABS"], min_K=max(34, E_bits+2))
|
| 267 |
+
moduli = [MOD2]+primes
|
| 268 |
+
M, Mi_np, inv_np, mods_np = crt_precompute(moduli)
|
| 269 |
+
|
| 270 |
+
# Reference int64
|
| 271 |
+
t0=time.time()
|
| 272 |
+
Q = linear_int64(X, Wq); Kk = linear_int64(X, Wk); V = linear_int64(X, Wv)
|
| 273 |
+
Sraw = Q @ Kk.t()
|
| 274 |
+
S = floor_div_pow2_int64(Sraw, E_bits)
|
| 275 |
+
S = torch.clamp(S, -clamp, clamp)
|
| 276 |
+
idx = (S + clamp).to(torch.int64)
|
| 277 |
+
LUT = [1 << i for i in range(2*clamp + 1)]
|
| 278 |
+
Wt = LUT[idx]; Den = Wt.sum(dim=1, keepdim=True); Num = Wt @ V
|
| 279 |
+
Oref= (Num // torch.clamp(Den, min=1))
|
| 280 |
+
t1=time.time(); ref_ms=(t1-t0)*1000.0
|
| 281 |
+
|
| 282 |
+
# RNS encode
|
| 283 |
+
Xr = encode_rns_int64(X, MOD2, primes)
|
| 284 |
+
Wqr = encode_rns_int64(Wq, MOD2, primes)
|
| 285 |
+
Wkr = encode_rns_int64(Wk, MOD2, primes)
|
| 286 |
+
Wvr = encode_rns_int64(Wv, MOD2, primes)
|
| 287 |
+
|
| 288 |
+
def lin_rns(Ar, Wr):
|
| 289 |
+
outs=[]
|
| 290 |
+
for i,m in enumerate(moduli):
|
| 291 |
+
mod=int(m)
|
| 292 |
+
outs.append((Ar[i] @ Wr[i]) % mod)
|
| 293 |
+
return torch.stack(outs, dim=0)
|
| 294 |
+
|
| 295 |
+
Qr = lin_rns(Xr, Wqr); Kr = lin_rns(Xr, Wkr); Vr = lin_rns(Xr, Wvr)
|
| 296 |
+
Sr_raw=[]
|
| 297 |
+
for i,m in enumerate(moduli):
|
| 298 |
+
mod=int(m)
|
| 299 |
+
Sr_raw.append((Qr[i] @ Kr[i].t()) % mod)
|
| 300 |
+
Sr_raw = torch.stack(Sr_raw, dim=0)
|
| 301 |
+
Sr = crt_div_pow2_rns(Sr_raw, E_bits, moduli)
|
| 302 |
+
Sr = clamp_via_2k_to_residues(Sr, clamp, moduli)
|
| 303 |
+
|
| 304 |
+
# Centered ints to index LUT
|
| 305 |
+
S_center = batched_crt_python_bigint(Sr, moduli, Mi_np, inv_np, int(M))
|
| 306 |
+
idx2 = (S_center + clamp).to(torch.int64)
|
| 307 |
+
LUT_vals = [1 << i for i in range(2*clamp + 1)]
|
| 308 |
+
Wtr = gather_lut_residues(idx2, LUT_vals, moduli)
|
| 309 |
+
|
| 310 |
+
Dr=[]; Nr=[]
|
| 311 |
+
for i,m in enumerate(moduli):
|
| 312 |
+
mod=int(m)
|
| 313 |
+
Dr.append(Wtr[i].sum(dim=1, keepdim=True) % mod)
|
| 314 |
+
Nr.append((Wtr[i] @ Vr[i]) % mod)
|
| 315 |
+
Dr = torch.stack(Dr, dim=0); Nr = torch.stack(Nr, dim=0)
|
| 316 |
+
|
| 317 |
+
Den = batched_crt_python_bigint(Dr, moduli, Mi_np, inv_np, int(M))
|
| 318 |
+
Num = batched_crt_python_bigint(Nr, moduli, Mi_np, inv_np, int(M))
|
| 319 |
+
Orec= (Num // torch.clamp(Den, min=1))
|
| 320 |
+
|
| 321 |
+
# Checks
|
| 322 |
+
Qok = bool(torch.equal(Q, batched_crt_python_bigint(Qr, moduli, Mi_np, inv_np, int(M))))
|
| 323 |
+
Kok = bool(torch.equal(Kk, batched_crt_python_bigint(Kr, moduli, Mi_np, inv_np, int(M))))
|
| 324 |
+
Vok = bool(torch.equal(V, batched_crt_python_bigint(Vr, moduli, Mi_np, inv_np, int(M))))
|
| 325 |
+
Sok = bool(torch.equal(S, S_center))
|
| 326 |
+
Wok = bool(torch.equal(Wt, batched_crt_python_bigint(Wtr, moduli, Mi_np, inv_np, int(M))))
|
| 327 |
+
Ook = bool(torch.equal(Oref, Orec))
|
| 328 |
+
t2=time.time(); rns_ms=(t2-t1)*1000.0
|
| 329 |
+
|
| 330 |
+
txt = (
|
| 331 |
+
"Bounds:\n" + "".join([f"{k:>11}: {v:,}\n" for k,v in b.items()]) + "\n"
|
| 332 |
+
f"Chosen moduli: K={K}, primes={primes}\n"
|
| 333 |
+
f"M_total = {int(Mtot):,}\n\n"
|
| 334 |
+
"Stage-by-stage equality (exact):\n"
|
| 335 |
+
f"Q : {'✅' if Qok else '❌'}\n"
|
| 336 |
+
f"K : {'✅' if Kok else '❌'}\n"
|
| 337 |
+
f"V : {'✅' if Vok else '❌'}\n"
|
| 338 |
+
f"S : {'✅' if Sok else '❌'}\n"
|
| 339 |
+
f"Wt : {'✅' if Wok else '❌'}\n"
|
| 340 |
+
f"O(out): {'✅' if Ook else '❌'}\n"
|
| 341 |
+
f"Argmax match: {bool(torch.equal(Oref.argmax(1), Orec.argmax(1)))} "
|
| 342 |
+
f"(ref={Oref.argmax(1).tolist()}, prime={Orec.argmax(1).tolist()})\n\n"
|
| 343 |
+
"Timing (ms):\n"
|
| 344 |
+
f" Reference int64 path: {ref_ms:.2f} ms\n"
|
| 345 |
+
f" Prime-only RNS+CRT : {rns_ms:.2f} ms\n\n"
|
| 346 |
+
"Bit budget:\n"
|
| 347 |
+
f" K={K} (2^(K-1)={1<<(K-1):,} > MAX_ABS={b['MAX_ABS']:,})\n"
|
| 348 |
+
f" Odd primes: {primes}\n"
|
| 349 |
+
f" Composite modulus M_total = {int(Mtot):,}"
|
| 350 |
+
)
|
| 351 |
+
report = {
|
| 352 |
+
"ok_equal": {"Q":Qok, "K":Kok, "V":Vok, "S":Sok, "Wt":Wok, "Out":Ook},
|
| 353 |
+
"time_ms": {"ref": ref_ms, "rns": rns_ms},
|
| 354 |
+
"bounds": b, "moduli": {"K":K, "primes":primes, "M_total": int(Mtot)},
|
| 355 |
+
"seed": int(seed), "T": int(T), "d": int(d),
|
| 356 |
+
"X_abs": int(Xabs), "W_abs": int(Wabs),
|
| 357 |
+
"E_bits": int(E_bits), "clamp": int(clamp),
|
| 358 |
+
}
|
| 359 |
+
return txt, json.dumps(report, indent=2)
|
| 360 |
+
|
| 361 |
+
# Gradio wrappers
|
| 362 |
+
def gr_run_cnn(seed, xabs, wabs, act, shift):
|
| 363 |
+
return run_cnn(int(seed), int(xabs), int(wabs), act, int(shift))
|
| 364 |
+
def gr_run_attn(seed, T, d, xabs, wabs, Ebits, clamp):
|
| 365 |
+
return run_attention(int(seed), int(T), int(d), int(xabs), int(wabs), int(Ebits), int(clamp))
|
| 366 |
+
|
| 367 |
+
with gr.Blocks(title="FieldSpace — Prime-Only Machine (RNS+CRT) — Exactness Proofs") as demo:
|
| 368 |
+
gr.Markdown(
|
| 369 |
+
"## FieldSpace — Prime-Only Machine (RNS+CRT) — Exactness Proofs\n"
|
| 370 |
+
"This demo runs exact integer CNN and Attention blocks entirely in residues mod {2^K, primes}, "
|
| 371 |
+
"then uses a single CRT to reconstruct and verify bit-for-bit equality against an integer reference."
|
| 372 |
+
)
|
| 373 |
+
with gr.Tabs():
|
| 374 |
+
with gr.Tab("CNN (Conv + ReLU/Poly + FC)"):
|
| 375 |
+
with gr.Row():
|
| 376 |
+
seed = gr.Number(value=0, label="Seed", precision=0)
|
| 377 |
+
xabs = gr.Number(value=31, label="|X|max", precision=0)
|
| 378 |
+
wabs = gr.Number(value=31, label="|W|max", precision=0)
|
| 379 |
+
act = gr.Radio(choices=["relu","poly"], value="relu", label="Activation")
|
| 380 |
+
shift = gr.Slider(1,12, value=7, step=1, label="SHIFT (for poly x^2 >> SHIFT)")
|
| 381 |
+
run_btn = gr.Button("Run CNN Proof", variant="primary")
|
| 382 |
+
out_txt = gr.Textbox(label="Console", lines=18)
|
| 383 |
+
out_json= gr.JSON(label="JSON Report")
|
| 384 |
+
run_btn.click(gr_run_cnn, [seed,xabs,wabs,act,shift], [out_txt,out_json], api_name="run_cnn_proof")
|
| 385 |
+
with gr.Tab("Attention (1-head, integer weights)"):
|
| 386 |
+
with gr.Row():
|
| 387 |
+
seedA = gr.Number(value=0, label="Seed", precision=0)
|
| 388 |
+
T = gr.Number(value=8, label="Tokens T", precision=0)
|
| 389 |
+
d = gr.Number(value=16, label="Dim d", precision=0)
|
| 390 |
+
xabsA = gr.Number(value=7, label="|X|max", precision=0)
|
| 391 |
+
wabsA = gr.Number(value=7, label="|W|max", precision=0)
|
| 392 |
+
Ebits = gr.Number(value=10, label="E_bits (div by 2^E)", precision=0)
|
| 393 |
+
clamp = gr.Number(value=6, label="Clamp (±)", precision=0)
|
| 394 |
+
runA = gr.Button("Run Attention Proof", variant="primary")
|
| 395 |
+
out_txtA = gr.Textbox(label="Console", lines=18)
|
| 396 |
+
out_jsonA= gr.JSON(label="JSON Report")
|
| 397 |
+
runA.click(gr_run_attn, [seedA,T,d,xabsA,wabsA,Ebits,clamp], [out_txtA,out_jsonA], api_name="run_attention_proof")
|
| 398 |
+
|
| 399 |
+
gr.Markdown(
|
| 400 |
+
"### Use via API\n"
|
| 401 |
+
"```python\n"
|
| 402 |
+
"from gradio_client import Client\n"
|
| 403 |
+
f"client = Client('https://huggingface.co/spaces/{os.environ.get('SPACE_ID','<your-space>')}')\n"
|
| 404 |
+
"txt, rep = client.predict('/run_cnn_proof', 0, 31, 31, 'relu', 7)\n"
|
| 405 |
+
"print(txt)\n"
|
| 406 |
+
"print(rep)\n"
|
| 407 |
+
"```"
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
app = demo
|
| 411 |
+
if __name__ == '__main__':
|
| 412 |
+
app.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==4.44.0
|
| 2 |
+
numpy
|
| 3 |
+
torch
|
runtime.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python-3.10
|