gsm-gpt2-rff

Model Overview

This repository demonstrates that model-based data selection can significantly improve data efficiency under iso-compute training budgets. See the associated training dataset gsm-gpt2-rff Dataset.

This repository contains LoRA adapters for gpt2 trained on the kurtos-ai/gsm-gpt2-rff dataset family under multiple data-selection configurations (gamma_0.01, gamma_0.02, gamma_0.05, gamma_0.1, and full).

Each configuration includes two checkpoints:

  • *_best: checkpoint selected by best validation loss during training.
  • *_last: final checkpoint at the end of training.

All adapters are PEFT LoRA adapters and are intended to be loaded on top of the GPT-2 base model.

Table of Available Adapters

All adapters have gpt2 as base model and an adapter size of 3,253,104 bytes.

Adapter Path Config Checkpoint Eval Loss Eval Accuracy
gamma_0.01_best gamma_0.01 Best (by eval loss) 5.8880 0.72%
gamma_0.01_last gamma_0.01 Last 6.3623 4.45%
gamma_0.02_best gamma_0.02 Best (by eval loss) 5.8692 1.6%
gamma_0.02_last gamma_0.02 Last 6.0615 3.25%
gamma_0.05_best gamma_0.05 Best (by eval loss) 5.8248 0.92%
gamma_0.05_last gamma_0.05 Last 5.8541 1.45%
gamma_0.1_best gamma_0.1 Best (by eval loss) 5.8283 1.01%
gamma_0.1_last gamma_0.1 Last 5.8506 1.43%
full_best full Best (by eval loss) 6.1236 0.53%
full_last full Last 6.1467 0.5%

Loading Instructions

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

repo_id = "kurtos-ai/gsm-gpt2-rff"
adapter_subfolder = "gamma_0.05_best"  # change to any adapter path in the table above

base_model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder=adapter_subfolder)
model = PeftModel.from_pretrained(base_model, repo_id, subfolder=adapter_subfolder)
model.eval()

prompt = "Question: If 2x + 3 = 11, what is x?\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
    output_ids = model.generate(**inputs, max_new_tokens=64)
print(tokenizer.decode(output_ids[0], skip_special_tokens=True))

Training Configuration Summary

Training was done in 4xRTX-5090, with the following configuration:

  • Batch size: 40
  • Gradient accumulation: 1
  • Learning rate: 0.0002
  • Mixed precision: bf16
  • Max number of tokens per example (truncation): 512

From adapter configs (adapter_config.json):

  • Base model: gpt2
  • PEFT method: LoRA (peft_type = LORA)
  • Task type: CAUSAL_LM
  • Target modules: c_attn, c_proj
  • Rank (r): 8
  • LoRA alpha: 16
  • LoRA dropout: 0.05
  • Bias setting: none
  • PEFT version: 0.18.1

The training was done at (near) iso-compute in the following manner:

Gamma Epochs Exa-FLOPs Total Time Total Time (to best)
gamma_0.01 3000 3.8198 10:15:05 00:11:05
gamma_0.02 1500 4.1164 10:41:53 00:12:27
gamma_0.05 600 4.3628 10:56:49 00:54:23
gamma_0.1 300 4.7578 11:21:53 02:46:47
full 30 5.7090 11:57:40 06:04:32

Evaluation was done every 200 training steps, both in terms of cross entropy loss and exact match accuracy, with respect to the following eval sources:

Full training records are included (with more detailed per-set evaluation metrics).

Evaluation Results

Best and last metrics from training_records.<config>.json. Dollar costs at runpod's price of $3.56/h.

Config Eval Loss Eval Acc Total Time Dollar cost Best Eval Loss Eval Acc @ Best Loss Time to best Dollar cost to best
gamma_0.01 6.3623 4.45% 10:15:05 $36.49 5.8880 0.72% 00:11:05 $0.66
gamma_0.02 6.0615 3.25% 10:41:53 $38.09 5.8692 1.6% 00:12:27 $0.74
gamma_0.05 5.8541 1.45% 10:56:49 $38.97 5.8248 0.92% 00:54:23 $3.23
gamma_0.1 5.8506 1.43% 11:21:53 $40.46 5.8283 1.01% 02:46:47 $9.90
full 6.1467 0.50% 11:57:40 $42.58 6.1236 0.53% 06:04:32 $21.63

Best checkpoints are selected by validation loss (cross entropy). Exact-match accuracy may peak at different training steps. Indeed, we observe a grokking effect for exact match accuracy, the best accuracies being reached much later in the training run.

Savings

Selecting each gamma fraction took (see gsm-gpt2-rff Dataset):

  • 29 min
  • $1.72

For each training run, savings and metric deltas are reported with respect to full at best checkpoint.

Selection fraction Eval Loss Diff Acc Diff Time savings Dollar savings Dollar savings (%)
0.01 -0.2356 +0.19% 05:53:27 $20.97 96.95%
0.02 -0.2544 +1.07% 05:52:05 $20.89 96.58%
0.05 -0.2988 +0.39% 05:10:09 $18.40 85.07%
0.1 -0.2953 +0.48% 03:17:45 $11.73 54.23%

We observe consistent evaluation loss improvements and notable gains in exact match accuracy when training on the filtered datasets. The compute savings are massive, even when accounting for the $1.72 spent on the selection.

These results suggest that data selection improves scaling laws under fixed compute budgets.

Intended Use

This model is intended for:

  • Research on data selection under iso-compute budgets
  • Studying scaling behavior under filtered training sets
  • Small-scale reasoning experiments with LoRA adapters

This model is not intended for:

  • Production deployment
  • High-accuracy mathematical reasoning
  • Safety-critical applications

Limitations

  • GPT-2 is a small model with limited reasoning ability.
  • Exact match accuracy remains low in absolute terms.
  • Selection procedure (see gsm-gpt2-rff Dataset) overrepresented examples from aqua_rat and math_qa, whose test splits are both used for evaluation. However:
    • Results (esp. exact match accuracy) also improve for gsm8k, which is underrepresented
    • Selection was not aware of the evaluation datasets.
  • Eval dataset is heavily multiple choice based (math_qa and aqua_rat are multiple choice datasets). This can explain (partially) the accuracy improvements, as the model learns to return a letter as an answer. However, accuracy improvements are also notable for gsm8k, which suggests that our data efficient method is doing more than merely teaching the model how to solve multiple choice questions.

License (Apache 2.0)

This model repository is released under the Apache License 2.0.

  • License text: http://www.apache.org/licenses/LICENSE-2.0
  • Please also review the licenses of any datasets used during training.
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for kurtos-ai/gsm-gpt2-rff

Adapter
(1672)
this model

Dataset used to train kurtos-ai/gsm-gpt2-rff