Train an LLM to Self-Correct with Verifiable Backtracking
A rejection sampling technique for improved reasoning.
Training a model to self-correct
TL;DR
1. Sample answers from a model (prompted to respond with thinking/reasoning and then the final answer)
2. Filter out the wrong answers and feed them back to the model with "Wait, that's not right..." appended in place of the closing </think> tag. aka the "verifiable backtracking step"
3. Filter out correct answers from step 1 and 2, and use them for SFT fine-tuning.
It appears this allows you to improve pass@k (i.e. the number of correct answers reached with k attempts).
Unlike "budget forcing" in the s1 paper, this is a training-time technique.
== The Setup ==
- Use ~8k rows of grade school maths (gsm8k) as training data
- Sample answers, and do SFT on Llama 3.2 1B.
- Compare maj@k and pass@k between i) the base model, ii) the base model with SFT on correct answers from #1 above, iii) the base model with SFT on correct answers from #1 and #2 above!
== Result ==
- Adding the "verifiable backtracking" samples to the SFT mix seems to improve pass@k from about 80% up to 85%, which seems not possible - in this case - via rejection sampling and SFT alone (or ORPO or GPRO as per an earlier video).
== Commentary ==
- Verifiable backtracking seems to be able to elicit a form of self correction that is perhaps latent in pre-training data, beyond what can be elicited via rejection sampling alone (at least for this setup).
== Isn't this the same as the s1 paper? ==
"Budget forcing" in the s1 paper is a related but parallel approach.
- Budget forcing is an inference time, not training/tuning technique.
- Budget forcing involves taking a thinking model, having it generate an answer, and then replacing the closing </think> tag with "Wait..." in order to continue generating for more tokens.
- There is no knowledge or checking being done as to whether the initial answers is correct, BUT, prolonging thinking to a target length does empirically lead to better final answers from the model. A key finding of the s1 paper!
Verifiable backtracking differs in that wrong answers are checked - during data generation (not inference/test time) - and then sent back for an attempted correction. Only answers that are verifiably wrong AND THEN verifiably correct are added to the rejection sampling data mix.
== Emergent Patterns ==
1. Interestingly, after doing this training, the answers generated by the updated model include cases where the model backtracks multiple times in the same answer (even though the training data - in the case I tested - only includes verified single back-tracks)
2. Sometimes the model backtracks and still gets the answer wrong. Yet, overall, the net effect on pass@k is positive.
== Big picture comment ==
Doing verifiable backtracking effectively linearises a parallel process of sampling and verifying multiple answers.
This is useful, because parallel sampling - in practise - requires the existences of a verifier (or the use of majority voting) to be useful for an end user. Linearising the reasoning allows for a better user experience where one can keep generating (and control the amount of generation) to get better answers.
One might describe a key contribution of o3 and R1 (and o1) as having linearised and unified the search and verification process.
== On the choice of "Wait..." ==
Since one verifies that wrong answers are wrong, one can use stronger language than just "Wait...":
- i.e. "Wait, that's definitely wrong..."
Interestingly (and intuitively), this results in getting a larger proportion of answers correct!!!
== Side-note on Inference ==
To be clear, at inference/test time, there can be and is no "verification" used. You just run inference and get whatever answer you get (although that answer will now more likely do backtracking).
Liked this post? Send me a reply or comment below to say yes or no!
Cheers, Ronan
🛠 Explore Fine-tuning, Inference, Vision, Audio, and Evaluation Tools