kschuerholt commited on
Commit
47cb867
·
verified ·
1 Parent(s): fd68243

Upload folder using huggingface_hub

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ arc_v1_public/step_518071 filter=lfs diff=lfs merge=lfs -text
37
+ arc_v2_public/step_723914 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,104 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This repository contains the TinyRecursiveModels checkpoints for arc v1 public eval and arc v2 public eval that were trained for the performance verification. They were trained using the code and recipe of the [official TRM repository](https://github.com/SamsungSAILMontreal/TinyRecursiveModels). We had to adapt the environment setup as detailed below. We provide these checkpoints for transparency and to facilitate further research. We did not contribute to the TRM reserach nor maintain the TRM code. For any questions, please reach out to the TRM maintainers.
2
+
3
+ TRM writes checkpoints as `torch state_dicts`. The subdirectories `arc_v1_public` and `arc_v2_public` contain the final checkpoints `step_<final-step>`, which can be loaded with the `load_checkpoint` or by providing the checkpoint path as `load_checkpoint=path/to/checkpoint`. For reference, see the `PretrainConfig` in `pretrain.py`.
4
+
5
+
6
+ ## Replication Results
7
+
8
+
9
+ ## Environment Setup
10
+ ```bash
11
+ # use uv for venv
12
+ sudo snap install astral-uv --classic
13
+ uv venv .venv -p 3.12
14
+ source .venv/bin/activate
15
+
16
+ # install python-dev for adam atan2
17
+ sudo apt install python3-dev -y
18
+ # install torch
19
+ PYTORCH_INDEX_URL=https://download.pytorch.org/whl/cu128
20
+ uv pip install torch torchvision torchaudio --index-url $PYTORCH_INDEX_URL
21
+ # install dependencies + adam atan
22
+ uv pip install packaging ninja wheel setuptools setuptools-scm
23
+ uv pip install --no-cache-dir --no-build-isolation adam-atan2
24
+
25
+ # test torch, cuda and AdamAtan2
26
+ python
27
+ import torch
28
+ t = torch.tensor([0,1,2]).to('cuda')
29
+ from adam_atan2 import AdamATan2
30
+
31
+ # install remaining dependencies
32
+ uv pip install -r requirements.txt
33
+ ```
34
+
35
+ ## Dataset preprocessing
36
+ The repository already contains the raw data, but it needs to be preprocessed. Run the following commands to preprocess the v1 and v2 datasets to make predictions for the public eval datasets.
37
+
38
+ ### ARC-AGI-1
39
+ ```bash
40
+ python -m dataset.build_arc_dataset \
41
+ --input-file-prefix kaggle/combined/arc-agi \
42
+ --output-dir data/arc1concept-aug-1000 \
43
+ --subsets training evaluation concept \
44
+ --test-set-name evaluation
45
+ ```
46
+
47
+ ### ARC-AGI-2
48
+ ```bash
49
+ python -m dataset.build_arc_dataset \
50
+ --input-file-prefix kaggle/combined/arc-agi \
51
+ --output-dir data/arc2concept-aug-1000 \
52
+ --subsets training2 evaluation2 concept \
53
+ --test-set-name evaluation2
54
+ ```
55
+ ## Training
56
+ To reproduce the checkpoints, run the following two training runs on a single 8:H100 node. Each run takes ~20-30h. To speed it up, instructions for multi-node training are below.
57
+
58
+ ### ARC-AGI-2
59
+ ```bash
60
+ run_name="trm_arc_v1_public"
61
+ torchrun --nproc-per-node 8 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \
62
+ arch=trm \
63
+ data_paths="[data/arc1concept-aug-1000]" \
64
+ arch.L_layers=2 \
65
+ arch.H_cycles=3 arch.L_cycles=4 \
66
+ +run_name=${run_name} ema=True
67
+ ```
68
+
69
+ ### ARC-AGI-2
70
+ ```bash
71
+ run_name="trm_arc_v2_public"
72
+ torchrun --nproc-per-node 8 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \
73
+ arch=trm \
74
+ data_paths="[data/arc2concept-aug-1000]" \
75
+ arch.L_layers=2 \
76
+ arch.H_cycles=3 arch.L_cycles=4 \
77
+ +run_name=${run_name} ema=True
78
+ ```
79
+
80
+ ### For multi-node training:
81
+ ```bash
82
+ export MAIN_ADDR=<MAIN_IP>
83
+ export MAIN_PORT=29500
84
+ export NNODES=2
85
+ export GPUS_PER_NODE=8
86
+ export OMP_NUM_THREADS=8
87
+ export NCCL_PORT_RANGE=40000-40050
88
+ run_name="arc_v1_public_2_nodes"
89
+ # on each node:
90
+ export NODE_RANK=0
91
+ torchrun \
92
+ --nnodes $NNODES \
93
+ --node_rank $NODE_RANK \
94
+ --nproc_per_node $GPUS_PER_NODE \
95
+ --rdzv_backend c10d \
96
+ --rdzv_endpoint $MAIN_ADDR:$MAIN_PORT \
97
+ pretrain.py \
98
+ arch=trm \
99
+ data_paths="[data/arc1concept-aug-1000]" \
100
+ arch.L_layers=2 \
101
+ arch.H_cycles=3 arch.L_cycles=4 \
102
+ +run_name=${run_name} ema=True \
103
+ eval_interval=50000
104
+ ```
arc_v1_public/all_config.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 3
3
+ H_layers: 0
4
+ L_cycles: 4
5
+ L_layers: 2
6
+ expansion: 4
7
+ forward_dtype: bfloat16
8
+ halt_exploration_prob: 0.1
9
+ halt_max_steps: 16
10
+ hidden_size: 512
11
+ loss:
12
+ loss_type: stablemax_cross_entropy
13
+ name: losses@ACTLossHead
14
+ mlp_t: false
15
+ name: recursive_reasoning.trm@TinyRecursiveReasoningModel_ACTV1
16
+ no_ACT_continue: true
17
+ num_heads: 8
18
+ pos_encodings: rope
19
+ puzzle_emb_len: 16
20
+ puzzle_emb_ndim: 512
21
+ beta1: 0.9
22
+ beta2: 0.95
23
+ checkpoint_every_eval: true
24
+ checkpoint_path: checkpoints/Arc1concept-aug-1000-ACT-torch/arc_v1_public_eval
25
+ data_paths:
26
+ - data/arc1concept-aug-1000
27
+ data_paths_test: []
28
+ ema: true
29
+ ema_rate: 0.999
30
+ epochs: 100000
31
+ eval_interval: 10000
32
+ eval_save_outputs: []
33
+ evaluators:
34
+ - name: arc@ARC
35
+ freeze_weights: false
36
+ global_batch_size: 768
37
+ load_checkpoint: null
38
+ lr: 0.0001
39
+ lr_min_ratio: 1.0
40
+ lr_warmup_steps: 2000
41
+ min_eval_interval: 0
42
+ project_name: Arc1concept-aug-1000-ACT-torch
43
+ puzzle_emb_lr: 0.01
44
+ puzzle_emb_weight_decay: 0.1
45
+ run_name: arc_v1_public_eval
46
+ seed: 0
47
+ weight_decay: 0.1
arc_v1_public/losses.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ import math
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ if valid_mask is None:
28
+ valid_mask = (labels != ignore_index)
29
+ transformed_labels = torch.where(valid_mask, labels, 0)
30
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
31
+
32
+ return -torch.where(valid_mask, prediction_logprobs, 0)
33
+
34
+
35
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
36
+ # Cast logits to f32
37
+ # Flatten logits
38
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
39
+
40
+
41
+ class ACTLossHead(nn.Module):
42
+ def __init__(self, model: nn.Module, loss_type: str):
43
+ super().__init__()
44
+ self.model = model
45
+ self.loss_fn = globals()[loss_type]
46
+
47
+ def initial_carry(self, *args, **kwargs):
48
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
49
+
50
+ def forward(
51
+ self,
52
+ return_keys: Sequence[str],
53
+ # Model args
54
+ **model_kwargs,
55
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
56
+ # Model logits
57
+ # B x SeqLen x D
58
+ new_carry, outputs = self.model(**model_kwargs)
59
+ labels = new_carry.current_data["labels"]
60
+
61
+ with torch.no_grad():
62
+ # Preds
63
+ outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
64
+
65
+ # Correctness
66
+ mask = (labels != IGNORE_LABEL_ID)
67
+ loss_counts = mask.sum(-1)
68
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
69
+
70
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
71
+ seq_is_correct = is_correct.sum(-1) == loss_counts
72
+
73
+ # Metrics (halted)
74
+ valid_metrics = new_carry.halted & (loss_counts > 0)
75
+ metrics = {
76
+ "count": valid_metrics.sum(),
77
+
78
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
79
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
80
+
81
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
82
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
83
+ }
84
+
85
+ # Losses
86
+
87
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
88
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
89
+ metrics.update({
90
+ "lm_loss": lm_loss.detach(),
91
+ "q_halt_loss": q_halt_loss.detach(),
92
+ })
93
+ # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
94
+ q_continue_loss = 0
95
+ if "target_q_continue" in outputs:
96
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
97
+
98
+ metrics["q_continue_loss"] = q_continue_loss.detach()
99
+ # Filter outputs for return
100
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
101
+
102
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
103
+
arc_v1_public/step_518071 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53689643ad1606d7c22c758f8af0a71b3b66275dea074f214d2f1048d9a01fb0
3
+ size 1822205258
arc_v1_public/trm.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List, Dict, Optional
2
+ from dataclasses import dataclass
3
+ import math
4
+ import torch
5
+ import copy
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+ import random
10
+ from models.common import trunc_normal_init_
11
+ from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
12
+ from models.sparse_embedding import CastedSparseEmbedding
13
+
14
+ IGNORE_LABEL_ID = -100
15
+
16
+ @dataclass
17
+ class TinyRecursiveReasoningModel_ACTV1InnerCarry:
18
+ z_H: torch.Tensor
19
+ z_L: torch.Tensor
20
+
21
+
22
+ @dataclass
23
+ class TinyRecursiveReasoningModel_ACTV1Carry:
24
+ inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
25
+
26
+ steps: torch.Tensor
27
+ halted: torch.Tensor
28
+
29
+ current_data: Dict[str, torch.Tensor]
30
+
31
+
32
+ class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
33
+ batch_size: int
34
+ seq_len: int
35
+ puzzle_emb_ndim: int = 0
36
+ num_puzzle_identifiers: int
37
+ vocab_size: int
38
+
39
+ H_cycles: int
40
+ L_cycles: int
41
+
42
+ H_layers: int # ignored
43
+ L_layers: int
44
+
45
+ # Transformer config
46
+ hidden_size: int
47
+ expansion: float
48
+ num_heads: int
49
+ pos_encodings: str
50
+
51
+ rms_norm_eps: float = 1e-5
52
+ rope_theta: float = 10000.0
53
+
54
+ # Halting Q-learning config
55
+ halt_max_steps: int
56
+ halt_exploration_prob: float
57
+
58
+ forward_dtype: str = "bfloat16"
59
+
60
+ # Alexia: added
61
+ mlp_t: bool = False # use mlp on L instead of transformer
62
+ puzzle_emb_len: int = 16 # if non-zero, its specified to this value
63
+ no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
64
+
65
+ class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
66
+ def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
67
+ super().__init__()
68
+
69
+ self.config = config
70
+ if self.config.mlp_t:
71
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
72
+ self.mlp_t = SwiGLU(
73
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
74
+ expansion=config.expansion,
75
+ )
76
+ else:
77
+ self.self_attn = Attention(
78
+ hidden_size=config.hidden_size,
79
+ head_dim=config.hidden_size // config.num_heads,
80
+ num_heads=config.num_heads,
81
+ num_key_value_heads=config.num_heads,
82
+ causal=False
83
+ )
84
+ self.mlp = SwiGLU(
85
+ hidden_size=config.hidden_size,
86
+ expansion=config.expansion,
87
+ )
88
+ self.norm_eps = config.rms_norm_eps
89
+
90
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
91
+ # B, L, D = hidden_states.shape
92
+ # Post Norm
93
+ if self.config.mlp_t:
94
+ hidden_states = hidden_states.transpose(1,2)
95
+ out = self.mlp_t(hidden_states)
96
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
97
+ hidden_states = hidden_states.transpose(1,2)
98
+ else:
99
+ # Self Attention
100
+ hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
101
+ # Fully Connected
102
+ out = self.mlp(hidden_states)
103
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
104
+ return hidden_states
105
+
106
+ class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
107
+ def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
108
+ super().__init__()
109
+ self.layers = torch.nn.ModuleList(layers)
110
+
111
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
112
+ hidden_states = hidden_states + input_injection
113
+ for layer in self.layers:
114
+ hidden_states = layer(hidden_states=hidden_states, **kwargs)
115
+ return hidden_states
116
+
117
+
118
+ class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
119
+ def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
120
+ super().__init__()
121
+ self.config = config
122
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
123
+
124
+ # I/O
125
+
126
+ self.embed_scale = math.sqrt(self.config.hidden_size)
127
+ embed_init_std = 1.0 / self.embed_scale
128
+
129
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
130
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
131
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
132
+
133
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
134
+ if self.config.puzzle_emb_ndim > 0:
135
+ # Zero init puzzle embeddings
136
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
137
+ batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
138
+
139
+ # LM Blocks
140
+ if self.config.pos_encodings == "rope":
141
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
142
+ max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
143
+ base=self.config.rope_theta)
144
+ elif self.config.pos_encodings == "learned":
145
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
146
+ else:
147
+ pass
148
+
149
+ # Reasoning Layers
150
+ self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
151
+
152
+ # Initial states
153
+ self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
154
+ self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
155
+
156
+ # Q head special init
157
+ # Init Q to (almost) zero for faster learning during bootstrapping
158
+ with torch.no_grad():
159
+ self.q_head.weight.zero_()
160
+ self.q_head.bias.fill_(-5) # type: ignore
161
+
162
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
163
+ # Token embedding
164
+ embedding = self.embed_tokens(input.to(torch.int32))
165
+
166
+ # Puzzle embeddings
167
+ if self.config.puzzle_emb_ndim > 0:
168
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
169
+
170
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
171
+ if pad_count > 0:
172
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
173
+
174
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
175
+
176
+ # Position embeddings
177
+ if self.config.pos_encodings == "learned":
178
+ # scale by 1/sqrt(2) to maintain forward variance
179
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
180
+
181
+ # Scale
182
+ return self.embed_scale * embedding
183
+
184
+ def empty_carry(self, batch_size: int):
185
+ return TinyRecursiveReasoningModel_ACTV1InnerCarry(
186
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
187
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
188
+ )
189
+
190
+ def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
191
+ return TinyRecursiveReasoningModel_ACTV1InnerCarry(
192
+ z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
193
+ z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
194
+ )
195
+
196
+ def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
197
+ seq_info = dict(
198
+ cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
199
+ )
200
+
201
+ # Input encoding
202
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
203
+
204
+ # Forward iterations
205
+ it = 0
206
+ z_H, z_L = carry.z_H, carry.z_L
207
+ # H_cycles-1 without grad
208
+ with torch.no_grad():
209
+ for _H_step in range(self.config.H_cycles-1):
210
+ for _L_step in range(self.config.L_cycles):
211
+ z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
212
+ z_H = self.L_level(z_H, z_L, **seq_info)
213
+ # 1 with grad
214
+ for _L_step in range(self.config.L_cycles):
215
+ z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
216
+ z_H = self.L_level(z_H, z_L, **seq_info)
217
+
218
+ # LM Outputs
219
+ new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
220
+ output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
221
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
222
+ return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
223
+
224
+
225
+ class TinyRecursiveReasoningModel_ACTV1(nn.Module):
226
+ """ACT wrapper."""
227
+
228
+ def __init__(self, config_dict: dict):
229
+ super().__init__()
230
+ self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
231
+ self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
232
+
233
+ @property
234
+ def puzzle_emb(self):
235
+ return self.inner.puzzle_emb
236
+
237
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
238
+ batch_size = batch["inputs"].shape[0]
239
+
240
+ return TinyRecursiveReasoningModel_ACTV1Carry(
241
+ inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
242
+
243
+ steps=torch.zeros((batch_size, ), dtype=torch.int32),
244
+ halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
245
+
246
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
247
+ )
248
+
249
+ def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
250
+
251
+ # Update data, carry (removing halted sequences)
252
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
253
+
254
+ new_steps = torch.where(carry.halted, 0, carry.steps)
255
+
256
+ new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
257
+
258
+ # Forward inner model
259
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
260
+
261
+ outputs = {
262
+ "logits": logits,
263
+ "q_halt_logits": q_halt_logits,
264
+ "q_continue_logits": q_continue_logits
265
+ }
266
+
267
+ with torch.no_grad():
268
+ # Step
269
+ new_steps = new_steps + 1
270
+ is_last_step = new_steps >= self.config.halt_max_steps
271
+
272
+ halted = is_last_step
273
+
274
+ # if training, and ACT is enabled
275
+ if self.training and (self.config.halt_max_steps > 1):
276
+
277
+ # Halt signal
278
+ # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
279
+
280
+ if self.config.no_ACT_continue:
281
+ halted = halted | (q_halt_logits > 0)
282
+ else:
283
+ halted = halted | (q_halt_logits > q_continue_logits)
284
+
285
+ # Exploration
286
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
287
+ halted = halted & (new_steps >= min_halt_steps)
288
+
289
+ if not self.config.no_ACT_continue:
290
+ # Compute target Q
291
+ # NOTE: No replay buffer and target networks for computing target Q-value.
292
+ # As batch_size is large, there're many parallel envs.
293
+ # Similar concept as PQN https://arxiv.org/abs/2407.04811
294
+ _, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
295
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
296
+
297
+ return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
arc_v2_public/all_config.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 3
3
+ H_layers: 0
4
+ L_cycles: 4
5
+ L_layers: 2
6
+ expansion: 4
7
+ forward_dtype: bfloat16
8
+ halt_exploration_prob: 0.1
9
+ halt_max_steps: 16
10
+ hidden_size: 512
11
+ loss:
12
+ loss_type: stablemax_cross_entropy
13
+ name: losses@ACTLossHead
14
+ mlp_t: false
15
+ name: recursive_reasoning.trm@TinyRecursiveReasoningModel_ACTV1
16
+ no_ACT_continue: true
17
+ num_heads: 8
18
+ pos_encodings: rope
19
+ puzzle_emb_len: 16
20
+ puzzle_emb_ndim: 512
21
+ beta1: 0.9
22
+ beta2: 0.95
23
+ checkpoint_every_eval: true
24
+ checkpoint_path: checkpoints/Arc2concept-aug-1000-ACT-torch/arc_v2_pub
25
+ data_paths:
26
+ - data/arc2concept-aug-1000
27
+ data_paths_test: []
28
+ ema: true
29
+ ema_rate: 0.999
30
+ epochs: 100000
31
+ eval_interval: 10000
32
+ eval_save_outputs: []
33
+ evaluators:
34
+ - name: arc@ARC
35
+ freeze_weights: false
36
+ global_batch_size: 768
37
+ load_checkpoint: null
38
+ lr: 0.0001
39
+ lr_min_ratio: 1.0
40
+ lr_warmup_steps: 2000
41
+ min_eval_interval: 0
42
+ project_name: Arc2concept-aug-1000-ACT-torch
43
+ puzzle_emb_lr: 0.01
44
+ puzzle_emb_weight_decay: 0.1
45
+ run_name: arc_v2_pub
46
+ seed: 0
47
+ weight_decay: 0.1
arc_v2_public/losses.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ import math
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ if valid_mask is None:
28
+ valid_mask = (labels != ignore_index)
29
+ transformed_labels = torch.where(valid_mask, labels, 0)
30
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
31
+
32
+ return -torch.where(valid_mask, prediction_logprobs, 0)
33
+
34
+
35
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
36
+ # Cast logits to f32
37
+ # Flatten logits
38
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
39
+
40
+
41
+ class ACTLossHead(nn.Module):
42
+ def __init__(self, model: nn.Module, loss_type: str):
43
+ super().__init__()
44
+ self.model = model
45
+ self.loss_fn = globals()[loss_type]
46
+
47
+ def initial_carry(self, *args, **kwargs):
48
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
49
+
50
+ def forward(
51
+ self,
52
+ return_keys: Sequence[str],
53
+ # Model args
54
+ **model_kwargs,
55
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
56
+ # Model logits
57
+ # B x SeqLen x D
58
+ new_carry, outputs = self.model(**model_kwargs)
59
+ labels = new_carry.current_data["labels"]
60
+
61
+ with torch.no_grad():
62
+ # Preds
63
+ outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
64
+
65
+ # Correctness
66
+ mask = (labels != IGNORE_LABEL_ID)
67
+ loss_counts = mask.sum(-1)
68
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
69
+
70
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
71
+ seq_is_correct = is_correct.sum(-1) == loss_counts
72
+
73
+ # Metrics (halted)
74
+ valid_metrics = new_carry.halted & (loss_counts > 0)
75
+ metrics = {
76
+ "count": valid_metrics.sum(),
77
+
78
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
79
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
80
+
81
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
82
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
83
+ }
84
+
85
+ # Losses
86
+
87
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
88
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
89
+ metrics.update({
90
+ "lm_loss": lm_loss.detach(),
91
+ "q_halt_loss": q_halt_loss.detach(),
92
+ })
93
+ # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
94
+ q_continue_loss = 0
95
+ if "target_q_continue" in outputs:
96
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
97
+
98
+ metrics["q_continue_loss"] = q_continue_loss.detach()
99
+ # Filter outputs for return
100
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
101
+
102
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
103
+
arc_v2_public/step_723914 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d7036b97e7ea38c7dd29d01216bfcfc4e212af3024d5233fe40dd3059e8f4a9
3
+ size 2467988810
arc_v2_public/trm.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List, Dict, Optional
2
+ from dataclasses import dataclass
3
+ import math
4
+ import torch
5
+ import copy
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+ import random
10
+ from models.common import trunc_normal_init_
11
+ from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
12
+ from models.sparse_embedding import CastedSparseEmbedding
13
+
14
+ IGNORE_LABEL_ID = -100
15
+
16
+ @dataclass
17
+ class TinyRecursiveReasoningModel_ACTV1InnerCarry:
18
+ z_H: torch.Tensor
19
+ z_L: torch.Tensor
20
+
21
+
22
+ @dataclass
23
+ class TinyRecursiveReasoningModel_ACTV1Carry:
24
+ inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
25
+
26
+ steps: torch.Tensor
27
+ halted: torch.Tensor
28
+
29
+ current_data: Dict[str, torch.Tensor]
30
+
31
+
32
+ class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
33
+ batch_size: int
34
+ seq_len: int
35
+ puzzle_emb_ndim: int = 0
36
+ num_puzzle_identifiers: int
37
+ vocab_size: int
38
+
39
+ H_cycles: int
40
+ L_cycles: int
41
+
42
+ H_layers: int # ignored
43
+ L_layers: int
44
+
45
+ # Transformer config
46
+ hidden_size: int
47
+ expansion: float
48
+ num_heads: int
49
+ pos_encodings: str
50
+
51
+ rms_norm_eps: float = 1e-5
52
+ rope_theta: float = 10000.0
53
+
54
+ # Halting Q-learning config
55
+ halt_max_steps: int
56
+ halt_exploration_prob: float
57
+
58
+ forward_dtype: str = "bfloat16"
59
+
60
+ # Alexia: added
61
+ mlp_t: bool = False # use mlp on L instead of transformer
62
+ puzzle_emb_len: int = 16 # if non-zero, its specified to this value
63
+ no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
64
+
65
+ class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
66
+ def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
67
+ super().__init__()
68
+
69
+ self.config = config
70
+ if self.config.mlp_t:
71
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
72
+ self.mlp_t = SwiGLU(
73
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
74
+ expansion=config.expansion,
75
+ )
76
+ else:
77
+ self.self_attn = Attention(
78
+ hidden_size=config.hidden_size,
79
+ head_dim=config.hidden_size // config.num_heads,
80
+ num_heads=config.num_heads,
81
+ num_key_value_heads=config.num_heads,
82
+ causal=False
83
+ )
84
+ self.mlp = SwiGLU(
85
+ hidden_size=config.hidden_size,
86
+ expansion=config.expansion,
87
+ )
88
+ self.norm_eps = config.rms_norm_eps
89
+
90
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
91
+ # B, L, D = hidden_states.shape
92
+ # Post Norm
93
+ if self.config.mlp_t:
94
+ hidden_states = hidden_states.transpose(1,2)
95
+ out = self.mlp_t(hidden_states)
96
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
97
+ hidden_states = hidden_states.transpose(1,2)
98
+ else:
99
+ # Self Attention
100
+ hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
101
+ # Fully Connected
102
+ out = self.mlp(hidden_states)
103
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
104
+ return hidden_states
105
+
106
+ class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
107
+ def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
108
+ super().__init__()
109
+ self.layers = torch.nn.ModuleList(layers)
110
+
111
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
112
+ hidden_states = hidden_states + input_injection
113
+ for layer in self.layers:
114
+ hidden_states = layer(hidden_states=hidden_states, **kwargs)
115
+ return hidden_states
116
+
117
+
118
+ class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
119
+ def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
120
+ super().__init__()
121
+ self.config = config
122
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
123
+
124
+ # I/O
125
+
126
+ self.embed_scale = math.sqrt(self.config.hidden_size)
127
+ embed_init_std = 1.0 / self.embed_scale
128
+
129
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
130
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
131
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
132
+
133
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
134
+ if self.config.puzzle_emb_ndim > 0:
135
+ # Zero init puzzle embeddings
136
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
137
+ batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
138
+
139
+ # LM Blocks
140
+ if self.config.pos_encodings == "rope":
141
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
142
+ max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
143
+ base=self.config.rope_theta)
144
+ elif self.config.pos_encodings == "learned":
145
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
146
+ else:
147
+ pass
148
+
149
+ # Reasoning Layers
150
+ self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
151
+
152
+ # Initial states
153
+ self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
154
+ self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
155
+
156
+ # Q head special init
157
+ # Init Q to (almost) zero for faster learning during bootstrapping
158
+ with torch.no_grad():
159
+ self.q_head.weight.zero_()
160
+ self.q_head.bias.fill_(-5) # type: ignore
161
+
162
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
163
+ # Token embedding
164
+ embedding = self.embed_tokens(input.to(torch.int32))
165
+
166
+ # Puzzle embeddings
167
+ if self.config.puzzle_emb_ndim > 0:
168
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
169
+
170
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
171
+ if pad_count > 0:
172
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
173
+
174
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
175
+
176
+ # Position embeddings
177
+ if self.config.pos_encodings == "learned":
178
+ # scale by 1/sqrt(2) to maintain forward variance
179
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
180
+
181
+ # Scale
182
+ return self.embed_scale * embedding
183
+
184
+ def empty_carry(self, batch_size: int):
185
+ return TinyRecursiveReasoningModel_ACTV1InnerCarry(
186
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
187
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
188
+ )
189
+
190
+ def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
191
+ return TinyRecursiveReasoningModel_ACTV1InnerCarry(
192
+ z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
193
+ z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
194
+ )
195
+
196
+ def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
197
+ seq_info = dict(
198
+ cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
199
+ )
200
+
201
+ # Input encoding
202
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
203
+
204
+ # Forward iterations
205
+ it = 0
206
+ z_H, z_L = carry.z_H, carry.z_L
207
+ # H_cycles-1 without grad
208
+ with torch.no_grad():
209
+ for _H_step in range(self.config.H_cycles-1):
210
+ for _L_step in range(self.config.L_cycles):
211
+ z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
212
+ z_H = self.L_level(z_H, z_L, **seq_info)
213
+ # 1 with grad
214
+ for _L_step in range(self.config.L_cycles):
215
+ z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
216
+ z_H = self.L_level(z_H, z_L, **seq_info)
217
+
218
+ # LM Outputs
219
+ new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
220
+ output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
221
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
222
+ return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
223
+
224
+
225
+ class TinyRecursiveReasoningModel_ACTV1(nn.Module):
226
+ """ACT wrapper."""
227
+
228
+ def __init__(self, config_dict: dict):
229
+ super().__init__()
230
+ self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
231
+ self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
232
+
233
+ @property
234
+ def puzzle_emb(self):
235
+ return self.inner.puzzle_emb
236
+
237
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
238
+ batch_size = batch["inputs"].shape[0]
239
+
240
+ return TinyRecursiveReasoningModel_ACTV1Carry(
241
+ inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
242
+
243
+ steps=torch.zeros((batch_size, ), dtype=torch.int32),
244
+ halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
245
+
246
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
247
+ )
248
+
249
+ def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
250
+
251
+ # Update data, carry (removing halted sequences)
252
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
253
+
254
+ new_steps = torch.where(carry.halted, 0, carry.steps)
255
+
256
+ new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
257
+
258
+ # Forward inner model
259
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
260
+
261
+ outputs = {
262
+ "logits": logits,
263
+ "q_halt_logits": q_halt_logits,
264
+ "q_continue_logits": q_continue_logits
265
+ }
266
+
267
+ with torch.no_grad():
268
+ # Step
269
+ new_steps = new_steps + 1
270
+ is_last_step = new_steps >= self.config.halt_max_steps
271
+
272
+ halted = is_last_step
273
+
274
+ # if training, and ACT is enabled
275
+ if self.training and (self.config.halt_max_steps > 1):
276
+
277
+ # Halt signal
278
+ # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
279
+
280
+ if self.config.no_ACT_continue:
281
+ halted = halted | (q_halt_logits > 0)
282
+ else:
283
+ halted = halted | (q_halt_logits > q_continue_logits)
284
+
285
+ # Exploration
286
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
287
+ halted = halted & (new_steps >= min_halt_steps)
288
+
289
+ if not self.config.no_ACT_continue:
290
+ # Compute target Q
291
+ # NOTE: No replay buffer and target networks for computing target Q-value.
292
+ # As batch_size is large, there're many parallel envs.
293
+ # Similar concept as PQN https://arxiv.org/abs/2407.04811
294
+ _, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
295
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
296
+
297
+ return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs