How to Train Recursive Models?
A frontier in adaptive compute
TL;DR
- I swapped out the transformer in @karpathy’s Nanochat for a recursive transformer
- It performs similarly when trained iso-FLOPs, but needs ~50% fewer parameters
- Using recursion opens up a path to adaptive compute without the need for long reasoning traces
SETUP:
- Nanochat is a 20 layer model
- I swapped out the middle 16 layers with a recursive 4-layer block. So, 2 prelude layers + 4 recursive layers (recursed 4x times on average) + 2 coda layers = 20 effective layers.
- Training is iso-data & iso-flops
RESULTS:
- The core/aggregate score is a little lower than the 20 layer model
- gsm8k is a little bit higher, and shows a meaningful effect of varying the number of recurrences at inference time
WHY IS RECURSION INTERESTING?
You get a lower parameter model with similar performance => inference on smaller devices or fewer GPUs (allows smaller pipeline bubbles)
Opens up adaptive compute -> you vary inference compute based on difficulty
IDEAS FROM LITERATURE ON ADAPTIVE COMPUTE
- You can vary compute per token by recursing until the predicted probability distribution starts to look the same! No need for a trained halting head.
- You can save on kv cache by just storing the state from the last recursion (throw out info from earlier recursions), it works!
- For continuous batching with adaptive compute you probably need to separate the recursive block from prelude/code OR drop prelude/coda perhaps.
OUTLOOK
- Probably big labs are or should incorporate adaptive compute?
- Expect libraries like SGLang and vLLM and perhaps Nvidia Dynamo to start supporting it.
- Big for edge device inference.
Many thanks to Sean McLeish for discussions in preparing this video/tweet.
References:
- Recursion & Adaptive Compute: https://arxiv.org/abs/2502.05171
- TRM Paper: https://arxiv.org/abs/2510.04871
Github Repo (see the `recursive` branch)
💡 Done-for-you Custom Fine-tuning Services
Learn More: https://trelis.com/fine-tuning-services/
💸 Starting a New Project/Venture?
Apply for a Trelis Grant: https://trelis.com/trelis-ai-grants/
📧 Get Trelis AI Tutorials by Email
Subscribe on Substack: https://trelis.substack.com
TIMESTAMPS:
0:00 Recursive Nanochat - a comparison with Karpathy’s 500M parameter model
3:07 Benchmark Results on Recursive Nanochat
5:54 What are the benefits of recursive models?
6:52 Recursive Models allow inference on smaller devices and fewer GPUs
8:07 Recursive Models open a pathway to adaptive compute
9:39 Recursive vs Non-recursive Architecture
13:48 How to handle the recursive stream via an adapter
16:49 Training for adaptive compute / recursions - Poisson log-normal recursion sampling
18:14 Handling torch.compile with recursive models
20:07 Implementing adaptive compute (stopping recursions early)
22:14 kv cache strategies for recursive models
24:16 Inference engine (vLLM) implications for recursive models
26:20 Training dynamics of recursive models (Wandb overview), incl. flops utilisation
31:21 Code Review of Trelis/nanochat
32:05 Truncated backpropagation through time
33:56 Recursive loop adapter initialisation
35:25 Dynamic torch compile
37:06 Wrap up
Recursive Transformers: Training a 300M Parameter Model with Performance Close to 500M Parameters
I trained a recursive transformer model with 300 million parameters that performs similarly to a standard 500 million parameter model when both are trained with equal compute. The recursive model uses 4 layers that repeat an average of 4 times, replacing 16 layers in the standard architecture.
Starting Point: NanoChat
The baseline is NanoChat, Andrej Karpathy’s training repository that produces a 500 million parameter, 20-layer model for approximately $100. This model scores roughly at GPT-2 levels on standard benchmarks.
I modified this architecture by keeping 2 initial layers (prelude), replacing the middle 16 layers with 4 recursive layers that repeat 4 times on average, and keeping 2 final layers (coda). This yields 8 total layers but an effective depth of 20 layers: 2 + (4 × 4) + 2 = 20.
Training Configuration
Both models trained on identical data using the same number of floating point operations. The recursive model required adjusting from 20 tokens per parameter (the chinchilla optimal ratio) to 34 tokens per parameter, since the actual parameter count decreased while maintaining equivalent computational cost.
During training, I sampled the number of recursions from a Poisson log-normal distribution with mean 4, minimum 1, and maximum 16. This exposes the model to varying recursion depths, though it complicates torch compile since each recursion count creates a different computational graph that must be cached.
Training speed reached 75-80% of the standard model’s throughput after compilation. The initial compilation phase was slower due to caching 16 different graph configurations. Mean flops utilization was lower than the dense transformer, partly because torch compile must run in dynamic mode rather than with a fixed graph.
Benchmark Results
On the composite chat core metric (combining ARC-Easy, ARC-Challenge, MMLU, GSM8K, HumanEval, and SpellingBee), the recursive model scored 0.27 compared to the standard model’s slightly higher score. All measurements used 4 recursions at inference time.
The GSM8K benchmark (grade school mathematics) showed more substantial differences. Performance improved from 3.5% with 2 recursions to approximately 6% with 4 recursions. This suggests reasoning tasks may benefit more from additional recursion depth, though I would need multiple training runs to establish statistical significance.
Validation loss tracked similarly between both models, slightly higher for the recursive version. Gradient norms were higher in the recursive model but remained stable below 0.5, indicating no training stability issues at the same learning rate.
Architecture Details
The recursive mechanism works by concatenating the input vector with the recurrence vector, then using a trainable linear adapter layer to project back to the original dimension. On the first recursion, the recurrence vector initializes to zeros. On subsequent recursions, it contains information from the previous pass.
The adapter layer requires careful initialization. It must act as an identity function for the input portion while zeroing the recurrence portion initially. Without this, the input doesn’t pass through cleanly and convergence suffers.
At inference time, I implemented a kickstart mechanism: instead of initializing the recurrence vector to zeros for each new token, I use the final recurrence values from the previous token. This provides a performance boost.
Training Implementation Notes
I limited backpropagation through time to a maximum of 4 recursions. When the model uses more than 4 recursions (up to the maximum of 16), gradients only flow through the final 4. This trades some theoretical accuracy for memory efficiency. Backpropagating through all 16 recursions would require substantially more memory.
For mid-training, I disabled torch compile entirely. Mid-training in NanoChat is brief, so the compilation overhead would dominate actual training time. Disabling compile increases memory usage, requiring a reduced batch size, but eliminates the compilation bottleneck.
Adaptive Compute Potential
The recursive architecture enables adaptive compute: using more recursions for difficult problems and fewer for simple ones. This differs from reasoning models that generate more tokens, because recursion happens in latent space rather than sampling from probability distributions.
One approach to adaptive compute monitors the KL divergence between probability distributions across consecutive recursions. When the distribution stabilizes (low KL divergence), the model can stop recurring. This avoids the need to train a separate halting head.
However, adaptive compute with continuous batching poses challenges. Different sequences in a batch may require different recursion depths. Sequences that finish early must idle while others continue, reducing GPU utilization. The maximum recursion depth in any batch effectively determines the compute time for all sequences in that batch.
This differs from standard transformers where all sequences pass through the same fixed number of layers. Solving this would require splitting inference engine logic to handle the embedding layer, recursive block, and language model head separately, somewhat analogous to how pre-fill and decode are handled differently.
KV Cache Considerations
In a standard 20-layer transformer, the KV cache stores activations for each layer. In the recursive model, I only store the most recent recurrence state rather than maintaining history for all recursions. When a token attends to previous tokens that used different recursion depths, it uses the latest available state.
This reduces KV cache size proportionally to the number of layers eliminated. Empirically, this simplified caching approach maintains good performance.
Practical Tradeoffs
The recursive model offers two main benefits: smaller parameter count (300M vs 500M) and potential for adaptive compute. The smaller size means fitting on fewer GPUs during training, reducing pipeline parallelism and associated bubble overhead. It also makes deployment on consumer hardware more feasible.
The performance is comparable on an iso-flops basis, though not definitively better. Without error bars from multiple training runs, the differences I observed are likely within noise for most benchmarks. The GSM8K improvement appears more substantial but would require replication to confirm.
The training complexity increases: dynamic torch compile, slower compilation, and reduced token throughput. For longer training runs, the compilation overhead would be amortized. For this 4-hour training run, it approximately doubled the wall-clock time.
Code Changes
The implementation modifies Karpathy’s NanoChat repository in the recursive branch. Key changes include:
Configuration parameters for prelude, recursion, and coda block sizes
Poisson log-normal sampling for recursion depth during training
Adapter layer initialization to identity for input passthrough
Dynamic torch compile mode instead of static
Backpropagation through time limiting to 4 recursions maximum
KV cache management using only the latest recurrence state
The repository includes weights and biases logs showing training curves, compilation overhead, and benchmark progression. All trained model weights are available on HuggingFace under the Trelis Research organization.


Impressive work on the recursive nanochat implementation. The 50% parameter reduction while matching performance is huge for edge deployments. I tried somethign similar with loop transformers a few months back but couldn't get the kv cache strategy right. The idea of discarding earlier recursion states and only keeping the latest one is clever, way simpler than what i was attempting.