ctm-experiments / README.md
vincentoh's picture
Upload README.md with huggingface_hub
7f14743 verified
|
raw
history blame
4.42 kB

CTM Experiments - Continuous Thought Machine Models

Experimental checkpoints trained on the Continuous Thought Machine architecture by Sakana AI.

These are community experiments on the original work - not official SakanaAI models.

Paper Reference

Continuous Thought Machines

Sakana AI

arXiv:2505.05522

Interactive Demo | Blog Post

@article{sakana2025ctm,
  title={Continuous Thought Machines},
  author={Sakana AI},
  journal={arXiv preprint arXiv:2505.05522},
  year={2025}
}

Core Insight

CTM's key innovation: accuracy improves with more internal iterations. The model "thinks longer" to reach better answers. This enables CTM to learn algorithmic reasoning that feedforward networks struggle with.

Models

Model File Size Task Accuracy Description
MNIST ctm-mnist.pt 1.3M Digit classification 97.9% 10-class MNIST
Parity-16 ctm-parity-16.pt 2.5M Cumulative parity 99.0% 16-bit sequences
Parity-64 ctm-parity-64.pt 66M Cumulative parity 75% 64-bit sequences
QAMNIST ctm-qamnist.pt 39M Multi-step arithmetic 100% 3-5 digits, 3-5 ops
Brackets ctm-brackets.pt 6.1M Bracket matching 94.7% Valid/invalid (()[])
Tracking-Quadrant ctm-tracking-quadrant.pt 6.7M Motion quadrant 100% 4-class prediction
Tracking-Position ctm-tracking-position.pt 6.7M Exact position 93.8% 256-class (16x16 grid)
Transfer ctm-transfer-parity-brackets.pt 2.5M Transfer learning 94.5% Parity core to brackets

Model Configurations

MNIST CTM

config = {
    "iterations": 15,
    "memory_length": 10,
    "d_model": 128,
    "d_input": 128,
    "heads": 2,
    "n_synch_out": 16,
    "n_synch_action": 16,
    "memory_hidden_dims": 8,
    "out_dims": 10,
    "synapse_depth": 1,
}

Parity-16 CTM

config = {
    "iterations": 50,
    "memory_length": 25,
    "d_model": 256,
    "d_input": 32,
    "heads": 8,
    "synapse_depth": 8,
    "out_dims": 16,  # cumulative parity
}

QAMNIST CTM

config = {
    "iterations": 10,
    "memory_length": 30,
    "d_model": 1024,
    "d_input": 64,
    "synapse_depth": 1,
    "heads": 4,
    "n_synch_out": 32,
    "n_synch_action": 32,
}

Brackets CTM

config = {
    "iterations": 30,
    "memory_length": 15,
    "d_model": 256,
    "d_input": 64,
    "heads": 4,
    "n_synch_out": 32,
    "n_synch_action": 32,
    "out_dims": 2,  # valid/invalid
}

Tracking CTM

config = {
    "iterations": 20,
    "memory_length": 15,
    "d_model": 256,
    "d_input": 64,
    "heads": 4,
    "n_synch_out": 32,
    "n_synch_action": 32,
}

Usage

import torch
from huggingface_hub import hf_hub_download

# Download model
model_path = hf_hub_download(
    repo_id="vincentoh/ctm-experiments",
    filename="ctm-mnist.pt"
)

# Load checkpoint
checkpoint = torch.load(model_path, map_location="cpu")

# Initialize CTM with matching config
from models.ctm import ContinuousThoughtMachine

model = ContinuousThoughtMachine(**config)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Inference
with torch.no_grad():
    output = model(input_tensor)

Training Details

  • Hardware: NVIDIA RTX 4070 Ti SUPER
  • Framework: PyTorch
  • Optimizer: AdamW
  • Training time: 5 minutes (MNIST) to 17 hours (QAMNIST)

Key Findings

  1. Architecture > Scale: Small sync dimensions (32) with linear synapses work better than large/deep variants
  2. "Thinking Longer" = Higher Accuracy: CTM accuracy improves with more internal iterations
  3. Transfer Learning Works: Parity-trained core transfers to brackets with 94.5% accuracy

License

MIT License (same as original CTM repository)

Acknowledgments

Links