Upload folder using huggingface_hub
Browse files- __pycache__/sharegpt_polar.cpython-312.pyc +0 -0
- build/lib/sharegpt_polar.py +113 -0
- outputs/.gitkeep +0 -0
- pyproject.toml +21 -0
- sharegpt_polar.egg-info/PKG-INFO +67 -0
- sharegpt_polar.egg-info/SOURCES.txt +9 -0
- sharegpt_polar.egg-info/dependency_links.txt +1 -0
- sharegpt_polar.egg-info/entry_points.txt +2 -0
- sharegpt_polar.egg-info/requires.txt +2 -0
- sharegpt_polar.egg-info/top_level.txt +1 -0
- sharegpt_polar.py +129 -0
- train_sharegpt_polar.py +108 -0
__pycache__/sharegpt_polar.cpython-312.pyc
ADDED
|
Binary file (6.79 kB). View file
|
|
|
build/lib/sharegpt_polar.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ShareGPT + POLAR reward environment."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from http import HTTPStatus
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
from datasets import Dataset, load_dataset
|
| 11 |
+
import httpx
|
| 12 |
+
|
| 13 |
+
import verifiers as vf
|
| 14 |
+
from verifiers.types import Messages
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
DEFAULT_SERVER = "wealth-intent-submissions-range.trycloudflare.com"
|
| 18 |
+
DEFAULT_MODEL = "internlm/POLAR-7B"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _load_sharegpt_dataset(path: str | Path) -> Dataset:
|
| 22 |
+
dataset = load_dataset("json", data_files=str(path), split="train")
|
| 23 |
+
|
| 24 |
+
def to_single_turn(example: dict[str, Any]) -> dict[str, Any]:
|
| 25 |
+
human_turn = next(
|
| 26 |
+
turn["value"] for turn in example["conversations"] if turn["from"] == "human"
|
| 27 |
+
)
|
| 28 |
+
assistant_turn = next(
|
| 29 |
+
turn["value"] for turn in example["conversations"] if turn["from"] == "gpt"
|
| 30 |
+
)
|
| 31 |
+
return {
|
| 32 |
+
"prompt": [{"role": "user", "content": human_turn}],
|
| 33 |
+
"info": {
|
| 34 |
+
"reference": [{"role": "assistant", "content": assistant_turn}],
|
| 35 |
+
},
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
return dataset.map(to_single_turn, remove_columns=dataset.column_names)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
async def polar_reward(
|
| 42 |
+
prompt: Messages,
|
| 43 |
+
completion: Messages,
|
| 44 |
+
info: dict[str, Any],
|
| 45 |
+
reward_client: "PolarClient",
|
| 46 |
+
**_: Any,
|
| 47 |
+
) -> float:
|
| 48 |
+
assistant_turns = [msg for msg in completion if msg.get("role") == "assistant"]
|
| 49 |
+
if not assistant_turns:
|
| 50 |
+
return 0.0
|
| 51 |
+
|
| 52 |
+
payload = [
|
| 53 |
+
{
|
| 54 |
+
"prompt": prompt,
|
| 55 |
+
"reference": info.get("reference", []),
|
| 56 |
+
"output": [assistant_turns[-1]],
|
| 57 |
+
}
|
| 58 |
+
]
|
| 59 |
+
scores = await reward_client.score(payload)
|
| 60 |
+
return float(scores[0]) if scores else 0.0
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_environment(
|
| 64 |
+
data_path: str | Path,
|
| 65 |
+
*,
|
| 66 |
+
server_address: str = DEFAULT_SERVER,
|
| 67 |
+
reward_model: str = DEFAULT_MODEL,
|
| 68 |
+
reward_scheme: type[vf.Rubric] | None = None,
|
| 69 |
+
**env_kwargs: Any,
|
| 70 |
+
) -> vf.SingleTurnEnv:
|
| 71 |
+
dataset = _load_sharegpt_dataset(data_path)
|
| 72 |
+
|
| 73 |
+
client = PolarClient(
|
| 74 |
+
base_url=f"https://{server_address}",
|
| 75 |
+
model=reward_model,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
rubric_cls = reward_scheme or vf.Rubric
|
| 79 |
+
rubric = rubric_cls(funcs=[polar_reward])
|
| 80 |
+
rubric.class_objects["reward_client"] = client
|
| 81 |
+
|
| 82 |
+
return vf.SingleTurnEnv(dataset=dataset, rubric=rubric, **env_kwargs)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class PolarClient:
|
| 86 |
+
"""Minimal async client for POLAR reward model served via vLLM."""
|
| 87 |
+
|
| 88 |
+
def __init__(self, *, base_url: str, model: str, timeout: float = 30.0, api_key: str | None = None):
|
| 89 |
+
self.base_url = base_url.rstrip("/")
|
| 90 |
+
self.model = model
|
| 91 |
+
self.timeout = timeout
|
| 92 |
+
self.api_key = api_key
|
| 93 |
+
|
| 94 |
+
async def score(self, payload: list[dict[str, Any]]) -> list[float]:
|
| 95 |
+
encoded = self._encode(payload)
|
| 96 |
+
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
| 97 |
+
response = await client.post(
|
| 98 |
+
f"{self.base_url}/v1/rewards",
|
| 99 |
+
json={"model": self.model, "input": encoded},
|
| 100 |
+
headers={"Authorization": f"Bearer {self.api_key}"} if self.api_key else None,
|
| 101 |
+
)
|
| 102 |
+
if response.status_code != HTTPStatus.OK:
|
| 103 |
+
raise RuntimeError(
|
| 104 |
+
f"POLAR reward request failed: {response.status_code} {response.text}"
|
| 105 |
+
)
|
| 106 |
+
data = response.json()
|
| 107 |
+
return data.get("rewards", [])
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
def _encode(payload: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
| 111 |
+
# Ensure payload matches expected schema; keep implementation simple for now.
|
| 112 |
+
return payload
|
| 113 |
+
|
outputs/.gitkeep
ADDED
|
File without changes
|
pyproject.toml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "sharegpt-polar"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "blehbklgvggfg"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.11"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"datasets>=2.16.0",
|
| 9 |
+
"httpx>=0.27.0",
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
[[project.authors]]
|
| 13 |
+
name = "Prime Intellect"
|
| 14 |
+
email = "infra@primeintellect.ai"
|
| 15 |
+
|
| 16 |
+
[project.entry-points."verifiers.environments"]
|
| 17 |
+
sharegpt-polar = "sharegpt_polar:load_environment"
|
| 18 |
+
|
| 19 |
+
[tool.setuptools]
|
| 20 |
+
py-modules = ["sharegpt_polar"]
|
| 21 |
+
|
sharegpt_polar.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: sharegpt-polar
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: ShareGPT reward environment scored by POLAR
|
| 5 |
+
Author-email: Prime Intellect <infra@primeintellect.ai>
|
| 6 |
+
Requires-Python: >=3.11
|
| 7 |
+
Description-Content-Type: text/markdown
|
| 8 |
+
Requires-Dist: datasets>=2.16.0
|
| 9 |
+
Requires-Dist: httpx>=0.27.0
|
| 10 |
+
|
| 11 |
+
# ShareGPT POLAR Environment
|
| 12 |
+
|
| 13 |
+
This environment scores policy rollouts using the POLAR reward model served via vLLM. It expects a ShareGPT-style JSONL dataset and treats the original assistant response as a reference trajectory.
|
| 14 |
+
|
| 15 |
+
## Dataset Format
|
| 16 |
+
|
| 17 |
+
Each line in the dataset must contain a `conversations` list with alternating human/GPT turns:
|
| 18 |
+
|
| 19 |
+
```json
|
| 20 |
+
{
|
| 21 |
+
"conversations": [
|
| 22 |
+
{"from": "human", "value": "Prompt text"},
|
| 23 |
+
{"from": "gpt", "value": "Reference answer"}
|
| 24 |
+
]
|
| 25 |
+
}
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
The loader extracts the first human message as the prompt and the first GPT message as the reference. Additional turns can be present and are ignored.
|
| 29 |
+
|
| 30 |
+
## Usage
|
| 31 |
+
|
| 32 |
+
```python
|
| 33 |
+
import verifiers as vf
|
| 34 |
+
|
| 35 |
+
env = vf.load_environment(
|
| 36 |
+
"sharegpt-polar",
|
| 37 |
+
data_path="/path/to/sharegpt.jsonl",
|
| 38 |
+
server_address="wealth-intent-submissions-range.trycloudflare.com",
|
| 39 |
+
)
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
The environment bundles an async reward function which submits `(prompt, reference, output)` to the POLAR reward model using `RewardModelClient` from xtuner.
|
| 43 |
+
|
| 44 |
+
## Dependencies
|
| 45 |
+
|
| 46 |
+
- `datasets>=2.16.0`
|
| 47 |
+
- `xtuner[deepspeed]==0.2.0`
|
| 48 |
+
|
| 49 |
+
Install them in your project via:
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
uv add datasets
|
| 53 |
+
uv add httpx
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
## Evaluation
|
| 57 |
+
|
| 58 |
+
Run quick evaluations with any OpenAI-compatible policy model:
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
uv run vf-install sharegpt-polar -p /home/Ubuntu/Mango/verifiers/environments
|
| 62 |
+
uv run vf-eval sharegpt-polar -m gpt-4.1-mini \
|
| 63 |
+
--env-args '{"data_path": "data/sharegpt.jsonl"}' -n 10 -r 1
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
Ensure the POLAR reward model is reachable at the supplied `server_address` before evaluation.
|
| 67 |
+
|
sharegpt_polar.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
pyproject.toml
|
| 3 |
+
sharegpt_polar.py
|
| 4 |
+
sharegpt_polar.egg-info/PKG-INFO
|
| 5 |
+
sharegpt_polar.egg-info/SOURCES.txt
|
| 6 |
+
sharegpt_polar.egg-info/dependency_links.txt
|
| 7 |
+
sharegpt_polar.egg-info/entry_points.txt
|
| 8 |
+
sharegpt_polar.egg-info/requires.txt
|
| 9 |
+
sharegpt_polar.egg-info/top_level.txt
|
sharegpt_polar.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
sharegpt_polar.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[verifiers.environments]
|
| 2 |
+
sharegpt-polar = sharegpt_polar:load_environment
|
sharegpt_polar.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets>=2.16.0
|
| 2 |
+
httpx>=0.27.0
|
sharegpt_polar.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
sharegpt_polar
|
sharegpt_polar.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ShareGPT + POLAR reward environment."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
from datasets import Dataset, load_dataset
|
| 9 |
+
import asyncio
|
| 10 |
+
|
| 11 |
+
import verifiers as vf
|
| 12 |
+
from verifiers.types import Messages
|
| 13 |
+
from xtuner.utils import RewardModelClient
|
| 14 |
+
|
| 15 |
+
DEFAULT_MODEL = "internlm/POLAR-7B"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _load_sharegpt_dataset(path: str | Path) -> Dataset:
|
| 19 |
+
dataset = load_dataset("json", data_files=str(path), split="train")
|
| 20 |
+
|
| 21 |
+
def to_single_turn(example: dict[str, Any]) -> dict[str, Any]:
|
| 22 |
+
human_turn = next(
|
| 23 |
+
turn["value"] for turn in example["conversations"] if turn["from"] == "human"
|
| 24 |
+
)
|
| 25 |
+
assistant_turn = next(
|
| 26 |
+
turn["value"] for turn in example["conversations"] if turn["from"] == "gpt"
|
| 27 |
+
)
|
| 28 |
+
return {
|
| 29 |
+
"prompt": [{"role": "user", "content": human_turn}],
|
| 30 |
+
"info": {
|
| 31 |
+
"reference": [{"role": "assistant", "content": assistant_turn}],
|
| 32 |
+
},
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
return dataset.map(to_single_turn, remove_columns=dataset.column_names)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class PoolingClient:
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
model_path: str,
|
| 42 |
+
server_address: str,
|
| 43 |
+
server_type: str = "lmdeploy",
|
| 44 |
+
max_length: int = 16384,
|
| 45 |
+
max_response_length: int = 4096,
|
| 46 |
+
response_cut_side: str = "left",
|
| 47 |
+
):
|
| 48 |
+
self.client = RewardModelClient(
|
| 49 |
+
model_path,
|
| 50 |
+
max_length=max_length,
|
| 51 |
+
max_response_length=max_response_length,
|
| 52 |
+
response_cut_side=response_cut_side,
|
| 53 |
+
server_type=server_type,
|
| 54 |
+
server_address=server_address,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def encode(self, sample: dict[str, Any]) -> str:
|
| 58 |
+
prompt_text = "\n".join(
|
| 59 |
+
message["content"] for message in sample.get("prompt", [])
|
| 60 |
+
)
|
| 61 |
+
reference_text = "\n".join(
|
| 62 |
+
message["content"] for message in sample.get("reference", [])
|
| 63 |
+
)
|
| 64 |
+
output_text = "\n".join(
|
| 65 |
+
message["content"] for message in sample.get("output", [])
|
| 66 |
+
)
|
| 67 |
+
return f"{prompt_text}\n{reference_text}<|reward|>{prompt_text}\n{output_text}[UNUSED_TOKEN_130]"
|
| 68 |
+
|
| 69 |
+
def score(self, payload: list[dict[str, Any]]) -> list[float]:
|
| 70 |
+
encoded_payload = [self.encode(item) for item in payload]
|
| 71 |
+
rewards = self.client.lmdeploy_request_reward(encoded_payload)
|
| 72 |
+
if rewards is None:
|
| 73 |
+
raise RuntimeError("Failed to get rewards from lmdeploy server")
|
| 74 |
+
return rewards
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
async def polar_reward(
|
| 78 |
+
prompt: Messages,
|
| 79 |
+
completion: Messages,
|
| 80 |
+
info: dict[str, Any],
|
| 81 |
+
reward_client: PoolingClient,
|
| 82 |
+
pooling_semaphore: asyncio.Semaphore,
|
| 83 |
+
**_: Any,
|
| 84 |
+
) -> float:
|
| 85 |
+
assistant_turns = [msg for msg in completion if msg.get("role") == "assistant"]
|
| 86 |
+
if not assistant_turns:
|
| 87 |
+
return 0.0
|
| 88 |
+
|
| 89 |
+
payload = [
|
| 90 |
+
{
|
| 91 |
+
"prompt": prompt,
|
| 92 |
+
"reference": info.get("reference", []),
|
| 93 |
+
"output": [assistant_turns[-1]],
|
| 94 |
+
}
|
| 95 |
+
]
|
| 96 |
+
async with pooling_semaphore:
|
| 97 |
+
loop = asyncio.get_running_loop()
|
| 98 |
+
rewards = await loop.run_in_executor(None, reward_client.score, payload)
|
| 99 |
+
if rewards:
|
| 100 |
+
return float(rewards[-1]) * 10.0
|
| 101 |
+
raise RuntimeError(f"Unexpected reward response: {rewards}")
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def load_environment(
|
| 105 |
+
data_path: str | Path,
|
| 106 |
+
*,
|
| 107 |
+
server_address: str,
|
| 108 |
+
reward_model: str = DEFAULT_MODEL,
|
| 109 |
+
reward_scheme: type[vf.Rubric] | None = None,
|
| 110 |
+
server_type: str = "lmdeploy",
|
| 111 |
+
**env_kwargs: Any,
|
| 112 |
+
) -> vf.SingleTurnEnv:
|
| 113 |
+
dataset = _load_sharegpt_dataset(data_path)
|
| 114 |
+
|
| 115 |
+
client = PoolingClient(
|
| 116 |
+
model_path=reward_model,
|
| 117 |
+
server_address=server_address,
|
| 118 |
+
server_type=server_type,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
rubric_cls = reward_scheme or vf.Rubric
|
| 122 |
+
rubric = rubric_cls(funcs=[polar_reward])
|
| 123 |
+
rubric.class_objects["reward_client"] = client
|
| 124 |
+
rubric.class_objects.setdefault("pooling_semaphore", asyncio.Semaphore(4))
|
| 125 |
+
|
| 126 |
+
return vf.SingleTurnEnv(dataset=dataset, rubric=rubric, **env_kwargs)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
|
train_sharegpt_polar.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GRPO training entrypoint for ShareGPT POLAR environment."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import verifiers as vf
|
| 8 |
+
from xtuner.utils import RewardModelClient
|
| 9 |
+
|
| 10 |
+
RUN_NAME = "sharegpt-polar"
|
| 11 |
+
MODEL_NAME = "NewEden/Snwy-SFT-GRPO-base"
|
| 12 |
+
DATA_PATH = "/home/Ubuntu/Mango/verifiers/new.jsonl"
|
| 13 |
+
SERVER_ADDRESS = "greene-cannon-republic-expect.trycloudflare.com"
|
| 14 |
+
REWARD_MODEL = "internlm/POLAR-7B"
|
| 15 |
+
|
| 16 |
+
# Training hyperparameters
|
| 17 |
+
PER_DEVICE_TRAIN_BATCH_SIZE = 2
|
| 18 |
+
NUM_GENERATIONS = 8
|
| 19 |
+
GRADIENT_ACCUMULATION_STEPS = 2
|
| 20 |
+
LEARNING_RATE = 1e-6
|
| 21 |
+
BETA = 0.1
|
| 22 |
+
MAX_STEPS = 1000
|
| 23 |
+
MAX_GRAD_NORM = 1.0
|
| 24 |
+
NUM_ITERATIONS = 1
|
| 25 |
+
MAX_TOKENS = 512
|
| 26 |
+
TEMPERATURE = 1.0
|
| 27 |
+
TOP_P = 1.0
|
| 28 |
+
SAVE_EVERY_STEPS = 50
|
| 29 |
+
LOGGING_STEPS = 1
|
| 30 |
+
REPORT_TO = ["wandb"]
|
| 31 |
+
LOG_COMPLETIONS = True
|
| 32 |
+
LOG_ON_EACH_NODE = False
|
| 33 |
+
ASYNC_GENERATION_TIMEOUT = 60000
|
| 34 |
+
MAX_CONCURRENT = 1024
|
| 35 |
+
WANDB_PROJECT = "14B-GRPO"
|
| 36 |
+
WANDB_NAME = RUN_NAME
|
| 37 |
+
|
| 38 |
+
if WANDB_PROJECT:
|
| 39 |
+
os.environ.setdefault("WANDB_PROJECT", WANDB_PROJECT)
|
| 40 |
+
if WANDB_NAME:
|
| 41 |
+
os.environ.setdefault("WANDB_NAME", WANDB_NAME)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _check_reward_server() -> None:
|
| 45 |
+
client = RewardModelClient(
|
| 46 |
+
REWARD_MODEL,
|
| 47 |
+
server_type="lmdeploy",
|
| 48 |
+
server_address=SERVER_ADDRESS,
|
| 49 |
+
)
|
| 50 |
+
sanity_samples = [
|
| 51 |
+
{
|
| 52 |
+
"prompt": "What is the capital of China?",
|
| 53 |
+
"reference": "Beijing.",
|
| 54 |
+
"output": "Beijing.",
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"prompt": "What is the capital of China?",
|
| 58 |
+
"reference": "Beijing.",
|
| 59 |
+
"output": "Shanghai.",
|
| 60 |
+
},
|
| 61 |
+
]
|
| 62 |
+
encoded = client.encode(sanity_samples)
|
| 63 |
+
rewards = client.lmdeploy_request_reward(encoded)
|
| 64 |
+
print("[sanity] lmdeploy rewards:", rewards)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
_check_reward_server()
|
| 68 |
+
|
| 69 |
+
vf_env = vf.load_environment(
|
| 70 |
+
env_id="sharegpt-polar",
|
| 71 |
+
data_path=DATA_PATH,
|
| 72 |
+
server_address=SERVER_ADDRESS,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
model, tokenizer = vf.get_model_and_tokenizer(MODEL_NAME)
|
| 76 |
+
|
| 77 |
+
training_args = vf.grpo_defaults(run_name=RUN_NAME)
|
| 78 |
+
training_args.per_device_train_batch_size = PER_DEVICE_TRAIN_BATCH_SIZE
|
| 79 |
+
training_args.num_generations = NUM_GENERATIONS
|
| 80 |
+
training_args.gradient_accumulation_steps = GRADIENT_ACCUMULATION_STEPS
|
| 81 |
+
training_args.learning_rate = LEARNING_RATE
|
| 82 |
+
training_args.beta = BETA
|
| 83 |
+
training_args.max_steps = MAX_STEPS
|
| 84 |
+
training_args.max_grad_norm = MAX_GRAD_NORM
|
| 85 |
+
training_args.num_iterations = NUM_ITERATIONS
|
| 86 |
+
training_args.max_tokens = MAX_TOKENS
|
| 87 |
+
training_args.temperature = TEMPERATURE
|
| 88 |
+
training_args.top_p = TOP_P
|
| 89 |
+
training_args.save_strategy = "steps"
|
| 90 |
+
training_args.save_steps = SAVE_EVERY_STEPS
|
| 91 |
+
training_args.logging_steps = LOGGING_STEPS
|
| 92 |
+
training_args.report_to = REPORT_TO
|
| 93 |
+
training_args.log_completions = LOG_COMPLETIONS
|
| 94 |
+
training_args.log_on_each_node = LOG_ON_EACH_NODE
|
| 95 |
+
training_args.async_generation_timeout = ASYNC_GENERATION_TIMEOUT
|
| 96 |
+
training_args.max_concurrent = MAX_CONCURRENT
|
| 97 |
+
|
| 98 |
+
trainer = vf.GRPOTrainer(
|
| 99 |
+
env=vf_env,
|
| 100 |
+
model=model,
|
| 101 |
+
processing_class=tokenizer,
|
| 102 |
+
args=training_args,
|
| 103 |
+
peft_config=vf.lora_defaults(r=128, alpha=64),
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
if __name__ == "__main__":
|
| 107 |
+
trainer.train()
|
| 108 |
+
|