| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| from typing import Union |
|
|
| import numpy as np |
| import torch |
| from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature |
| from lerobot.constants import OBS_STATE |
| from lerobot.datasets.utils import cast_stats_to_numpy |
| from lerobot.policies.normalize import Normalize, Unnormalize |
| from transformers.feature_extraction_utils import BatchFeature |
| from transformers.image_utils import ImageInput |
| from transformers.processing_utils import ( |
| ImagesKwargs, |
| ProcessingKwargs, |
| ProcessorMixin, |
| TextKwargs, |
| Unpack, |
| VideosKwargs, |
| ) |
| from transformers.tokenization_utils_base import PreTokenizedInput, TextInput |
| from transformers.video_utils import VideoInput |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "0" |
|
|
| """constants""" |
| DEFAULT_IMAGE_TOKEN = "<|image_pad|>" |
| DEFAULT_VIDEO_TOKEN = "<|video_pad|>" |
| VISION_START_TOKEN = "<|vision_start|>" |
| VISION_END_TOKEN = "<|vision_end|>" |
|
|
| ACTION_START_TOKEN = "<|action_start|>" |
| DEFAULT_ACTION_TOKEN = "<|action_pad|>" |
| PASS_ACTION_TOKEN = "<|action_pass|>" |
| ACTION_END_TOKEN = "<|action_end|>" |
|
|
| STATE_START_TOKEN = "<|state_start|>" |
| DEFAULT_STATE_TOKEN = "<|state_pad|>" |
| STATE_END_TOKEN = "<|state_end|>" |
| TASK_VLA_TOKEN = "<|vla|>" |
|
|
|
|
| RobotInput = Union[np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]] |
|
|
|
|
| class EO1VisionVideosProcessorKwargs(VideosKwargs, total=False): |
| fps: list[float] | float |
|
|
|
|
| class EO1VisionImagesKwargs(ImagesKwargs): |
| min_pixels: int | None |
| max_pixels: int | None |
| patch_size: int | None |
| temporal_patch_size: int | None |
| merge_size: int | None |
|
|
|
|
| class EO1VisionTextKwargs(TextKwargs): |
| noise_token_num: int | None |
| noise_prompt: str | None |
|
|
|
|
| class EO1VisionProcessorKwargs(ProcessingKwargs, total=False): |
| text_kwargs: EO1VisionTextKwargs |
| images_kwargs: EO1VisionImagesKwargs |
| videos_kwargs: EO1VisionVideosProcessorKwargs |
| _defaults = { |
| "text_kwargs": { |
| "padding": False, |
| "return_mm_token_type_ids": False, |
| }, |
| } |
|
|
|
|
| class EO1VisionProcessor(ProcessorMixin): |
| """EO1Vision Processor for Image, Text, Video, and Robotic Action Processing""" |
|
|
| attributes = ["image_processor", "tokenizer", "video_processor"] |
| valid_kwargs = ["chat_template"] |
| image_processor_class = "AutoImageProcessor" |
| video_processor_class = "AutoVideoProcessor" |
| tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") |
|
|
| def __init__( |
| self, |
| image_processor=None, |
| tokenizer=None, |
| video_processor=None, |
| chat_template=None, |
| robot_config=None, |
| **kwargs, |
| ): |
| self.image_token = getattr(tokenizer, "image_token", DEFAULT_IMAGE_TOKEN) |
| self.video_token = getattr(tokenizer, "video_token", DEFAULT_VIDEO_TOKEN) |
| self.action_token = getattr(tokenizer, "action_token", DEFAULT_ACTION_TOKEN) |
| self.state_token = getattr(tokenizer, "state_token", DEFAULT_STATE_TOKEN) |
|
|
| |
| self.action_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_ACTION_TOKEN) or 151666 |
| self.action_pass_id = tokenizer.convert_tokens_to_ids(PASS_ACTION_TOKEN) or 151672 |
| self.robot_config = robot_config or {} |
| self.set_normalization(self.robot_config) |
|
|
| super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template) |
|
|
| def set_normalization(self, robot_config: dict): |
| features, stats, state_mode = ( |
| robot_config.get("features"), |
| robot_config.get("stats"), |
| robot_config.get("state_mode"), |
| ) |
| if None in [features, stats, state_mode]: |
| return |
| else: |
| normalization_mapping = { |
| "STATE": NormalizationMode(state_mode), |
| "ACTION": NormalizationMode(state_mode), |
| } |
| normalize_inputs, unnormalize_outputs = {}, {} |
| for repo_id, fea in features.items(): |
| stat = cast_stats_to_numpy(stats[repo_id]) |
| fea = dataset_to_policy_features(fea) |
|
|
| input_features = {k: v for k, v in fea.items() if v.type == FeatureType.STATE} |
| output_features = {k: v for k, v in fea.items() if v.type == FeatureType.ACTION} |
|
|
| normalize_inputs[repo_id] = Normalize(input_features, normalization_mapping, stat) |
| unnormalize_outputs[repo_id] = Unnormalize(output_features, normalization_mapping, stat) |
|
|
| self.robot_config = dict(robot_config) |
| self.normalize_inputs, self.unnormalize_outputs = normalize_inputs, unnormalize_outputs |
|
|
| def __call__( |
| self, |
| images: ImageInput = None, |
| text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None, |
| videos: VideoInput = None, |
| states: RobotInput = None, |
| actions: RobotInput = None, |
| **kwargs: Unpack[EO1VisionProcessorKwargs], |
| ) -> BatchFeature: |
| output_kwargs = self._merge_kwargs( |
| EO1VisionProcessorKwargs, |
| tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
| **kwargs, |
| ) |
|
|
| noise_token_num = output_kwargs["text_kwargs"].pop("noise_token_num", None) |
| output_kwargs["text_kwargs"].pop("noise_prompt", None) |
|
|
| image_inputs = videos_inputs = {} |
| if images is not None: |
| image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) |
| image_grid_thw = image_inputs["image_grid_thw"] |
| else: |
| image_inputs = {} |
| image_grid_thw = None |
|
|
| if videos is not None: |
| fps = output_kwargs["videos_kwargs"].get("fps", 2.0) |
| videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) |
| video_grid_thw = videos_inputs["video_grid_thw"] |
| if isinstance(fps, (int, float)): |
| second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw) |
| elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): |
| second_per_grid_ts = [self.video_processor.temporal_patch_size / tmp for tmp in fps] |
| else: |
| raise ValueError( |
| f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to \ |
| the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." |
| ) |
| videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) |
| else: |
| videos_inputs = {} |
| video_grid_thw = None |
|
|
| if not isinstance(text, list): |
| text = [text] |
|
|
| if images is not None: |
| merge_length = self.image_processor.merge_size**2 |
| index = 0 |
| for i in range(len(text)): |
| while self.image_token in text[i]: |
| text[i] = text[i].replace( |
| self.image_token, |
| "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), |
| 1, |
| ) |
| index += 1 |
| text[i] = text[i].replace("<|placeholder|>", self.image_token) |
|
|
| if videos is not None: |
| merge_length = self.video_processor.merge_size**2 |
| index = 0 |
| for i in range(len(text)): |
| while self.video_token in text[i]: |
| text[i] = text[i].replace( |
| self.video_token, |
| "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), |
| 1, |
| ) |
| index += 1 |
| text[i] = text[i].replace("<|placeholder|>", self.video_token) |
|
|
| |
| noise_token_num = noise_token_num or self.robot_config.get("action_chunk_size") |
| for i in range(len(text)): |
| while self.action_token in text[i]: |
| text[i] = text[i].replace( |
| self.action_token, |
| "<|placeholder|>" * noise_token_num, |
| 1, |
| ) |
| text[i] = text[i].replace("<|placeholder|>", self.action_token) |
|
|
| return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None) |
| text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) |
| if return_mm_token_type_ids: |
| array_ids = np.array(text_inputs["input_ids"]) |
| mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) |
| mm_token_type_ids[array_ids == self.image_token_id] = 1 |
| text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() |
|
|
| |
| robot_inputs = {} |
|
|
| if states is not None: |
| if isinstance(states, list): |
| states = torch.stack(states, dim=0) |
| if states.ndim == 1: |
| states = states.unsqueeze(0) |
| robot_inputs.update({"states": states}) |
|
|
| if actions is not None: |
| if isinstance(actions, list): |
| actions = torch.stack(actions, dim=0) |
| if actions.ndim == 2: |
| actions = actions.unsqueeze(0) |
| robot_inputs.update({"actions": actions}) |
|
|
| return BatchFeature( |
| data={**text_inputs, **image_inputs, **videos_inputs, **robot_inputs}, |
| ) |
|
|
| @property |
| def model_input_names(self): |
| tokenizer_input_names = self.tokenizer.model_input_names |
| image_processor_input_names = self.image_processor.model_input_names |
| names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) |
| return names_from_processor + ["second_per_grid_ts"] + ["states", "actions"] |
|
|
| @torch.no_grad |
| def _prepare_robot_inputs(self, batch: dict): |
| """Prepare model inputs from raw robot batch""" |
| batch_messages = [] |
| batch_states = [] |
| max_state_dim = self.robot_config.get("max_state_dim", 32) |
|
|
| state_keys = [x for x in batch.keys() if x.startswith(OBS_STATE)] |
| batch_size = len(batch[state_keys[0]]) |
|
|
| if "repo_id" in batch: |
| repo_ids = batch.pop("repo_id") |
| else: |
| print("no repo_id found, use the first one in normalize_inputs") |
| repo_ids = list(self.normalize_inputs.keys())[0] |
| repo_ids = [repo_ids] * batch_size if isinstance(repo_ids, str) else repo_ids |
|
|
| for i, repo_id in enumerate(repo_ids): |
| mini_batch = {k: v[i] for k, v in batch.items()} |
|
|
| normalize_inputs = self.normalize_inputs[repo_id] |
| select_video_keys = self.robot_config["select_video_keys"][repo_id] |
| select_state_keys = self.robot_config["select_state_keys"][repo_id] |
|
|
| for k in normalize_inputs.features: |
| if not isinstance(mini_batch[k], torch.Tensor): |
| mini_batch[k] = torch.tensor(mini_batch[k], dtype=torch.float32) |
|
|
| mini_batch = normalize_inputs(mini_batch) |
| states = torch.concat([mini_batch[k] for k in select_state_keys]) |
| batch_states.append(pad_vector(states, max_state_dim)) |
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| *({"type": "image", "image": mini_batch[k]} for k in select_video_keys), |
| {"type": "state", "state": []}, |
| {"type": "text", "text": f"{mini_batch['task']}{TASK_VLA_TOKEN}"}, |
| ], |
| } |
| ] |
| batch_messages += [messages] |
| return batch_messages, batch_states, repo_ids |
|
|
| def _process_robot_outputs(self, repo_ids: list[str], actions: torch.Tensor): |
| """Process model outputs back to robot format""" |
| output_actions = [] |
| for i, repo_id in enumerate(repo_ids): |
| unnormalize_outputs = self.unnormalize_outputs[repo_id] |
| select_action_keys = self.robot_config["select_action_keys"][repo_id] |
| features = unnormalize_outputs.features |
| cum_dims = [0] + np.cumsum([features[k].shape[0] for k in select_action_keys]).tolist() |
| origin_action = actions[i].to(torch.float32)[..., : cum_dims[-1]] |
| batch = { |
| k: origin_action[..., cum_dims[m] : cum_dims[m + 1]] for m, k in enumerate(select_action_keys) |
| } |
| unnorm_actions = unnormalize_outputs(batch) |
| unnorm_actions = torch.concat([unnorm_actions[k] for k in select_action_keys], -1) |
| output_actions.append(unnorm_actions) |
| output_actions = torch.stack(output_actions, dim=0) |
| return output_actions |
|
|
| @torch.no_grad |
| def select_action(self, model, batch: dict, **kwargs): |
| batch_messages, batch_states, repo_ids = self._prepare_robot_inputs(batch) |
|
|
| noise_prompt = f"{ACTION_START_TOKEN}{DEFAULT_ACTION_TOKEN}{ACTION_END_TOKEN}" |
| inputs = self.apply_chat_template( |
| batch_messages, |
| states=batch_states, |
| add_generation_prompt=True, |
| noise_prompt=noise_prompt, |
| tokenize=True, |
| return_dict=True, |
| return_tensors="pt", |
| ).to(model.device) |
|
|
| actions = model.sample_actions(**inputs).cpu() |
| output_actions = self._process_robot_outputs(repo_ids, actions) |
| return BatchFeature({"action": output_actions}) |
|
|
|
|
| def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: |
| """Lerobot robot policy features""" |
| policy_features = {} |
| for key, ft in features.items(): |
| shape = ft["shape"] |
| if ft["dtype"] in ["image", "video"]: |
| type = FeatureType.VISUAL |
| if len(shape) != 3: |
| raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})") |
| names = ft["names"] |
| if names[2] in ["channel", "channels"]: |
| shape = (shape[2], shape[0], shape[1]) |
| elif key == "observation.environment_state": |
| type = FeatureType.ENV |
| elif key.startswith("observation"): |
| type = FeatureType.STATE |
| elif key.startswith("action"): |
| type = FeatureType.ACTION |
| else: |
| continue |
| policy_features[key] = PolicyFeature( |
| type=type, |
| shape=shape, |
| ) |
| return policy_features |
|
|
|
|
| def pad_vector(vector, new_dim=32): |
| """Can be (batch_size x sequence_length x features_dimension) |
| or (batch_size x features_dimension) |
| """ |
| if vector.shape[-1] == new_dim: |
| return vector |
| shape = list(vector.shape) |
| current_dim = shape[-1] |
| shape[-1] = new_dim |
| new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device) |
| new_vector[..., :current_dim] = vector |
| return new_vector |
|
|
|
|
| EO1VisionProcessor.register_for_auto_class() |
|
|