| # CTM Experiments - Continuous Thought Machine Models | |
| Experimental checkpoints trained on the [Continuous Thought Machine](https://github.com/SakanaAI/continuous-thought-machines) architecture by Sakana AI. | |
| **These are community experiments on the original work - not official SakanaAI models.** | |
| ## Paper Reference | |
| > **Continuous Thought Machines** | |
| > | |
| > Sakana AI | |
| > | |
| > [arXiv:2505.05522](https://arxiv.org/abs/2505.05522) | |
| > | |
| > [Interactive Demo](https://pub.sakana.ai/ctm/) | [Blog Post](https://sakana.ai/ctm/) | |
| ```bibtex | |
| @article{sakana2025ctm, | |
| title={Continuous Thought Machines}, | |
| author={Sakana AI}, | |
| journal={arXiv preprint arXiv:2505.05522}, | |
| year={2025} | |
| } | |
| ``` | |
| ## Core Insight | |
| CTM's key innovation: **accuracy improves with more internal iterations**. The model "thinks longer" to reach better answers. This enables CTM to learn algorithmic reasoning that feedforward networks struggle with. | |
| ## Models | |
| | Model | File | Size | Task | Accuracy | Description | | |
| |-------|------|------|------|----------|-------------| | |
| | MNIST | `ctm-mnist.pt` | 1.3M | Digit classification | 97.9% | 10-class MNIST | | |
| | Parity-16 | `ctm-parity-16.pt` | 2.5M | Cumulative parity | 99.0% | 16-bit sequences | | |
| | Parity-64 | `ctm-parity-64.pt` | 66M | Cumulative parity | 58.6% | 64-bit sequences (custom config) | | |
| | Parity-64 Official | `ctm-parity-64-official.pt` | 21M | Cumulative parity | 57.7% | 64-bit sequences (official config) | | |
| | QAMNIST | `ctm-qamnist.pt` | 39M | Multi-step arithmetic | 100% | 3-5 digits, 3-5 ops | | |
| | Brackets | `ctm-brackets.pt` | 6.1M | Bracket matching | 94.7% | Valid/invalid `(()[])` | | |
| | Tracking-Quadrant | `ctm-tracking-quadrant.pt` | 6.7M | Motion quadrant | 100% | 4-class prediction | | |
| | Tracking-Position | `ctm-tracking-position.pt` | 6.7M | Exact position | 93.8% | 256-class (16x16 grid) | | |
| | Transfer | `ctm-transfer-parity-brackets.pt` | 2.5M | Transfer learning | 94.5% | Parity core to brackets | | |
| | Jigsaw MNIST | `ctm-jigsaw-mnist.pt` | 19M | Jigsaw puzzle solving | 92.3% | Reassemble 2x2 shuffled MNIST | | |
| | Rotation MNIST | `ctm-rotation-mnist.pt` | 4.2M | Rotation prediction | 89.1% | Predict rotation angle (4 classes) | | |
| | Brackets Transfer | `ctm-brackets-transfer-depth4.pt` | 6.1M | Transfer learning | 95.1% | Parity→Brackets (depth 4 synapse) | | |
| | Dual-Task | `ctm-dual-task-brackets-parity.pt` | 2.8M | Multi-task | 86.1% | Brackets (94%) + Parity (78%) jointly | | |
| | Parity-64 | `ctm-parity-64-8x8.pt` | 4.1M | Long parity | 58.6% | 64-bit (8x8) cumulative parity | | |
| | Parity-144 | `ctm-parity-144-12x12.pt` | 4.1M | Long parity | 51.7% | 144-bit (12x12) cumulative parity | | |
| ## Model Configurations | |
| ### MNIST CTM | |
| ```python | |
| config = { | |
| "iterations": 15, | |
| "memory_length": 10, | |
| "d_model": 128, | |
| "d_input": 128, | |
| "heads": 2, | |
| "n_synch_out": 16, | |
| "n_synch_action": 16, | |
| "memory_hidden_dims": 8, | |
| "out_dims": 10, | |
| "synapse_depth": 1, | |
| } | |
| ``` | |
| ### Parity-16 CTM | |
| ```python | |
| config = { | |
| "iterations": 50, | |
| "memory_length": 25, | |
| "d_model": 256, | |
| "d_input": 32, | |
| "heads": 8, | |
| "synapse_depth": 8, | |
| "out_dims": 16, # cumulative parity | |
| } | |
| ``` | |
| ### Parity-64 Official CTM | |
| ```python | |
| config = { | |
| "iterations": 75, | |
| "memory_length": 25, | |
| "d_model": 1024, | |
| "d_input": 64, | |
| "heads": 8, | |
| "n_synch_out": 32, | |
| "n_synch_action": 32, | |
| "synapse_depth": 1, # linear synapse (official) | |
| "out_dims": 64, # cumulative parity | |
| } | |
| ``` | |
| ### QAMNIST CTM | |
| ```python | |
| config = { | |
| "iterations": 10, | |
| "memory_length": 30, | |
| "d_model": 1024, | |
| "d_input": 64, | |
| "synapse_depth": 1, | |
| "heads": 4, | |
| "n_synch_out": 32, | |
| "n_synch_action": 32, | |
| } | |
| ``` | |
| ### Brackets CTM | |
| ```python | |
| config = { | |
| "iterations": 30, | |
| "memory_length": 15, | |
| "d_model": 256, | |
| "d_input": 64, | |
| "heads": 4, | |
| "n_synch_out": 32, | |
| "n_synch_action": 32, | |
| "out_dims": 2, # valid/invalid | |
| } | |
| ``` | |
| ### Tracking CTM | |
| ```python | |
| config = { | |
| "iterations": 20, | |
| "memory_length": 15, | |
| "d_model": 256, | |
| "d_input": 64, | |
| "heads": 4, | |
| "n_synch_out": 32, | |
| "n_synch_action": 32, | |
| } | |
| ``` | |
| ### Jigsaw MNIST CTM | |
| ```python | |
| config = { | |
| "iterations": 30, | |
| "memory_length": 20, | |
| "d_model": 512, | |
| "d_input": 128, | |
| "heads": 8, | |
| "n_synch_out": 32, | |
| "n_synch_action": 32, | |
| "synapse_depth": 1, | |
| "out_dims": 24, # 4 tiles x 6 permutation options | |
| "backbone_type": "jigsaw", | |
| } | |
| ``` | |
| ### Rotation MNIST CTM | |
| ```python | |
| config = { | |
| "iterations": 20, | |
| "memory_length": 15, | |
| "d_model": 256, | |
| "d_input": 64, | |
| "heads": 4, | |
| "n_synch_out": 32, | |
| "n_synch_action": 32, | |
| "synapse_depth": 1, | |
| "out_dims": 4, # 0°, 90°, 180°, 270° | |
| "backbone_type": "rotation", | |
| } | |
| ``` | |
| ## Usage | |
| ```python | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| # Download model | |
| model_path = hf_hub_download( | |
| repo_id="vincentoh/ctm-experiments", | |
| filename="ctm-mnist.pt" | |
| ) | |
| # Load checkpoint | |
| checkpoint = torch.load(model_path, map_location="cpu") | |
| # Initialize CTM with matching config | |
| from models.ctm import ContinuousThoughtMachine | |
| model = ContinuousThoughtMachine(**config) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| # Inference | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| ``` | |
| ## Training Details | |
| - **Hardware**: NVIDIA RTX 4070 Ti SUPER | |
| - **Framework**: PyTorch | |
| - **Optimizer**: AdamW | |
| - **Training time**: 5 minutes (MNIST) to 17 hours (QAMNIST) | |
| ## Key Findings | |
| 1. **Architecture > Scale**: Small sync dimensions (32) with linear synapses work better than large/deep variants | |
| 2. **"Thinking Longer" = Higher Accuracy**: CTM accuracy improves with more internal iterations | |
| 3. **Transfer Learning Works**: Parity-trained core transfers to brackets with 94.5% accuracy | |
| 4. **Architectural Limits**: CTM has a ~58% ceiling on 64-bit parity regardless of hyperparameters | |
| ## Parity Scaling Experiments | |
| We tested CTM on increasingly long parity sequences to find where it breaks down: | |
| | Sequence | Grid | Accuracy | vs Random | Status | | |
| |----------|------|----------|-----------|--------| | |
| | 16 | 4x4 | **99.0%** | +49.0% | ✅ Solved | | |
| | 36 | 6x6 | **66.3%** | +16.3% | ⚠️ Degraded | | |
| | 64 | 8x8 | **58.6%** | +8.6% | ❌ Struggling | | |
| | 64 (official) | 8x8 | **57.7%** | +7.7% | ❌ Same ceiling | | |
| | 144 | 12x12 | **51.7%** | +1.7% | ❌ Random | | |
| **Key insight**: The ~58% ceiling for parity-64 is an **architectural limit**, not a hyperparameter issue. Both custom config (d_model=512, synapse_depth=4) and official config (d_model=1024, synapse_depth=1) achieve essentially the same accuracy. | |
| ### Why CTM Fails on Long Parity | |
| Parity requires **strict sequential computation**: process bit 1 before bit 2 before bit 3... CTM's attention-based "thinking" is fundamentally parallel - all positions attend simultaneously. The model can learn approximate sequential patterns for short sequences (~64 steps), but this breaks down for longer sequences. | |
| **CTM excels at:** | |
| - Moderate sequence lengths (< 64 elements) | |
| - Local dependencies (brackets: track depth, not full history) | |
| - Parallelizable structure (MNIST: patches contribute independently) | |
| **CTM struggles with:** | |
| - Long strict sequential dependencies (parity-144) | |
| - Tasks requiring O(n) sequential steps where n > ~64 | |
| ## License | |
| MIT License (same as original CTM repository) | |
| ## Acknowledgments | |
| - [Sakana AI](https://sakana.ai/) for the Continuous Thought Machine architecture | |
| - Original [CTM Repository](https://github.com/SakanaAI/continuous-thought-machines) | |
| ## Links | |
| - [Original Paper](https://arxiv.org/abs/2505.05522) | |
| - [Interactive Demo](https://pub.sakana.ai/ctm/) | |