File size: 3,995 Bytes
86dd177 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | """ShareGPT + POLAR reward environment."""
from __future__ import annotations
from pathlib import Path
from typing import Any
from datasets import Dataset, load_dataset
import asyncio
import verifiers as vf
from verifiers.types import Messages
from xtuner.utils import RewardModelClient
DEFAULT_MODEL = "internlm/POLAR-7B"
def _load_sharegpt_dataset(path: str | Path) -> Dataset:
dataset = load_dataset("json", data_files=str(path), split="train")
def to_single_turn(example: dict[str, Any]) -> dict[str, Any]:
human_turn = next(
turn["value"] for turn in example["conversations"] if turn["from"] == "human"
)
assistant_turn = next(
turn["value"] for turn in example["conversations"] if turn["from"] == "gpt"
)
return {
"prompt": [{"role": "user", "content": human_turn}],
"info": {
"reference": [{"role": "assistant", "content": assistant_turn}],
},
}
return dataset.map(to_single_turn, remove_columns=dataset.column_names)
class PoolingClient:
def __init__(
self,
model_path: str,
server_address: str,
server_type: str = "lmdeploy",
max_length: int = 16384,
max_response_length: int = 4096,
response_cut_side: str = "left",
):
self.client = RewardModelClient(
model_path,
max_length=max_length,
max_response_length=max_response_length,
response_cut_side=response_cut_side,
server_type=server_type,
server_address=server_address,
)
def encode(self, sample: dict[str, Any]) -> str:
prompt_text = "\n".join(
message["content"] for message in sample.get("prompt", [])
)
reference_text = "\n".join(
message["content"] for message in sample.get("reference", [])
)
output_text = "\n".join(
message["content"] for message in sample.get("output", [])
)
return f"{prompt_text}\n{reference_text}<|reward|>{prompt_text}\n{output_text}[UNUSED_TOKEN_130]"
def score(self, payload: list[dict[str, Any]]) -> list[float]:
encoded_payload = [self.encode(item) for item in payload]
rewards = self.client.lmdeploy_request_reward(encoded_payload)
if rewards is None:
raise RuntimeError("Failed to get rewards from lmdeploy server")
return rewards
async def polar_reward(
prompt: Messages,
completion: Messages,
info: dict[str, Any],
reward_client: PoolingClient,
pooling_semaphore: asyncio.Semaphore,
**_: Any,
) -> float:
assistant_turns = [msg for msg in completion if msg.get("role") == "assistant"]
if not assistant_turns:
return 0.0
payload = [
{
"prompt": prompt,
"reference": info.get("reference", []),
"output": [assistant_turns[-1]],
}
]
async with pooling_semaphore:
loop = asyncio.get_running_loop()
rewards = await loop.run_in_executor(None, reward_client.score, payload)
if rewards:
return float(rewards[-1]) * 10.0
raise RuntimeError(f"Unexpected reward response: {rewards}")
def load_environment(
data_path: str | Path,
*,
server_address: str,
reward_model: str = DEFAULT_MODEL,
reward_scheme: type[vf.Rubric] | None = None,
server_type: str = "lmdeploy",
**env_kwargs: Any,
) -> vf.SingleTurnEnv:
dataset = _load_sharegpt_dataset(data_path)
client = PoolingClient(
model_path=reward_model,
server_address=server_address,
server_type=server_type,
)
rubric_cls = reward_scheme or vf.Rubric
rubric = rubric_cls(funcs=[polar_reward])
rubric.class_objects["reward_client"] = client
rubric.class_objects.setdefault("pooling_semaphore", asyncio.Semaphore(4))
return vf.SingleTurnEnv(dataset=dataset, rubric=rubric, **env_kwargs)
|