lujangusface commited on
Commit
9dc0d9e
·
verified ·
1 Parent(s): 73b48a5

docs: standardize model card for public release

Browse files
Files changed (1) hide show
  1. README.md +94 -79
README.md CHANGED
@@ -1,65 +1,94 @@
1
  ---
2
- language:
3
- - en
4
- license: mit
5
  library_name: transformers
6
- tags:
7
- - speculative-decoding
8
- - eagle3
9
- - draft-model
10
- - jax
11
- - tpu
12
  base_model: deepseek-ai/DeepSeek-R1-Distill-Qwen-14B
 
 
 
 
 
 
 
 
13
  ---
14
 
15
  # EAGLE3 Draft Head — DeepSeek-R1-Distill-Qwen-14B
16
 
17
- An [EAGLE3](https://github.com/SafeAILab/EAGLE) speculative decoding draft head for
18
- [`deepseek-ai/DeepSeek-R1-Distill-Qwen-14B`](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-14B),
19
- trained on TPU v4 with [SpecJAX](https://github.com/thoughtworks/specjax) a pure JAX port of the EAGLE3 training pipeline.
20
 
21
  ## Usage
22
 
23
- Load with [SGLang](https://github.com/sgl-project/sglang) (recommended):
 
 
24
 
25
  ```bash
26
  python -m sglang.launch_server \
27
- --model deepseek-ai/DeepSeek-R1-Distill-Qwen-14B \
28
- --speculative-algorithm EAGLE3 \
29
- --speculative-draft-model-path thoughtworks/DeepSeek-R1-Distill-Qwen-14B-Eagle3 \
30
- --speculative-num-steps 5 \
31
- --speculative-eagle-topk 4 \
32
- --speculative-num-draft-tokens 16
33
  ```
34
 
35
- Or with [vLLM](https://github.com/vllm-project/vllm):
 
 
36
 
37
  ```bash
38
- python -m vllm.entrypoints.openai.api_server \
39
- --model deepseek-ai/DeepSeek-R1-Distill-Qwen-14B \
40
- --speculative-model thoughtworks/DeepSeek-R1-Distill-Qwen-14B-Eagle3 \
41
- --num-speculative-tokens 5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  ```
43
 
44
  ## Training Details
45
 
46
- | | |
47
- |---|---|
48
- | **Target model** | `deepseek-ai/DeepSeek-R1-Distill-Qwen-14B` |
49
- | **Training framework** | [SpecJAX](https://github.com/thoughtworks/specjax) JAX/TPU port of EAGLE3 |
50
- | **Hardware** | TPU v4-32 (4 hosts × 4 chips, tp=4, dp=4) |
51
- | **Dataset** | 54K mixed (45% ShareGPT / 35% UltraChat-200K / 20% Open-PerfectBlend) |
52
- | **Epochs** | 3 |
53
- | **Steps** | 2,490 |
54
- | **Wall time** | ~84 min |
55
- | **Learning rate** | 8e-4, cosine decay, 3% warmup |
56
- | **Batch size** | 2 (grad accum 8, effective batch 16 per DP rank) |
57
- | **Max length** | 512 tokens |
58
- | **TTT length** | 7 (test-time training rollout positions) |
59
 
60
- ## Results
61
 
62
- Token acceptance rates measured on the training distribution (54K mixed):
 
 
 
 
63
 
64
  | Position | Acceptance Rate |
65
  |----------|----------------|
@@ -71,54 +100,40 @@ Token acceptance rates measured on the training distribution (54K mixed):
71
  | acc_5 | 55.7% |
72
  | acc_6 | 54.1% |
73
 
74
- **Epoch progression** (no overfitting detected):
75
 
76
- | Checkpoint | acc_0 | Loss |
77
- |------------|-------|------|
78
- | epoch_1 | 52.0% | ~10.6 |
79
- | epoch_2 | 60.9% | 7.64 |
80
- | epoch_3 | 65.8% | 6.76 |
81
-
82
- Full training curves: [W&B run `li7xhsk7`](https://wandb.ai/gustavo-lujan-thoughtworks/ds-r1-qwen-14b-eagle3-experiments/runs/li7xhsk7)
83
 
84
  ## Model Architecture
85
 
86
- The draft head is a single-layer transformer that takes the target model's hidden states
87
- as input and predicts the next token using EAGLE3's feature-fusion approach.
88
-
89
- ```json
90
- {
91
- "architectures": ["LlamaForCausalLMEagle3"],
92
- "model_type": "llama",
93
- "hidden_size": 5120,
94
- "intermediate_size": 13824,
95
- "num_hidden_layers": 1,
96
- "num_attention_heads": 40,
97
- "num_key_value_heads": 8,
98
- "head_dim": 128,
99
- "vocab_size": 152064,
100
- "draft_vocab_size": 32000,
101
- "rope_theta": 1000000.0,
102
- "rms_norm_eps": 1e-05,
103
- "torch_dtype": "bfloat16"
104
- }
105
- ```
106
 
107
- ## Citation
108
 
109
- If you use this model, please cite the EAGLE3 paper:
 
 
 
 
 
 
 
 
110
 
111
  ```bibtex
112
- @article{li2024eagle3,
113
- title={EAGLE-3: Scaling up Inference Acceleration of Large Language Models via Training-Time Test},
114
  author={Li, Yuhui and Wei, Fangyun and Zhang, Chao and Zhang, Hongyang},
115
- journal={arXiv},
116
- year={2024}
117
  }
118
  ```
119
-
120
- ## License
121
-
122
- This draft head is released under the MIT license.
123
- The base model ([DeepSeek-R1-Distill-Qwen-14B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-14B))
124
- is subject to its own license terms.
 
1
  ---
 
 
 
2
  library_name: transformers
3
+ license: mit
4
+ language:
5
+ - en
 
 
 
6
  base_model: deepseek-ai/DeepSeek-R1-Distill-Qwen-14B
7
+ tags:
8
+ - eagle3
9
+ - speculative-decoding
10
+ - sglang
11
+ - draft-model
12
+ - jax
13
+ - tpu
14
+ pipeline_tag: text-generation
15
  ---
16
 
17
  # EAGLE3 Draft Head — DeepSeek-R1-Distill-Qwen-14B
18
 
19
+ A speculative decoding draft head for [deepseek-ai/DeepSeek-R1-Distill-Qwen-14B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-14B), trained using the [EAGLE3](https://arxiv.org/abs/2503.01840) method on Google Cloud TPU with the [SpecJAX](https://github.com/tails-mpt/SpecJAX) framework.
20
+
21
+ EAGLE3 draft heads accelerate autoregressive generation by proposing multiple tokens per step that a target model then verifies in parallel typically achieving 2-3x throughput gains with no change in output quality.
22
 
23
  ## Usage
24
 
25
+ ### SGLang (GPU)
26
+
27
+ > **Note**: DeepSeek-R1-Distill-Qwen uses the Qwen2 architecture. EAGLE3 support requires a small patch to SGLang (adding `set_eagle3_layers_to_capture()` to the Qwen2 model). See the [SpecJAX inference guide](https://github.com/tails-mpt/SpecJAX/tree/main/inference) for details.
28
 
29
  ```bash
30
  python -m sglang.launch_server \
31
+ --model deepseek-ai/DeepSeek-R1-Distill-Qwen-14B \
32
+ --speculative-algorithm EAGLE3 \
33
+ --speculative-draft-model-path thoughtworks/DeepSeek-R1-Distill-Qwen-14B-Eagle3 \
34
+ --speculative-num-steps 5 \
35
+ --speculative-eagle-topk 4 \
36
+ --dtype bfloat16
37
  ```
38
 
39
+ ### sglang-jax (TPU)
40
+
41
+ > **Note**: Requires the same Qwen2 EAGLE3 patch applied to sglang-jax. The sglang-jax EAGLE3 pipeline is functional but not yet performance-optimized.
42
 
43
  ```bash
44
+ python -m sgl_jax.launch_server \
45
+ --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-14B \
46
+ --speculative-algorithm EAGLE3 \
47
+ --speculative-draft-model-path thoughtworks/DeepSeek-R1-Distill-Qwen-14B-Eagle3 \
48
+ --speculative-eagle-topk 1 \
49
+ --speculative-num-steps 3 \
50
+ --speculative-num-draft-tokens 4 \
51
+ --tp-size 4 --dtype bfloat16
52
+ ```
53
+
54
+ ### Python (SGLang client)
55
+
56
+ ```python
57
+ import sglang as sgl
58
+
59
+ llm = sgl.LLM(
60
+ model="deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
61
+ speculative_algorithm="EAGLE3",
62
+ speculative_draft_model_path="thoughtworks/DeepSeek-R1-Distill-Qwen-14B-Eagle3",
63
+ speculative_num_steps=5,
64
+ speculative_eagle_topk=4,
65
+ dtype="bfloat16",
66
+ )
67
  ```
68
 
69
  ## Training Details
70
 
71
+ | Parameter | Value |
72
+ |-----------|-------|
73
+ | Framework | [SpecJAX](https://github.com/tails-mpt/SpecJAX) — pure JAX, no Flax/PyTorch |
74
+ | Hardware | Google Cloud TPU v4-32 (4 hosts x 4 chips, TP=4, DP=4) |
75
+ | Dataset | 54K mixed: ShareGPT (45%) + UltraChat-200K (35%) + Open-PerfectBlend (20%) |
76
+ | Epochs | 3 |
77
+ | Steps | 2,490 total |
78
+ | Optimizer | AdamW, cosine LR decay, 3% warmup |
79
+ | Learning rate | 8e-4 |
80
+ | Batch size | B=2, sequence length T=512, gradient accumulation 8 |
81
+ | TTT length | 7 (multi-step speculative rollout) |
82
+ | Training time | ~84 minutes |
83
+ | Precision | bfloat16 |
84
 
85
+ ### Training Method
86
 
87
+ This model uses [EAGLE3](https://arxiv.org/abs/2503.01840)'s Test-Time Training (TTT) objective with a rollout length of 7. At each training step, the draft head autoregressively proposes 7 tokens; the target model provides ground-truth hidden states and logits for all positions; a geometric loss (0.8^k weighting) trains the draft to match the target at each position.
88
+
89
+ ## Performance
90
+
91
+ Token acceptance rates on generic instruction-following data (ShareGPT-style prompts):
92
 
93
  | Position | Acceptance Rate |
94
  |----------|----------------|
 
100
  | acc_5 | 55.7% |
101
  | acc_6 | 54.1% |
102
 
103
+ This model achieves the highest acc_0 among all SpecJAX-trained EAGLE3 draft heads.
104
 
105
+ *Measured on held-out evaluation data. Actual throughput gains depend on hardware, prompt distribution, and runtime version.*
 
 
 
 
 
 
106
 
107
  ## Model Architecture
108
 
109
+ The draft head is a single-layer transformer that operates on the target model's hidden states:
110
+
111
+ | Parameter | Value |
112
+ |-----------|-------|
113
+ | Architecture | `LlamaForCausalLM` (1 decoder layer) |
114
+ | Hidden size | 5120 |
115
+ | Attention heads | 40 (GQA: 8 KV heads) |
116
+ | Vocabulary size | 152,064 (full target vocab) |
117
+ | Draft vocab size | 32,000 (top tokens by training frequency) |
118
+ | Parameters | ~530M |
 
 
 
 
 
 
 
 
 
 
119
 
120
+ ## Limitations
121
 
122
+ - Trained on English-dominant instruction data; performance may degrade on non-English inputs or highly domain-specific content.
123
+ - Acceptance rates are measured on generic chat data and will vary by prompt distribution.
124
+ - This is a v1 checkpoint trained on generic data. A v2 with target-model-regenerated training data is planned.
125
+
126
+ ## License
127
+
128
+ This model is released under the [MIT License](https://opensource.org/licenses/MIT). The base model ([DeepSeek-R1-Distill-Qwen-14B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-14B)) is subject to its own license terms.
129
+
130
+ ## References
131
 
132
  ```bibtex
133
+ @article{li2025eagle3,
134
+ title={EAGLE3: Scalable Speculative Decoding with Training-Free Multi-Draft Speculation},
135
  author={Li, Yuhui and Wei, Fangyun and Zhang, Chao and Zhang, Hongyang},
136
+ journal={arXiv preprint arXiv:2503.01840},
137
+ year={2025}
138
  }
139
  ```