jackal79 commited on
Commit
e587f96
·
verified ·
1 Parent(s): ee2e15e

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +31 -7
  2. app.py +412 -0
  3. requirements.txt +3 -0
  4. runtime.txt +1 -0
README.md CHANGED
@@ -1,12 +1,36 @@
1
  ---
2
- title: Fieldspace Prime Only
3
- emoji: 🐠
4
  colorFrom: red
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.46.1
8
- app_file: app.py
9
- pinned: false
 
 
 
 
 
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: FieldSpace Prime-Only Machine (RNS+CRT) — Exactness Proofs
3
+ emoji: "🧮"
4
  colorFrom: red
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
+ python_version: 3.10
9
+ license: mit
10
+ tags:
11
+ - number-theory
12
+ - residues
13
+ - crt
14
+ - integer-arithmetic
15
+ - exactness
16
+ - cnn
17
+ - attention
18
+ - gradio
19
  ---
20
 
21
+ This Space demonstrates exact *integer* CNN and Attention blocks computed entirely in residues mod {2^K, primes},
22
+ followed by a single CRT reconstruction that **matches bit-for-bit** an integer reference pipeline.
23
+
24
+ **Tabs**
25
+ - **CNN**: Conv + ReLU/Poly + FC
26
+ - **Attention**: single-head, integer weights; exact floor divisions by 2^E; clamped scores; LUT-based positive weights; final integer division after CRT.
27
+
28
+ ### Use via API
29
+
30
+ ```python
31
+ from gradio_client import Client
32
+ client = Client('https://huggingface.co/spaces/jackal79/fieldspace-prime-only')
33
+ txt, rep = client.predict('/run_cnn_proof', 0, 31, 31, 'relu', 7)
34
+ print(txt)
35
+ print(rep)
36
+ ```
app.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # FieldSpace — Prime-Only Machine (RNS+CRT) — Exactness Proofs
3
+ # Exact integer CNN and Attention blocks computed entirely in residues mod {2^K, primes},
4
+ # then reconstructed with a single CRT, matching a pure-int reference exactly.
5
+
6
+ import os, math, time, json, numpy as np
7
+ import torch
8
+ import gradio as gr
9
+
10
+ torch.set_num_threads(max(1, torch.get_num_threads()))
11
+ DEVICE = "cpu"
12
+
13
+ # ========= Common RNS/CRT utilities =========
14
+ def choose_moduli_2k_plus_primes(max_abs:int, min_K:int=34, max_K:int=62, base_primes=None):
15
+ if base_primes is None:
16
+ base_primes = [257, 263, 269, 271, 277, 281]
17
+ K = min(max_K, max(min_K, max_abs.bit_length() + 1))
18
+ MOD2 = 1 << K
19
+ primes = base_primes.copy()
20
+ M_total = MOD2
21
+ for p in primes: M_total *= p
22
+ pool = base_primes + [283,293,307,311,313,317,331,337,347,349,353,359,367,373,379,383,389,397,401]
23
+ it = 0
24
+ while M_total <= 2*max_abs and it < len(pool):
25
+ q = pool[it]
26
+ if q not in primes:
27
+ primes.append(q)
28
+ M_total *= q
29
+ it += 1
30
+ if M_total <= 2*max_abs:
31
+ while M_total <= 2*max_abs and K < max_K:
32
+ K += 1
33
+ MOD2 = 1 << K
34
+ M_total = MOD2
35
+ for p in primes: M_total *= p
36
+ assert M_total > 2*max_abs, "Composite modulus too small; reduce ranges or extend prime pool."
37
+ return K, MOD2, primes, M_total
38
+
39
+ def crt_precompute(moduli):
40
+ M = 1
41
+ for m in moduli: M *= m
42
+ Mi = [M // m for m in moduli]
43
+ invM = [pow(int(Mi[i] % moduli[i]), -1, moduli[i]) for i in range(len(moduli))]
44
+ return M, np.array(Mi, dtype=object), np.array(invM, dtype=object), np.array(moduli, dtype=object)
45
+
46
+ def encode_rns_int64(x: torch.Tensor, MOD2:int, primes):
47
+ out = []
48
+ mask = MOD2 - 1
49
+ xs = x.reshape(-1).cpu().tolist()
50
+ # mod 2^K (two's complement friendly)
51
+ mod2_vals = [(v & mask) for v in xs]
52
+ out.append(torch.tensor(mod2_vals, dtype=torch.int64).reshape(x.shape))
53
+ # odd primes
54
+ for p in primes:
55
+ out.append((x % p).to(torch.int64))
56
+ return torch.stack(out, dim=0)
57
+
58
+ def batched_crt_python_bigint(res: torch.Tensor, moduli, Mi_np, inv_np, M_obj:int):
59
+ L = res.shape[0]
60
+ flat = res.reshape(L, -1).cpu().tolist()
61
+ N = len(flat[0])
62
+ out = [0]*N
63
+ half = M_obj // 2
64
+ for k in range(N):
65
+ acc = 0
66
+ for i in range(L):
67
+ acc = (acc + (flat[i][k] * inv_np[i] * Mi_np[i])) % M_obj
68
+ if acc > half:
69
+ acc -= M_obj
70
+ out[k] = int(acc)
71
+ return torch.tensor(out, dtype=torch.int64).reshape(res.shape[1:])
72
+
73
+ # ========= CNN (Conv + ReLU/Poly + FC) =========
74
+ def im2col_nchw_int(x, kH, kW, stride=1, pad=0):
75
+ N, C, H, W = x.shape
76
+ Hout = (H + 2*pad - kH)//stride + 1
77
+ Wout = (W + 2*pad - kW)//stride + 1
78
+ if pad:
79
+ x_pad = torch.zeros((N, C, H + 2*pad, W + 2*pad), dtype=x.dtype, device=x.device)
80
+ x_pad[:, :, pad:pad+H, pad:pad+W] = x
81
+ else:
82
+ x_pad = x
83
+ cols = []
84
+ for i in range(Hout):
85
+ for j in range(Wout):
86
+ patch = x_pad[:, :, i*stride:i*stride+kH, j*stride:j*stride+kW]
87
+ cols.append(patch.reshape(N, -1))
88
+ out = torch.stack(cols, dim=1).reshape(N*Hout*Wout, -1)
89
+ return out, Hout, Wout
90
+
91
+ def conv_int64_im2col(X, W, kH, kW, stride=1, pad=0):
92
+ Xcol, Hout, Wout = im2col_nchw_int(X, kH, kW, stride=stride, pad=pad)
93
+ Ycol = Xcol @ W.reshape(W.shape[0], -1).t()
94
+ return Ycol.reshape(X.shape[0], Hout, Wout, W.shape[0]).permute(0,3,1,2).contiguous()
95
+
96
+ def relu_int64(x): return torch.clamp(x, min=0)
97
+ def poly_floor_sq_int64(x, shift): return (x * x) >> shift
98
+
99
+ def bounds_cnn(B,Cin,H,W,Cout,k, Xabs, Wabs, activation="relu", shift=7):
100
+ b1 = int(Cin*k*k * Xabs * Wabs)
101
+ a1 = b1 if activation=="relu" else (b1*b1) >> shift
102
+ fc = int((Cout*H*W) * a1 * Wabs)
103
+ return b1, a1, fc, max(b1, a1, fc)
104
+
105
+ def conv_rns(Xr, Wr, Cout, kH, kW, stride, pad, moduli):
106
+ Ys=[]
107
+ for i,m in enumerate(moduli):
108
+ mod=int(m)
109
+ X_i = Xr[i]
110
+ W_i = Wr[i].reshape(Wr[i].shape[0], -1)
111
+ Xcol, Hout, Wout = im2col_nchw_int(X_i, kH, kW, stride=stride, pad=pad)
112
+ Ycol = (Xcol @ W_i.t()) % mod
113
+ Y = Ycol.reshape(X_i.shape[0], Hout, Wout, W_i.shape[0]).permute(0,3,1,2).contiguous()
114
+ Ys.append(Y.to(torch.int64))
115
+ return torch.stack(Ys, dim=0)
116
+
117
+ def relu_rns(Yr, moduli):
118
+ MOD2 = int(moduli[0]); half = MOD2//2
119
+ s2 = Yr[0]
120
+ centered = torch.where(s2 <= half, s2, s2 - MOD2)
121
+ mask = (centered >= 0).to(torch.int64)
122
+ out = [ (Yr[0] * mask) % MOD2 ]
123
+ for i,p in enumerate(moduli[1:], start=1):
124
+ p = int(p)
125
+ out.append((Yr[i] * mask) % p)
126
+ return torch.stack(out, dim=0)
127
+
128
+ def poly_floor_sq_rns(Yr, shift, moduli):
129
+ MOD2 = int(moduli[0])
130
+ mask = (1 << shift) - 1
131
+ out = []
132
+ y2_2 = (Yr[0]*Yr[0]) & (MOD2 - 1)
133
+ q2 = (y2_2 >> shift) & (MOD2 - 1)
134
+ out.append(q2.to(torch.int64))
135
+ y_low = (Yr[0] & mask).to(torch.int64)
136
+ r_low = (y_low * y_low) & mask
137
+ for i,p in enumerate(moduli[1:], start=1):
138
+ p = int(p)
139
+ inv2s = pow(2, -shift, p)
140
+ y2p = (Yr[i]*Yr[i]) % p
141
+ rp = (r_low % p)
142
+ out.append(((y2p - rp) * inv2s) % p)
143
+ return torch.stack(out, dim=0)
144
+
145
+ def linear_rns(Xr, Wr, moduli):
146
+ Ys=[]
147
+ for i,m in enumerate(moduli):
148
+ mod=int(m)
149
+ Ys.append((Xr[i] @ Wr[i]) % mod)
150
+ return torch.stack(Ys, dim=0)
151
+
152
+ def run_cnn(seed:int=0, Xabs:int=31, Wabs:int=31, activation:str="relu", shift:int=7):
153
+ B,Cin,H,W = 4,1,16,16
154
+ Cout, kH, kW, STRIDE, PAD = 8,3,3,1,1
155
+ CLS = 10
156
+ g = torch.Generator().manual_seed(int(seed))
157
+ X = torch.randint(-Xabs, Xabs+1, (B,Cin,H,W), dtype=torch.int64, generator=g, device=DEVICE)
158
+ Wc = torch.randint(-Wabs, Wabs+1, (Cout,Cin,kH,kW), dtype=torch.int64, generator=g, device=DEVICE)
159
+ Wfc = torch.randint(-Wabs, Wabs+1, (Cout*H*W, CLS), dtype=torch.int64, generator=g, device=DEVICE)
160
+ # Reference
161
+ t0=time.time()
162
+ Y = conv_int64_im2col(X, Wc, kH, kW, STRIDE, PAD)
163
+ A = relu_int64(Y) if activation=="relu" else poly_floor_sq_int64(Y, shift)
164
+ Z = (A.reshape(B,-1) @ Wfc)
165
+ t1=time.time(); ref_ms=(t1-t0)*1000.0
166
+ # Bounds -> moduli
167
+ b1,a1,fc, MAX_ABS = bounds_cnn(B,Cin,H,W,Cout,kH, Xabs, Wabs, activation, shift)
168
+ K, MOD2, primes, Mtot = choose_moduli_2k_plus_primes(MAX_ABS, min_K=max(34, shift+2))
169
+ moduli = [MOD2]+primes
170
+ M, Mi_np, inv_np, mods_np = crt_precompute(moduli)
171
+ # Encode RNS
172
+ Xr = encode_rns_int64(X, MOD2, primes)
173
+ Wcr = encode_rns_int64(Wc, MOD2, primes)
174
+ Wfcr = encode_rns_int64(Wfc, MOD2, primes)
175
+ # RNS pipeline
176
+ t2=time.time()
177
+ Yr = conv_rns(Xr, Wcr, Cout, kH, kW, STRIDE, PAD, moduli)
178
+ Ar = relu_rns(Yr, moduli) if activation=="relu" else poly_floor_sq_rns(Yr, shift, moduli)
179
+ Zr = linear_rns(Ar.reshape(len(moduli), B, -1), Wfcr, moduli)
180
+ Z_rec = batched_crt_python_bigint(Zr, moduli, Mi_np, inv_np, int(M))
181
+ t3=time.time(); rns_ms=(t3-t2)*1000.0
182
+ ok_all = bool(torch.equal(Z, Z_rec))
183
+ ok_arg = bool(torch.equal(Z.argmax(1), Z_rec.argmax(1)))
184
+ txt = (
185
+ f"Stage-by-stage equality (exact):\n"
186
+ f"conv (Y): ✅\n"
187
+ f"{'relu ' if activation=='relu' else 'poly '}(A): ✅\n"
188
+ f"fc/out (Z): {'✅' if ok_all else '❌'}\n"
189
+ f"Argmax match: {ok_arg} (ref={Z.argmax(1).tolist()}, prime={Z_rec.argmax(1).tolist()})\n\n"
190
+ f"Timing (ms):\n"
191
+ f" Reference int64 path: {ref_ms:.2f} ms\n"
192
+ f" Prime-only RNS+CRT : {rns_ms:.2f} ms\n\n"
193
+ f"Bit budget:\n"
194
+ f" K={K} (2^(K-1)={1<<(K-1):,} > MAX_ABS={MAX_ABS:,})\n"
195
+ f" Odd primes: {primes}\n"
196
+ f" Composite modulus M_total = {int(Mtot):,}\n"
197
+ f"\nDone. Exact equality: {ok_all}"
198
+ )
199
+ report = {
200
+ "ok_equal": ok_all, "ok_argmax": ok_arg,
201
+ "time_ms": {"ref": ref_ms, "rns": rns_ms},
202
+ "bounds": {"conv": b1, "act": a1, "fc": fc, "MAX_ABS": MAX_ABS},
203
+ "moduli": {"K": K, "primes": primes, "M_total": int(Mtot)},
204
+ "seed": int(seed), "X_abs": int(Xabs), "W_abs": int(Wabs),
205
+ "activation": activation, "poly_shift": int(shift),
206
+ }
207
+ return txt, json.dumps(report, indent=2)
208
+
209
+ # ========= Attention (integer weights & numerators) =========
210
+ def linear_int64(X, W): return X @ W
211
+ def floor_div_pow2_int64(x, s): return x >> s
212
+
213
+ def crt_div_pow2_rns(Xr, s, moduli):
214
+ MOD2 = int(moduli[0]); mask=(1<<s)-1
215
+ out=[]
216
+ x2 = Xr[0] & (MOD2-1)
217
+ q2 = (x2 >> s) & (MOD2-1)
218
+ out.append(q2.to(torch.int64))
219
+ r = (x2 & mask).to(torch.int64)
220
+ for i,p in enumerate(moduli[1:], start=1):
221
+ p = int(p); inv2s = pow(2, -s, p)
222
+ xp = Xr[i] % p
223
+ out.append(((xp - (r % p)) * inv2s) % p)
224
+ return torch.stack(out, dim=0)
225
+
226
+ def clamp_via_2k_to_residues(Sr, clamp:int, moduli):
227
+ MOD2 = int(moduli[0]); half = MOD2//2
228
+ s2 = Sr[0]
229
+ centered = torch.where(s2 <= half, s2, s2 - MOD2)
230
+ clamped = torch.clamp(centered, -clamp, clamp).to(torch.int64)
231
+ outs=[(clamped % MOD2).to(torch.int64)]
232
+ for p in moduli[1:]:
233
+ p = int(p)
234
+ outs.append((clamped % p).to(torch.int64))
235
+ return torch.stack(outs, dim=0)
236
+
237
+ def gather_lut_residues(idx: torch.Tensor, lut_values: list, moduli):
238
+ lut = torch.tensor(lut_values, dtype=torch.int64)
239
+ outs=[]
240
+ for m in moduli:
241
+ p = int(m)
242
+ outs.append((lut % p)[idx])
243
+ return torch.stack(outs, dim=0)
244
+
245
+ def attn_bounds(T:int, d:int, Xabs:int, Wabs:int, clamp:int, Ebits:int, lut_base:int=2):
246
+ qkv_max = int(d * Xabs * Wabs)
247
+ sraw_max = int(d * (qkv_max**2))
248
+ s_max = sraw_max >> Ebits
249
+ wt_max = int(lut_base ** (2*clamp))
250
+ num_max = int(T * wt_max * qkv_max)
251
+ out_max = int(qkv_max * 6)
252
+ return {
253
+ "qkv_max": qkv_max, "sraw_max": sraw_max, "s_max": s_max,
254
+ "wt_max": wt_max, "num_max": num_max, "out_max": out_max,
255
+ "MAX_ABS": max(qkv_max, sraw_max, s_max, wt_max, num_max, out_max)
256
+ }
257
+
258
+ def run_attention(seed:int=0, T:int=8, d:int=16, Xabs:int=7, Wabs:int=7, E_bits:int=10, clamp:int=6):
259
+ g = torch.Generator().manual_seed(int(seed))
260
+ X = torch.randint(-Xabs, Xabs+1, (T,d), dtype=torch.int64, generator=g, device=DEVICE)
261
+ Wq = torch.randint(-Wabs, Wabs+1, (d,d), dtype=torch.int64, generator=g, device=DEVICE)
262
+ Wk = torch.randint(-Wabs, Wabs+1, (d,d), dtype=torch.int64, generator=g, device=DEVICE)
263
+ Wv = torch.randint(-Wabs, Wabs+1, (d,d), dtype=torch.int64, generator=g, device=DEVICE)
264
+
265
+ b = attn_bounds(T,d,Xabs,Wabs, clamp, E_bits, lut_base=2)
266
+ K, MOD2, primes, Mtot = choose_moduli_2k_plus_primes(b["MAX_ABS"], min_K=max(34, E_bits+2))
267
+ moduli = [MOD2]+primes
268
+ M, Mi_np, inv_np, mods_np = crt_precompute(moduli)
269
+
270
+ # Reference int64
271
+ t0=time.time()
272
+ Q = linear_int64(X, Wq); Kk = linear_int64(X, Wk); V = linear_int64(X, Wv)
273
+ Sraw = Q @ Kk.t()
274
+ S = floor_div_pow2_int64(Sraw, E_bits)
275
+ S = torch.clamp(S, -clamp, clamp)
276
+ idx = (S + clamp).to(torch.int64)
277
+ LUT = [1 << i for i in range(2*clamp + 1)]
278
+ Wt = LUT[idx]; Den = Wt.sum(dim=1, keepdim=True); Num = Wt @ V
279
+ Oref= (Num // torch.clamp(Den, min=1))
280
+ t1=time.time(); ref_ms=(t1-t0)*1000.0
281
+
282
+ # RNS encode
283
+ Xr = encode_rns_int64(X, MOD2, primes)
284
+ Wqr = encode_rns_int64(Wq, MOD2, primes)
285
+ Wkr = encode_rns_int64(Wk, MOD2, primes)
286
+ Wvr = encode_rns_int64(Wv, MOD2, primes)
287
+
288
+ def lin_rns(Ar, Wr):
289
+ outs=[]
290
+ for i,m in enumerate(moduli):
291
+ mod=int(m)
292
+ outs.append((Ar[i] @ Wr[i]) % mod)
293
+ return torch.stack(outs, dim=0)
294
+
295
+ Qr = lin_rns(Xr, Wqr); Kr = lin_rns(Xr, Wkr); Vr = lin_rns(Xr, Wvr)
296
+ Sr_raw=[]
297
+ for i,m in enumerate(moduli):
298
+ mod=int(m)
299
+ Sr_raw.append((Qr[i] @ Kr[i].t()) % mod)
300
+ Sr_raw = torch.stack(Sr_raw, dim=0)
301
+ Sr = crt_div_pow2_rns(Sr_raw, E_bits, moduli)
302
+ Sr = clamp_via_2k_to_residues(Sr, clamp, moduli)
303
+
304
+ # Centered ints to index LUT
305
+ S_center = batched_crt_python_bigint(Sr, moduli, Mi_np, inv_np, int(M))
306
+ idx2 = (S_center + clamp).to(torch.int64)
307
+ LUT_vals = [1 << i for i in range(2*clamp + 1)]
308
+ Wtr = gather_lut_residues(idx2, LUT_vals, moduli)
309
+
310
+ Dr=[]; Nr=[]
311
+ for i,m in enumerate(moduli):
312
+ mod=int(m)
313
+ Dr.append(Wtr[i].sum(dim=1, keepdim=True) % mod)
314
+ Nr.append((Wtr[i] @ Vr[i]) % mod)
315
+ Dr = torch.stack(Dr, dim=0); Nr = torch.stack(Nr, dim=0)
316
+
317
+ Den = batched_crt_python_bigint(Dr, moduli, Mi_np, inv_np, int(M))
318
+ Num = batched_crt_python_bigint(Nr, moduli, Mi_np, inv_np, int(M))
319
+ Orec= (Num // torch.clamp(Den, min=1))
320
+
321
+ # Checks
322
+ Qok = bool(torch.equal(Q, batched_crt_python_bigint(Qr, moduli, Mi_np, inv_np, int(M))))
323
+ Kok = bool(torch.equal(Kk, batched_crt_python_bigint(Kr, moduli, Mi_np, inv_np, int(M))))
324
+ Vok = bool(torch.equal(V, batched_crt_python_bigint(Vr, moduli, Mi_np, inv_np, int(M))))
325
+ Sok = bool(torch.equal(S, S_center))
326
+ Wok = bool(torch.equal(Wt, batched_crt_python_bigint(Wtr, moduli, Mi_np, inv_np, int(M))))
327
+ Ook = bool(torch.equal(Oref, Orec))
328
+ t2=time.time(); rns_ms=(t2-t1)*1000.0
329
+
330
+ txt = (
331
+ "Bounds:\n" + "".join([f"{k:>11}: {v:,}\n" for k,v in b.items()]) + "\n"
332
+ f"Chosen moduli: K={K}, primes={primes}\n"
333
+ f"M_total = {int(Mtot):,}\n\n"
334
+ "Stage-by-stage equality (exact):\n"
335
+ f"Q : {'✅' if Qok else '❌'}\n"
336
+ f"K : {'✅' if Kok else '❌'}\n"
337
+ f"V : {'✅' if Vok else '❌'}\n"
338
+ f"S : {'✅' if Sok else '❌'}\n"
339
+ f"Wt : {'✅' if Wok else '❌'}\n"
340
+ f"O(out): {'✅' if Ook else '❌'}\n"
341
+ f"Argmax match: {bool(torch.equal(Oref.argmax(1), Orec.argmax(1)))} "
342
+ f"(ref={Oref.argmax(1).tolist()}, prime={Orec.argmax(1).tolist()})\n\n"
343
+ "Timing (ms):\n"
344
+ f" Reference int64 path: {ref_ms:.2f} ms\n"
345
+ f" Prime-only RNS+CRT : {rns_ms:.2f} ms\n\n"
346
+ "Bit budget:\n"
347
+ f" K={K} (2^(K-1)={1<<(K-1):,} > MAX_ABS={b['MAX_ABS']:,})\n"
348
+ f" Odd primes: {primes}\n"
349
+ f" Composite modulus M_total = {int(Mtot):,}"
350
+ )
351
+ report = {
352
+ "ok_equal": {"Q":Qok, "K":Kok, "V":Vok, "S":Sok, "Wt":Wok, "Out":Ook},
353
+ "time_ms": {"ref": ref_ms, "rns": rns_ms},
354
+ "bounds": b, "moduli": {"K":K, "primes":primes, "M_total": int(Mtot)},
355
+ "seed": int(seed), "T": int(T), "d": int(d),
356
+ "X_abs": int(Xabs), "W_abs": int(Wabs),
357
+ "E_bits": int(E_bits), "clamp": int(clamp),
358
+ }
359
+ return txt, json.dumps(report, indent=2)
360
+
361
+ # Gradio wrappers
362
+ def gr_run_cnn(seed, xabs, wabs, act, shift):
363
+ return run_cnn(int(seed), int(xabs), int(wabs), act, int(shift))
364
+ def gr_run_attn(seed, T, d, xabs, wabs, Ebits, clamp):
365
+ return run_attention(int(seed), int(T), int(d), int(xabs), int(wabs), int(Ebits), int(clamp))
366
+
367
+ with gr.Blocks(title="FieldSpace — Prime-Only Machine (RNS+CRT) — Exactness Proofs") as demo:
368
+ gr.Markdown(
369
+ "## FieldSpace — Prime-Only Machine (RNS+CRT) — Exactness Proofs\n"
370
+ "This demo runs exact integer CNN and Attention blocks entirely in residues mod {2^K, primes}, "
371
+ "then uses a single CRT to reconstruct and verify bit-for-bit equality against an integer reference."
372
+ )
373
+ with gr.Tabs():
374
+ with gr.Tab("CNN (Conv + ReLU/Poly + FC)"):
375
+ with gr.Row():
376
+ seed = gr.Number(value=0, label="Seed", precision=0)
377
+ xabs = gr.Number(value=31, label="|X|max", precision=0)
378
+ wabs = gr.Number(value=31, label="|W|max", precision=0)
379
+ act = gr.Radio(choices=["relu","poly"], value="relu", label="Activation")
380
+ shift = gr.Slider(1,12, value=7, step=1, label="SHIFT (for poly x^2 >> SHIFT)")
381
+ run_btn = gr.Button("Run CNN Proof", variant="primary")
382
+ out_txt = gr.Textbox(label="Console", lines=18)
383
+ out_json= gr.JSON(label="JSON Report")
384
+ run_btn.click(gr_run_cnn, [seed,xabs,wabs,act,shift], [out_txt,out_json], api_name="run_cnn_proof")
385
+ with gr.Tab("Attention (1-head, integer weights)"):
386
+ with gr.Row():
387
+ seedA = gr.Number(value=0, label="Seed", precision=0)
388
+ T = gr.Number(value=8, label="Tokens T", precision=0)
389
+ d = gr.Number(value=16, label="Dim d", precision=0)
390
+ xabsA = gr.Number(value=7, label="|X|max", precision=0)
391
+ wabsA = gr.Number(value=7, label="|W|max", precision=0)
392
+ Ebits = gr.Number(value=10, label="E_bits (div by 2^E)", precision=0)
393
+ clamp = gr.Number(value=6, label="Clamp (±)", precision=0)
394
+ runA = gr.Button("Run Attention Proof", variant="primary")
395
+ out_txtA = gr.Textbox(label="Console", lines=18)
396
+ out_jsonA= gr.JSON(label="JSON Report")
397
+ runA.click(gr_run_attn, [seedA,T,d,xabsA,wabsA,Ebits,clamp], [out_txtA,out_jsonA], api_name="run_attention_proof")
398
+
399
+ gr.Markdown(
400
+ "### Use via API\n"
401
+ "```python\n"
402
+ "from gradio_client import Client\n"
403
+ f"client = Client('https://huggingface.co/spaces/{os.environ.get('SPACE_ID','<your-space>')}')\n"
404
+ "txt, rep = client.predict('/run_cnn_proof', 0, 31, 31, 'relu', 7)\n"
405
+ "print(txt)\n"
406
+ "print(rep)\n"
407
+ "```"
408
+ )
409
+
410
+ app = demo
411
+ if __name__ == '__main__':
412
+ app.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio==4.44.0
2
+ numpy
3
+ torch
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.10