jackal79 commited on
Commit
e13fb8e
·
verified ·
1 Parent(s): dc7fd9e

Single file Space, app.py self installs deps, CNN and 1 head attention proofs

Browse files
Files changed (1) hide show
  1. app.py +247 -270
app.py CHANGED
@@ -1,298 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
- import gradio as gr
3
  import numpy as np
 
 
 
 
 
 
4
 
5
  # -----------------------------
6
- # Utilities: RNS + CRT
7
  # -----------------------------
8
- def pick_moduli(max_abs, primes=(257,263,269,271,277,281)):
9
- # choose K so 2^(K-1) > max_abs
10
- K = max(20, int(max_abs).bit_length()+1)
11
- MOD2 = 1 << K
12
- # ensure product > 2*max_abs (with plenty margin for toy sizes this is fine)
13
- M = MOD2
14
- ps = []
15
- for p in primes:
16
- ps.append(p); M *= p
17
- if M > 2*max_abs: break
18
- return K, MOD2, tuple(ps)
19
-
20
- def encode_rns(x, MOD2, primes):
21
- outs = [np.asarray(x, dtype=object) & (MOD2-1)]
22
- for p in primes:
23
- outs.append(np.asarray(x, dtype=object) % p)
24
- return np.stack(outs, axis=0) # [L, *shape]
25
-
26
- def batched_crt(res, MOD2, primes):
27
- # res: [L, *shape], L=1+len(primes), dtype=object (python ints)
28
- moduli = (MOD2,)+tuple(primes)
29
- L = res.shape[0]
30
- M = 1
31
- for m in moduli: M *= m
32
- Mi = [M//m for m in moduli]
33
- inv = [pow(int(Mi[i] % moduli[i]), -1, moduli[i]) for i in range(L)]
34
- flat = res.reshape(L, -1)
35
- out = []
36
- half = M//2
37
- for k in range(flat.shape[1]):
38
- acc = 0
39
- for i in range(L):
40
- acc = (acc + (flat[i,k] * inv[i] * Mi[i])) % M
41
- if acc > half: acc -= M
42
- out.append(int(acc))
43
- return np.array(out, dtype=np.int64).reshape(res.shape[1:])
44
 
45
  # -----------------------------
46
- # CNN proof (tiny shapes)
47
  # -----------------------------
48
  def im2col_nchw_int(x, kH, kW, stride=1, pad=0):
49
- N,C,H,W = x.shape
50
- Hp = H+2*pad; Wp = W+2*pad
51
- xp = np.zeros((N,C,Hp,Wp), dtype=np.int64)
52
- xp[:,:,pad:pad+H, pad:pad+W] = x
53
- Hout = (Hp - kH)//stride + 1
54
- Wout = (Wp - kW)//stride + 1
 
 
55
  cols = []
56
  for i in range(Hout):
57
  for j in range(Wout):
58
- patch = xp[:,:,i*stride:i*stride+kH, j*stride:j*stride+kW]
59
- cols.append(patch.reshape(N,-1))
60
- out = np.stack(cols, axis=1) # [N, Hout*Wout, C*kH*kW]
61
- out = out.reshape(N*Hout*Wout, -1)
62
  return out, Hout, Wout
63
 
64
- def conv_ref(X, W, kH, kW, stride=1, pad=0):
65
- N,C,H,W = X.shape
66
- Cout,Cin,KH,KW = W.shape
67
- Xcol,Hout,Wout = im2col_nchw_int(X,kH,kW,stride,pad)
68
- Ycol = Xcol @ W.reshape(Cout,-1).T
69
- Y = Ycol.reshape(N,Hout,Wout,Cout).transpose(0,3,1,2).copy()
70
- return Y
71
-
72
- def relu_ref(x): return np.where(x<0, 0, x)
73
- def poly_ref(x, s): return (x*x) >> s
74
- def linear_ref(X, W): return X @ W
75
-
76
- def relu_rns(Xr, MOD2, primes):
77
- # Xr: [L, *shape]
78
- L = Xr.shape[0]
79
- K = MOD2.bit_length()-1
80
- x2 = Xr[0]
81
- mask = (x2 < (1<<(K-1))).astype(np.int64) # 1 if non-negative, else 0
82
- out = []
83
- out.append((x2 * mask) & (MOD2-1))
84
- for i,p in enumerate(primes, start=1):
85
- out.append((Xr[i] * mask) % p)
86
- return np.stack(out, axis=0)
87
-
88
- def poly_rns(Xr, s, MOD2, primes):
89
- # q = floor(x^2 / 2^s)
90
- mask_s = (1<<s)-1
91
- out = []
92
- # 2^K channel
93
- y2 = (Xr[0]*Xr[0]) & (MOD2-1)
94
- q2 = (y2 >> s) & (MOD2-1)
95
- out.append(q2)
96
- # r from low s bits of x only
97
- low = Xr[0] & mask_s
98
- r = (low*low) & mask_s
99
- for i,p in enumerate(primes, start=1):
100
- y2p = (Xr[i]*Xr[i]) % p
101
- inv = pow(2, -s, p)
102
- qi = ((y2p - (r % p)) * inv) % p
103
- out.append(qi)
104
- return np.stack(out, axis=0)
105
-
106
- def matmul_rns(Ar, Br, MOD2, primes):
107
- # Ar: [L, M, K], Br: [L, K, N] -> [L, M, N]
108
- outs=[]
109
- outs.append((Ar[0] @ Br[0]) & (MOD2-1))
110
- for i,p in enumerate(primes, start=1):
111
- outs.append((Ar[i] @ Br[i]) % p)
112
- return np.stack(outs, axis=0)
113
-
114
- def conv_rns(Xr, Wr, kH, kW, stride=1, pad=0, MOD2=0, primes=()):
115
- # reshape per-modulus and reuse im2col on ints
116
- L = Xr.shape[0]
117
- N,C,H,W = Xr[0].shape
118
- Cout = Wr[0].shape[0]
119
- Ymods=[]
120
- for i in range(L):
121
- Xi = Xr[i].astype(object)
122
- Wi = Wr[i].astype(object)
123
- Xcol,Hout,Wout = im2col_nchw_int(np.asarray(Xi, dtype=np.int64), kH,kW,stride,pad)
124
- Wcol = np.asarray(Wi.reshape(Cout,-1), dtype=np.int64)
125
- if i==0:
126
- MOD = MOD2
127
- Ycol = (Xcol @ Wcol.T) & (MOD-1)
128
- else:
129
- p = primes[i-1]
130
- Ycol = (Xcol @ Wcol.T) % p
131
- Y = Ycol.reshape(N,Hout,Wout,Cout).transpose(0,3,1,2).copy()
132
- Ymods.append(Y.astype(object))
133
- return np.stack(Ymods, axis=0)
134
-
135
- def cnn_proof(seed, xmax, wmax, activation, shift):
136
- # Tiny shapes to keep it snappy
137
- B,Cin,H,W = 2,1,8,8
138
- Cout = 2
139
- KH=KW=3
140
- CLS=8
141
- rng = np.random.default_rng(seed)
142
- X = rng.integers(-xmax, xmax+1, size=(B,Cin,H,W), dtype=np.int64)
143
- Wc = rng.integers(-wmax, wmax+1, size=(Cout,Cin,KH,KW), dtype=np.int64)
144
- Wfc = rng.integers(-wmax, wmax+1, size=(Cout*H*W, CLS), dtype=np.int64)
145
-
146
- # Reference (int64)
147
- Y = conv_ref(X,Wc,KH,KW,1,1)
148
- A = relu_ref(Y) if activation=="relu" else poly_ref(Y, shift)
149
- Z = linear_ref(A.reshape(B,-1), Wfc)
150
-
151
- # Bounds (very conservative)
152
- b1 = int(Cin*KH*KW * xmax * wmax)
153
- a1 = b1 if activation=="relu" else (b1*b1)>>shift
154
- b2 = int(Cout*KH*KW * a1 * wmax)
155
- a2 = b2 if activation=="relu" else (b2*b2)>>shift
156
- fc = int((Cout*H*W) * a2 * wmax)
157
- MAX_ABS = max(abs(b1),abs(a1),abs(b2),abs(a2),abs(fc), 1)
158
-
159
- K, MOD2, primes = pick_moduli(max_abs=MAX_ABS)
160
- # Encode
161
- Xr = encode_rns(X, MOD2, primes)
162
- Wcr = encode_rns(Wc, MOD2, primes)
163
- Wfr = encode_rns(Wfc, MOD2, primes)
164
-
165
- # RNS path
166
- Yr = conv_rns(Xr,Wcr,KH,KW,1,1,MOD2,primes)
167
- Ar = relu_rns(Yr, MOD2, primes) if activation=="relu" else poly_rns(Yr, shift, MOD2, primes)
168
- Arf = Ar.reshape(Ar.shape[0], B, -1)
169
- Zr = matmul_rns(Arf, Wfr, MOD2, primes)
170
- Zrec= batched_crt(Zr, MOD2, primes)
171
-
172
- ok = np.array_equal(Z, Zrec)
173
- arg = np.array_equal(Z.argmax(1), Zrec.argmax(1))
174
-
175
- log = []
176
- log.append("Stage-by-stage equality (exact):")
177
- log.append(f"conv : {'✅' if np.array_equal(batched_crt(Yr,MOD2,primes), Y) else '❌'}")
178
- log.append(f"act : {'✅' if np.array_equal(batched_crt(Ar,MOD2,primes), A) else '❌'}")
179
- log.append(f"fc : {'✅' if ok else '❌'}")
180
- log.append(f"Argmax match: {arg} (ref={Z.argmax(1).tolist()}, prime={Zrec.argmax(1).tolist()})")
181
 
182
- report = {
183
- "activation": activation,
184
- "shift": int(shift),
185
- "K": int(K),
186
- "primes": list(map(int,primes)),
187
- "exact": bool(ok),
188
- "argmax_exact": bool(arg),
189
- "ref_logits_sample": Z[0,:min(8,CLS)].tolist(),
190
- "rec_logits_sample": Zrec[0,:min(8,CLS)].tolist(),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  }
192
- return "\n".join(log), report
193
 
194
  # -----------------------------
195
- # Attention (exact numerators)
196
  # -----------------------------
197
- def attn_numerators_proof(seed):
198
- rng = np.random.default_rng(seed)
199
- T,d = 2,4
200
- xmax = 15
201
- t = 7 # scale shift
202
-
203
- Q = rng.integers(-xmax,xmax+1,size=(T,d),dtype=np.int64)
204
- K_ = rng.integers(-xmax,xmax+1,size=(T,d),dtype=np.int64)
205
- V = rng.integers(-xmax,xmax+1,size=(T,d),dtype=np.int64)
206
-
207
- Sraw = Q @ K_.T
208
- S = Sraw >> t # exact floor
209
- N = S @ V # numerators
210
-
211
- # conservative bound
212
- s_max = int((d * xmax * xmax) >> t)
213
- n_max = int(T * s_max * xmax)
214
- Kpow, MOD2, primes = pick_moduli(max_abs=max(abs(s_max),abs(n_max),1))
215
-
216
- Qr = encode_rns(Q,MOD2,primes)
217
- Kr = encode_rns(K_,MOD2,primes)
218
- Vr = encode_rns(V,MOD2,primes)
219
-
220
- # Sraw_r and S_r
221
- Sraw_r = matmul_rns(Qr, Kr.transpose(0,2,1), MOD2, primes)
222
- # floor by right shift in 2^K channel + multiplicative inverse in odd primes
223
- # (here scaling is power-of-two so we can just shift in 2^K, and multiply by inv(2^t) after peeling low bits = 0)
224
- # since we already floored by pure shift in reference, it's safe to do:
225
- Sr = []
226
- Sr.append( (Sraw_r[0] >> t) & (MOD2-1) )
227
- invs = [pow(2,-t,p) for p in primes]
228
- for i,p in enumerate(primes, start=1):
229
- Sr.append( (Sraw_r[i] * invs[i-1]) % p )
230
- Sr = np.stack(Sr, axis=0)
231
-
232
- Nr = matmul_rns(Sr, Vr, MOD2, primes)
233
- Srec = batched_crt(Sr, MOD2, primes)
234
- Nrec = batched_crt(Nr, MOD2, primes)
235
-
236
- okS = np.array_equal(S, Srec)
237
- okN = np.array_equal(N, Nrec)
238
-
239
- log = []
240
- log.append("Attention (exact numerators):")
241
- log.append(f"S : {'✅' if okS else '❌'}")
242
- log.append(f"SV : {'✅' if okN else '❌'}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
  report = {
245
- "K": int(Kpow),
246
- "primes": list(map(int,primes)),
247
- "S_exact": bool(okS),
248
- "N_exact": bool(okN),
249
- "S_sample": S.tolist(),
250
- "Srec_sample": Srec.tolist(),
251
- "N_sample": N.tolist(),
252
- "Nrec_sample": Nrec.tolist(),
253
  }
254
- return "\n".join(log), report
255
 
256
  # -----------------------------
257
- # Gradio UI
258
  # -----------------------------
259
- with gr.Blocks(title="FieldSpace — Prime-Only Machine (RNS+CRT)") as demo:
260
- gr.Markdown("# FieldSpace Prime-Only Machine (RNS+CRT) Exactness Proofs")
261
- gr.Markdown("This demo proves exactness for a tiny CNN and a one-head attention numerator block using only residues mod $\\{2^K, p_i\\}$ and a single CRT.")
262
- with gr.Tab("CNN (Conv + ReLU/Poly + FC)"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  with gr.Row():
264
- seed = gr.Number(value=0, precision=0, label="Seed")
265
- xmax = gr.Number(value=31, precision=0, label="|X|max")
266
- wmax = gr.Number(value=31, precision=0, label="|W|max")
267
- activation = gr.Radio(choices=["relu","poly"], value="relu", label="Activation")
268
- shift = gr.Slider(1,12,value=7,step=1,label="SHIFT for poly (x^2 >> SHIFT)")
269
- run_btn = gr.Button("Run CNN Proof", variant="primary")
270
- out_text = gr.Textbox(label="Console", lines=10)
271
- out_json = gr.JSON(label="JSON Report")
272
- run_btn.click(fn=lambda a,b,c,d,e: cnn_proof(int(a),int(b),int(c),d,int(e)),
273
- inputs=[seed,xmax,wmax,activation,shift],
274
- outputs=[out_text,out_json])
275
- with gr.Tab("Attention (Exact Numerators)"):
276
- aseed = gr.Number(value=0, precision=0, label="Seed")
277
- runA = gr.Button("Run Attention Proof", variant="secondary")
278
- AT = gr.Textbox(label="Console", lines=8)
279
- AJ = gr.JSON(label="JSON Report")
280
- runA.click(fn=lambda s: attn_numerators_proof(int(s)),
281
- inputs=[aseed], outputs=[AT,AJ])
282
- gr.Markdown("**Use via API**")
283
- gr.Code("""from gradio_client import Client
284
- client = Client('https://huggingface.co/spaces/jackal79/fieldspace-prime-only')
285
- txt, rep = client.predict('/run_cnn_proof', 0, 31, 31, 'relu', 7)""")
286
-
287
- # Named routes for programmatic calls
288
- demo.load(None, None, None)
289
- demo.add_named_endpoint("/run_cnn_proof", cnn_proof, inputs=[
290
- gr.Number(precision=0), gr.Number(precision=0), gr.Number(precision=0),
291
- gr.Textbox(), gr.Number(precision=0)
292
- ], outputs=[gr.Textbox(), gr.JSON()])
293
- demo.add_named_endpoint("/run_attn_proof", attn_numerators_proof,
294
- inputs=[gr.Number(precision=0)],
295
- outputs=[gr.Textbox(), gr.JSON()])
296
 
297
  if __name__ == "__main__":
298
- demo.queue().launch()
 
1
+ # app.py, single file Space
2
+ # FieldSpace, Prime-Only Machine, mod 2^K exactness proofs
3
+ import sys, subprocess, time, math, json
4
+
5
+ def _ensure(package_spec: str):
6
+ try:
7
+ __import__(package_spec.split("==")[0].split(">=")[0].split("<")[0])
8
+ return
9
+ except Exception:
10
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package_spec])
11
+
12
+ # Self install, single file repo, no external requirements
13
+ for spec in ["gradio==4.44.1", "numpy>=1.26", "torch>=2.2,<2.4"]:
14
+ _ensure(spec)
15
 
 
16
  import numpy as np
17
+ import torch
18
+ import gradio as gr
19
+
20
+ torch.set_grad_enabled(False)
21
+ torch.set_num_threads(1)
22
+ DEVICE = "cpu"
23
 
24
  # -----------------------------
25
+ # Utilities (mod 2^K)
26
  # -----------------------------
27
+ def seed_all(s: int):
28
+ return torch.Generator(device=DEVICE).manual_seed(int(s))
29
+
30
+ def choose_K(max_abs: int, min_bits: int = 24) -> int:
31
+ need = max(min_bits, int(max_abs).bit_length() + 1)
32
+ return min(62, max(need, min_bits))
33
+
34
+ def mod2_encode(x: torch.Tensor, MOD2: int) -> torch.Tensor:
35
+ return (x & (MOD2 - 1)).to(torch.int64)
36
+
37
+ def mod2_center(v: torch.Tensor, MOD2: int) -> torch.Tensor:
38
+ K = MOD2.bit_length() - 1
39
+ half = 1 << (K - 1)
40
+ out = v.clone()
41
+ out[out > half] -= MOD2
42
+ return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  # -----------------------------
45
+ # CNN (Conv -> Act -> Conv -> Act -> FC)
46
  # -----------------------------
47
  def im2col_nchw_int(x, kH, kW, stride=1, pad=0):
48
+ N, C, H, W = x.shape
49
+ Hout = (H + 2*pad - kH)//stride + 1
50
+ Wout = (W + 2*pad - kW)//stride + 1
51
+ if pad:
52
+ x_pad = torch.zeros((N, C, H + 2*pad, W + 2*pad), dtype=x.dtype, device=x.device)
53
+ x_pad[:, :, pad:pad+H, pad:pad+W] = x
54
+ else:
55
+ x_pad = x
56
  cols = []
57
  for i in range(Hout):
58
  for j in range(Wout):
59
+ patch = x_pad[:, :, i*stride:i*stride+kH, j*stride:j*stride+kW]
60
+ cols.append(patch.reshape(N, -1))
61
+ out = torch.stack(cols, dim=1) # [N, Hout*Wout, C*kH*kW]
62
+ out = out.reshape(N*Hout*Wout, -1) # [N*Hout*Wout, C*kH*kW]
63
  return out, Hout, Wout
64
 
65
+ def conv_int64_im2col(X, W, kH, kW, stride=1, pad=0):
66
+ Xcol, Hout, Wout = im2col_nchw_int(X, kH, kW, stride=stride, pad=pad)
67
+ Ycol = Xcol @ W.reshape(W.shape[0], -1).t()
68
+ return Ycol.reshape(X.shape[0], Hout, Wout, W.shape[0]).permute(0,3,1,2).contiguous()
69
+
70
+ def run_cnn_proof(seed:int=0, Xmax:int=31, Wmax:int=31, act_kind:str="relu", poly_shift:int=7):
71
+ g = seed_all(seed)
72
+ B,Cin,H,W = 4,1,16,16
73
+ C1,C2 = 8,8
74
+ kH,kW = 3,3
75
+ STRIDE,PAD = 1,1
76
+ CLS = 10
77
+
78
+ X = torch.randint(-Xmax, Xmax+1, (B,Cin,H,W), dtype=torch.int64, generator=g, device=DEVICE)
79
+ W1 = torch.randint(-Wmax, Wmax+1, (C1,Cin,kH,kW), dtype=torch.int64, generator=g, device=DEVICE)
80
+ W2 = torch.randint(-Wmax, Wmax+1, (C2,C1,kH,kW), dtype=torch.int64, generator=g, device=DEVICE)
81
+ Wf = torch.randint(-Wmax, Wmax+1, (C2*H*W, CLS), dtype=torch.int64, generator=g, device=DEVICE)
82
+
83
+ t0 = time.time()
84
+ Z1 = conv_int64_im2col(X, W1, kH,kW, STRIDE, PAD)
85
+ if act_kind=="relu":
86
+ A1 = Z1.clamp_min_(0)
87
+ else:
88
+ A1 = ((Z1*Z1) >> int(poly_shift))
89
+ Z2 = conv_int64_im2col(A1, W2, kH,kW, STRIDE, PAD)
90
+ if act_kind=="relu":
91
+ A2 = Z2.clamp_min_(0)
92
+ else:
93
+ A2 = ((Z2*Z2) >> int(poly_shift))
94
+ Y = (A2.reshape(B, -1) @ Wf)
95
+ t1 = time.time()
96
+
97
+ max_abs = int(max(Z1.abs().max().item(), A1.abs().max().item(),
98
+ Z2.abs().max().item(), A2.abs().max().item(), Y.abs().max().item()))
99
+ K = choose_K(max_abs, min_bits=24)
100
+ MOD2 = 1 << K
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ Z1m = mod2_encode(conv_int64_im2col(mod2_encode(X,MOD2), mod2_encode(W1,MOD2), kH,kW, STRIDE, PAD), MOD2)
103
+ Z1i = mod2_center(Z1m, MOD2)
104
+ A1i = Z1i.clamp_min_(0) if act_kind=="relu" else ((Z1i*Z1i) >> int(poly_shift))
105
+
106
+ Z2m = mod2_encode(conv_int64_im2col(mod2_encode(A1i,MOD2), mod2_encode(W2,MOD2), kH,kW, STRIDE, PAD), MOD2)
107
+ Z2i = mod2_center(Z2m, MOD2)
108
+ A2i = Z2i.clamp_min_(0) if act_kind=="relu" else ((Z2i*Z2i) >> int(poly_shift))
109
+
110
+ Ym = mod2_encode((mod2_encode(A2i.reshape(B,-1),MOD2) @ mod2_encode(Wf,MOD2)), MOD2)
111
+ Yi = mod2_center(Ym, MOD2)
112
+
113
+ ok_all = bool(torch.equal(Y, Yi))
114
+ arg_ok = bool(torch.equal(Y.argmax(1), Yi.argmax(1)))
115
+
116
+ txt = []
117
+ txt.append(f"K={K} (2^(K-1)={1<<(K-1):,} > max_abs={max_abs:,}), act={act_kind} ({'ReLU' if act_kind=='relu' else f'poly SHIFT={poly_shift}'})")
118
+ txt.append("Stage by stage equality, exact:")
119
+ txt.append(f"Z1: {'OK' if torch.equal(Z1,Z1i) else 'NO'} | A1: {'OK' if torch.equal(A1,A1i) else 'NO'} | "
120
+ f"Z2: {'OK' if torch.equal(Z2,Z2i) else 'NO'} | Y: {'OK' if torch.equal(Y,Yi) else 'NO'}")
121
+ txt.append(f"Argmax match: {arg_ok} (ref={Y.argmax(1).tolist()}, mod2={Yi.argmax(1).tolist()})")
122
+ txt.append(f"Timing ms: ref≈{(t1-t0)*1000:.2f}")
123
+ return "\n".join(txt), {
124
+ "ok_all": ok_all, "argmax_ok": arg_ok, "K": K,
125
+ "act": act_kind, "poly_shift": int(poly_shift),
126
+ "ref_top1": Y.argmax(1).tolist(), "mod2_top1": Yi.argmax(1).tolist()
127
  }
 
128
 
129
  # -----------------------------
130
+ # 1 head Attention
131
  # -----------------------------
132
+ def split_heads(Z, B,T,C,H):
133
+ Dh = C//H
134
+ return Z.reshape(B,T,H,Dh).permute(0,2,1,3).contiguous()
135
+
136
+ def merge_heads(Zh, B,T,C,H):
137
+ return Zh.permute(0,2,1,3).reshape(B,T,C).contiguous()
138
+
139
+ def pick_shift_for_cap(smax:int, cap:int=8) -> int:
140
+ if smax <= cap: return 0
141
+ s = 0
142
+ while (smax >> s) > cap: s += 1
143
+ return s
144
+
145
+ def run_attn_proof(seed:int=0, Xmax:int=7, Wmax:int=7, B:int=1, T:int=8, C:int=16, H:int=1):
146
+ assert C % H == 0
147
+ Dh = C//H
148
+ g = seed_all(seed)
149
+
150
+ X = torch.randint(0, Xmax+1, (B,T,C), dtype=torch.int64, generator=g, device=DEVICE)
151
+ Wq = torch.randint(0, Wmax+1, (C,C), dtype=torch.int64, generator=g, device=DEVICE)
152
+ Wk = torch.randint(0, Wmax+1, (C,C), dtype=torch.int64, generator=g, device=DEVICE)
153
+ Wv = torch.randint(0, Wmax+1, (C,C), dtype=torch.int64, generator=g, device=DEVICE)
154
+ Wo = torch.randint(0, Wmax+1, (C,C), dtype=torch.int64, generator=g, device=DEVICE)
155
+
156
+ t0 = time.time()
157
+ Q = X @ Wq; K = X @ Wk; V = X @ Wv
158
+ Qh, Kh, Vh = split_heads(Q,B,T,C,H), split_heads(K,B,T,C,H), split_heads(V,B,T,C,H)
159
+ Sraw = torch.einsum("bhtd,bhTd->bhtT", Qh, Kh)
160
+ SHIFT = pick_shift_for_cap(int(Sraw.max().item()), cap=8)
161
+ S = (Sraw >> SHIFT).clamp_min_(0)
162
+ E = (torch.ones_like(S) << S)
163
+ Den = E.sum(-1, keepdim=True)
164
+ Num = torch.einsum("bhtT,bhTd->bhtd", E, Vh)
165
+ Oh = Num // Den
166
+ O = merge_heads(Oh,B,T,C,H)
167
+ Y = O @ Wo
168
+ t1 = time.time()
169
+
170
+ max_abs = int(max(Q.abs().max().item(), K.abs().max().item(), V.abs().max().item(),
171
+ Sraw.abs().max().item(), S.abs().max().item(),
172
+ Num.abs().max().item(), O.abs().max().item(), Y.abs().max().item()))
173
+ Kbits = choose_K(max_abs, min_bits=24)
174
+ MOD2 = 1 << Kbits
175
+
176
+ Qm = mod2_encode(X,MOD2) @ mod2_encode(Wq,MOD2); Qi = mod2_center(Qm,MOD2)
177
+ Km = mod2_encode(X,MOD2) @ mod2_encode(Wk,MOD2); Ki = mod2_center(Km,MOD2)
178
+ Vm = mod2_encode(X,MOD2) @ mod2_encode(Wv,MOD2); Vi = mod2_center(Vm,MOD2)
179
+ assert torch.equal(Qi,Q) and torch.equal(Ki,K) and torch.equal(Vi,V)
180
+
181
+ Qh, Kh, Vh = split_heads(Qi,B,T,C,H), split_heads(Ki,B,T,C,H), split_heads(Vi,B,T,C,H)
182
+
183
+ BH = B*H
184
+ Sraw_mod = torch.empty((BH, T, T), dtype=torch.int64)
185
+ Q2, K2 = Qh.reshape(BH,T,C//H), Kh.reshape(BH,T,C//H)
186
+ for b in range(BH):
187
+ Sraw_mod[b] = (mod2_encode(Q2[b],MOD2) @ mod2_encode(K2[b].t(),MOD2)) & (MOD2-1)
188
+ Sraw_i = mod2_center(Sraw_mod.reshape(B,H,T,T), MOD2)
189
+
190
+ if SHIFT == 0:
191
+ S_i = Sraw_i
192
+ else:
193
+ S2 = mod2_encode(Sraw_i, MOD2)
194
+ q2 = (S2 >> SHIFT) & (MOD2 - 1)
195
+ S_i = mod2_center(q2, MOD2)
196
+ S_i = S_i.clamp_min_(0)
197
+
198
+ E_mod = (torch.ones_like(S_i) << S_i) & (MOD2 - 1)
199
+ Den_m = E_mod.sum(-1)
200
+ Den_i = mod2_center(Den_m, MOD2)
201
+
202
+ Vh_m = mod2_encode(Vh, MOD2)
203
+ Num_m = torch.empty((B,H,T, C//H), dtype=torch.int64)
204
+ for b in range(B):
205
+ for h in range(H):
206
+ Num_m[b,h] = (E_mod[b,h] @ Vh_m[b,h]) & (MOD2 - 1)
207
+ Num_i = mod2_center(Num_m, MOD2)
208
+
209
+ Oh = Num_i // Den_i.unsqueeze(-1)
210
+ Oi = merge_heads(Oh,B,T,C,H)
211
+
212
+ Ym = (mod2_encode(Oi,MOD2) @ mod2_encode(Wo,MOD2)) & (MOD2 - 1)
213
+ Yi = mod2_center(Ym, MOD2)
214
+
215
+ ok_all = bool(torch.equal(Y, Yi))
216
+ arg_ok = bool(torch.equal(Y.argmax(-1), Yi.argmax(-1)))
217
+
218
+ lines = []
219
+ lines.append(f"Auto SHIFT={SHIFT} cap=8")
220
+ lines.append(f"K={Kbits} 2^(K-1)={1<<(Kbits-1):,} > max_abs={max_abs:,}")
221
+ lines.append("Stage by stage equality, exact:")
222
+ lines.append("Q:{0} K:{1} V:{2} S:{3} Y:{4}".format(
223
+ 'OK' if torch.equal(Q,Qi) else 'NO',
224
+ 'OK' if torch.equal(K,Ki) else 'NO',
225
+ 'OK' if torch.equal(V,Vi) else 'NO',
226
+ 'OK' if torch.equal(S,S_i) else 'NO',
227
+ 'OK' if torch.equal(Y,Yi) else 'NO'
228
+ ))
229
+ lines.append(f"Argmax match: {arg_ok} ref={Y.argmax(-1).tolist()} mod2={Yi.argmax(-1).tolist()}")
230
+ lines.append(f"Timing ms: ref≈{(t1-t0)*1000:.2f}")
231
 
232
  report = {
233
+ "ok_all": ok_all, "argmax_ok": arg_ok, "K": Kbits,
234
+ "ref_top1": Y.argmax(-1).tolist(), "mod2_top1": Yi.argmax(-1).tolist()
 
 
 
 
 
 
235
  }
236
+ return "\n".join(lines), report
237
 
238
  # -----------------------------
239
+ # UI
240
  # -----------------------------
241
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
242
+ gr.Markdown("# FieldSpace, Prime Only Machine mod $2^K$, Exactness Proofs, single file")
243
+ gr.Markdown("CNN and 1 head attention, all int64, exact mod 2^K reconstruction.")
244
+
245
+ with gr.Tab("CNN Proof"):
246
+ with gr.Row():
247
+ seed = gr.Number(value=0, label="Seed", precision=0)
248
+ xmax = gr.Number(value=31, label="|X|max", precision=0)
249
+ wmax = gr.Number(value=31, label="|W|max", precision=0)
250
+ act_sel = gr.Radio(choices=["relu","poly"], value="relu", label="Activation")
251
+ shift_poly= gr.Slider(1, 12, value=7, step=1, label="SHIFT for poly x^2 >> SHIFT")
252
+ run_btn = gr.Button("Run CNN Proof", variant="primary")
253
+ out_txt = gr.Textbox(label="Console", lines=12)
254
+ out_json = gr.JSON(label="JSON Report")
255
+ def _run_cnn(s, xM, wM, act, sh): return run_cnn_proof(int(s), int(xM), int(wM), act, int(sh))
256
+ run_btn.click(_run_cnn, [seed,xmax,wmax,act_sel,shift_poly], [out_txt,out_json], api_name="run_cnn_proof")
257
+
258
+ with gr.Tab("Attention Proof 1 head"):
259
+ with gr.Row():
260
+ seedA = gr.Number(value=0, label="Seed", precision=0)
261
+ XmaxA = gr.Number(value=7, label="X in [0, Xmax]", precision=0)
262
+ WmaxA = gr.Number(value=7, label="W in [0, Wmax]", precision=0)
263
  with gr.Row():
264
+ BA = gr.Number(value=1, label="Batch B", precision=0)
265
+ TA = gr.Number(value=8, label="Seq T", precision=0)
266
+ CA = gr.Number(value=16, label="Width C", precision=0)
267
+ HA = gr.Number(value=1, label="Heads H divides C", precision=0)
268
+ runA = gr.Button("Run Attention Proof", variant="primary")
269
+ outA_t = gr.Textbox(label="Console", lines=12)
270
+ outA_j = gr.JSON(label="JSON Report")
271
+ def _run_attn(s, xM, wM, B,T,C,H): return run_attn_proof(int(s), int(xM), int(wM), int(B), int(T), int(C), int(H))
272
+ runA.click(_run_attn, [seedA,XmaxA,WmaxA,BA,TA,CA,HA], [outA_t,outA_j], api_name="run_attn_proof")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
  if __name__ == "__main__":
275
+ demo.launch()