jackal79 commited on
Commit
dc7fd9e
·
verified ·
1 Parent(s): 96d0fc4

Deploy FieldSpace app: CNN + Attention numerators (RNS+CRT)

Browse files
Files changed (3) hide show
  1. README.md +9 -21
  2. app.py +271 -227
  3. requirements.txt +1 -3
README.md CHANGED
@@ -1,29 +1,17 @@
1
  ---
2
  title: FieldSpace — Prime-Only Machine (RNS+CRT) — Exactness Proofs
3
- emoji: "🧮"
4
- colorFrom: red
5
- colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 4.44.0
8
- python_version: "3.10"
 
9
  license: mit
10
- tags:
11
- - number-theory
12
- - residues
13
- - crt
14
- - integer-arithmetic
15
- - exactness
16
- - cnn
17
- - gradio
18
  ---
19
 
20
- This Space demonstrates exact *integer* CNN computed entirely in residues mod `{2^K, primes}` with a single CRT reconstruction.
 
 
21
 
22
- ### Use via API
23
- ```python
24
- from gradio_client import Client
25
- client = Client('https://huggingface.co/spaces/jackal79/fieldspace-prime-only')
26
- txt, rep = client.predict('/run_cnn_proof', 0, 31, 31, 'relu', 7)
27
- print(txt)
28
- print(rep)
29
- ```
 
1
  ---
2
  title: FieldSpace — Prime-Only Machine (RNS+CRT) — Exactness Proofs
3
+ emoji: 🧮
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.44.0
8
+ app_file: app.py
9
+ python: 3.10
10
  license: mit
 
 
 
 
 
 
 
 
11
  ---
12
 
13
+ This Space proves exact execution for:
14
+ - **CNN**: Conv → ReLU/Poly → FC
15
+ - **Attention (one head)**: exact numerators (QK>>t, then SV)
16
 
17
+ All ops run in residues mod `{2^K, p_i}` and reconstruct once via a batched CRT.
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,254 +1,298 @@
1
 
2
- import time, math, numpy as np, torch, gradio as gr
 
3
 
4
- # deterministic CPU; keep integer-only math
5
- torch.set_num_threads(max(1, torch.get_num_threads()))
6
- DEVICE = "cpu"
7
-
8
- # ===== Core helpers =====
9
- def im2col_nchw_int(x, kH, kW, stride=1, pad=0):
10
- N, C, H, W = x.shape
11
- Hout = (H + 2*pad - kH)//stride + 1
12
- Wout = (W + 2*pad - kW)//stride + 1
13
- if pad:
14
- x_pad = torch.zeros((N, C, H + 2*pad, W + 2*pad), dtype=x.dtype, device=x.device)
15
- x_pad[:, :, pad:pad+H, pad:pad+W] = x
16
- else:
17
- x_pad = x
18
- cols = []
19
- for i in range(Hout):
20
- for j in range(Wout):
21
- patch = x_pad[:, :, i*stride:i*stride+kH, j*stride:j*stride+kW]
22
- cols.append(patch.reshape(N, -1))
23
- out = torch.stack(cols, dim=1) # [N, Hout*Wout, C*kH*kW]
24
- out = out.reshape(N*Hout*Wout, -1) # [N*Hout*Wout, C*kH*kW]
25
- return out, Hout, Wout
26
-
27
- def conv_int64_im2col(X, W, kH, kW, stride=1, pad=0):
28
- Xcol, Hout, Wout = im2col_nchw_int(X, kH, kW, stride=stride, pad=pad)
29
- Ycol = Xcol @ W.reshape(W.shape[0], -1).t()
30
- return Ycol.reshape(X.shape[0], Hout, Wout, W.shape[0]).permute(0,3,1,2).contiguous()
31
-
32
- def poly_act_floor_int64(x, s): # exact q = floor(x^2 / 2^s)
33
- return (x * x) >> s
34
-
35
- def flatten_nchw(x): return x.reshape(x.shape[0], -1)
36
- def linear_int64(X, W): return X @ W
37
-
38
- # ===== RNS / CRT =====
39
- def choose_moduli(MAX_ABS, base_primes=None, K_cap=62):
40
- # minimal K so 2^(K-1) > MAX_ABS, cap at 62 for int64 safety
41
- K = min(K_cap, max(2, MAX_ABS.bit_length() + 1))
42
  MOD2 = 1 << K
43
- primes = list(base_primes or [257, 263, 269, 271, 277, 281])
44
- M_total = MOD2
45
- for p in primes: M_total *= p
46
- # ensure range
47
- i = 283
48
- while M_total <= 2*MAX_ABS:
49
- primes.append(i)
50
- M_total *= i
51
- i = {283:293, 293:307, 307:311, 311:313, 313:317}.get(i, i+2)
52
- return K, MOD2, primes
53
 
54
  def encode_rns(x, MOD2, primes):
55
- outs = []
56
- mask = MOD2 - 1
57
- x_list = x.reshape(-1).cpu().tolist()
58
- mod2_res = [(v & mask) for v in x_list] # mod 2^K via mask (python int)
59
- outs.append(torch.tensor(mod2_res, dtype=torch.int64).reshape(x.shape))
60
  for p in primes:
61
- outs.append((x % p).to(torch.int64))
62
- return torch.stack(outs, dim=0) # [L, *shape]
63
 
64
- def precompute_crt(moduli):
 
 
 
65
  M = 1
66
  for m in moduli: M *= m
67
- Mi = [M // m for m in moduli]
68
- invMi = [pow(int(Mi[i] % moduli[i]), -1, moduli[i]) for i in range(len(moduli))]
69
- return M, np.array(Mi, dtype=object), np.array(invMi, dtype=object), np.array(moduli, dtype=object)
70
-
71
- def batched_crt(res, M, Mi_np, inv_np, mods_np):
72
- L = res.shape[0]
73
- flat = res.reshape(L, -1).cpu().tolist()
74
- N = len(flat[0])
75
- out = [0]*N
76
- half = M // 2
77
- for k in range(N):
78
  acc = 0
79
  for i in range(L):
80
- acc = (acc + (flat[i][k] * inv_np[i] * Mi_np[i])) % M
81
  if acc > half: acc -= M
82
- out[k] = int(acc)
83
- return torch.tensor(out, dtype=torch.int64).reshape(res.shape[1:])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- # ===== RNS ops =====
86
- def conv_rns(Xr, Wr, kH, kW, stride, pad, mods_np):
 
 
 
 
 
 
 
 
 
 
 
 
87
  L = Xr.shape[0]
88
- Ys = []
89
- for i in range(L):
90
- mod = int(mods_np[i])
91
- X_i = Xr[i]
92
- W_i = Wr[i].reshape(Wr[i].shape[0], -1)
93
- Xcol, Hout, Wout = im2col_nchw_int(X_i.to(torch.int64), kH, kW, stride=stride, pad=pad)
94
- Ycol = (Xcol @ W_i.t()) % mod
95
- Y = Ycol.reshape(X_i.shape[0], Hout, Wout, W_i.shape[0]).permute(0,3,1,2).contiguous()
96
- Ys.append(Y.to(torch.int64))
97
- return torch.stack(Ys, dim=0) # [L,B,Cout,H,W]
98
-
99
- def relu_rns(Yr, MOD2, primes, K):
100
- # Use sign from 2^K residue: negative iff top bit set.
101
- L = Yr.shape[0]
102
- y2k = Yr[0] & (MOD2 - 1)
103
- neg = ((y2k >> (K-1)) & 1).to(torch.int64) # 1 if negative else 0
104
- keep = 1 - neg # 0 or 1
105
  out = []
106
- out.append((y2k * keep) & (MOD2 - 1))
107
- for i, p in enumerate(primes, start=1):
108
- out.append((Yr[i] * keep) % p)
109
- return torch.stack(out, dim=0)
110
 
111
- def poly_act_floor_rns(Yr, shift, MOD2, primes):
112
- mask = (1 << shift) - 1
 
 
113
  # 2^K channel
114
- y2_2 = (Yr[0] * Yr[0]) & (MOD2 - 1)
115
- q2 = (y2_2 >> shift) & (MOD2 - 1)
116
- # remainder from low bits only
117
- y_low = (Yr[0] & mask).to(torch.int64)
118
- r_low = (y_low * y_low) & mask
119
- out = [q2.to(torch.int64)]
120
- for p_idx, p in enumerate(primes, start=1):
121
- y2p = (Yr[p_idx] * Yr[p_idx]) % p
122
- rp = (r_low % p)
123
- inv_pow2 = pow(2, -shift, p)
124
- qi = ((y2p - rp) * inv_pow2) % p
125
- out.append(qi.to(torch.int64))
126
- return torch.stack(out, dim=0)
127
-
128
- def linear_rns(Xr, Wr, mods_np):
 
 
 
 
 
 
 
 
129
  L = Xr.shape[0]
130
- Ys = []
 
 
131
  for i in range(L):
132
- mod = int(mods_np[i])
133
- Ys.append((Xr[i] @ Wr[i]) % mod)
134
- return torch.stack(Ys, dim=0)
135
-
136
- # ===== CNN proof (Conv -> Act -> Conv -> Act -> FC) =====
137
- def run_cnn_proof(seed:int=0, x_abs:int=31, w_abs:int=31, act_kind:str="relu", shift:int=7):
138
- B, Cin, H, W = 4, 1, 16, 16
139
- C1, C2 = 8, 8
140
- KH, KW = 3, 3
141
- STRIDE, PAD = 1, 1
142
- CLS = 10
143
-
144
- g = torch.Generator().manual_seed(int(seed))
145
- X = torch.randint(-x_abs, x_abs+1, (B, Cin, H, W), dtype=torch.int64, generator=g)
146
- W1 = torch.randint(-w_abs, w_abs+1, (C1, Cin, KH, KW), dtype=torch.int64, generator=g)
147
- W2 = torch.randint(-w_abs, w_abs+1, (C2, C1, KH, KW), dtype=torch.int64, generator=g)
148
- Wfc = torch.randint(-w_abs, w_abs+1, (C2*H*W, CLS), dtype=torch.int64, generator=g)
149
-
150
- t0 = time.time()
151
- Y1_ref = conv_int64_im2col(X, W1, KH, KW, stride=STRIDE, pad=PAD)
152
- A1_ref = (Y1_ref.clamp_min(0) if act_kind=="relu" else poly_act_floor_int64(Y1_ref, shift))
153
- Y2_ref = conv_int64_im2col(A1_ref, W2, KH, KW, stride=STRIDE, pad=PAD)
154
- A2_ref = (Y2_ref.clamp_min(0) if act_kind=="relu" else poly_act_floor_int64(Y2_ref, shift))
155
- Z_ref = linear_int64(A2_ref.reshape(B, -1), Wfc)
156
- t1 = time.time()
157
- ref_ms = (t1 - t0)*1000
158
-
159
- # bounds
160
- def bconv(cin, k, xmax, wmax): return int(cin*k*k * xmax * wmax)
161
- b1 = bconv(Cin, KH, x_abs, w_abs)
162
- a1 = (b1*b1) >> (shift if act_kind=="poly" else 0)
163
- b2 = int(C1*KH*KH * a1 * w_abs)
164
- a2 = (b2*b2) >> (shift if act_kind=="poly" else 0)
165
- fc = int((C2*H*W) * a2 * w_abs)
166
- MAX_ABS = max(b1,a1,b2,a2,fc)
167
-
168
- K, MOD2, primes = choose_moduli(MAX_ABS)
169
- moduli = [MOD2] + primes
170
- M, Mi_np, inv_np, mods_np = precompute_crt(moduli)
171
-
172
- # encode & run RNS
173
- X_r = encode_rns(X, MOD2, primes)
174
- W1_r = encode_rns(W1, MOD2, primes)
175
- W2_r = encode_rns(W2, MOD2, primes)
176
- Wfc_r = encode_rns(Wfc, MOD2, primes)
177
-
178
- Y1_r = conv_rns(X_r, W1_r, KH, KW, STRIDE, PAD, mods_np)
179
- A1_r = (relu_rns(Y1_r, MOD2, primes, K) if act_kind=="relu" else poly_act_floor_rns(Y1_r, shift, MOD2, primes))
180
- Y2_r = conv_rns(A1_r, W2_r, KH, KW, STRIDE, PAD, mods_np)
181
- A2_r = (relu_rns(Y2_r, MOD2, primes, K) if act_kind=="relu" else poly_act_floor_rns(Y2_r, shift, MOD2, primes))
182
- t2 = time.time()
183
- Z_r = linear_rns(A2_r.reshape(len(moduli), B, -1), Wfc_r, mods_np)
184
- Z_rec= batched_crt(Z_r, M, Mi_np, inv_np, mods_np)
185
- t3 = time.time()
186
- rns_ms = (t3 - t2)*1000
187
-
188
- # checks
189
- flags = {
190
- "conv1": bool(torch.equal(batched_crt(Y1_r, M, Mi_np, inv_np, mods_np), Y1_ref)),
191
- "act1": bool(torch.equal(batched_crt(A1_r, M, Mi_np, inv_np, mods_np), A1_ref)),
192
- "conv2": bool(torch.equal(batched_crt(Y2_r, M, Mi_np, inv_np, mods_np), Y2_ref)),
193
- "act2": bool(torch.equal(batched_crt(A2_r, M, Mi_np, inv_np, mods_np), A2_ref)),
194
- "final": bool(torch.equal(Z_rec, Z_ref)),
195
- "argmax": bool(torch.equal(Z_rec.argmax(1), Z_ref.argmax(1))),
196
- }
197
 
198
- console = []
199
- console.append("Stage-by-stage equality (exact):")
200
- console.append(f"conv1 : {'✅' if flags['conv1'] else '❌'}")
201
- console.append(f"act1 : {'✅' if flags['act1'] else '❌'}")
202
- console.append(f"conv2 : {'✅' if flags['conv2'] else '❌'}")
203
- console.append(f"act2 : {'✅' if flags['act2'] else '❌'}")
204
- console.append(f"final : {'✅' if flags['final'] else '❌'}")
205
- console.append(f"Argmax match: {flags['argmax']} (ref={Z_ref.argmax(1).tolist()}, prime={Z_rec.argmax(1).tolist()})")
206
- console.append("")
207
- console.append("Timing (ms):")
208
- console.append(f" Reference int64 path: {ref_ms:.2f} ms")
209
- console.append(f" Prime-only RNS+CRT : {rns_ms:.2f} ms")
210
- console.append("")
211
- console.append("Bit budget:")
212
- console.append(f" K={K} (2^(K-1)={1<<(K-1):,} > MAX_ABS={MAX_ABS:,})")
213
- console.append(f" Odd primes: {primes}")
214
- console.append(f" Composite modulus M_total = {int(M):,}")
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  report = {
217
- "ok": flags["final"],
218
- "argmax_ok": flags["argmax"],
219
- "act": act_kind,
220
  "shift": int(shift),
221
- "seed": int(seed),
222
- "x_abs": int(x_abs),
223
- "w_abs": int(w_abs),
224
  "K": int(K),
225
- "primes": primes,
226
- "ref_ms": ref_ms,
227
- "rns_ms": rns_ms,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  }
229
- return "\n".join(console), report
230
-
231
- # ===== Gradio UI =====
232
- with gr.Blocks(title="FieldSpace — Prime-Only Machine (RNS+CRT) — Exactness Proofs") as demo:
233
- gr.Markdown("## FieldSpace — Prime-Only Machine (RNS+CRT) — Exactness Proofs")
234
- gr.Markdown(
235
- "This demo runs **exact integer CNN** computations entirely in residues mod `{2^K, primes}` "
236
- "and reconstructs with a single CRT, verifying bit-for-bit equality against an int64 reference."
237
- )
238
-
239
- with gr.Row():
240
- seed = gr.Number(value=0, label="Seed", precision=0)
241
- xabs = gr.Number(value=31, label="|X|max", precision=0)
242
- wabs = gr.Number(value=31, label="|W|max", precision=0)
243
- with gr.Row():
244
- act = gr.Radio(choices=["relu","poly"], value="relu", label="Activation")
245
- shift = gr.Slider(1, 14, value=7, step=1, label="SHIFT (for poly x^2 >> SHIFT)")
246
-
247
- run_btn = gr.Button("Run CNN Proof", variant="primary")
248
- console = gr.Textbox(label="Console", lines=16)
249
- report = gr.JSON(label="JSON Report")
250
-
251
- run_btn.click(fn=run_cnn_proof, inputs=[seed, xabs, wabs, act, shift], outputs=[console, report], api_name="/run_cnn_proof")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  if __name__ == "__main__":
254
- demo.launch()
 
1
 
2
+ import gradio as gr
3
+ import numpy as np
4
 
5
+ # -----------------------------
6
+ # Utilities: RNS + CRT
7
+ # -----------------------------
8
+ def pick_moduli(max_abs, primes=(257,263,269,271,277,281)):
9
+ # choose K so 2^(K-1) > max_abs
10
+ K = max(20, int(max_abs).bit_length()+1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  MOD2 = 1 << K
12
+ # ensure product > 2*max_abs (with plenty margin for toy sizes this is fine)
13
+ M = MOD2
14
+ ps = []
15
+ for p in primes:
16
+ ps.append(p); M *= p
17
+ if M > 2*max_abs: break
18
+ return K, MOD2, tuple(ps)
 
 
 
19
 
20
  def encode_rns(x, MOD2, primes):
21
+ outs = [np.asarray(x, dtype=object) & (MOD2-1)]
 
 
 
 
22
  for p in primes:
23
+ outs.append(np.asarray(x, dtype=object) % p)
24
+ return np.stack(outs, axis=0) # [L, *shape]
25
 
26
+ def batched_crt(res, MOD2, primes):
27
+ # res: [L, *shape], L=1+len(primes), dtype=object (python ints)
28
+ moduli = (MOD2,)+tuple(primes)
29
+ L = res.shape[0]
30
  M = 1
31
  for m in moduli: M *= m
32
+ Mi = [M//m for m in moduli]
33
+ inv = [pow(int(Mi[i] % moduli[i]), -1, moduli[i]) for i in range(L)]
34
+ flat = res.reshape(L, -1)
35
+ out = []
36
+ half = M//2
37
+ for k in range(flat.shape[1]):
 
 
 
 
 
38
  acc = 0
39
  for i in range(L):
40
+ acc = (acc + (flat[i,k] * inv[i] * Mi[i])) % M
41
  if acc > half: acc -= M
42
+ out.append(int(acc))
43
+ return np.array(out, dtype=np.int64).reshape(res.shape[1:])
44
+
45
+ # -----------------------------
46
+ # CNN proof (tiny shapes)
47
+ # -----------------------------
48
+ def im2col_nchw_int(x, kH, kW, stride=1, pad=0):
49
+ N,C,H,W = x.shape
50
+ Hp = H+2*pad; Wp = W+2*pad
51
+ xp = np.zeros((N,C,Hp,Wp), dtype=np.int64)
52
+ xp[:,:,pad:pad+H, pad:pad+W] = x
53
+ Hout = (Hp - kH)//stride + 1
54
+ Wout = (Wp - kW)//stride + 1
55
+ cols = []
56
+ for i in range(Hout):
57
+ for j in range(Wout):
58
+ patch = xp[:,:,i*stride:i*stride+kH, j*stride:j*stride+kW]
59
+ cols.append(patch.reshape(N,-1))
60
+ out = np.stack(cols, axis=1) # [N, Hout*Wout, C*kH*kW]
61
+ out = out.reshape(N*Hout*Wout, -1)
62
+ return out, Hout, Wout
63
 
64
+ def conv_ref(X, W, kH, kW, stride=1, pad=0):
65
+ N,C,H,W = X.shape
66
+ Cout,Cin,KH,KW = W.shape
67
+ Xcol,Hout,Wout = im2col_nchw_int(X,kH,kW,stride,pad)
68
+ Ycol = Xcol @ W.reshape(Cout,-1).T
69
+ Y = Ycol.reshape(N,Hout,Wout,Cout).transpose(0,3,1,2).copy()
70
+ return Y
71
+
72
+ def relu_ref(x): return np.where(x<0, 0, x)
73
+ def poly_ref(x, s): return (x*x) >> s
74
+ def linear_ref(X, W): return X @ W
75
+
76
+ def relu_rns(Xr, MOD2, primes):
77
+ # Xr: [L, *shape]
78
  L = Xr.shape[0]
79
+ K = MOD2.bit_length()-1
80
+ x2 = Xr[0]
81
+ mask = (x2 < (1<<(K-1))).astype(np.int64) # 1 if non-negative, else 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  out = []
83
+ out.append((x2 * mask) & (MOD2-1))
84
+ for i,p in enumerate(primes, start=1):
85
+ out.append((Xr[i] * mask) % p)
86
+ return np.stack(out, axis=0)
87
 
88
+ def poly_rns(Xr, s, MOD2, primes):
89
+ # q = floor(x^2 / 2^s)
90
+ mask_s = (1<<s)-1
91
+ out = []
92
  # 2^K channel
93
+ y2 = (Xr[0]*Xr[0]) & (MOD2-1)
94
+ q2 = (y2 >> s) & (MOD2-1)
95
+ out.append(q2)
96
+ # r from low s bits of x only
97
+ low = Xr[0] & mask_s
98
+ r = (low*low) & mask_s
99
+ for i,p in enumerate(primes, start=1):
100
+ y2p = (Xr[i]*Xr[i]) % p
101
+ inv = pow(2, -s, p)
102
+ qi = ((y2p - (r % p)) * inv) % p
103
+ out.append(qi)
104
+ return np.stack(out, axis=0)
105
+
106
+ def matmul_rns(Ar, Br, MOD2, primes):
107
+ # Ar: [L, M, K], Br: [L, K, N] -> [L, M, N]
108
+ outs=[]
109
+ outs.append((Ar[0] @ Br[0]) & (MOD2-1))
110
+ for i,p in enumerate(primes, start=1):
111
+ outs.append((Ar[i] @ Br[i]) % p)
112
+ return np.stack(outs, axis=0)
113
+
114
+ def conv_rns(Xr, Wr, kH, kW, stride=1, pad=0, MOD2=0, primes=()):
115
+ # reshape per-modulus and reuse im2col on ints
116
  L = Xr.shape[0]
117
+ N,C,H,W = Xr[0].shape
118
+ Cout = Wr[0].shape[0]
119
+ Ymods=[]
120
  for i in range(L):
121
+ Xi = Xr[i].astype(object)
122
+ Wi = Wr[i].astype(object)
123
+ Xcol,Hout,Wout = im2col_nchw_int(np.asarray(Xi, dtype=np.int64), kH,kW,stride,pad)
124
+ Wcol = np.asarray(Wi.reshape(Cout,-1), dtype=np.int64)
125
+ if i==0:
126
+ MOD = MOD2
127
+ Ycol = (Xcol @ Wcol.T) & (MOD-1)
128
+ else:
129
+ p = primes[i-1]
130
+ Ycol = (Xcol @ Wcol.T) % p
131
+ Y = Ycol.reshape(N,Hout,Wout,Cout).transpose(0,3,1,2).copy()
132
+ Ymods.append(Y.astype(object))
133
+ return np.stack(Ymods, axis=0)
134
+
135
+ def cnn_proof(seed, xmax, wmax, activation, shift):
136
+ # Tiny shapes to keep it snappy
137
+ B,Cin,H,W = 2,1,8,8
138
+ Cout = 2
139
+ KH=KW=3
140
+ CLS=8
141
+ rng = np.random.default_rng(seed)
142
+ X = rng.integers(-xmax, xmax+1, size=(B,Cin,H,W), dtype=np.int64)
143
+ Wc = rng.integers(-wmax, wmax+1, size=(Cout,Cin,KH,KW), dtype=np.int64)
144
+ Wfc = rng.integers(-wmax, wmax+1, size=(Cout*H*W, CLS), dtype=np.int64)
145
+
146
+ # Reference (int64)
147
+ Y = conv_ref(X,Wc,KH,KW,1,1)
148
+ A = relu_ref(Y) if activation=="relu" else poly_ref(Y, shift)
149
+ Z = linear_ref(A.reshape(B,-1), Wfc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
+ # Bounds (very conservative)
152
+ b1 = int(Cin*KH*KW * xmax * wmax)
153
+ a1 = b1 if activation=="relu" else (b1*b1)>>shift
154
+ b2 = int(Cout*KH*KW * a1 * wmax)
155
+ a2 = b2 if activation=="relu" else (b2*b2)>>shift
156
+ fc = int((Cout*H*W) * a2 * wmax)
157
+ MAX_ABS = max(abs(b1),abs(a1),abs(b2),abs(a2),abs(fc), 1)
158
+
159
+ K, MOD2, primes = pick_moduli(max_abs=MAX_ABS)
160
+ # Encode
161
+ Xr = encode_rns(X, MOD2, primes)
162
+ Wcr = encode_rns(Wc, MOD2, primes)
163
+ Wfr = encode_rns(Wfc, MOD2, primes)
164
+
165
+ # RNS path
166
+ Yr = conv_rns(Xr,Wcr,KH,KW,1,1,MOD2,primes)
167
+ Ar = relu_rns(Yr, MOD2, primes) if activation=="relu" else poly_rns(Yr, shift, MOD2, primes)
168
+ Arf = Ar.reshape(Ar.shape[0], B, -1)
169
+ Zr = matmul_rns(Arf, Wfr, MOD2, primes)
170
+ Zrec= batched_crt(Zr, MOD2, primes)
171
+
172
+ ok = np.array_equal(Z, Zrec)
173
+ arg = np.array_equal(Z.argmax(1), Zrec.argmax(1))
174
+
175
+ log = []
176
+ log.append("Stage-by-stage equality (exact):")
177
+ log.append(f"conv : {'✅' if np.array_equal(batched_crt(Yr,MOD2,primes), Y) else '❌'}")
178
+ log.append(f"act : {'✅' if np.array_equal(batched_crt(Ar,MOD2,primes), A) else '❌'}")
179
+ log.append(f"fc : {'✅' if ok else '❌'}")
180
+ log.append(f"Argmax match: {arg} (ref={Z.argmax(1).tolist()}, prime={Zrec.argmax(1).tolist()})")
181
 
182
  report = {
183
+ "activation": activation,
 
 
184
  "shift": int(shift),
 
 
 
185
  "K": int(K),
186
+ "primes": list(map(int,primes)),
187
+ "exact": bool(ok),
188
+ "argmax_exact": bool(arg),
189
+ "ref_logits_sample": Z[0,:min(8,CLS)].tolist(),
190
+ "rec_logits_sample": Zrec[0,:min(8,CLS)].tolist(),
191
+ }
192
+ return "\n".join(log), report
193
+
194
+ # -----------------------------
195
+ # Attention (exact numerators)
196
+ # -----------------------------
197
+ def attn_numerators_proof(seed):
198
+ rng = np.random.default_rng(seed)
199
+ T,d = 2,4
200
+ xmax = 15
201
+ t = 7 # scale shift
202
+
203
+ Q = rng.integers(-xmax,xmax+1,size=(T,d),dtype=np.int64)
204
+ K_ = rng.integers(-xmax,xmax+1,size=(T,d),dtype=np.int64)
205
+ V = rng.integers(-xmax,xmax+1,size=(T,d),dtype=np.int64)
206
+
207
+ Sraw = Q @ K_.T
208
+ S = Sraw >> t # exact floor
209
+ N = S @ V # numerators
210
+
211
+ # conservative bound
212
+ s_max = int((d * xmax * xmax) >> t)
213
+ n_max = int(T * s_max * xmax)
214
+ Kpow, MOD2, primes = pick_moduli(max_abs=max(abs(s_max),abs(n_max),1))
215
+
216
+ Qr = encode_rns(Q,MOD2,primes)
217
+ Kr = encode_rns(K_,MOD2,primes)
218
+ Vr = encode_rns(V,MOD2,primes)
219
+
220
+ # Sraw_r and S_r
221
+ Sraw_r = matmul_rns(Qr, Kr.transpose(0,2,1), MOD2, primes)
222
+ # floor by right shift in 2^K channel + multiplicative inverse in odd primes
223
+ # (here scaling is power-of-two so we can just shift in 2^K, and multiply by inv(2^t) after peeling low bits = 0)
224
+ # since we already floored by pure shift in reference, it's safe to do:
225
+ Sr = []
226
+ Sr.append( (Sraw_r[0] >> t) & (MOD2-1) )
227
+ invs = [pow(2,-t,p) for p in primes]
228
+ for i,p in enumerate(primes, start=1):
229
+ Sr.append( (Sraw_r[i] * invs[i-1]) % p )
230
+ Sr = np.stack(Sr, axis=0)
231
+
232
+ Nr = matmul_rns(Sr, Vr, MOD2, primes)
233
+ Srec = batched_crt(Sr, MOD2, primes)
234
+ Nrec = batched_crt(Nr, MOD2, primes)
235
+
236
+ okS = np.array_equal(S, Srec)
237
+ okN = np.array_equal(N, Nrec)
238
+
239
+ log = []
240
+ log.append("Attention (exact numerators):")
241
+ log.append(f"S : {'✅' if okS else '❌'}")
242
+ log.append(f"SV : {'✅' if okN else '❌'}")
243
+
244
+ report = {
245
+ "K": int(Kpow),
246
+ "primes": list(map(int,primes)),
247
+ "S_exact": bool(okS),
248
+ "N_exact": bool(okN),
249
+ "S_sample": S.tolist(),
250
+ "Srec_sample": Srec.tolist(),
251
+ "N_sample": N.tolist(),
252
+ "Nrec_sample": Nrec.tolist(),
253
  }
254
+ return "\n".join(log), report
255
+
256
+ # -----------------------------
257
+ # Gradio UI
258
+ # -----------------------------
259
+ with gr.Blocks(title="FieldSpace — Prime-Only Machine (RNS+CRT)") as demo:
260
+ gr.Markdown("# FieldSpace Prime-Only Machine (RNS+CRT) Exactness Proofs")
261
+ gr.Markdown("This demo proves exactness for a tiny CNN and a one-head attention numerator block using only residues mod $\\{2^K, p_i\\}$ and a single CRT.")
262
+ with gr.Tab("CNN (Conv + ReLU/Poly + FC)"):
263
+ with gr.Row():
264
+ seed = gr.Number(value=0, precision=0, label="Seed")
265
+ xmax = gr.Number(value=31, precision=0, label="|X|max")
266
+ wmax = gr.Number(value=31, precision=0, label="|W|max")
267
+ activation = gr.Radio(choices=["relu","poly"], value="relu", label="Activation")
268
+ shift = gr.Slider(1,12,value=7,step=1,label="SHIFT for poly (x^2 >> SHIFT)")
269
+ run_btn = gr.Button("Run CNN Proof", variant="primary")
270
+ out_text = gr.Textbox(label="Console", lines=10)
271
+ out_json = gr.JSON(label="JSON Report")
272
+ run_btn.click(fn=lambda a,b,c,d,e: cnn_proof(int(a),int(b),int(c),d,int(e)),
273
+ inputs=[seed,xmax,wmax,activation,shift],
274
+ outputs=[out_text,out_json])
275
+ with gr.Tab("Attention (Exact Numerators)"):
276
+ aseed = gr.Number(value=0, precision=0, label="Seed")
277
+ runA = gr.Button("Run Attention Proof", variant="secondary")
278
+ AT = gr.Textbox(label="Console", lines=8)
279
+ AJ = gr.JSON(label="JSON Report")
280
+ runA.click(fn=lambda s: attn_numerators_proof(int(s)),
281
+ inputs=[aseed], outputs=[AT,AJ])
282
+ gr.Markdown("**Use via API**")
283
+ gr.Code("""from gradio_client import Client
284
+ client = Client('https://huggingface.co/spaces/jackal79/fieldspace-prime-only')
285
+ txt, rep = client.predict('/run_cnn_proof', 0, 31, 31, 'relu', 7)""")
286
+
287
+ # Named routes for programmatic calls
288
+ demo.load(None, None, None)
289
+ demo.add_named_endpoint("/run_cnn_proof", cnn_proof, inputs=[
290
+ gr.Number(precision=0), gr.Number(precision=0), gr.Number(precision=0),
291
+ gr.Textbox(), gr.Number(precision=0)
292
+ ], outputs=[gr.Textbox(), gr.JSON()])
293
+ demo.add_named_endpoint("/run_attn_proof", attn_numerators_proof,
294
+ inputs=[gr.Number(precision=0)],
295
+ outputs=[gr.Textbox(), gr.JSON()])
296
 
297
  if __name__ == "__main__":
298
+ demo.queue().launch()
requirements.txt CHANGED
@@ -1,4 +1,2 @@
1
- gradio>=4.44,<5
2
  numpy>=1.24
3
- torch==2.3.1
4
- --extra-index-url https://download.pytorch.org/whl/cpu
 
1
+ gradio>=4.44.0
2
  numpy>=1.24