Upload folder using huggingface_hub
Browse files- .DS_Store +0 -0
- .gitattributes +2 -0
- README.md +104 -3
- arc_v1_public/all_config.yaml +47 -0
- arc_v1_public/losses.py +103 -0
- arc_v1_public/step_518071 +3 -0
- arc_v1_public/trm.py +297 -0
- arc_v2_public/all_config.yaml +47 -0
- arc_v2_public/losses.py +103 -0
- arc_v2_public/step_723914 +3 -0
- arc_v2_public/trm.py +297 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
arc_v1_public/step_518071 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
arc_v2_public/step_723914 filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +1,104 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
This repository contains the TinyRecursiveModels checkpoints for arc v1 public eval and arc v2 public eval that were trained for the performance verification. They were trained using the code and recipe of the [official TRM repository](https://github.com/SamsungSAILMontreal/TinyRecursiveModels). We had to adapt the environment setup as detailed below. We provide these checkpoints for transparency and to facilitate further research. We did not contribute to the TRM reserach nor maintain the TRM code. For any questions, please reach out to the TRM maintainers.
|
| 2 |
+
|
| 3 |
+
TRM writes checkpoints as `torch state_dicts`. The subdirectories `arc_v1_public` and `arc_v2_public` contain the final checkpoints `step_<final-step>`, which can be loaded with the `load_checkpoint` or by providing the checkpoint path as `load_checkpoint=path/to/checkpoint`. For reference, see the `PretrainConfig` in `pretrain.py`.
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
## Replication Results
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
## Environment Setup
|
| 10 |
+
```bash
|
| 11 |
+
# use uv for venv
|
| 12 |
+
sudo snap install astral-uv --classic
|
| 13 |
+
uv venv .venv -p 3.12
|
| 14 |
+
source .venv/bin/activate
|
| 15 |
+
|
| 16 |
+
# install python-dev for adam atan2
|
| 17 |
+
sudo apt install python3-dev -y
|
| 18 |
+
# install torch
|
| 19 |
+
PYTORCH_INDEX_URL=https://download.pytorch.org/whl/cu128
|
| 20 |
+
uv pip install torch torchvision torchaudio --index-url $PYTORCH_INDEX_URL
|
| 21 |
+
# install dependencies + adam atan
|
| 22 |
+
uv pip install packaging ninja wheel setuptools setuptools-scm
|
| 23 |
+
uv pip install --no-cache-dir --no-build-isolation adam-atan2
|
| 24 |
+
|
| 25 |
+
# test torch, cuda and AdamAtan2
|
| 26 |
+
python
|
| 27 |
+
import torch
|
| 28 |
+
t = torch.tensor([0,1,2]).to('cuda')
|
| 29 |
+
from adam_atan2 import AdamATan2
|
| 30 |
+
|
| 31 |
+
# install remaining dependencies
|
| 32 |
+
uv pip install -r requirements.txt
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## Dataset preprocessing
|
| 36 |
+
The repository already contains the raw data, but it needs to be preprocessed. Run the following commands to preprocess the v1 and v2 datasets to make predictions for the public eval datasets.
|
| 37 |
+
|
| 38 |
+
### ARC-AGI-1
|
| 39 |
+
```bash
|
| 40 |
+
python -m dataset.build_arc_dataset \
|
| 41 |
+
--input-file-prefix kaggle/combined/arc-agi \
|
| 42 |
+
--output-dir data/arc1concept-aug-1000 \
|
| 43 |
+
--subsets training evaluation concept \
|
| 44 |
+
--test-set-name evaluation
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### ARC-AGI-2
|
| 48 |
+
```bash
|
| 49 |
+
python -m dataset.build_arc_dataset \
|
| 50 |
+
--input-file-prefix kaggle/combined/arc-agi \
|
| 51 |
+
--output-dir data/arc2concept-aug-1000 \
|
| 52 |
+
--subsets training2 evaluation2 concept \
|
| 53 |
+
--test-set-name evaluation2
|
| 54 |
+
```
|
| 55 |
+
## Training
|
| 56 |
+
To reproduce the checkpoints, run the following two training runs on a single 8:H100 node. Each run takes ~20-30h. To speed it up, instructions for multi-node training are below.
|
| 57 |
+
|
| 58 |
+
### ARC-AGI-2
|
| 59 |
+
```bash
|
| 60 |
+
run_name="trm_arc_v1_public"
|
| 61 |
+
torchrun --nproc-per-node 8 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \
|
| 62 |
+
arch=trm \
|
| 63 |
+
data_paths="[data/arc1concept-aug-1000]" \
|
| 64 |
+
arch.L_layers=2 \
|
| 65 |
+
arch.H_cycles=3 arch.L_cycles=4 \
|
| 66 |
+
+run_name=${run_name} ema=True
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### ARC-AGI-2
|
| 70 |
+
```bash
|
| 71 |
+
run_name="trm_arc_v2_public"
|
| 72 |
+
torchrun --nproc-per-node 8 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \
|
| 73 |
+
arch=trm \
|
| 74 |
+
data_paths="[data/arc2concept-aug-1000]" \
|
| 75 |
+
arch.L_layers=2 \
|
| 76 |
+
arch.H_cycles=3 arch.L_cycles=4 \
|
| 77 |
+
+run_name=${run_name} ema=True
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
### For multi-node training:
|
| 81 |
+
```bash
|
| 82 |
+
export MAIN_ADDR=<MAIN_IP>
|
| 83 |
+
export MAIN_PORT=29500
|
| 84 |
+
export NNODES=2
|
| 85 |
+
export GPUS_PER_NODE=8
|
| 86 |
+
export OMP_NUM_THREADS=8
|
| 87 |
+
export NCCL_PORT_RANGE=40000-40050
|
| 88 |
+
run_name="arc_v1_public_2_nodes"
|
| 89 |
+
# on each node:
|
| 90 |
+
export NODE_RANK=0
|
| 91 |
+
torchrun \
|
| 92 |
+
--nnodes $NNODES \
|
| 93 |
+
--node_rank $NODE_RANK \
|
| 94 |
+
--nproc_per_node $GPUS_PER_NODE \
|
| 95 |
+
--rdzv_backend c10d \
|
| 96 |
+
--rdzv_endpoint $MAIN_ADDR:$MAIN_PORT \
|
| 97 |
+
pretrain.py \
|
| 98 |
+
arch=trm \
|
| 99 |
+
data_paths="[data/arc1concept-aug-1000]" \
|
| 100 |
+
arch.L_layers=2 \
|
| 101 |
+
arch.H_cycles=3 arch.L_cycles=4 \
|
| 102 |
+
+run_name=${run_name} ema=True \
|
| 103 |
+
eval_interval=50000
|
| 104 |
+
```
|
arc_v1_public/all_config.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
arch:
|
| 2 |
+
H_cycles: 3
|
| 3 |
+
H_layers: 0
|
| 4 |
+
L_cycles: 4
|
| 5 |
+
L_layers: 2
|
| 6 |
+
expansion: 4
|
| 7 |
+
forward_dtype: bfloat16
|
| 8 |
+
halt_exploration_prob: 0.1
|
| 9 |
+
halt_max_steps: 16
|
| 10 |
+
hidden_size: 512
|
| 11 |
+
loss:
|
| 12 |
+
loss_type: stablemax_cross_entropy
|
| 13 |
+
name: losses@ACTLossHead
|
| 14 |
+
mlp_t: false
|
| 15 |
+
name: recursive_reasoning.trm@TinyRecursiveReasoningModel_ACTV1
|
| 16 |
+
no_ACT_continue: true
|
| 17 |
+
num_heads: 8
|
| 18 |
+
pos_encodings: rope
|
| 19 |
+
puzzle_emb_len: 16
|
| 20 |
+
puzzle_emb_ndim: 512
|
| 21 |
+
beta1: 0.9
|
| 22 |
+
beta2: 0.95
|
| 23 |
+
checkpoint_every_eval: true
|
| 24 |
+
checkpoint_path: checkpoints/Arc1concept-aug-1000-ACT-torch/arc_v1_public_eval
|
| 25 |
+
data_paths:
|
| 26 |
+
- data/arc1concept-aug-1000
|
| 27 |
+
data_paths_test: []
|
| 28 |
+
ema: true
|
| 29 |
+
ema_rate: 0.999
|
| 30 |
+
epochs: 100000
|
| 31 |
+
eval_interval: 10000
|
| 32 |
+
eval_save_outputs: []
|
| 33 |
+
evaluators:
|
| 34 |
+
- name: arc@ARC
|
| 35 |
+
freeze_weights: false
|
| 36 |
+
global_batch_size: 768
|
| 37 |
+
load_checkpoint: null
|
| 38 |
+
lr: 0.0001
|
| 39 |
+
lr_min_ratio: 1.0
|
| 40 |
+
lr_warmup_steps: 2000
|
| 41 |
+
min_eval_interval: 0
|
| 42 |
+
project_name: Arc1concept-aug-1000-ACT-torch
|
| 43 |
+
puzzle_emb_lr: 0.01
|
| 44 |
+
puzzle_emb_weight_decay: 0.1
|
| 45 |
+
run_name: arc_v1_public_eval
|
| 46 |
+
seed: 0
|
| 47 |
+
weight_decay: 0.1
|
arc_v1_public/losses.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Tuple, Dict, Sequence, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch import nn
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
IGNORE_LABEL_ID = -100
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def s(x, epsilon=1e-30):
|
| 12 |
+
return torch.where(
|
| 13 |
+
x<0,
|
| 14 |
+
1/(1-x+ epsilon),
|
| 15 |
+
x + 1
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def log_stablemax(x, dim=-1):
|
| 20 |
+
s_x = s(x)
|
| 21 |
+
return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
|
| 25 |
+
logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
|
| 26 |
+
|
| 27 |
+
if valid_mask is None:
|
| 28 |
+
valid_mask = (labels != ignore_index)
|
| 29 |
+
transformed_labels = torch.where(valid_mask, labels, 0)
|
| 30 |
+
prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
|
| 31 |
+
|
| 32 |
+
return -torch.where(valid_mask, prediction_logprobs, 0)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
|
| 36 |
+
# Cast logits to f32
|
| 37 |
+
# Flatten logits
|
| 38 |
+
return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ACTLossHead(nn.Module):
|
| 42 |
+
def __init__(self, model: nn.Module, loss_type: str):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.model = model
|
| 45 |
+
self.loss_fn = globals()[loss_type]
|
| 46 |
+
|
| 47 |
+
def initial_carry(self, *args, **kwargs):
|
| 48 |
+
return self.model.initial_carry(*args, **kwargs) # type: ignore
|
| 49 |
+
|
| 50 |
+
def forward(
|
| 51 |
+
self,
|
| 52 |
+
return_keys: Sequence[str],
|
| 53 |
+
# Model args
|
| 54 |
+
**model_kwargs,
|
| 55 |
+
) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
|
| 56 |
+
# Model logits
|
| 57 |
+
# B x SeqLen x D
|
| 58 |
+
new_carry, outputs = self.model(**model_kwargs)
|
| 59 |
+
labels = new_carry.current_data["labels"]
|
| 60 |
+
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
# Preds
|
| 63 |
+
outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
|
| 64 |
+
|
| 65 |
+
# Correctness
|
| 66 |
+
mask = (labels != IGNORE_LABEL_ID)
|
| 67 |
+
loss_counts = mask.sum(-1)
|
| 68 |
+
loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
|
| 69 |
+
|
| 70 |
+
is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
|
| 71 |
+
seq_is_correct = is_correct.sum(-1) == loss_counts
|
| 72 |
+
|
| 73 |
+
# Metrics (halted)
|
| 74 |
+
valid_metrics = new_carry.halted & (loss_counts > 0)
|
| 75 |
+
metrics = {
|
| 76 |
+
"count": valid_metrics.sum(),
|
| 77 |
+
|
| 78 |
+
"accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
|
| 79 |
+
"exact_accuracy": (valid_metrics & seq_is_correct).sum(),
|
| 80 |
+
|
| 81 |
+
"q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
|
| 82 |
+
"steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
# Losses
|
| 86 |
+
|
| 87 |
+
lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
|
| 88 |
+
q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
|
| 89 |
+
metrics.update({
|
| 90 |
+
"lm_loss": lm_loss.detach(),
|
| 91 |
+
"q_halt_loss": q_halt_loss.detach(),
|
| 92 |
+
})
|
| 93 |
+
# Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
|
| 94 |
+
q_continue_loss = 0
|
| 95 |
+
if "target_q_continue" in outputs:
|
| 96 |
+
q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
|
| 97 |
+
|
| 98 |
+
metrics["q_continue_loss"] = q_continue_loss.detach()
|
| 99 |
+
# Filter outputs for return
|
| 100 |
+
detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
|
| 101 |
+
|
| 102 |
+
return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
|
| 103 |
+
|
arc_v1_public/step_518071
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:53689643ad1606d7c22c758f8af0a71b3b66275dea074f214d2f1048d9a01fb0
|
| 3 |
+
size 1822205258
|
arc_v1_public/trm.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, List, Dict, Optional
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import copy
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
import random
|
| 10 |
+
from models.common import trunc_normal_init_
|
| 11 |
+
from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
|
| 12 |
+
from models.sparse_embedding import CastedSparseEmbedding
|
| 13 |
+
|
| 14 |
+
IGNORE_LABEL_ID = -100
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
|
| 18 |
+
z_H: torch.Tensor
|
| 19 |
+
z_L: torch.Tensor
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class TinyRecursiveReasoningModel_ACTV1Carry:
|
| 24 |
+
inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
|
| 25 |
+
|
| 26 |
+
steps: torch.Tensor
|
| 27 |
+
halted: torch.Tensor
|
| 28 |
+
|
| 29 |
+
current_data: Dict[str, torch.Tensor]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
|
| 33 |
+
batch_size: int
|
| 34 |
+
seq_len: int
|
| 35 |
+
puzzle_emb_ndim: int = 0
|
| 36 |
+
num_puzzle_identifiers: int
|
| 37 |
+
vocab_size: int
|
| 38 |
+
|
| 39 |
+
H_cycles: int
|
| 40 |
+
L_cycles: int
|
| 41 |
+
|
| 42 |
+
H_layers: int # ignored
|
| 43 |
+
L_layers: int
|
| 44 |
+
|
| 45 |
+
# Transformer config
|
| 46 |
+
hidden_size: int
|
| 47 |
+
expansion: float
|
| 48 |
+
num_heads: int
|
| 49 |
+
pos_encodings: str
|
| 50 |
+
|
| 51 |
+
rms_norm_eps: float = 1e-5
|
| 52 |
+
rope_theta: float = 10000.0
|
| 53 |
+
|
| 54 |
+
# Halting Q-learning config
|
| 55 |
+
halt_max_steps: int
|
| 56 |
+
halt_exploration_prob: float
|
| 57 |
+
|
| 58 |
+
forward_dtype: str = "bfloat16"
|
| 59 |
+
|
| 60 |
+
# Alexia: added
|
| 61 |
+
mlp_t: bool = False # use mlp on L instead of transformer
|
| 62 |
+
puzzle_emb_len: int = 16 # if non-zero, its specified to this value
|
| 63 |
+
no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
|
| 64 |
+
|
| 65 |
+
class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
|
| 66 |
+
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
self.config = config
|
| 70 |
+
if self.config.mlp_t:
|
| 71 |
+
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
|
| 72 |
+
self.mlp_t = SwiGLU(
|
| 73 |
+
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
|
| 74 |
+
expansion=config.expansion,
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
self.self_attn = Attention(
|
| 78 |
+
hidden_size=config.hidden_size,
|
| 79 |
+
head_dim=config.hidden_size // config.num_heads,
|
| 80 |
+
num_heads=config.num_heads,
|
| 81 |
+
num_key_value_heads=config.num_heads,
|
| 82 |
+
causal=False
|
| 83 |
+
)
|
| 84 |
+
self.mlp = SwiGLU(
|
| 85 |
+
hidden_size=config.hidden_size,
|
| 86 |
+
expansion=config.expansion,
|
| 87 |
+
)
|
| 88 |
+
self.norm_eps = config.rms_norm_eps
|
| 89 |
+
|
| 90 |
+
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 91 |
+
# B, L, D = hidden_states.shape
|
| 92 |
+
# Post Norm
|
| 93 |
+
if self.config.mlp_t:
|
| 94 |
+
hidden_states = hidden_states.transpose(1,2)
|
| 95 |
+
out = self.mlp_t(hidden_states)
|
| 96 |
+
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
| 97 |
+
hidden_states = hidden_states.transpose(1,2)
|
| 98 |
+
else:
|
| 99 |
+
# Self Attention
|
| 100 |
+
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
|
| 101 |
+
# Fully Connected
|
| 102 |
+
out = self.mlp(hidden_states)
|
| 103 |
+
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
| 104 |
+
return hidden_states
|
| 105 |
+
|
| 106 |
+
class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
|
| 107 |
+
def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.layers = torch.nn.ModuleList(layers)
|
| 110 |
+
|
| 111 |
+
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 112 |
+
hidden_states = hidden_states + input_injection
|
| 113 |
+
for layer in self.layers:
|
| 114 |
+
hidden_states = layer(hidden_states=hidden_states, **kwargs)
|
| 115 |
+
return hidden_states
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
|
| 119 |
+
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.config = config
|
| 122 |
+
self.forward_dtype = getattr(torch, self.config.forward_dtype)
|
| 123 |
+
|
| 124 |
+
# I/O
|
| 125 |
+
|
| 126 |
+
self.embed_scale = math.sqrt(self.config.hidden_size)
|
| 127 |
+
embed_init_std = 1.0 / self.embed_scale
|
| 128 |
+
|
| 129 |
+
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
| 130 |
+
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
| 131 |
+
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
|
| 132 |
+
|
| 133 |
+
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
|
| 134 |
+
if self.config.puzzle_emb_ndim > 0:
|
| 135 |
+
# Zero init puzzle embeddings
|
| 136 |
+
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
|
| 137 |
+
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
|
| 138 |
+
|
| 139 |
+
# LM Blocks
|
| 140 |
+
if self.config.pos_encodings == "rope":
|
| 141 |
+
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
|
| 142 |
+
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
|
| 143 |
+
base=self.config.rope_theta)
|
| 144 |
+
elif self.config.pos_encodings == "learned":
|
| 145 |
+
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
| 146 |
+
else:
|
| 147 |
+
pass
|
| 148 |
+
|
| 149 |
+
# Reasoning Layers
|
| 150 |
+
self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
|
| 151 |
+
|
| 152 |
+
# Initial states
|
| 153 |
+
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
| 154 |
+
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
| 155 |
+
|
| 156 |
+
# Q head special init
|
| 157 |
+
# Init Q to (almost) zero for faster learning during bootstrapping
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
self.q_head.weight.zero_()
|
| 160 |
+
self.q_head.bias.fill_(-5) # type: ignore
|
| 161 |
+
|
| 162 |
+
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
|
| 163 |
+
# Token embedding
|
| 164 |
+
embedding = self.embed_tokens(input.to(torch.int32))
|
| 165 |
+
|
| 166 |
+
# Puzzle embeddings
|
| 167 |
+
if self.config.puzzle_emb_ndim > 0:
|
| 168 |
+
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
|
| 169 |
+
|
| 170 |
+
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
|
| 171 |
+
if pad_count > 0:
|
| 172 |
+
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
|
| 173 |
+
|
| 174 |
+
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
|
| 175 |
+
|
| 176 |
+
# Position embeddings
|
| 177 |
+
if self.config.pos_encodings == "learned":
|
| 178 |
+
# scale by 1/sqrt(2) to maintain forward variance
|
| 179 |
+
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
|
| 180 |
+
|
| 181 |
+
# Scale
|
| 182 |
+
return self.embed_scale * embedding
|
| 183 |
+
|
| 184 |
+
def empty_carry(self, batch_size: int):
|
| 185 |
+
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
| 186 |
+
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
| 187 |
+
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
|
| 191 |
+
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
| 192 |
+
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
|
| 193 |
+
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 197 |
+
seq_info = dict(
|
| 198 |
+
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# Input encoding
|
| 202 |
+
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
|
| 203 |
+
|
| 204 |
+
# Forward iterations
|
| 205 |
+
it = 0
|
| 206 |
+
z_H, z_L = carry.z_H, carry.z_L
|
| 207 |
+
# H_cycles-1 without grad
|
| 208 |
+
with torch.no_grad():
|
| 209 |
+
for _H_step in range(self.config.H_cycles-1):
|
| 210 |
+
for _L_step in range(self.config.L_cycles):
|
| 211 |
+
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
| 212 |
+
z_H = self.L_level(z_H, z_L, **seq_info)
|
| 213 |
+
# 1 with grad
|
| 214 |
+
for _L_step in range(self.config.L_cycles):
|
| 215 |
+
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
| 216 |
+
z_H = self.L_level(z_H, z_L, **seq_info)
|
| 217 |
+
|
| 218 |
+
# LM Outputs
|
| 219 |
+
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
|
| 220 |
+
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
|
| 221 |
+
q_logits = self.q_head(z_H[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
|
| 222 |
+
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class TinyRecursiveReasoningModel_ACTV1(nn.Module):
|
| 226 |
+
"""ACT wrapper."""
|
| 227 |
+
|
| 228 |
+
def __init__(self, config_dict: dict):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
|
| 231 |
+
self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
|
| 232 |
+
|
| 233 |
+
@property
|
| 234 |
+
def puzzle_emb(self):
|
| 235 |
+
return self.inner.puzzle_emb
|
| 236 |
+
|
| 237 |
+
def initial_carry(self, batch: Dict[str, torch.Tensor]):
|
| 238 |
+
batch_size = batch["inputs"].shape[0]
|
| 239 |
+
|
| 240 |
+
return TinyRecursiveReasoningModel_ACTV1Carry(
|
| 241 |
+
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
|
| 242 |
+
|
| 243 |
+
steps=torch.zeros((batch_size, ), dtype=torch.int32),
|
| 244 |
+
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
|
| 245 |
+
|
| 246 |
+
current_data={k: torch.empty_like(v) for k, v in batch.items()}
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
|
| 250 |
+
|
| 251 |
+
# Update data, carry (removing halted sequences)
|
| 252 |
+
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
|
| 253 |
+
|
| 254 |
+
new_steps = torch.where(carry.halted, 0, carry.steps)
|
| 255 |
+
|
| 256 |
+
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
|
| 257 |
+
|
| 258 |
+
# Forward inner model
|
| 259 |
+
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
|
| 260 |
+
|
| 261 |
+
outputs = {
|
| 262 |
+
"logits": logits,
|
| 263 |
+
"q_halt_logits": q_halt_logits,
|
| 264 |
+
"q_continue_logits": q_continue_logits
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
with torch.no_grad():
|
| 268 |
+
# Step
|
| 269 |
+
new_steps = new_steps + 1
|
| 270 |
+
is_last_step = new_steps >= self.config.halt_max_steps
|
| 271 |
+
|
| 272 |
+
halted = is_last_step
|
| 273 |
+
|
| 274 |
+
# if training, and ACT is enabled
|
| 275 |
+
if self.training and (self.config.halt_max_steps > 1):
|
| 276 |
+
|
| 277 |
+
# Halt signal
|
| 278 |
+
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
|
| 279 |
+
|
| 280 |
+
if self.config.no_ACT_continue:
|
| 281 |
+
halted = halted | (q_halt_logits > 0)
|
| 282 |
+
else:
|
| 283 |
+
halted = halted | (q_halt_logits > q_continue_logits)
|
| 284 |
+
|
| 285 |
+
# Exploration
|
| 286 |
+
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
|
| 287 |
+
halted = halted & (new_steps >= min_halt_steps)
|
| 288 |
+
|
| 289 |
+
if not self.config.no_ACT_continue:
|
| 290 |
+
# Compute target Q
|
| 291 |
+
# NOTE: No replay buffer and target networks for computing target Q-value.
|
| 292 |
+
# As batch_size is large, there're many parallel envs.
|
| 293 |
+
# Similar concept as PQN https://arxiv.org/abs/2407.04811
|
| 294 |
+
_, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
|
| 295 |
+
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
|
| 296 |
+
|
| 297 |
+
return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
|
arc_v2_public/all_config.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
arch:
|
| 2 |
+
H_cycles: 3
|
| 3 |
+
H_layers: 0
|
| 4 |
+
L_cycles: 4
|
| 5 |
+
L_layers: 2
|
| 6 |
+
expansion: 4
|
| 7 |
+
forward_dtype: bfloat16
|
| 8 |
+
halt_exploration_prob: 0.1
|
| 9 |
+
halt_max_steps: 16
|
| 10 |
+
hidden_size: 512
|
| 11 |
+
loss:
|
| 12 |
+
loss_type: stablemax_cross_entropy
|
| 13 |
+
name: losses@ACTLossHead
|
| 14 |
+
mlp_t: false
|
| 15 |
+
name: recursive_reasoning.trm@TinyRecursiveReasoningModel_ACTV1
|
| 16 |
+
no_ACT_continue: true
|
| 17 |
+
num_heads: 8
|
| 18 |
+
pos_encodings: rope
|
| 19 |
+
puzzle_emb_len: 16
|
| 20 |
+
puzzle_emb_ndim: 512
|
| 21 |
+
beta1: 0.9
|
| 22 |
+
beta2: 0.95
|
| 23 |
+
checkpoint_every_eval: true
|
| 24 |
+
checkpoint_path: checkpoints/Arc2concept-aug-1000-ACT-torch/arc_v2_pub
|
| 25 |
+
data_paths:
|
| 26 |
+
- data/arc2concept-aug-1000
|
| 27 |
+
data_paths_test: []
|
| 28 |
+
ema: true
|
| 29 |
+
ema_rate: 0.999
|
| 30 |
+
epochs: 100000
|
| 31 |
+
eval_interval: 10000
|
| 32 |
+
eval_save_outputs: []
|
| 33 |
+
evaluators:
|
| 34 |
+
- name: arc@ARC
|
| 35 |
+
freeze_weights: false
|
| 36 |
+
global_batch_size: 768
|
| 37 |
+
load_checkpoint: null
|
| 38 |
+
lr: 0.0001
|
| 39 |
+
lr_min_ratio: 1.0
|
| 40 |
+
lr_warmup_steps: 2000
|
| 41 |
+
min_eval_interval: 0
|
| 42 |
+
project_name: Arc2concept-aug-1000-ACT-torch
|
| 43 |
+
puzzle_emb_lr: 0.01
|
| 44 |
+
puzzle_emb_weight_decay: 0.1
|
| 45 |
+
run_name: arc_v2_pub
|
| 46 |
+
seed: 0
|
| 47 |
+
weight_decay: 0.1
|
arc_v2_public/losses.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Tuple, Dict, Sequence, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch import nn
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
IGNORE_LABEL_ID = -100
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def s(x, epsilon=1e-30):
|
| 12 |
+
return torch.where(
|
| 13 |
+
x<0,
|
| 14 |
+
1/(1-x+ epsilon),
|
| 15 |
+
x + 1
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def log_stablemax(x, dim=-1):
|
| 20 |
+
s_x = s(x)
|
| 21 |
+
return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
|
| 25 |
+
logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
|
| 26 |
+
|
| 27 |
+
if valid_mask is None:
|
| 28 |
+
valid_mask = (labels != ignore_index)
|
| 29 |
+
transformed_labels = torch.where(valid_mask, labels, 0)
|
| 30 |
+
prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
|
| 31 |
+
|
| 32 |
+
return -torch.where(valid_mask, prediction_logprobs, 0)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
|
| 36 |
+
# Cast logits to f32
|
| 37 |
+
# Flatten logits
|
| 38 |
+
return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ACTLossHead(nn.Module):
|
| 42 |
+
def __init__(self, model: nn.Module, loss_type: str):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.model = model
|
| 45 |
+
self.loss_fn = globals()[loss_type]
|
| 46 |
+
|
| 47 |
+
def initial_carry(self, *args, **kwargs):
|
| 48 |
+
return self.model.initial_carry(*args, **kwargs) # type: ignore
|
| 49 |
+
|
| 50 |
+
def forward(
|
| 51 |
+
self,
|
| 52 |
+
return_keys: Sequence[str],
|
| 53 |
+
# Model args
|
| 54 |
+
**model_kwargs,
|
| 55 |
+
) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
|
| 56 |
+
# Model logits
|
| 57 |
+
# B x SeqLen x D
|
| 58 |
+
new_carry, outputs = self.model(**model_kwargs)
|
| 59 |
+
labels = new_carry.current_data["labels"]
|
| 60 |
+
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
# Preds
|
| 63 |
+
outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
|
| 64 |
+
|
| 65 |
+
# Correctness
|
| 66 |
+
mask = (labels != IGNORE_LABEL_ID)
|
| 67 |
+
loss_counts = mask.sum(-1)
|
| 68 |
+
loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
|
| 69 |
+
|
| 70 |
+
is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
|
| 71 |
+
seq_is_correct = is_correct.sum(-1) == loss_counts
|
| 72 |
+
|
| 73 |
+
# Metrics (halted)
|
| 74 |
+
valid_metrics = new_carry.halted & (loss_counts > 0)
|
| 75 |
+
metrics = {
|
| 76 |
+
"count": valid_metrics.sum(),
|
| 77 |
+
|
| 78 |
+
"accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
|
| 79 |
+
"exact_accuracy": (valid_metrics & seq_is_correct).sum(),
|
| 80 |
+
|
| 81 |
+
"q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
|
| 82 |
+
"steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
# Losses
|
| 86 |
+
|
| 87 |
+
lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
|
| 88 |
+
q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
|
| 89 |
+
metrics.update({
|
| 90 |
+
"lm_loss": lm_loss.detach(),
|
| 91 |
+
"q_halt_loss": q_halt_loss.detach(),
|
| 92 |
+
})
|
| 93 |
+
# Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
|
| 94 |
+
q_continue_loss = 0
|
| 95 |
+
if "target_q_continue" in outputs:
|
| 96 |
+
q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
|
| 97 |
+
|
| 98 |
+
metrics["q_continue_loss"] = q_continue_loss.detach()
|
| 99 |
+
# Filter outputs for return
|
| 100 |
+
detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
|
| 101 |
+
|
| 102 |
+
return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
|
| 103 |
+
|
arc_v2_public/step_723914
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8d7036b97e7ea38c7dd29d01216bfcfc4e212af3024d5233fe40dd3059e8f4a9
|
| 3 |
+
size 2467988810
|
arc_v2_public/trm.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, List, Dict, Optional
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import copy
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
import random
|
| 10 |
+
from models.common import trunc_normal_init_
|
| 11 |
+
from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
|
| 12 |
+
from models.sparse_embedding import CastedSparseEmbedding
|
| 13 |
+
|
| 14 |
+
IGNORE_LABEL_ID = -100
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
|
| 18 |
+
z_H: torch.Tensor
|
| 19 |
+
z_L: torch.Tensor
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class TinyRecursiveReasoningModel_ACTV1Carry:
|
| 24 |
+
inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
|
| 25 |
+
|
| 26 |
+
steps: torch.Tensor
|
| 27 |
+
halted: torch.Tensor
|
| 28 |
+
|
| 29 |
+
current_data: Dict[str, torch.Tensor]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
|
| 33 |
+
batch_size: int
|
| 34 |
+
seq_len: int
|
| 35 |
+
puzzle_emb_ndim: int = 0
|
| 36 |
+
num_puzzle_identifiers: int
|
| 37 |
+
vocab_size: int
|
| 38 |
+
|
| 39 |
+
H_cycles: int
|
| 40 |
+
L_cycles: int
|
| 41 |
+
|
| 42 |
+
H_layers: int # ignored
|
| 43 |
+
L_layers: int
|
| 44 |
+
|
| 45 |
+
# Transformer config
|
| 46 |
+
hidden_size: int
|
| 47 |
+
expansion: float
|
| 48 |
+
num_heads: int
|
| 49 |
+
pos_encodings: str
|
| 50 |
+
|
| 51 |
+
rms_norm_eps: float = 1e-5
|
| 52 |
+
rope_theta: float = 10000.0
|
| 53 |
+
|
| 54 |
+
# Halting Q-learning config
|
| 55 |
+
halt_max_steps: int
|
| 56 |
+
halt_exploration_prob: float
|
| 57 |
+
|
| 58 |
+
forward_dtype: str = "bfloat16"
|
| 59 |
+
|
| 60 |
+
# Alexia: added
|
| 61 |
+
mlp_t: bool = False # use mlp on L instead of transformer
|
| 62 |
+
puzzle_emb_len: int = 16 # if non-zero, its specified to this value
|
| 63 |
+
no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
|
| 64 |
+
|
| 65 |
+
class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
|
| 66 |
+
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
self.config = config
|
| 70 |
+
if self.config.mlp_t:
|
| 71 |
+
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
|
| 72 |
+
self.mlp_t = SwiGLU(
|
| 73 |
+
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
|
| 74 |
+
expansion=config.expansion,
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
self.self_attn = Attention(
|
| 78 |
+
hidden_size=config.hidden_size,
|
| 79 |
+
head_dim=config.hidden_size // config.num_heads,
|
| 80 |
+
num_heads=config.num_heads,
|
| 81 |
+
num_key_value_heads=config.num_heads,
|
| 82 |
+
causal=False
|
| 83 |
+
)
|
| 84 |
+
self.mlp = SwiGLU(
|
| 85 |
+
hidden_size=config.hidden_size,
|
| 86 |
+
expansion=config.expansion,
|
| 87 |
+
)
|
| 88 |
+
self.norm_eps = config.rms_norm_eps
|
| 89 |
+
|
| 90 |
+
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 91 |
+
# B, L, D = hidden_states.shape
|
| 92 |
+
# Post Norm
|
| 93 |
+
if self.config.mlp_t:
|
| 94 |
+
hidden_states = hidden_states.transpose(1,2)
|
| 95 |
+
out = self.mlp_t(hidden_states)
|
| 96 |
+
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
| 97 |
+
hidden_states = hidden_states.transpose(1,2)
|
| 98 |
+
else:
|
| 99 |
+
# Self Attention
|
| 100 |
+
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
|
| 101 |
+
# Fully Connected
|
| 102 |
+
out = self.mlp(hidden_states)
|
| 103 |
+
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
| 104 |
+
return hidden_states
|
| 105 |
+
|
| 106 |
+
class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
|
| 107 |
+
def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.layers = torch.nn.ModuleList(layers)
|
| 110 |
+
|
| 111 |
+
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 112 |
+
hidden_states = hidden_states + input_injection
|
| 113 |
+
for layer in self.layers:
|
| 114 |
+
hidden_states = layer(hidden_states=hidden_states, **kwargs)
|
| 115 |
+
return hidden_states
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
|
| 119 |
+
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.config = config
|
| 122 |
+
self.forward_dtype = getattr(torch, self.config.forward_dtype)
|
| 123 |
+
|
| 124 |
+
# I/O
|
| 125 |
+
|
| 126 |
+
self.embed_scale = math.sqrt(self.config.hidden_size)
|
| 127 |
+
embed_init_std = 1.0 / self.embed_scale
|
| 128 |
+
|
| 129 |
+
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
| 130 |
+
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
| 131 |
+
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
|
| 132 |
+
|
| 133 |
+
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
|
| 134 |
+
if self.config.puzzle_emb_ndim > 0:
|
| 135 |
+
# Zero init puzzle embeddings
|
| 136 |
+
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
|
| 137 |
+
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
|
| 138 |
+
|
| 139 |
+
# LM Blocks
|
| 140 |
+
if self.config.pos_encodings == "rope":
|
| 141 |
+
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
|
| 142 |
+
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
|
| 143 |
+
base=self.config.rope_theta)
|
| 144 |
+
elif self.config.pos_encodings == "learned":
|
| 145 |
+
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
| 146 |
+
else:
|
| 147 |
+
pass
|
| 148 |
+
|
| 149 |
+
# Reasoning Layers
|
| 150 |
+
self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
|
| 151 |
+
|
| 152 |
+
# Initial states
|
| 153 |
+
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
| 154 |
+
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
| 155 |
+
|
| 156 |
+
# Q head special init
|
| 157 |
+
# Init Q to (almost) zero for faster learning during bootstrapping
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
self.q_head.weight.zero_()
|
| 160 |
+
self.q_head.bias.fill_(-5) # type: ignore
|
| 161 |
+
|
| 162 |
+
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
|
| 163 |
+
# Token embedding
|
| 164 |
+
embedding = self.embed_tokens(input.to(torch.int32))
|
| 165 |
+
|
| 166 |
+
# Puzzle embeddings
|
| 167 |
+
if self.config.puzzle_emb_ndim > 0:
|
| 168 |
+
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
|
| 169 |
+
|
| 170 |
+
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
|
| 171 |
+
if pad_count > 0:
|
| 172 |
+
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
|
| 173 |
+
|
| 174 |
+
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
|
| 175 |
+
|
| 176 |
+
# Position embeddings
|
| 177 |
+
if self.config.pos_encodings == "learned":
|
| 178 |
+
# scale by 1/sqrt(2) to maintain forward variance
|
| 179 |
+
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
|
| 180 |
+
|
| 181 |
+
# Scale
|
| 182 |
+
return self.embed_scale * embedding
|
| 183 |
+
|
| 184 |
+
def empty_carry(self, batch_size: int):
|
| 185 |
+
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
| 186 |
+
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
| 187 |
+
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
|
| 191 |
+
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
| 192 |
+
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
|
| 193 |
+
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 197 |
+
seq_info = dict(
|
| 198 |
+
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# Input encoding
|
| 202 |
+
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
|
| 203 |
+
|
| 204 |
+
# Forward iterations
|
| 205 |
+
it = 0
|
| 206 |
+
z_H, z_L = carry.z_H, carry.z_L
|
| 207 |
+
# H_cycles-1 without grad
|
| 208 |
+
with torch.no_grad():
|
| 209 |
+
for _H_step in range(self.config.H_cycles-1):
|
| 210 |
+
for _L_step in range(self.config.L_cycles):
|
| 211 |
+
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
| 212 |
+
z_H = self.L_level(z_H, z_L, **seq_info)
|
| 213 |
+
# 1 with grad
|
| 214 |
+
for _L_step in range(self.config.L_cycles):
|
| 215 |
+
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
| 216 |
+
z_H = self.L_level(z_H, z_L, **seq_info)
|
| 217 |
+
|
| 218 |
+
# LM Outputs
|
| 219 |
+
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
|
| 220 |
+
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
|
| 221 |
+
q_logits = self.q_head(z_H[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
|
| 222 |
+
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class TinyRecursiveReasoningModel_ACTV1(nn.Module):
|
| 226 |
+
"""ACT wrapper."""
|
| 227 |
+
|
| 228 |
+
def __init__(self, config_dict: dict):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
|
| 231 |
+
self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
|
| 232 |
+
|
| 233 |
+
@property
|
| 234 |
+
def puzzle_emb(self):
|
| 235 |
+
return self.inner.puzzle_emb
|
| 236 |
+
|
| 237 |
+
def initial_carry(self, batch: Dict[str, torch.Tensor]):
|
| 238 |
+
batch_size = batch["inputs"].shape[0]
|
| 239 |
+
|
| 240 |
+
return TinyRecursiveReasoningModel_ACTV1Carry(
|
| 241 |
+
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
|
| 242 |
+
|
| 243 |
+
steps=torch.zeros((batch_size, ), dtype=torch.int32),
|
| 244 |
+
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
|
| 245 |
+
|
| 246 |
+
current_data={k: torch.empty_like(v) for k, v in batch.items()}
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
|
| 250 |
+
|
| 251 |
+
# Update data, carry (removing halted sequences)
|
| 252 |
+
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
|
| 253 |
+
|
| 254 |
+
new_steps = torch.where(carry.halted, 0, carry.steps)
|
| 255 |
+
|
| 256 |
+
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
|
| 257 |
+
|
| 258 |
+
# Forward inner model
|
| 259 |
+
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
|
| 260 |
+
|
| 261 |
+
outputs = {
|
| 262 |
+
"logits": logits,
|
| 263 |
+
"q_halt_logits": q_halt_logits,
|
| 264 |
+
"q_continue_logits": q_continue_logits
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
with torch.no_grad():
|
| 268 |
+
# Step
|
| 269 |
+
new_steps = new_steps + 1
|
| 270 |
+
is_last_step = new_steps >= self.config.halt_max_steps
|
| 271 |
+
|
| 272 |
+
halted = is_last_step
|
| 273 |
+
|
| 274 |
+
# if training, and ACT is enabled
|
| 275 |
+
if self.training and (self.config.halt_max_steps > 1):
|
| 276 |
+
|
| 277 |
+
# Halt signal
|
| 278 |
+
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
|
| 279 |
+
|
| 280 |
+
if self.config.no_ACT_continue:
|
| 281 |
+
halted = halted | (q_halt_logits > 0)
|
| 282 |
+
else:
|
| 283 |
+
halted = halted | (q_halt_logits > q_continue_logits)
|
| 284 |
+
|
| 285 |
+
# Exploration
|
| 286 |
+
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
|
| 287 |
+
halted = halted & (new_steps >= min_halt_steps)
|
| 288 |
+
|
| 289 |
+
if not self.config.no_ACT_continue:
|
| 290 |
+
# Compute target Q
|
| 291 |
+
# NOTE: No replay buffer and target networks for computing target Q-value.
|
| 292 |
+
# As batch_size is large, there're many parallel envs.
|
| 293 |
+
# Similar concept as PQN https://arxiv.org/abs/2407.04811
|
| 294 |
+
_, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
|
| 295 |
+
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
|
| 296 |
+
|
| 297 |
+
return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
|