File size: 3,995 Bytes
86dd177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""ShareGPT + POLAR reward environment."""

from __future__ import annotations

from pathlib import Path
from typing import Any

from datasets import Dataset, load_dataset
import asyncio

import verifiers as vf
from verifiers.types import Messages
from xtuner.utils import RewardModelClient

DEFAULT_MODEL = "internlm/POLAR-7B"


def _load_sharegpt_dataset(path: str | Path) -> Dataset:
    dataset = load_dataset("json", data_files=str(path), split="train")

    def to_single_turn(example: dict[str, Any]) -> dict[str, Any]:
        human_turn = next(
            turn["value"] for turn in example["conversations"] if turn["from"] == "human"
        )
        assistant_turn = next(
            turn["value"] for turn in example["conversations"] if turn["from"] == "gpt"
        )
        return {
            "prompt": [{"role": "user", "content": human_turn}],
            "info": {
                "reference": [{"role": "assistant", "content": assistant_turn}],
            },
        }

    return dataset.map(to_single_turn, remove_columns=dataset.column_names)


class PoolingClient:
    def __init__(
        self,
        model_path: str,
        server_address: str,
        server_type: str = "lmdeploy",
        max_length: int = 16384,
        max_response_length: int = 4096,
        response_cut_side: str = "left",
    ):
        self.client = RewardModelClient(
            model_path,
            max_length=max_length,
            max_response_length=max_response_length,
            response_cut_side=response_cut_side,
            server_type=server_type,
            server_address=server_address,
        )

    def encode(self, sample: dict[str, Any]) -> str:
        prompt_text = "\n".join(
            message["content"] for message in sample.get("prompt", [])
        )
        reference_text = "\n".join(
            message["content"] for message in sample.get("reference", [])
        )
        output_text = "\n".join(
            message["content"] for message in sample.get("output", [])
        )
        return f"{prompt_text}\n{reference_text}<|reward|>{prompt_text}\n{output_text}[UNUSED_TOKEN_130]"

    def score(self, payload: list[dict[str, Any]]) -> list[float]:
        encoded_payload = [self.encode(item) for item in payload]
        rewards = self.client.lmdeploy_request_reward(encoded_payload)
        if rewards is None:
            raise RuntimeError("Failed to get rewards from lmdeploy server")
        return rewards


async def polar_reward(
    prompt: Messages,
    completion: Messages,
    info: dict[str, Any],
    reward_client: PoolingClient,
    pooling_semaphore: asyncio.Semaphore,
    **_: Any,
) -> float:
    assistant_turns = [msg for msg in completion if msg.get("role") == "assistant"]
    if not assistant_turns:
        return 0.0

    payload = [
        {
            "prompt": prompt,
            "reference": info.get("reference", []),
            "output": [assistant_turns[-1]],
        }
    ]
    async with pooling_semaphore:
        loop = asyncio.get_running_loop()
        rewards = await loop.run_in_executor(None, reward_client.score, payload)
    if rewards:
        return float(rewards[-1]) * 10.0
    raise RuntimeError(f"Unexpected reward response: {rewards}")


def load_environment(
    data_path: str | Path,
    *,
    server_address: str,
    reward_model: str = DEFAULT_MODEL,
    reward_scheme: type[vf.Rubric] | None = None,
    server_type: str = "lmdeploy",
    **env_kwargs: Any,
) -> vf.SingleTurnEnv:
    dataset = _load_sharegpt_dataset(data_path)

    client = PoolingClient(
        model_path=reward_model,
        server_address=server_address,
        server_type=server_type,
    )

    rubric_cls = reward_scheme or vf.Rubric
    rubric = rubric_cls(funcs=[polar_reward])
    rubric.class_objects["reward_client"] = client
    rubric.class_objects.setdefault("pooling_semaphore", asyncio.Semaphore(4))

    return vf.SingleTurnEnv(dataset=dataset, rubric=rubric, **env_kwargs)