r11_rw_finalstep_v1

Reproducibility control โ€” R9 config with v1 data (80.4%)

  • Reproduces R9 best config (reward-weighted, final-step only, v1 data)
  • R9 achieved 81.9% for same config โ€” confirms ~1pp run-to-run variance
  • Uses original v1 teacher states for comparison

Overview

This model implements QThink (Parallel Latent Reasoning via Per-Step Distillation of Multiple Rollouts) โ€” an autoregressive latent reasoning loop that processes K=6 continuous thought steps before generating a text answer. Teacher hidden states are extracted from multiple chain-of-thought rollouts generated by the base model and distilled into the latent representations at every step.

Key idea: Instead of distilling from a single reasoning trace, we aggregate hidden states from multiple rollouts (16 per problem) and supervise every latent step. The teacher signal is the reward-weighted average of CORRECT rollouts only.

Architecture

  • Base model: Qwen3-1.7B (1.7B parameters)
  • Fine-tuning: LoRA (rank=32, alpha=16) on q/k/v/o/gate/up/down_proj
  • Projection head: Linear(2048, 2048) โ†’ GELU โ†’ Linear(2048, 2048) โ†’ LayerNorm(2048)
  • Latent steps: K=6 autoregressive continuous thought steps
  • Inference: Process prompt โ†’ K latent steps via ProjectionHead + KV cache โ†’ greedy text generation

Training Details

Parameter Value
Mode codi_reward_weighted
Per-step distillation False
Distillation ฮณ 1.0
Learning rate 0.0002
Epochs 3
Batch size (per GPU) 2
Gradient accumulation 8
Effective batch size 128 (across 8 GPUs)
Max answer length 128
Latent steps (K) 6
Task GSM8k (7,473 training problems)
Rollouts per problem 16
GSM8k test accuracy 80.4%

How to Use

Requirements

pip install torch transformers

Inference Code

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model
model_name = "LakshyAAAgrawal/continuous-thought-r11_rw_finalstep_v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.bfloat16, device_map="auto"
)
model.eval()

# Load projection head
class ProjectionHead(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
        )
    def forward(self, x):
        return self.mlp(x)

proj = ProjectionHead(model.config.hidden_size)
proj.load_state_dict(torch.load(
    hf_hub_download(model_name, "projection_head.pt"), map_location="cpu"
))
proj = proj.to(model.dtype).to(model.device).eval()

# Generate with latent reasoning
question = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"
messages = [{"role": "user", "content": f"Solve the following math problem step by step. Show your work and put your final numerical answer after ####.\n\n{question}"}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# Step 1: Process prompt
with torch.no_grad():
    out = model(**inputs, output_hidden_states=True, use_cache=True)
    past_kv = out.past_key_values
    latent = out.hidden_states[-1][:, -1, :]  # last token hidden state

    # Step 2: K latent reasoning steps
    mask = inputs["attention_mask"].clone()
    for k in range(6):
        latent = proj(latent)
        mask = torch.cat([mask, torch.ones(1, 1, device=mask.device, dtype=mask.dtype)], dim=1)
        out = model(inputs_embeds=latent.unsqueeze(1), attention_mask=mask,
                    past_key_values=past_kv, output_hidden_states=True, use_cache=True)
        past_kv = out.past_key_values
        latent = out.hidden_states[-1][:, -1, :]

    # Step 3: Greedy text generation
    next_token = out.logits[:, -1, :].argmax(dim=-1)
    generated = [next_token]
    for _ in range(2047):
        if next_token.item() == tokenizer.eos_token_id:
            break
        mask = torch.cat([mask, torch.ones(1, 1, device=mask.device, dtype=mask.dtype)], dim=1)
        out = model(input_ids=next_token.unsqueeze(0), attention_mask=mask,
                    past_key_values=past_kv, use_cache=True)
        past_kv = out.past_key_values
        next_token = out.logits[:, -1, :].argmax(dim=-1)
        generated.append(next_token)

    response = tokenizer.decode(torch.cat(generated), skip_special_tokens=True)
    print(response)

Evaluation

To evaluate on the full GSM8k test set, use the evaluation script from our repository:

python evaluate.py \
    --model_dir LakshyAAAgrawal/continuous-thought-r11_rw_finalstep_v1 \
    --mode codi_reward_weighted \
    --output results/r11_rw_finalstep_v1.json \
    --max_new_tokens 2048 \
    --num_latent 6

Important: Use max_new_tokens=2048 for evaluation. The model generates verbose chain-of-thought text after the latent steps, requiring more tokens than standard models.

Results Comparison

Model Mode Per-step ฮณ ans_len GSM8k Accuracy
This model codi_reward_weighted False 1.0 128 80.4%
Qwen3-1.7B (base) โ€” โ€” โ€” โ€” 77.3%
Discrete SFT sft โ€” โ€” โ€” 80.7%
QThink RW final-step rw no 1.0 128 81.0%
QThink Uniform per-step ans256 uniform yes 2.0 256 83.2%
QThink RW per-step ans256 rw yes 1.0 256 82.7%

Citation

If you use this model, please cite:

@misc{continuous-thought-2025,
  title={QThink: Parallel Latent Reasoning via Per-Step Distillation of Multiple Rollouts},
  author={Lakshya Agrawal},
  year={2025},
  url={https://huggingface.co/LakshyAAAgrawal/continuous-thought-r11_rw_finalstep_v1}
}
Downloads last month
1
Safetensors
Model size
2B params
Tensor type
BF16
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for LakshyAAAgrawal/continuous-thought-r11_rw_finalstep_v1

Finetuned
Qwen/Qwen3-1.7B
Finetuned
(609)
this model

Dataset used to train LakshyAAAgrawal/continuous-thought-r11_rw_finalstep_v1

Evaluation results