# CTM Experiments Personal experiments with [Continuous Thought Machines](https://github.com/SakanaAI/continuous-thought-machines) (SakanaAI). **Interactive Demo**: https://pub.sakana.ai/ctm/ ## Core Insight: Thinking Takes Time CTM's key innovation: **accuracy improves with more internal iterations**. The model "thinks longer" to reach better answers. This enables CTM to learn algorithmic reasoning that feedforward networks struggle with: | Task | Challenge | What CTM Learns | |------|-----------|-----------------| | **Parity** | Count bits across sequence | Iterative accumulation | | **Brackets** | Track nested structure | Stack-like memory (LIFO) | | **Object Tracking** | Extrapolate motion | Physics simulation | | **Mazes** | Navigate 2D paths | Sequential decision making | | **Jigsaw** | Classify shuffled patches | Part-whole integration | ## Results Summary | Experiment | Accuracy | Notes | |------------|----------|-------| | **MNIST** | **97.9%** | Digit classification, 5 min training | | **Parity-16** | **99.0%** | 16-bit cumulative parity | | **QAMNIST** | **100%** | Multi-step arithmetic (3-5 digits, 3-5 ops) | | **Brackets** | **94.7%** | Stack-like reasoning for `(()[])` vs `([)]` | | **Object Tracking** | **100%** | Quadrant prediction from motion (4 classes) | | **Velocity Prediction** | **100%** | Direction prediction (9 classes) | | **Position Prediction** | **93.8%** | Exact position (256 classes, 16x16 grid) | | **Transfer Learning** | **94.5%** | Parity→Brackets (core frozen) | | **Maze Solving** | **Visualized** | Pretrained model inference on 15x15 mazes | | **Jigsaw MNIST** | **92%** | Classify digits from shuffled patches (no positional encoding) | ## Key Findings ### 1. Architecture Matters More Than Scale Early experiments showed 50% accuracy on parity (random guessing). The fix wasn't more parameters - it was using the **correct architecture**: | Parameter | Wrong | Correct (Official) | |-----------|-------|-------------------| | `n_synch_out` | 512 | **32** | | `n_synch_action` | 512 | **32** | | `synapse_depth` | 4 (U-NET) | **1** (linear) | The official parity implementation uses surprisingly small synchronization dimensions with a linear synapse - this is critical for learning. ### 2. "Thinking Longer" = Higher Accuracy ![MNIST Inference per Tick](continuous-thought-machines/experiments/results/mnist_inference.png) CTM accuracy improves with more internal iterations: - **Tick 0**: 7% (random) - **Tick 10-11**: 100% (peak) - **Final tick**: 98% Harder tasks need more "thinking time" - parity peaks at tick 35. ### 3. Transfer Learning Works Pretrained parity model transfers to brackets: - **Baseline**: 52.5% (random) - **After transfer**: 94.5% (core frozen, only backbone/output trained) The iterative counting learned for parity transfers to stack tracking for brackets - matching from-scratch performance with only 37.7% of parameters trainable. ### 4. Maze Solving "The Hard Way" CTM solves mazes by outputting action trajectories (Up/Down/Left/Right/Wait), not pixel masks: - **Step accuracy**: 60%+ after 2000 iterations - Uses auto-extending curriculum (loss only on trajectory up to first error) - Demonstrates sequential reasoning capability ![Maze Attention Overlay](continuous-thought-machines/experiments/results/maze_attention.gif) *CTM "thinking" through a 15x15 maze: blue = predicted path, red = attention focus, green = start position. The attention heatmap shows where CTM looks at each internal tick (T=75 iterations).* ## Detailed Results ### MNIST Digit Classification (97.9%) ![MNIST Training Accuracy](continuous-thought-machines/experiments/results/mnist-ctm_smoothed.png) CTM learns digit classification in ~5 minutes on RTX 4070 Ti. ### Parity-16 Cumulative Parity (99.0%) ![Parity Inference per Tick](continuous-thought-machines/experiments/results/parity_inference.png) 16-bit parity with cumulative outputs - harder task shows clearer "thinking" benefit. ### QAMNIST Multi-Step Arithmetic (100%) ![QAMNIST Training Accuracy](continuous-thought-machines/experiments/results/qamnist-ctm-10_smoothed.png) 100% accuracy on multi-step arithmetic (3-5 MNIST digits, 3-5 operations) after 300k iterations. ### Maze Navigation (Pretrained Model) Using the authors' pretrained checkpoint (`ctm_mazeslarge_D=2048_T=75_M=25.pt`), we ran inference on the small-mazes dataset: - **Model**: D=2048 neurons, T=75 thinking steps, M=25 max trajectory length - **Dataset**: 1000 test mazes (15x15 grid) - **Output**: Action trajectories (Up/Down/Left/Right/Wait) The visualization shows CTM's attention patterns as it navigates: 1. **Red heatmap**: Where CTM "looks" at each thinking step 2. **Blue path**: Predicted solution trajectory 3. **Green marker**: Start position Key insight: CTM learns sequential decision-making through iterative internal computation, not memorization. ### Object Tracking - Position Prediction (93.8%) ![Position Tracking Training](continuous-thought-machines/experiments/results/tracking_position.png) The hardest tracking task: predict exact cell (256 classes) from 5 frames of motion. CTM reaches 93.8% test accuracy, demonstrating temporal reasoning across video frames. ## Experiment Tracking - **Configs**: [`experiments/experiments.json`](continuous-thought-machines/experiments/experiments.json) - **Training Scripts**: [`experiments/training/`](continuous-thought-machines/experiments/training/) - **Inference Scripts**: [`experiments/inference/`](continuous-thought-machines/experiments/inference/) - **Results**: [`experiments/results/`](continuous-thought-machines/experiments/results/) ## Custom Experiments ### Bracket Matching Classify bracket strings as valid or invalid: `(()[])` vs `([)]` Requires tracking nested depth and bracket types - implementing a stack through iterative thinking. ### Object Tracking Predict properties of a moving dot from 5 video frames (16x16 grid). ``` Frame 0 Frame 1 Frame 2 Frame 3 Frame 4 . . . . . . . . . . . . . . . . . . . . . * . . . . * . . . . * . . . . . . . . . . . . . . . . . . . . . . . * . . . . . . . . . . . . . . . . . . . . . . . * ``` Three prediction tasks tested: | Task | Classes | Accuracy | Notes | |------|---------|----------|-------| | **Quadrant** | 4 | 100% | TL/TR/BL/BR - easiest | | **Velocity** | 9 | 100% | 8 directions + stationary | | **Position** | 256 | 93.8% | Exact cell (16x16) - hardest | All tasks converged, demonstrating CTM's ability to learn temporal/spatial reasoning. ### Transfer Learning Freeze core CTM dynamics from parity-16, train only backbone/output for brackets. ### Maze Inference Run pretrained maze model on small-mazes dataset to visualize CTM's "thinking" process: ```bash python -m tasks.mazes.analysis.run \ --actions viz \ --checkpoint checkpoints/mazes/ctm_mazeslarge_D=2048_T=75_M=25.pt \ --dataset_for_viz small-mazes ``` Outputs attention overlay GIFs to `tasks/mazes/analysis/outputs/viz/`. ### Jigsaw MNIST Classify MNIST digits from **randomly shuffled patches** without positional encoding. ``` Original: Shuffled (input): ┌───┬───┬───┬───┐ ┌───┬───┬───┬───┐ │ 1 │ 2 │ 3 │ 4 │ │12 │ 7 │ 2 │15 │ ├───┼───┼───┼───┤ ├───┼───┼───┼───┤ │ 5 │ 6 │ 7 │ 8 │ => │ 4 │11 │ 9 │ 1 │ ├───┼───┼───┼───┤ ├───┼───┼───┼───┤ │ 9 │10 │11 │12 │ │ 6 │ 3 │14 │ 5 │ ├───┼───┼───┼───┤ ├───┼───┼───┼───┤ │13 │14 │15 │16 │ │16 │ 8 │10 │13 │ └───┴───┴───┴───┘ └───┴───┴───┴───┘ ``` **Task**: Given 16 shuffled 7x7 patches, predict the digit class (0-9). **Challenge**: No positional encoding - CTM must learn to recognize digit parts and integrate them correctly through its internal synchronization dynamics. **Result**: **92% test accuracy** - CTM successfully learns part-whole relationships without explicit position information. ![Jigsaw Training](continuous-thought-machines/experiments/results/jigsaw_training.png) ## Resources - [CTM Paper](2505.05522v4.pdf) - [Original SakanaAI Repo](https://github.com/SakanaAI/continuous-thought-machines)