r11_rw_finalstep_v1
Reproducibility control โ R9 config with v1 data (80.4%)
- Reproduces R9 best config (reward-weighted, final-step only, v1 data)
- R9 achieved 81.9% for same config โ confirms ~1pp run-to-run variance
- Uses original v1 teacher states for comparison
Overview
This model implements QThink (Parallel Latent Reasoning via Per-Step Distillation of Multiple Rollouts) โ an autoregressive latent reasoning loop that processes K=6 continuous thought steps before generating a text answer. Teacher hidden states are extracted from multiple chain-of-thought rollouts generated by the base model and distilled into the latent representations at every step.
Key idea: Instead of distilling from a single reasoning trace, we aggregate hidden states from multiple rollouts (16 per problem) and supervise every latent step. The teacher signal is the reward-weighted average of CORRECT rollouts only.
Architecture
- Base model: Qwen3-1.7B (1.7B parameters)
- Fine-tuning: LoRA (rank=32, alpha=16) on q/k/v/o/gate/up/down_proj
- Projection head: Linear(2048, 2048) โ GELU โ Linear(2048, 2048) โ LayerNorm(2048)
- Latent steps: K=6 autoregressive continuous thought steps
- Inference: Process prompt โ K latent steps via ProjectionHead + KV cache โ greedy text generation
Training Details
| Parameter | Value |
|---|---|
| Mode | codi_reward_weighted |
| Per-step distillation | False |
| Distillation ฮณ | 1.0 |
| Learning rate | 0.0002 |
| Epochs | 3 |
| Batch size (per GPU) | 2 |
| Gradient accumulation | 8 |
| Effective batch size | 128 (across 8 GPUs) |
| Max answer length | 128 |
| Latent steps (K) | 6 |
| Task | GSM8k (7,473 training problems) |
| Rollouts per problem | 16 |
| GSM8k test accuracy | 80.4% |
How to Use
Requirements
pip install torch transformers
Inference Code
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model
model_name = "LakshyAAAgrawal/continuous-thought-r11_rw_finalstep_v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto"
)
model.eval()
# Load projection head
class ProjectionHead(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, hidden_size),
nn.LayerNorm(hidden_size),
)
def forward(self, x):
return self.mlp(x)
proj = ProjectionHead(model.config.hidden_size)
proj.load_state_dict(torch.load(
hf_hub_download(model_name, "projection_head.pt"), map_location="cpu"
))
proj = proj.to(model.dtype).to(model.device).eval()
# Generate with latent reasoning
question = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"
messages = [{"role": "user", "content": f"Solve the following math problem step by step. Show your work and put your final numerical answer after ####.\n\n{question}"}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Step 1: Process prompt
with torch.no_grad():
out = model(**inputs, output_hidden_states=True, use_cache=True)
past_kv = out.past_key_values
latent = out.hidden_states[-1][:, -1, :] # last token hidden state
# Step 2: K latent reasoning steps
mask = inputs["attention_mask"].clone()
for k in range(6):
latent = proj(latent)
mask = torch.cat([mask, torch.ones(1, 1, device=mask.device, dtype=mask.dtype)], dim=1)
out = model(inputs_embeds=latent.unsqueeze(1), attention_mask=mask,
past_key_values=past_kv, output_hidden_states=True, use_cache=True)
past_kv = out.past_key_values
latent = out.hidden_states[-1][:, -1, :]
# Step 3: Greedy text generation
next_token = out.logits[:, -1, :].argmax(dim=-1)
generated = [next_token]
for _ in range(2047):
if next_token.item() == tokenizer.eos_token_id:
break
mask = torch.cat([mask, torch.ones(1, 1, device=mask.device, dtype=mask.dtype)], dim=1)
out = model(input_ids=next_token.unsqueeze(0), attention_mask=mask,
past_key_values=past_kv, use_cache=True)
past_kv = out.past_key_values
next_token = out.logits[:, -1, :].argmax(dim=-1)
generated.append(next_token)
response = tokenizer.decode(torch.cat(generated), skip_special_tokens=True)
print(response)
Evaluation
To evaluate on the full GSM8k test set, use the evaluation script from our repository:
python evaluate.py \
--model_dir LakshyAAAgrawal/continuous-thought-r11_rw_finalstep_v1 \
--mode codi_reward_weighted \
--output results/r11_rw_finalstep_v1.json \
--max_new_tokens 2048 \
--num_latent 6
Important: Use max_new_tokens=2048 for evaluation. The model generates verbose
chain-of-thought text after the latent steps, requiring more tokens than standard models.
Results Comparison
| Model | Mode | Per-step | ฮณ | ans_len | GSM8k Accuracy |
|---|---|---|---|---|---|
| This model | codi_reward_weighted | False | 1.0 | 128 | 80.4% |
| Qwen3-1.7B (base) | โ | โ | โ | โ | 77.3% |
| Discrete SFT | sft | โ | โ | โ | 80.7% |
| QThink RW final-step | rw | no | 1.0 | 128 | 81.0% |
| QThink Uniform per-step ans256 | uniform | yes | 2.0 | 256 | 83.2% |
| QThink RW per-step ans256 | rw | yes | 1.0 | 256 | 82.7% |
Citation
If you use this model, please cite:
@misc{continuous-thought-2025,
title={QThink: Parallel Latent Reasoning via Per-Step Distillation of Multiple Rollouts},
author={Lakshya Agrawal},
year={2025},
url={https://huggingface.co/LakshyAAAgrawal/continuous-thought-r11_rw_finalstep_v1}
}
- Downloads last month
- 1
Model tree for LakshyAAAgrawal/continuous-thought-r11_rw_finalstep_v1
Dataset used to train LakshyAAAgrawal/continuous-thought-r11_rw_finalstep_v1
Evaluation results
- Accuracy on GSM8ktest set self-reported80.400