File size: 7,598 Bytes
63d759d 7f14743 63d759d 7f14743 63d759d 7f14743 63d759d 7f14743 63d759d 7f14743 63d759d 7f14743 63d759d 7f14743 63d759d 7f14743 63d759d 8244b6b 63d759d 8244b6b 63d759d 8244b6b 63d759d 8244b6b 63d759d 8244b6b 63d759d 8244b6b 63d759d 8244b6b 63d759d 8244b6b 63d759d 8244b6b 63d759d 8244b6b 63d759d 8244b6b 63d759d 8244b6b 63d759d 8244b6b 63d759d 8244b6b 63d759d 8244b6b 63d759d 7f14743 63d759d 7f14743 63d759d 7f14743 63d759d 7f14743 63d759d 8244b6b 63d759d 7f14743 63d759d 8244b6b 63d759d 8244b6b 63d759d 8244b6b 63d759d 7f14743 63d759d 7f14743 63d759d 7f14743 63d759d 7f14743 63d759d 7f14743 63d759d 7f14743 63d759d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 |
# CTM Experiments - Continuous Thought Machine Models
Experimental checkpoints trained on the [Continuous Thought Machine](https://github.com/SakanaAI/continuous-thought-machines) architecture by Sakana AI.
**These are community experiments on the original work - not official SakanaAI models.**
## Paper Reference
> **Continuous Thought Machines**
>
> Sakana AI
>
> [arXiv:2505.05522](https://arxiv.org/abs/2505.05522)
>
> [Interactive Demo](https://pub.sakana.ai/ctm/) | [Blog Post](https://sakana.ai/ctm/)
```bibtex
@article{sakana2025ctm,
title={Continuous Thought Machines},
author={Sakana AI},
journal={arXiv preprint arXiv:2505.05522},
year={2025}
}
```
## Core Insight
CTM's key innovation: **accuracy improves with more internal iterations**. The model "thinks longer" to reach better answers. This enables CTM to learn algorithmic reasoning that feedforward networks struggle with.
## Models
| Model | File | Size | Task | Accuracy | Description |
|-------|------|------|------|----------|-------------|
| MNIST | `ctm-mnist.pt` | 1.3M | Digit classification | 97.9% | 10-class MNIST |
| Parity-16 | `ctm-parity-16.pt` | 2.5M | Cumulative parity | 99.0% | 16-bit sequences |
| Parity-64 | `ctm-parity-64.pt` | 66M | Cumulative parity | 58.6% | 64-bit sequences (custom config) |
| Parity-64 Official | `ctm-parity-64-official.pt` | 21M | Cumulative parity | 57.7% | 64-bit sequences (official config) |
| QAMNIST | `ctm-qamnist.pt` | 39M | Multi-step arithmetic | 100% | 3-5 digits, 3-5 ops |
| Brackets | `ctm-brackets.pt` | 6.1M | Bracket matching | 94.7% | Valid/invalid `(()[])` |
| Tracking-Quadrant | `ctm-tracking-quadrant.pt` | 6.7M | Motion quadrant | 100% | 4-class prediction |
| Tracking-Position | `ctm-tracking-position.pt` | 6.7M | Exact position | 93.8% | 256-class (16x16 grid) |
| Transfer | `ctm-transfer-parity-brackets.pt` | 2.5M | Transfer learning | 94.5% | Parity core to brackets |
| Jigsaw MNIST | `ctm-jigsaw-mnist.pt` | 19M | Jigsaw puzzle solving | 92.3% | Reassemble 2x2 shuffled MNIST |
| Rotation MNIST | `ctm-rotation-mnist.pt` | 4.2M | Rotation prediction | 89.1% | Predict rotation angle (4 classes) |
| Brackets Transfer | `ctm-brackets-transfer-depth4.pt` | 6.1M | Transfer learning | 95.1% | Parity→Brackets (depth 4 synapse) |
| Dual-Task | `ctm-dual-task-brackets-parity.pt` | 2.8M | Multi-task | 86.1% | Brackets (94%) + Parity (78%) jointly |
| Parity-64 | `ctm-parity-64-8x8.pt` | 4.1M | Long parity | 58.6% | 64-bit (8x8) cumulative parity |
| Parity-144 | `ctm-parity-144-12x12.pt` | 4.1M | Long parity | 51.7% | 144-bit (12x12) cumulative parity |
## Model Configurations
### MNIST CTM
```python
config = {
"iterations": 15,
"memory_length": 10,
"d_model": 128,
"d_input": 128,
"heads": 2,
"n_synch_out": 16,
"n_synch_action": 16,
"memory_hidden_dims": 8,
"out_dims": 10,
"synapse_depth": 1,
}
```
### Parity-16 CTM
```python
config = {
"iterations": 50,
"memory_length": 25,
"d_model": 256,
"d_input": 32,
"heads": 8,
"synapse_depth": 8,
"out_dims": 16, # cumulative parity
}
```
### Parity-64 Official CTM
```python
config = {
"iterations": 75,
"memory_length": 25,
"d_model": 1024,
"d_input": 64,
"heads": 8,
"n_synch_out": 32,
"n_synch_action": 32,
"synapse_depth": 1, # linear synapse (official)
"out_dims": 64, # cumulative parity
}
```
### QAMNIST CTM
```python
config = {
"iterations": 10,
"memory_length": 30,
"d_model": 1024,
"d_input": 64,
"synapse_depth": 1,
"heads": 4,
"n_synch_out": 32,
"n_synch_action": 32,
}
```
### Brackets CTM
```python
config = {
"iterations": 30,
"memory_length": 15,
"d_model": 256,
"d_input": 64,
"heads": 4,
"n_synch_out": 32,
"n_synch_action": 32,
"out_dims": 2, # valid/invalid
}
```
### Tracking CTM
```python
config = {
"iterations": 20,
"memory_length": 15,
"d_model": 256,
"d_input": 64,
"heads": 4,
"n_synch_out": 32,
"n_synch_action": 32,
}
```
### Jigsaw MNIST CTM
```python
config = {
"iterations": 30,
"memory_length": 20,
"d_model": 512,
"d_input": 128,
"heads": 8,
"n_synch_out": 32,
"n_synch_action": 32,
"synapse_depth": 1,
"out_dims": 24, # 4 tiles x 6 permutation options
"backbone_type": "jigsaw",
}
```
### Rotation MNIST CTM
```python
config = {
"iterations": 20,
"memory_length": 15,
"d_model": 256,
"d_input": 64,
"heads": 4,
"n_synch_out": 32,
"n_synch_action": 32,
"synapse_depth": 1,
"out_dims": 4, # 0°, 90°, 180°, 270°
"backbone_type": "rotation",
}
```
## Usage
```python
import torch
from huggingface_hub import hf_hub_download
# Download model
model_path = hf_hub_download(
repo_id="vincentoh/ctm-experiments",
filename="ctm-mnist.pt"
)
# Load checkpoint
checkpoint = torch.load(model_path, map_location="cpu")
# Initialize CTM with matching config
from models.ctm import ContinuousThoughtMachine
model = ContinuousThoughtMachine(**config)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Inference
with torch.no_grad():
output = model(input_tensor)
```
## Training Details
- **Hardware**: NVIDIA RTX 4070 Ti SUPER
- **Framework**: PyTorch
- **Optimizer**: AdamW
- **Training time**: 5 minutes (MNIST) to 17 hours (QAMNIST)
## Key Findings
1. **Architecture > Scale**: Small sync dimensions (32) with linear synapses work better than large/deep variants
2. **"Thinking Longer" = Higher Accuracy**: CTM accuracy improves with more internal iterations
3. **Transfer Learning Works**: Parity-trained core transfers to brackets with 94.5% accuracy
4. **Architectural Limits**: CTM has a ~58% ceiling on 64-bit parity regardless of hyperparameters
## Parity Scaling Experiments
We tested CTM on increasingly long parity sequences to find where it breaks down:
| Sequence | Grid | Accuracy | vs Random | Status |
|----------|------|----------|-----------|--------|
| 16 | 4x4 | **99.0%** | +49.0% | ✅ Solved |
| 36 | 6x6 | **66.3%** | +16.3% | ⚠️ Degraded |
| 64 | 8x8 | **58.6%** | +8.6% | ❌ Struggling |
| 64 (official) | 8x8 | **57.7%** | +7.7% | ❌ Same ceiling |
| 144 | 12x12 | **51.7%** | +1.7% | ❌ Random |
**Key insight**: The ~58% ceiling for parity-64 is an **architectural limit**, not a hyperparameter issue. Both custom config (d_model=512, synapse_depth=4) and official config (d_model=1024, synapse_depth=1) achieve essentially the same accuracy.
### Why CTM Fails on Long Parity
Parity requires **strict sequential computation**: process bit 1 before bit 2 before bit 3... CTM's attention-based "thinking" is fundamentally parallel - all positions attend simultaneously. The model can learn approximate sequential patterns for short sequences (~64 steps), but this breaks down for longer sequences.
**CTM excels at:**
- Moderate sequence lengths (< 64 elements)
- Local dependencies (brackets: track depth, not full history)
- Parallelizable structure (MNIST: patches contribute independently)
**CTM struggles with:**
- Long strict sequential dependencies (parity-144)
- Tasks requiring O(n) sequential steps where n > ~64
## License
MIT License (same as original CTM repository)
## Acknowledgments
- [Sakana AI](https://sakana.ai/) for the Continuous Thought Machine architecture
- Original [CTM Repository](https://github.com/SakanaAI/continuous-thought-machines)
## Links
- [Original Paper](https://arxiv.org/abs/2505.05522)
- [Interactive Demo](https://pub.sakana.ai/ctm/)
|