WebGL Inference for Transformer Models
WebGL was designed for graphics, not compute. We're "abusing" the rendering pipeline to do math. While not as elegant as proper compute shaders, this approach works today in every browser and can run meaningful models (25M+ parameters) at usable speeds. This is the model this blog was tested with:
Vocab size: 5,256 (padded to 8,192)
Embedding dim: 512
Attention heads: 8 (head dim = 64)
Transformer blocks: 6
Max sequence length: 2,048
MLP hidden dim: 2,048 (4x expansion)
Total parameters: ~25M
We will think of WebGL as a massively parallel compute device where:
- Weights are stored as textures (read-only)
- Activations flow through textures (read/write via framebuffers)
- Fragment shaders perform matrix operations (one output pixel = one output value)
- The full-screen quad pattern drives computation
WebGL2 Requirements
Rather than floats to rgba, we're just going to use floating point textures, which are available in WebGL as an extension.
const gl = canvas.getContext('webgl2');
// Required: allows R32F textures as render targets
const ext = gl.getExtension('EXT_color_buffer_float');
if (!ext) throw new Error('EXT_color_buffer_float not supported');
For inference without visible rendering, create an offscreen Canvas:
let canvas;
if (typeof OffscreenCanvas !== 'undefined') {
canvas = new OffscreenCanvas(1, 1);
} else {
canvas = document.createElement('canvas');
}
const gl = canvas.getContext('webgl2');
Texture Storage Strategy
Format: R32F (Single-Channel Float32)
We use R32F format (one 32-bit float per texel) rather than RGBA32F because:
- Simpler indexing (no channel packing/unpacking)
- More intuitive mapping to weight matrices
- Sufficient precision for inference
gl.texImage2D(
gl.TEXTURE_2D,
0,
gl.R32F, // Internal format
width,
height,
0,
gl.RED, // Format
gl.FLOAT, // Type
data // Float32Array
);
Texture Layout
All textures use power-of-two dimensions for compatibility. Gab was mostly designed to use power of two hyper paramaters, only the vocab needs to be padded:
| Texture | Dimensions | Purpose |
|---|---|---|
| Token Embeddings | 512 × 8192 | vocab (padded) × embedding |
| Position Embeddings | 512 × 2048 | seq × embedding |
| QKV Weights | 512 × 512 | embedding × embedding |
| MLP1 Weights | 512 × 2048 | embedding × hidden |
| MLP2 Weights | 2048 × 512 | hidden × embedding |
| Attention Scores | 2048 × 2048 | seq × seq |
| Hidden States | 512 × 2048 | embedding × seq |
Data is stored row-major in Float32Arrays:
data[row * width + col] -> texelFetch(tex, ivec2(col, row), 0).r
For a weight matrix W[output][input]:
// Store as: weightData[outputIdx * inputDim + inputIdx]
for (let o = 0; o < outputDim; o++) {
for (let i = 0; i < inputDim; i++) {
weightData[o * inputDim + i] = W[o][i];
}
}
// Access as: texelFetch(weights, ivec2(inputIdx, outputIdx), 0).r
Total GPU Memory
| Category | Size |
|---|---|
| Weight textures (~50) | ~108 MB |
| Working textures (~12) | ~10 MB |
| Total | ~118 MB |
The Shader Pipeline
Vertex Shader (Shared)
All operations use the same vertex shader that renders a full-screen quad:
#version 300 es
in vec2 a_position;
out vec2 v_texCoord;
void main() {
v_texCoord = a_position * 0.5 + 0.5;
gl_Position = vec4(a_position, 0.0, 1.0);
}
The quad vertices are: [-1,-1], [1,-1], [-1,1], [-1,1], [1,-1], [1,1]
Fragment Shader Pattern
Each fragment shader computes one output value based on gl_FragCoord:
#version 300 es
precision highp float;
out float outValue;
uniform sampler2D u_input;
uniform int u_width;
uniform int u_height;
void main() {
int x = int(gl_FragCoord.x); // Output column
int y = int(gl_FragCoord.y); // Output row
// Bounds check
if (x >= u_width || y >= u_height) {
outValue = 0.0;
return;
}
// Compute output value...
outValue = /* computation */;
}
Core Shaders
1. Embedding Lookup
void main() {
int d = int(gl_FragCoord.x); // Embedding dimension
int pos = int(gl_FragCoord.y); // Sequence position
float tokenId = texelFetch(u_tokens, ivec2(pos, 0), 0).r;
float tokEmb = texelFetch(u_tokenEmb, ivec2(d, int(tokenId)), 0).r;
float posEmb = texelFetch(u_posEmb, ivec2(d, pos), 0).r;
outValue = tokEmb + posEmb;
}
2. Layer Normalization
void main() {
int d = int(gl_FragCoord.x);
int pos = int(gl_FragCoord.y);
// Compute mean across embedding dimension
float mean = 0.0;
for (int i = 0; i < u_embDim; i++) {
mean += texelFetch(u_input, ivec2(i, pos), 0).r;
}
mean /= float(u_embDim);
// Compute variance
float variance = 0.0;
for (int i = 0; i < u_embDim; i++) {
float diff = texelFetch(u_input, ivec2(i, pos), 0).r - mean;
variance += diff * diff;
}
variance /= float(u_embDim);
// Normalize, scale, shift
float x = texelFetch(u_input, ivec2(d, pos), 0).r;
float normalized = (x - mean) / sqrt(variance + 1e-5);
float gamma = texelFetch(u_gamma, ivec2(d, 0), 0).r;
float beta = texelFetch(u_beta, ivec2(d, 0), 0).r;
outValue = gamma * normalized + beta;
}
3. Matrix Multiplication
The workhorse operation. Computes Output = Input @ Weight:
void main() {
int outCol = int(gl_FragCoord.x); // Output dimension
int outRow = int(gl_FragCoord.y); // Sequence position
float sum = 0.0;
for (int i = 0; i < u_inputDim; i++) {
float inputVal = texelFetch(u_input, ivec2(i, outRow), 0).r;
float weightVal = texelFetch(u_weight, ivec2(i, outCol), 0).r;
sum += inputVal * weightVal;
}
outValue = sum;
}
Critical: Weight matrix must be stored transposed as W[outputIdx][inputIdx] for this indexing to work.
4. Attention Scores
void main() {
int s = int(gl_FragCoord.x); // Source position (key)
int t = int(gl_FragCoord.y); // Target position (query)
// Causal mask
if (s > t) {
outValue = -1e9;
return;
}
int headStart = u_headIdx * u_headDim;
float score = 0.0;
for (int d = 0; d < u_headDim; d++) {
float q = texelFetch(u_Q, ivec2(headStart + d, t), 0).r;
float k = texelFetch(u_K, ivec2(headStart + d, s), 0).r;
score += q * k;
}
outValue = score * u_scale; // scale = 1/sqrt(headDim)
}
5. Softmax (Row-wise)
void main() {
int x = int(gl_FragCoord.x);
int y = int(gl_FragCoord.y);
// Find max for numerical stability
float maxVal = -1e9;
for (int i = 0; i <= y; i++) { // Causal: only up to position y
maxVal = max(maxVal, texelFetch(u_input, ivec2(i, y), 0).r);
}
// Compute exp sum
float sumExp = 0.0;
for (int i = 0; i <= y; i++) {
sumExp += exp(texelFetch(u_input, ivec2(i, y), 0).r - maxVal);
}
// Output normalized probability
if (x > y) {
outValue = 0.0; // Causal mask
} else {
float val = texelFetch(u_input, ivec2(x, y), 0).r;
outValue = exp(val - maxVal) / sumExp;
}
}
6. GELU Activation
// Approximation: GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
float c = 0.7978845608; // sqrt(2/pi)
outValue = 0.5 * x * (1.0 + tanh(c * (x + 0.044715 * x * x * x)));
Implementation Patterns
1. Ping-Pong Buffers
When a computation needs to both read from and write to the same logical buffer, use two physical textures and alternate:
// Hidden states use ping-pong
let currentHidden = 'hiddenA';
let otherHidden = 'hiddenB';
// After each operation that modifies hidden state:
[currentHidden, otherHidden] = [otherHidden, currentHidden];
We also use ping-pong for attention output accumulation:
let attnRead = 'attnOutA';
let attnWrite = 'attnOutB';
for (let head = 0; head < numHeads; head++) {
// Compute head output...
// Copy head to accumulated output (ping-pong)
runProgram(copyHeadShader, framebuffers[attnWrite], {
u_headOutput: textures.headOut,
u_accumulator: textures[attnRead], // Read from A
// ...
});
// Swap for next iteration
[attnRead, attnWrite] = [attnWrite, attnRead];
}
// After loop: attnRead contains final result
2. Running a Shader Program
function runProgram(program, outputFramebuffer, outputWidth, outputHeight, uniforms) {
gl.useProgram(program);
gl.bindFramebuffer(gl.FRAMEBUFFER, outputFramebuffer);
gl.viewport(0, 0, outputWidth, outputHeight);
// Set up quad vertex attribute
const posLoc = gl.getAttribLocation(program, 'a_position');
gl.bindBuffer(gl.ARRAY_BUFFER, quadBuffer);
gl.enableVertexAttribArray(posLoc);
gl.vertexAttribPointer(posLoc, 2, gl.FLOAT, false, 0, 0);
// Set uniforms
let textureUnit = 0;
for (const [name, value] of Object.entries(uniforms)) {
const loc = gl.getUniformLocation(program, name);
if (loc === null) continue;
if (value instanceof WebGLTexture) {
gl.activeTexture(gl.TEXTURE0 + textureUnit);
gl.bindTexture(gl.TEXTURE_2D, value);
gl.uniform1i(loc, textureUnit);
textureUnit++;
} else if (Number.isInteger(value)) {
gl.uniform1i(loc, value);
} else {
gl.uniform1f(loc, value);
}
}
// Draw full-screen quad (6 vertices = 2 triangles)
gl.drawArrays(gl.TRIANGLES, 0, 6);
}
3. Weight Matrix Transposition
The matmul shader uses this access pattern:
weightVal = texelFetch(u_weight, ivec2(inputIdx, outputIdx), 0).r;
This expects weights stored as W[outputIdx][inputIdx]. If your model stores weights as W[inputIdx][outputIdx], transpose during loading:
function transpose(data, rows, cols) {
const result = new Float32Array(rows * cols);
for (let r = 0; r < rows; r++) {
for (let c = 0; c < cols; c++) {
result[c * rows + r] = data[r * cols + c];
}
}
return result;
}
// Attention weights need transposition
const wq = transpose(wqRaw, embeddingDim, embeddingDim);
4. Reading Results Back to CPU
// Bind the framebuffer containing results
gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffers.probs);
// Read pixels
const probsData = new Float32Array(8192);
gl.readPixels(0, 0, 8192, 1, gl.RED, gl.FLOAT, probsData);
// probsData now contains the output probabilities
Draw Call Analysis
Each token generation requires the following draw calls:
| Operation | Draw Calls | Notes |
|---|---|---|
| Embedding lookup | 1 | |
| Per block (×6): | ||
| - LayerNorm1 | 1 | |
| - Q projection | 1 | |
| - K projection | 1 | |
| - V projection | 1 | |
| - Attention scores (×8 heads) | 8 | One per head |
| - Softmax (×8 heads) | 8 | One per head |
| - Attention output (×8 heads) | 8 | One per head |
| - Copy head (×8 heads) | 8 | One per head |
| - Output projection | 1 | |
| - LayerNorm2 | 1 | |
| - MLP dense1 + GELU | 1 | |
| - MLP dense2 + residual | 1 | |
| Final LayerNorm | 1 | |
| Output projection | 1 | |
| Final softmax | 1 |
The attention computation dominates with 32 calls per block × 6 blocks = 192 calls, or 79% of total.
Per block: 1 + 3 + (8×4) + 1 + 1 + 1 + 1 = 40 draw calls
All blocks: 40 × 6 = 240 draw calls
Fixed ops: 1 (embed) + 1 (final LN) + 1 (output proj) + 1 (softmax) = 4
──────────────────────────────────────────────────────────────────────────
TOTAL: 244 draw calls per token
Lessons Learned
1. The Feedback Loop Bug
Problem: Reading from and writing to the same texture produces undefined behavior.
Symptom: Model generates complete gibberish despite correct weight loading.
Solution: Use ping-pong buffers for any texture that needs to be both read and written in the same logical operation.
Detection: Browser console shows warnings (but only the first 32):
WebGL warning: Texture level 0 would be read by TEXTURE_2D unit 1,
but written by framebuffer attachment COLOR_ATTACHMENT0
2. Coordinate System
texelFetch(texture, ivec2(x, y), 0) uses:
- x = column (horizontal)
- y = row (vertical)
Data uploaded via texImage2D maps: data[row * width + col] -> texelFetch(tex, ivec2(col, row), 0)
3. Integer Uniforms
WebGL2 distinguishes between uniform1i and uniform1f. Using the wrong one silently fails:
if (Number.isInteger(value)) {
gl.uniform1i(loc, value);
} else {
gl.uniform1f(loc, value);
}
4. Texture Filtering
For compute workloads, always use NEAREST filtering:
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
LINEAR filtering interpolates between texels, which corrupts your data.
5. Debugging Texture Uploads
Verify data roundtrips correctly:
function verifyTextureUpload(gl, texture, originalData, width, height, name) {
const fb = gl.createFramebuffer();
gl.bindFramebuffer(gl.FRAMEBUFFER, fb);
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
const readBack = new Float32Array(width * height);
gl.readPixels(0, 0, width, height, gl.RED, gl.FLOAT, readBack);
gl.deleteFramebuffer(fb);
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
let maxDiff = 0;
for (let i = 0; i < originalData.length; i++) {
maxDiff = Math.max(maxDiff, Math.abs(originalData[i] - readBack[i]));
}
console.log(`[${name}] Max diff: ${maxDiff}`);
return maxDiff < 1e-5;
}
In action
The actual kernels end up being large, Full implementation here