Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +32 -0
- lerobot/common/policies/vqbet/configuration_vqbet.py +200 -0
- lerobot/common/policies/vqbet/modeling_vqbet.py +911 -0
- lerobot/common/robot_devices/cameras/configs.py +114 -0
- lerobot/common/robot_devices/cameras/intelrealsense.py +538 -0
- lerobot/common/robot_devices/cameras/opencv.py +518 -0
- lerobot/common/robot_devices/cameras/utils.py +67 -0
- lerobot/common/robot_devices/control_configs.py +129 -0
- lerobot/common/robot_devices/control_utils.py +347 -0
- lerobot/common/robot_devices/motors/configs.py +41 -0
- lerobot/common/robot_devices/motors/dynamixel.py +873 -0
- lerobot/common/robot_devices/motors/feetech.py +898 -0
- lerobot/common/robot_devices/motors/utils.py +67 -0
- lerobot/common/robot_devices/robots/configs.py +613 -0
- lerobot/common/robot_devices/robots/dynamixel_calibration.py +144 -0
- lerobot/common/robot_devices/robots/feetech_calibration.py +498 -0
- lerobot/common/robot_devices/robots/lekiwi_remote.py +224 -0
- lerobot/common/robot_devices/robots/manipulator.py +627 -0
- lerobot/common/robot_devices/robots/mobile_manipulator.py +703 -0
- lerobot/common/robot_devices/robots/stretch.py +208 -0
- lerobot/common/robot_devices/robots/utils.py +86 -0
- lerobot/common/robot_devices/utils.py +65 -0
- lerobot/common/utils/benchmark.py +92 -0
- lerobot/common/utils/hub.py +202 -0
- lerobot/common/utils/import_utils.py +59 -0
- lerobot/common/utils/io_utils.py +111 -0
- lerobot/common/utils/logging_utils.py +163 -0
- lerobot/common/utils/random_utils.py +191 -0
- lerobot/common/utils/train_utils.py +161 -0
- lerobot/common/utils/utils.py +230 -0
- lerobot/common/utils/wandb_utils.py +127 -0
- lerobot/configs/default.py +70 -0
- lerobot/configs/eval.py +65 -0
- lerobot/configs/parser.py +232 -0
- lerobot/configs/policies.py +176 -0
- lerobot/configs/train.py +175 -0
- lerobot/configs/types.py +41 -0
- lerobot/scripts/configure_motor.py +176 -0
- lerobot/scripts/control_robot.py +393 -0
- lerobot/scripts/control_sim_robot.py +561 -0
- lerobot/scripts/display_sys_info.py +90 -0
- lerobot/scripts/eval.py +502 -0
- lerobot/scripts/find_motors_bus_port.py +55 -0
- lerobot/scripts/push_dataset_to_hub.py +364 -0
- lerobot/scripts/push_pretrained.py +71 -0
- lerobot/scripts/train.py +288 -0
- lerobot/scripts/visualize_dataset.py +292 -0
- lerobot/scripts/visualize_dataset_html.py +479 -0
- lerobot/scripts/visualize_image_transforms.py +130 -0
- lerobot/templates/visualize_dataset_homepage.html +68 -0
.gitattributes
CHANGED
|
@@ -18,3 +18,35 @@
|
|
| 18 |
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 19 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 20 |
*.json !text !filter !merge !diff
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 19 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 20 |
*.json !text !filter !merge !diff
|
| 21 |
+
media/lerobot-logo-light.png filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
media/aloha/follower_rotated.webp filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
media/aloha/follower_zero.webp filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
media/lerobot-logo-thumbnail.png filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
media/wandb.png filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
media/gym/pusht_diffusion.gif filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
media/aloha/leader_rest.webp filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
media/aloha/leader_zero.webp filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
media/koch/follower_rest.webp filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
media/koch/follower_zero.webp filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
media/gym/aloha_act.gif filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
media/aloha/leader_rotated.webp filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
media/koch/follower_rotated.webp filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
media/koch/leader_rest.webp filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
media/gym/simxarm_tdmpc.gif filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
media/koch/leader_rotated.webp filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
media/aloha/follower_rest.webp filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
media/koch/leader_zero.webp filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
media/lekiwi/kiwi.webp filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
media/lekiwi/mobile_calib_rest.webp filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
media/lekiwi/mobile_calib_rotated.webp filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
media/moss/follower_initial.webp filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
media/lekiwi/motor_ids.webp filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
media/moss/follower_rotated.webp filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
media/lekiwi/mobile_calib_zero.webp filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
media/moss/follower_rest.webp filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
media/moss/leader_rotated.webp filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
media/moss/follower_zero.webp filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
media/moss/leader_zero.webp filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
media/so100/follower_initial.webp filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
media/so100/follower_rest.webp filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
media/so100/leader_follower.webp filter=lfs diff=lfs merge=lfs -text
|
lerobot/common/policies/vqbet/configuration_vqbet.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
|
| 4 |
+
# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
|
| 5 |
+
# and The HuggingFace Inc. team. All rights reserved.
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
|
| 19 |
+
from dataclasses import dataclass, field
|
| 20 |
+
|
| 21 |
+
from lerobot.common.optim.optimizers import AdamConfig
|
| 22 |
+
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
|
| 23 |
+
from lerobot.configs.policies import PreTrainedConfig
|
| 24 |
+
from lerobot.configs.types import NormalizationMode
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@PreTrainedConfig.register_subclass("vqbet")
|
| 28 |
+
@dataclass
|
| 29 |
+
class VQBeTConfig(PreTrainedConfig):
|
| 30 |
+
"""Configuration class for VQ-BeT.
|
| 31 |
+
|
| 32 |
+
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
| 33 |
+
|
| 34 |
+
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
| 35 |
+
Those are: `input_shapes` and `output_shapes`.
|
| 36 |
+
|
| 37 |
+
Notes on the inputs and outputs:
|
| 38 |
+
- "observation.state" is required as an input key.
|
| 39 |
+
- At least one key starting with "observation.image is required as an input.
|
| 40 |
+
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
|
| 41 |
+
views. Right now we only support all images having the same shape.
|
| 42 |
+
- "action" is required as an output key.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
| 46 |
+
current step and additional steps going back).
|
| 47 |
+
n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts.
|
| 48 |
+
action_chunk_size: Action chunk size of each action prediction token.
|
| 49 |
+
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
| 50 |
+
The key represents the input data name, and the value is a list indicating the dimensions
|
| 51 |
+
of the corresponding data. For example, "observation.image" refers to an input from
|
| 52 |
+
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
| 53 |
+
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
| 54 |
+
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
| 55 |
+
The key represents the output data name, and the value is a list indicating the dimensions
|
| 56 |
+
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
| 57 |
+
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
| 58 |
+
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
| 59 |
+
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
| 60 |
+
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
| 61 |
+
[-1, 1] range.
|
| 62 |
+
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
| 63 |
+
original scale. Note that this is also used for normalizing the training targets.
|
| 64 |
+
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
| 65 |
+
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
| 66 |
+
within the image size. If None, no cropping is done.
|
| 67 |
+
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
| 68 |
+
mode).
|
| 69 |
+
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
| 70 |
+
`None` means no pretrained weights.
|
| 71 |
+
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
| 72 |
+
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
| 73 |
+
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
| 74 |
+
n_vqvae_training_steps: Number of optimization steps for training Residual VQ.
|
| 75 |
+
vqvae_n_embed: Number of embedding vectors in the RVQ dictionary (each layer).
|
| 76 |
+
vqvae_embedding_dim: Dimension of each embedding vector in the RVQ dictionary.
|
| 77 |
+
vqvae_enc_hidden_dim: Size of hidden dimensions of Encoder / Decoder part of Residaul VQ-VAE
|
| 78 |
+
gpt_block_size: Max block size of minGPT (should be larger than the number of input tokens)
|
| 79 |
+
gpt_input_dim: Size of output input of GPT. This is also used as the dimension of observation features.
|
| 80 |
+
gpt_output_dim: Size of output dimension of GPT. This is also used as a input dimension of offset / bin prediction headers.
|
| 81 |
+
gpt_n_layer: Number of layers of GPT
|
| 82 |
+
gpt_n_head: Number of headers of GPT
|
| 83 |
+
gpt_hidden_dim: Size of hidden dimensions of GPT
|
| 84 |
+
dropout: Dropout rate for GPT
|
| 85 |
+
mlp_hidden_dim: Size of hidden dimensions of offset header / bin prediction headers parts of VQ-BeT
|
| 86 |
+
offset_loss_weight: A constant that is multiplied to the offset loss
|
| 87 |
+
primary_code_loss_weight: A constant that is multiplied to the primary code prediction loss
|
| 88 |
+
secondary_code_loss_weight: A constant that is multiplied to the secondary code prediction loss
|
| 89 |
+
bet_softmax_temperature: Sampling temperature of code for rollout with VQ-BeT
|
| 90 |
+
sequentially_select: Whether select code of primary / secondary as sequentially (pick primary code,
|
| 91 |
+
and then select secodnary code), or at the same time.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
# Inputs / output structure.
|
| 95 |
+
n_obs_steps: int = 5
|
| 96 |
+
n_action_pred_token: int = 3
|
| 97 |
+
action_chunk_size: int = 5
|
| 98 |
+
|
| 99 |
+
normalization_mapping: dict[str, NormalizationMode] = field(
|
| 100 |
+
default_factory=lambda: {
|
| 101 |
+
"VISUAL": NormalizationMode.IDENTITY,
|
| 102 |
+
"STATE": NormalizationMode.MIN_MAX,
|
| 103 |
+
"ACTION": NormalizationMode.MIN_MAX,
|
| 104 |
+
}
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Architecture / modeling.
|
| 108 |
+
# Vision backbone.
|
| 109 |
+
vision_backbone: str = "resnet18"
|
| 110 |
+
crop_shape: tuple[int, int] | None = (84, 84)
|
| 111 |
+
crop_is_random: bool = True
|
| 112 |
+
pretrained_backbone_weights: str | None = None
|
| 113 |
+
use_group_norm: bool = True
|
| 114 |
+
spatial_softmax_num_keypoints: int = 32
|
| 115 |
+
# VQ-VAE
|
| 116 |
+
n_vqvae_training_steps: int = 20000
|
| 117 |
+
vqvae_n_embed: int = 16
|
| 118 |
+
vqvae_embedding_dim: int = 256
|
| 119 |
+
vqvae_enc_hidden_dim: int = 128
|
| 120 |
+
# VQ-BeT
|
| 121 |
+
gpt_block_size: int = 500
|
| 122 |
+
gpt_input_dim: int = 512
|
| 123 |
+
gpt_output_dim: int = 512
|
| 124 |
+
gpt_n_layer: int = 8
|
| 125 |
+
gpt_n_head: int = 8
|
| 126 |
+
gpt_hidden_dim: int = 512
|
| 127 |
+
dropout: float = 0.1
|
| 128 |
+
mlp_hidden_dim: int = 1024
|
| 129 |
+
offset_loss_weight: float = 10000.0
|
| 130 |
+
primary_code_loss_weight: float = 5.0
|
| 131 |
+
secondary_code_loss_weight: float = 0.5
|
| 132 |
+
bet_softmax_temperature: float = 0.1
|
| 133 |
+
sequentially_select: bool = False
|
| 134 |
+
|
| 135 |
+
# Training presets
|
| 136 |
+
optimizer_lr: float = 1e-4
|
| 137 |
+
optimizer_betas: tuple = (0.95, 0.999)
|
| 138 |
+
optimizer_eps: float = 1e-8
|
| 139 |
+
optimizer_weight_decay: float = 1e-6
|
| 140 |
+
optimizer_vqvae_lr: float = 1e-3
|
| 141 |
+
optimizer_vqvae_weight_decay: float = 1e-4
|
| 142 |
+
scheduler_warmup_steps: int = 500
|
| 143 |
+
|
| 144 |
+
def __post_init__(self):
|
| 145 |
+
super().__post_init__()
|
| 146 |
+
|
| 147 |
+
"""Input validation (not exhaustive)."""
|
| 148 |
+
if not self.vision_backbone.startswith("resnet"):
|
| 149 |
+
raise ValueError(
|
| 150 |
+
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def get_optimizer_preset(self) -> AdamConfig:
|
| 154 |
+
return AdamConfig(
|
| 155 |
+
lr=self.optimizer_lr,
|
| 156 |
+
betas=self.optimizer_betas,
|
| 157 |
+
eps=self.optimizer_eps,
|
| 158 |
+
weight_decay=self.optimizer_weight_decay,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def get_scheduler_preset(self) -> VQBeTSchedulerConfig:
|
| 162 |
+
return VQBeTSchedulerConfig(
|
| 163 |
+
num_warmup_steps=self.scheduler_warmup_steps,
|
| 164 |
+
num_vqvae_training_steps=self.n_vqvae_training_steps,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def validate_features(self) -> None:
|
| 168 |
+
# Note: this check was previously performed inside VQBeTRgbEncoder in the form of
|
| 169 |
+
# assert len(image_keys) == 1
|
| 170 |
+
if not len(self.image_features) == 1:
|
| 171 |
+
raise ValueError("You must provide only one image among the inputs.")
|
| 172 |
+
|
| 173 |
+
if self.crop_shape is not None:
|
| 174 |
+
for key, image_ft in self.image_features.items():
|
| 175 |
+
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
| 176 |
+
raise ValueError(
|
| 177 |
+
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
| 178 |
+
f"for `crop_shape` and {image_ft.shape} for "
|
| 179 |
+
f"`{key}`."
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Check that all input images have the same shape.
|
| 183 |
+
first_image_key, first_image_ft = next(iter(self.image_features.items()))
|
| 184 |
+
for key, image_ft in self.image_features.items():
|
| 185 |
+
if image_ft.shape != first_image_ft.shape:
|
| 186 |
+
raise ValueError(
|
| 187 |
+
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
@property
|
| 191 |
+
def observation_delta_indices(self) -> list:
|
| 192 |
+
return list(range(1 - self.n_obs_steps, 1))
|
| 193 |
+
|
| 194 |
+
@property
|
| 195 |
+
def action_delta_indices(self) -> list:
|
| 196 |
+
return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1))
|
| 197 |
+
|
| 198 |
+
@property
|
| 199 |
+
def reward_delta_indices(self) -> None:
|
| 200 |
+
return None
|
lerobot/common/policies/vqbet/modeling_vqbet.py
ADDED
|
@@ -0,0 +1,911 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
|
| 4 |
+
# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
|
| 5 |
+
# and The HuggingFace Inc. team. All rights reserved.
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
|
| 19 |
+
import warnings
|
| 20 |
+
from collections import deque
|
| 21 |
+
from typing import Callable, List
|
| 22 |
+
|
| 23 |
+
import einops
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn.functional as F # noqa: N812
|
| 27 |
+
import torchvision
|
| 28 |
+
from torch import Tensor, nn
|
| 29 |
+
|
| 30 |
+
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
| 31 |
+
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
| 32 |
+
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
| 33 |
+
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
| 34 |
+
from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
|
| 35 |
+
|
| 36 |
+
# ruff: noqa: N806
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class VQBeTPolicy(PreTrainedPolicy):
|
| 40 |
+
"""
|
| 41 |
+
VQ-BeT Policy as per "Behavior Generation with Latent Actions"
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
config_class = VQBeTConfig
|
| 45 |
+
name = "vqbet"
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
config: VQBeTConfig | None = None,
|
| 50 |
+
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
| 51 |
+
):
|
| 52 |
+
"""
|
| 53 |
+
Args:
|
| 54 |
+
config: Policy configuration class instance or None, in which case the default instantiation of
|
| 55 |
+
the configuration class is used.
|
| 56 |
+
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
| 57 |
+
that they will be passed with a call to `load_state_dict` before the policy is used.
|
| 58 |
+
"""
|
| 59 |
+
super().__init__(config)
|
| 60 |
+
config.validate_features()
|
| 61 |
+
self.config = config
|
| 62 |
+
|
| 63 |
+
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
| 64 |
+
self.normalize_targets = Normalize(
|
| 65 |
+
config.output_features, config.normalization_mapping, dataset_stats
|
| 66 |
+
)
|
| 67 |
+
self.unnormalize_outputs = Unnormalize(
|
| 68 |
+
config.output_features, config.normalization_mapping, dataset_stats
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.vqbet = VQBeTModel(config)
|
| 72 |
+
|
| 73 |
+
self.reset()
|
| 74 |
+
|
| 75 |
+
def get_optim_params(self) -> dict:
|
| 76 |
+
vqvae_params = (
|
| 77 |
+
list(self.vqbet.action_head.vqvae_model.encoder.parameters())
|
| 78 |
+
+ list(self.vqbet.action_head.vqvae_model.decoder.parameters())
|
| 79 |
+
+ list(self.vqbet.action_head.vqvae_model.vq_layer.parameters())
|
| 80 |
+
)
|
| 81 |
+
decay_params, no_decay_params = self.vqbet.policy.configure_parameters()
|
| 82 |
+
decay_params = (
|
| 83 |
+
decay_params
|
| 84 |
+
+ list(self.vqbet.rgb_encoder.parameters())
|
| 85 |
+
+ list(self.vqbet.state_projector.parameters())
|
| 86 |
+
+ list(self.vqbet.rgb_feature_projector.parameters())
|
| 87 |
+
+ [self.vqbet.action_token]
|
| 88 |
+
+ list(self.vqbet.action_head.map_to_cbet_preds_offset.parameters())
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
if self.config.sequentially_select:
|
| 92 |
+
decay_params = (
|
| 93 |
+
decay_params
|
| 94 |
+
+ list(self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
|
| 95 |
+
+ list(self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
decay_params = decay_params + list(self.vqbet.action_head.map_to_cbet_preds_bin.parameters())
|
| 99 |
+
|
| 100 |
+
return [
|
| 101 |
+
{
|
| 102 |
+
"params": decay_params,
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"params": vqvae_params,
|
| 106 |
+
"weight_decay": self.config.optimizer_vqvae_weight_decay,
|
| 107 |
+
"lr": self.config.optimizer_vqvae_lr,
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"params": no_decay_params,
|
| 111 |
+
"weight_decay": 0.0,
|
| 112 |
+
},
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
def reset(self):
|
| 116 |
+
"""
|
| 117 |
+
Clear observation and action queues. Should be called on `env.reset()`
|
| 118 |
+
queues are populated during rollout of the policy, they contain the n latest observations and actions
|
| 119 |
+
"""
|
| 120 |
+
self._queues = {
|
| 121 |
+
"observation.images": deque(maxlen=self.config.n_obs_steps),
|
| 122 |
+
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
| 123 |
+
"action": deque(maxlen=self.config.action_chunk_size),
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
@torch.no_grad
|
| 127 |
+
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
| 128 |
+
"""Select a single action given environment observations.
|
| 129 |
+
|
| 130 |
+
This method wraps `select_actions` in order to return one action at a time for execution in the
|
| 131 |
+
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
| 132 |
+
queue is empty.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
batch = self.normalize_inputs(batch)
|
| 136 |
+
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
| 137 |
+
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
| 138 |
+
# Note: It's important that this happens after stacking the images into a single key.
|
| 139 |
+
self._queues = populate_queues(self._queues, batch)
|
| 140 |
+
|
| 141 |
+
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
| 142 |
+
warnings.warn(
|
| 143 |
+
"To evaluate in the environment, your VQ-BeT model should contain a pretrained Residual VQ.",
|
| 144 |
+
stacklevel=1,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
if len(self._queues["action"]) == 0:
|
| 148 |
+
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
| 149 |
+
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
|
| 150 |
+
|
| 151 |
+
# the dimension of returned action is (batch_size, action_chunk_size, action_dim)
|
| 152 |
+
actions = self.unnormalize_outputs({"action": actions})["action"]
|
| 153 |
+
# since the data in the action queue's dimension is (action_chunk_size, batch_size, action_dim), we transpose the action and fill the queue
|
| 154 |
+
self._queues["action"].extend(actions.transpose(0, 1))
|
| 155 |
+
|
| 156 |
+
action = self._queues["action"].popleft()
|
| 157 |
+
return action
|
| 158 |
+
|
| 159 |
+
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
| 160 |
+
"""Run the batch through the model and compute the loss for training or validation."""
|
| 161 |
+
batch = self.normalize_inputs(batch)
|
| 162 |
+
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
| 163 |
+
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
| 164 |
+
batch = self.normalize_targets(batch)
|
| 165 |
+
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
| 166 |
+
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
| 167 |
+
# loss: total loss of training RVQ
|
| 168 |
+
# n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
|
| 169 |
+
# n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree).
|
| 170 |
+
loss, n_different_codes, n_different_combinations, recon_l1_error = (
|
| 171 |
+
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
|
| 172 |
+
)
|
| 173 |
+
return loss, {
|
| 174 |
+
"n_different_codes": n_different_codes,
|
| 175 |
+
"n_different_combinations": n_different_combinations,
|
| 176 |
+
"recon_l1_error": recon_l1_error,
|
| 177 |
+
}
|
| 178 |
+
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts.
|
| 179 |
+
_, loss_dict = self.vqbet(batch, rollout=False)
|
| 180 |
+
loss = loss_dict.pop("loss")
|
| 181 |
+
|
| 182 |
+
return loss, loss_dict
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class SpatialSoftmax(nn.Module):
|
| 186 |
+
"""
|
| 187 |
+
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
|
| 188 |
+
(https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
|
| 189 |
+
|
| 190 |
+
At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
|
| 191 |
+
of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
|
| 192 |
+
|
| 193 |
+
Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
|
| 194 |
+
-----------------------------------------------------
|
| 195 |
+
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
|
| 196 |
+
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
|
| 197 |
+
| ... | ... | ... | ... |
|
| 198 |
+
| (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
|
| 199 |
+
-----------------------------------------------------
|
| 200 |
+
This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
|
| 201 |
+
product with the coordinates (120x2) to get expected points of maximal activation (512x2).
|
| 202 |
+
|
| 203 |
+
The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
|
| 204 |
+
provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
|
| 205 |
+
linear mapping (in_channels, H, W) -> (num_kp, H, W).
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
def __init__(self, input_shape, num_kp=None):
|
| 209 |
+
"""
|
| 210 |
+
Args:
|
| 211 |
+
input_shape (list): (C, H, W) input feature map shape.
|
| 212 |
+
num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
|
| 213 |
+
"""
|
| 214 |
+
super().__init__()
|
| 215 |
+
|
| 216 |
+
assert len(input_shape) == 3
|
| 217 |
+
self._in_c, self._in_h, self._in_w = input_shape
|
| 218 |
+
|
| 219 |
+
if num_kp is not None:
|
| 220 |
+
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
|
| 221 |
+
self._out_c = num_kp
|
| 222 |
+
else:
|
| 223 |
+
self.nets = None
|
| 224 |
+
self._out_c = self._in_c
|
| 225 |
+
|
| 226 |
+
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
| 227 |
+
# and causes a small degradation in pc_success of pre-trained models.
|
| 228 |
+
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
| 229 |
+
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
| 230 |
+
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
| 231 |
+
# register as buffer so it's moved to the correct device.
|
| 232 |
+
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
|
| 233 |
+
|
| 234 |
+
def forward(self, features: Tensor) -> Tensor:
|
| 235 |
+
"""
|
| 236 |
+
Args:
|
| 237 |
+
features: (B, C, H, W) input feature maps.
|
| 238 |
+
Returns:
|
| 239 |
+
(B, K, 2) image-space coordinates of keypoints.
|
| 240 |
+
"""
|
| 241 |
+
if self.nets is not None:
|
| 242 |
+
features = self.nets(features)
|
| 243 |
+
|
| 244 |
+
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
|
| 245 |
+
features = features.reshape(-1, self._in_h * self._in_w)
|
| 246 |
+
# 2d softmax normalization
|
| 247 |
+
attention = F.softmax(features, dim=-1)
|
| 248 |
+
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
|
| 249 |
+
expected_xy = attention @ self.pos_grid
|
| 250 |
+
# reshape to [B, K, 2]
|
| 251 |
+
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
|
| 252 |
+
|
| 253 |
+
return feature_keypoints
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class VQBeTModel(nn.Module):
|
| 257 |
+
"""VQ-BeT: The underlying neural network for VQ-BeT
|
| 258 |
+
|
| 259 |
+
Note: In this code we use the terms `rgb_encoder`, 'policy', `action_head`. The meanings are as follows.
|
| 260 |
+
- The `rgb_encoder` process rgb-style image observations to one-dimensional embedding vectors
|
| 261 |
+
- A `policy` is a minGPT architecture, that takes observation sequences and action query tokens to generate `features`.
|
| 262 |
+
- These `features` pass through the action head, which passes through the code prediction, offset prediction head,
|
| 263 |
+
and finally generates a prediction for the action chunks.
|
| 264 |
+
|
| 265 |
+
-------------------------------** legend **-------------------------------
|
| 266 |
+
│ n = n_obs_steps, p = n_action_pred_token, c = action_chunk_size) │
|
| 267 |
+
│ o_{t} : visual observation at timestep {t} │
|
| 268 |
+
│ s_{t} : state observation at timestep {t} │
|
| 269 |
+
│ a_{t} : action at timestep {t} │
|
| 270 |
+
│ A_Q : action_query_token │
|
| 271 |
+
--------------------------------------------------------------------------
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
Training Phase 1. Discretize action using Residual VQ (for config.n_vqvae_training_steps steps)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
| 278 |
+
│ │ │ │ │ │
|
| 279 |
+
│ RVQ encoder │ ─► │ Residual │ ─► │ RVQ Decoder │
|
| 280 |
+
│ (a_{t}~a_{t+p}) │ │ Code Quantizer │ │ │
|
| 281 |
+
│ │ │ │ │ │
|
| 282 |
+
└─────────────────┘ └─────────────────┘ └─────────────────┘
|
| 283 |
+
|
| 284 |
+
Training Phase 2.
|
| 285 |
+
|
| 286 |
+
timestep {t-n+1} timestep {t-n+2} timestep {t}
|
| 287 |
+
┌─────┴─────┐ ┌─────┴─────┐ ┌─────┴─────┐
|
| 288 |
+
|
| 289 |
+
o_{t-n+1} o_{t-n+2} ... o_{t}
|
| 290 |
+
│ │ │
|
| 291 |
+
│ s_{t-n+1} │ s_{t-n+2} ... │ s_{t} p
|
| 292 |
+
│ │ │ │ │ │ ┌───────┴───────┐
|
| 293 |
+
│ │ A_Q │ │ A_Q ... │ │ A_Q ... A_Q
|
| 294 |
+
│ │ │ │ │ │ │ │ │ │
|
| 295 |
+
┌───▼─────▼─────▼─────▼─────▼─────▼─────────────────▼─────▼─────▼───────────────▼───┐
|
| 296 |
+
│ │
|
| 297 |
+
│ GPT │ => policy
|
| 298 |
+
│ │
|
| 299 |
+
└───────────────▼─────────────────▼─────────────────────────────▼───────────────▼───┘
|
| 300 |
+
│ │ │ │
|
| 301 |
+
┌───┴───┐ ┌───┴───┐ ┌───┴───┐ ┌───┴───┐
|
| 302 |
+
code offset code offset code offset code offset
|
| 303 |
+
▼ │ ▼ │ ▼ │ ▼ │ => action_head
|
| 304 |
+
RVQ Decoder │ RVQ Decoder │ RVQ Decoder │ RVQ Decoder │
|
| 305 |
+
└── + ──┘ └── + ──┘ └── + ──┘ └── + ──┘
|
| 306 |
+
▼ ▼ ▼ ▼
|
| 307 |
+
action chunk action chunk action chunk action chunk
|
| 308 |
+
a_{t-n+1} ~ a_{t-n+2} ~ a_{t} ~ ... a_{t+p-1} ~
|
| 309 |
+
a_{t-n+c} a_{t-n+c+1} a_{t+c-1} a_{t+p+c-1}
|
| 310 |
+
|
| 311 |
+
▼
|
| 312 |
+
ONLY this chunk is used in rollout!
|
| 313 |
+
"""
|
| 314 |
+
|
| 315 |
+
def __init__(self, config: VQBeTConfig):
|
| 316 |
+
super().__init__()
|
| 317 |
+
self.config = config
|
| 318 |
+
|
| 319 |
+
self.rgb_encoder = VQBeTRgbEncoder(config)
|
| 320 |
+
self.num_images = len(self.config.image_features)
|
| 321 |
+
# This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above.
|
| 322 |
+
# Note: During the forward pass, this token is repeated as many times as needed. The authors also experimented with initializing the necessary number of tokens independently and observed inferior results.
|
| 323 |
+
self.action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
|
| 324 |
+
|
| 325 |
+
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
|
| 326 |
+
self.state_projector = MLP(
|
| 327 |
+
config.robot_state_feature.shape[0], hidden_channels=[self.config.gpt_input_dim]
|
| 328 |
+
)
|
| 329 |
+
self.rgb_feature_projector = MLP(
|
| 330 |
+
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# GPT part of VQ-BeT
|
| 334 |
+
self.policy = GPT(config)
|
| 335 |
+
# bin prediction head / offset prediction head part of VQ-BeT
|
| 336 |
+
self.action_head = VQBeTHead(config)
|
| 337 |
+
|
| 338 |
+
# Action tokens for: each observation step, the current action token, and all future action tokens.
|
| 339 |
+
num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1
|
| 340 |
+
self.register_buffer(
|
| 341 |
+
"select_target_actions_indices",
|
| 342 |
+
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
|
| 346 |
+
# Input validation.
|
| 347 |
+
assert set(batch).issuperset({"observation.state", "observation.images"})
|
| 348 |
+
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
| 349 |
+
assert n_obs_steps == self.config.n_obs_steps
|
| 350 |
+
|
| 351 |
+
# Extract image feature (first combine batch and sequence dims).
|
| 352 |
+
img_features = self.rgb_encoder(
|
| 353 |
+
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
| 354 |
+
)
|
| 355 |
+
# Separate batch and sequence dims.
|
| 356 |
+
img_features = einops.rearrange(
|
| 357 |
+
img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
# Arrange prior and current observation step tokens as shown in the class docstring.
|
| 361 |
+
# First project features to token dimension.
|
| 362 |
+
rgb_tokens = self.rgb_feature_projector(
|
| 363 |
+
img_features
|
| 364 |
+
) # (batch, obs_step, number of different cameras, projection dims)
|
| 365 |
+
input_tokens = [rgb_tokens[:, :, i] for i in range(rgb_tokens.size(2))]
|
| 366 |
+
input_tokens.append(
|
| 367 |
+
self.state_projector(batch["observation.state"])
|
| 368 |
+
) # (batch, obs_step, projection dims)
|
| 369 |
+
input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
|
| 370 |
+
# Interleave tokens by stacking and rearranging.
|
| 371 |
+
input_tokens = torch.stack(input_tokens, dim=2)
|
| 372 |
+
input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d")
|
| 373 |
+
|
| 374 |
+
len_additional_action_token = self.config.n_action_pred_token - 1
|
| 375 |
+
future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1)
|
| 376 |
+
|
| 377 |
+
# add additional action query tokens for predicting future action chunks
|
| 378 |
+
input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1)
|
| 379 |
+
|
| 380 |
+
# get action features (pass through GPT)
|
| 381 |
+
features = self.policy(input_tokens)
|
| 382 |
+
# len(self.config.input_features) is the number of different observation modes.
|
| 383 |
+
# this line gets the index of action prompt tokens.
|
| 384 |
+
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len(
|
| 385 |
+
self.config.input_features
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
# only extract the output tokens at the position of action query:
|
| 389 |
+
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
|
| 390 |
+
# mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://arxiv.org/pdf/2206.11251).
|
| 391 |
+
# Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
|
| 392 |
+
if len_additional_action_token > 0:
|
| 393 |
+
features = torch.cat(
|
| 394 |
+
[features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
|
| 395 |
+
)
|
| 396 |
+
else:
|
| 397 |
+
features = features[:, historical_act_pred_index]
|
| 398 |
+
# pass through action head
|
| 399 |
+
action_head_output = self.action_head(features)
|
| 400 |
+
# if rollout, VQ-BeT don't calculate loss
|
| 401 |
+
if rollout:
|
| 402 |
+
return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape(
|
| 403 |
+
batch_size, self.config.action_chunk_size, -1
|
| 404 |
+
)
|
| 405 |
+
# else, it calculate overall loss (bin prediction loss, and offset loss)
|
| 406 |
+
else:
|
| 407 |
+
output = batch["action"][:, self.select_target_actions_indices]
|
| 408 |
+
loss = self.action_head.loss_fn(action_head_output, output, reduction="mean")
|
| 409 |
+
return action_head_output, loss
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class VQBeTHead(nn.Module):
|
| 413 |
+
def __init__(self, config: VQBeTConfig):
|
| 414 |
+
"""
|
| 415 |
+
VQBeTHead takes output of GPT layers, and pass the feature through bin prediction head (`self.map_to_cbet_preds_bin`), and offset prediction head (`self.map_to_cbet_preds_offset`)
|
| 416 |
+
|
| 417 |
+
self.map_to_cbet_preds_bin: outputs probability of each code (for each layer).
|
| 418 |
+
The input dimension of `self.map_to_cbet_preds_bin` is same with the output of GPT,
|
| 419 |
+
and the output dimension of `self.map_to_cbet_preds_bin` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed`.
|
| 420 |
+
if the agent select the code sequentially, we use self.map_to_cbet_preds_primary_bin and self.map_to_cbet_preds_secondary_bin instead of self._map_to_cbet_preds_bin.
|
| 421 |
+
|
| 422 |
+
self.map_to_cbet_preds_offset: output the predicted offsets for all the codes in all the layers.
|
| 423 |
+
The input dimension of ` self.map_to_cbet_preds_offset` is same with the output of GPT,
|
| 424 |
+
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed * config.action_chunk_size * config.action_feature.shape[0]`.
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
super().__init__()
|
| 428 |
+
self.config = config
|
| 429 |
+
# init vqvae
|
| 430 |
+
self.vqvae_model = VqVae(config)
|
| 431 |
+
if config.sequentially_select:
|
| 432 |
+
self.map_to_cbet_preds_primary_bin = MLP(
|
| 433 |
+
in_channels=config.gpt_output_dim,
|
| 434 |
+
hidden_channels=[self.config.vqvae_n_embed],
|
| 435 |
+
)
|
| 436 |
+
self.map_to_cbet_preds_secondary_bin = MLP(
|
| 437 |
+
in_channels=config.gpt_output_dim + self.config.vqvae_n_embed,
|
| 438 |
+
hidden_channels=[self.config.vqvae_n_embed],
|
| 439 |
+
)
|
| 440 |
+
else:
|
| 441 |
+
self.map_to_cbet_preds_bin = MLP(
|
| 442 |
+
in_channels=config.gpt_output_dim,
|
| 443 |
+
hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed],
|
| 444 |
+
)
|
| 445 |
+
self.map_to_cbet_preds_offset = MLP(
|
| 446 |
+
in_channels=config.gpt_output_dim,
|
| 447 |
+
hidden_channels=[
|
| 448 |
+
self.vqvae_model.vqvae_num_layers
|
| 449 |
+
* self.config.vqvae_n_embed
|
| 450 |
+
* config.action_chunk_size
|
| 451 |
+
* config.action_feature.shape[0],
|
| 452 |
+
],
|
| 453 |
+
)
|
| 454 |
+
# loss
|
| 455 |
+
self._focal_loss_fn = FocalLoss(gamma=2.0)
|
| 456 |
+
|
| 457 |
+
def discretize(self, n_vqvae_training_steps, actions):
|
| 458 |
+
# Resize the action sequence data to fit the action chunk size using a sliding window approach.
|
| 459 |
+
actions = torch.cat(
|
| 460 |
+
[
|
| 461 |
+
actions[:, j : j + self.config.action_chunk_size, :]
|
| 462 |
+
for j in range(actions.shape[1] + 1 - self.config.action_chunk_size)
|
| 463 |
+
],
|
| 464 |
+
dim=0,
|
| 465 |
+
)
|
| 466 |
+
# `actions` is a tensor of shape (new_batch, action_chunk_size, action_dim) where new_batch is the number of possible chunks created from the original sequences using the sliding window.
|
| 467 |
+
|
| 468 |
+
loss, metric = self.vqvae_model.vqvae_forward(actions)
|
| 469 |
+
n_different_codes = sum(
|
| 470 |
+
[len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)]
|
| 471 |
+
)
|
| 472 |
+
n_different_combinations = len(torch.unique(metric[2], dim=0))
|
| 473 |
+
recon_l1_error = metric[0].detach().cpu().item()
|
| 474 |
+
self.vqvae_model.optimized_steps += 1
|
| 475 |
+
# if we updated RVQ more than `n_vqvae_training_steps` steps, we freeze the RVQ part.
|
| 476 |
+
if self.vqvae_model.optimized_steps >= n_vqvae_training_steps:
|
| 477 |
+
self.vqvae_model.discretized = torch.tensor(True)
|
| 478 |
+
self.vqvae_model.vq_layer.freeze_codebook = torch.tensor(True)
|
| 479 |
+
print("Finished discretizing action data!")
|
| 480 |
+
self.vqvae_model.eval()
|
| 481 |
+
for param in self.vqvae_model.vq_layer.parameters():
|
| 482 |
+
param.requires_grad = False
|
| 483 |
+
return loss, n_different_codes, n_different_combinations, recon_l1_error
|
| 484 |
+
|
| 485 |
+
def forward(self, x, **kwargs) -> dict:
|
| 486 |
+
# N is the batch size, and T is number of action query tokens, which are process through same GPT
|
| 487 |
+
N, T, _ = x.shape
|
| 488 |
+
# we calculate N and T side parallelly. Thus, the dimensions would be
|
| 489 |
+
# (batch size * number of action query tokens, action chunk size, action dimension)
|
| 490 |
+
x = einops.rearrange(x, "N T WA -> (N T) WA")
|
| 491 |
+
|
| 492 |
+
# sample offsets
|
| 493 |
+
cbet_offsets = self.map_to_cbet_preds_offset(x)
|
| 494 |
+
cbet_offsets = einops.rearrange(
|
| 495 |
+
cbet_offsets,
|
| 496 |
+
"(NT) (G C WA) -> (NT) G C WA",
|
| 497 |
+
G=self.vqvae_model.vqvae_num_layers,
|
| 498 |
+
C=self.config.vqvae_n_embed,
|
| 499 |
+
)
|
| 500 |
+
# if self.config.sequentially_select is True, bin prediction head first sample the primary code, and then sample secondary code
|
| 501 |
+
if self.config.sequentially_select:
|
| 502 |
+
cbet_primary_logits = self.map_to_cbet_preds_primary_bin(x)
|
| 503 |
+
|
| 504 |
+
# select primary bin first
|
| 505 |
+
cbet_primary_probs = torch.softmax(
|
| 506 |
+
cbet_primary_logits / self.config.bet_softmax_temperature, dim=-1
|
| 507 |
+
)
|
| 508 |
+
NT, choices = cbet_primary_probs.shape
|
| 509 |
+
sampled_primary_centers = einops.rearrange(
|
| 510 |
+
torch.multinomial(cbet_primary_probs.view(-1, choices), num_samples=1),
|
| 511 |
+
"(NT) 1 -> NT",
|
| 512 |
+
NT=NT,
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin(
|
| 516 |
+
torch.cat(
|
| 517 |
+
(x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)),
|
| 518 |
+
axis=1,
|
| 519 |
+
)
|
| 520 |
+
)
|
| 521 |
+
cbet_secondary_probs = torch.softmax(
|
| 522 |
+
cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1
|
| 523 |
+
)
|
| 524 |
+
sampled_secondary_centers = einops.rearrange(
|
| 525 |
+
torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1),
|
| 526 |
+
"(NT) 1 -> NT",
|
| 527 |
+
NT=NT,
|
| 528 |
+
)
|
| 529 |
+
sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1)
|
| 530 |
+
cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1)
|
| 531 |
+
# if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once.
|
| 532 |
+
else:
|
| 533 |
+
cbet_logits = self.map_to_cbet_preds_bin(x)
|
| 534 |
+
cbet_logits = einops.rearrange(
|
| 535 |
+
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
|
| 536 |
+
)
|
| 537 |
+
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
|
| 538 |
+
NT, G, choices = cbet_probs.shape
|
| 539 |
+
sampled_centers = einops.rearrange(
|
| 540 |
+
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
|
| 541 |
+
"(NT G) 1 -> NT G",
|
| 542 |
+
NT=NT,
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
device = get_device_from_parameters(self)
|
| 546 |
+
indices = (
|
| 547 |
+
torch.arange(NT, device=device).unsqueeze(1),
|
| 548 |
+
torch.arange(self.vqvae_model.vqvae_num_layers, device=device).unsqueeze(0),
|
| 549 |
+
sampled_centers,
|
| 550 |
+
)
|
| 551 |
+
# Use advanced indexing to sample the values (Extract the only offsets corresponding to the sampled codes.)
|
| 552 |
+
sampled_offsets = cbet_offsets[indices]
|
| 553 |
+
# Then, sum the offsets over the RVQ layers to get a net offset for the bin prediction
|
| 554 |
+
sampled_offsets = sampled_offsets.sum(dim=1)
|
| 555 |
+
with torch.no_grad():
|
| 556 |
+
# Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder
|
| 557 |
+
return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach()
|
| 558 |
+
# pass the centroids through decoder to get actions.
|
| 559 |
+
decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach()
|
| 560 |
+
# reshaped extracted offset to match with decoded centroids
|
| 561 |
+
sampled_offsets = einops.rearrange(
|
| 562 |
+
sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
|
| 563 |
+
)
|
| 564 |
+
# add offset and decoded centroids
|
| 565 |
+
predicted_action = decoded_action + sampled_offsets
|
| 566 |
+
predicted_action = einops.rearrange(
|
| 567 |
+
predicted_action,
|
| 568 |
+
"(N T) W A -> N T (W A)",
|
| 569 |
+
N=N,
|
| 570 |
+
T=T,
|
| 571 |
+
W=self.config.action_chunk_size,
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
return {
|
| 575 |
+
"cbet_logits": cbet_logits,
|
| 576 |
+
"predicted_action": predicted_action,
|
| 577 |
+
"sampled_centers": sampled_centers,
|
| 578 |
+
"decoded_action": decoded_action,
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
def loss_fn(self, pred, target, **kwargs):
|
| 582 |
+
"""
|
| 583 |
+
for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.
|
| 584 |
+
|
| 585 |
+
predicted_action: predicted action chunk (offset + decoded centroids)
|
| 586 |
+
sampled_centers: sampled centroids (code of RVQ)
|
| 587 |
+
decoded_action: decoded action, which is produced by passing sampled_centers through RVQ decoder
|
| 588 |
+
NT: batch size * T
|
| 589 |
+
T: number of action query tokens, which are process through same GPT
|
| 590 |
+
cbet_logits: probability of all codes in each layer
|
| 591 |
+
"""
|
| 592 |
+
action_seq = target
|
| 593 |
+
predicted_action = pred["predicted_action"]
|
| 594 |
+
sampled_centers = pred["sampled_centers"]
|
| 595 |
+
decoded_action = pred["decoded_action"]
|
| 596 |
+
NT = predicted_action.shape[0] * predicted_action.shape[1]
|
| 597 |
+
|
| 598 |
+
cbet_logits = pred["cbet_logits"]
|
| 599 |
+
|
| 600 |
+
predicted_action = einops.rearrange(
|
| 601 |
+
predicted_action, "N T (W A) -> (N T) W A", W=self.config.action_chunk_size
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
|
| 605 |
+
# Figure out the loss for the actions.
|
| 606 |
+
# First, we need to find the closest cluster center for each ground truth action.
|
| 607 |
+
with torch.no_grad():
|
| 608 |
+
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
|
| 609 |
+
|
| 610 |
+
# Now we can compute the loss.
|
| 611 |
+
|
| 612 |
+
# offset loss is L1 distance between the predicted action and ground truth action
|
| 613 |
+
offset_loss = F.l1_loss(action_seq, predicted_action)
|
| 614 |
+
|
| 615 |
+
# calculate primary code prediction loss
|
| 616 |
+
cbet_loss1 = self._focal_loss_fn(
|
| 617 |
+
cbet_logits[:, 0, :],
|
| 618 |
+
action_bins[:, 0],
|
| 619 |
+
)
|
| 620 |
+
# calculate secondary code prediction loss
|
| 621 |
+
cbet_loss2 = self._focal_loss_fn(
|
| 622 |
+
cbet_logits[:, 1, :],
|
| 623 |
+
action_bins[:, 1],
|
| 624 |
+
)
|
| 625 |
+
# add all the prediction loss
|
| 626 |
+
cbet_loss = (
|
| 627 |
+
cbet_loss1 * self.config.primary_code_loss_weight
|
| 628 |
+
+ cbet_loss2 * self.config.secondary_code_loss_weight
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT)
|
| 632 |
+
equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT)
|
| 633 |
+
|
| 634 |
+
action_mse_error = torch.mean((action_seq - predicted_action) ** 2)
|
| 635 |
+
vq_action_error = torch.mean(torch.abs(action_seq - decoded_action))
|
| 636 |
+
offset_action_error = torch.mean(torch.abs(action_seq - predicted_action))
|
| 637 |
+
action_error_max = torch.max(torch.abs(action_seq - predicted_action))
|
| 638 |
+
|
| 639 |
+
loss = cbet_loss + self.config.offset_loss_weight * offset_loss
|
| 640 |
+
|
| 641 |
+
loss_dict = {
|
| 642 |
+
"loss": loss,
|
| 643 |
+
"classification_loss": cbet_loss.detach().cpu().item(),
|
| 644 |
+
"offset_loss": offset_loss.detach().cpu().item(),
|
| 645 |
+
"equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(),
|
| 646 |
+
"equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(),
|
| 647 |
+
"vq_action_error": vq_action_error.detach().cpu().item(),
|
| 648 |
+
"offset_action_error": offset_action_error.detach().cpu().item(),
|
| 649 |
+
"action_error_max": action_error_max.detach().cpu().item(),
|
| 650 |
+
"action_mse_error": action_mse_error.detach().cpu().item(),
|
| 651 |
+
}
|
| 652 |
+
return loss_dict
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
class VQBeTRgbEncoder(nn.Module):
|
| 656 |
+
"""Encode an RGB image into a 1D feature vector.
|
| 657 |
+
|
| 658 |
+
Includes the ability to normalize and crop the image first.
|
| 659 |
+
|
| 660 |
+
Same with DiffusionRgbEncoder from modeling_diffusion.py
|
| 661 |
+
"""
|
| 662 |
+
|
| 663 |
+
def __init__(self, config: VQBeTConfig):
|
| 664 |
+
super().__init__()
|
| 665 |
+
# Set up optional preprocessing.
|
| 666 |
+
if config.crop_shape is not None:
|
| 667 |
+
self.do_crop = True
|
| 668 |
+
# Always use center crop for eval
|
| 669 |
+
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
| 670 |
+
if config.crop_is_random:
|
| 671 |
+
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
| 672 |
+
else:
|
| 673 |
+
self.maybe_random_crop = self.center_crop
|
| 674 |
+
else:
|
| 675 |
+
self.do_crop = False
|
| 676 |
+
|
| 677 |
+
# Set up backbone.
|
| 678 |
+
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
| 679 |
+
weights=config.pretrained_backbone_weights
|
| 680 |
+
)
|
| 681 |
+
# Note: This assumes that the layer4 feature map is children()[-3]
|
| 682 |
+
# TODO(alexander-soare): Use a safer alternative.
|
| 683 |
+
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
| 684 |
+
if config.use_group_norm:
|
| 685 |
+
if config.pretrained_backbone_weights:
|
| 686 |
+
raise ValueError(
|
| 687 |
+
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
|
| 688 |
+
)
|
| 689 |
+
self.backbone = _replace_submodules(
|
| 690 |
+
root_module=self.backbone,
|
| 691 |
+
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
| 692 |
+
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
# Set up pooling and final layers.
|
| 696 |
+
# Use a dry run to get the feature map shape.
|
| 697 |
+
# The dummy input should take the number of image channels from `config.image_features` and it should
|
| 698 |
+
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
| 699 |
+
# height and width from `config.image_features`.
|
| 700 |
+
|
| 701 |
+
images_shape = next(iter(config.image_features.values())).shape
|
| 702 |
+
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
| 703 |
+
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
| 704 |
+
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
| 705 |
+
|
| 706 |
+
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
| 707 |
+
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
| 708 |
+
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
| 709 |
+
self.relu = nn.ReLU()
|
| 710 |
+
|
| 711 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 712 |
+
"""
|
| 713 |
+
Args:
|
| 714 |
+
x: (B, C, H, W) image tensor with pixel values in [0, 1].
|
| 715 |
+
Returns:
|
| 716 |
+
(B, D) image feature.
|
| 717 |
+
"""
|
| 718 |
+
# Preprocess: maybe crop (if it was set up in the __init__).
|
| 719 |
+
if self.do_crop:
|
| 720 |
+
if self.training: # noqa: SIM108
|
| 721 |
+
x = self.maybe_random_crop(x)
|
| 722 |
+
else:
|
| 723 |
+
# Always use center crop for eval.
|
| 724 |
+
x = self.center_crop(x)
|
| 725 |
+
# Extract backbone feature.
|
| 726 |
+
x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
|
| 727 |
+
# Final linear layer with non-linearity.
|
| 728 |
+
x = self.relu(self.out(x))
|
| 729 |
+
return x
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
def _replace_submodules(
|
| 733 |
+
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
| 734 |
+
) -> nn.Module:
|
| 735 |
+
"""
|
| 736 |
+
Args:
|
| 737 |
+
root_module: The module for which the submodules need to be replaced
|
| 738 |
+
predicate: Takes a module as an argument and must return True if the that module is to be replaced.
|
| 739 |
+
func: Takes a module as an argument and returns a new module to replace it with.
|
| 740 |
+
Returns:
|
| 741 |
+
The root module with its submodules replaced.
|
| 742 |
+
"""
|
| 743 |
+
if predicate(root_module):
|
| 744 |
+
return func(root_module)
|
| 745 |
+
|
| 746 |
+
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
| 747 |
+
for *parents, k in replace_list:
|
| 748 |
+
parent_module = root_module
|
| 749 |
+
if len(parents) > 0:
|
| 750 |
+
parent_module = root_module.get_submodule(".".join(parents))
|
| 751 |
+
if isinstance(parent_module, nn.Sequential):
|
| 752 |
+
src_module = parent_module[int(k)]
|
| 753 |
+
else:
|
| 754 |
+
src_module = getattr(parent_module, k)
|
| 755 |
+
tgt_module = func(src_module)
|
| 756 |
+
if isinstance(parent_module, nn.Sequential):
|
| 757 |
+
parent_module[int(k)] = tgt_module
|
| 758 |
+
else:
|
| 759 |
+
setattr(parent_module, k, tgt_module)
|
| 760 |
+
# verify that all BN are replaced
|
| 761 |
+
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
| 762 |
+
return root_module
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
class VqVae(nn.Module):
|
| 766 |
+
def __init__(
|
| 767 |
+
self,
|
| 768 |
+
config: VQBeTConfig,
|
| 769 |
+
):
|
| 770 |
+
"""
|
| 771 |
+
VQ-VAE is composed of three parts: encoder, vq_layer, and decoder.
|
| 772 |
+
Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively.
|
| 773 |
+
The vq_layer uses residual VQs.
|
| 774 |
+
|
| 775 |
+
This class contains functions for training the encoder and decoder along with the residual VQ layer (for training phase 1),
|
| 776 |
+
as well as functions to help BeT training part in training phase 2.
|
| 777 |
+
"""
|
| 778 |
+
|
| 779 |
+
super().__init__()
|
| 780 |
+
self.config = config
|
| 781 |
+
# 'discretized' indicates whether the Residual VQ part is trained or not. (After finishing the training, we set discretized=True)
|
| 782 |
+
self.register_buffer("discretized", torch.tensor(False))
|
| 783 |
+
self.optimized_steps = 0
|
| 784 |
+
# we use the fixed number of layers for Residual VQ across all environments.
|
| 785 |
+
self.vqvae_num_layers = 2
|
| 786 |
+
|
| 787 |
+
self.vq_layer = ResidualVQ(
|
| 788 |
+
dim=config.vqvae_embedding_dim,
|
| 789 |
+
num_quantizers=self.vqvae_num_layers,
|
| 790 |
+
codebook_size=config.vqvae_n_embed,
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
self.encoder = MLP(
|
| 794 |
+
in_channels=self.config.action_feature.shape[0] * self.config.action_chunk_size,
|
| 795 |
+
hidden_channels=[
|
| 796 |
+
config.vqvae_enc_hidden_dim,
|
| 797 |
+
config.vqvae_enc_hidden_dim,
|
| 798 |
+
config.vqvae_embedding_dim,
|
| 799 |
+
],
|
| 800 |
+
)
|
| 801 |
+
self.decoder = MLP(
|
| 802 |
+
in_channels=config.vqvae_embedding_dim,
|
| 803 |
+
hidden_channels=[
|
| 804 |
+
config.vqvae_enc_hidden_dim,
|
| 805 |
+
config.vqvae_enc_hidden_dim,
|
| 806 |
+
self.config.action_feature.shape[0] * self.config.action_chunk_size,
|
| 807 |
+
],
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
def get_embeddings_from_code(self, encoding_indices):
|
| 811 |
+
# This function gets code indices as inputs, and outputs embedding vectors corresponding to the code indices.
|
| 812 |
+
with torch.no_grad():
|
| 813 |
+
z_embed = self.vq_layer.get_codebook_vector_from_indices(encoding_indices)
|
| 814 |
+
# since the RVQ has multiple layers, it adds the vectors in the axis of layers to provide a vector for that code combination.
|
| 815 |
+
z_embed = z_embed.sum(dim=0)
|
| 816 |
+
return z_embed
|
| 817 |
+
|
| 818 |
+
def get_action_from_latent(self, latent):
|
| 819 |
+
# given latent vector, this function outputs the decoded action.
|
| 820 |
+
output = self.decoder(latent)
|
| 821 |
+
if self.config.action_chunk_size == 1:
|
| 822 |
+
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
|
| 823 |
+
else:
|
| 824 |
+
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
|
| 825 |
+
|
| 826 |
+
def get_code(self, state):
|
| 827 |
+
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)
|
| 828 |
+
# this function outputs the `GT code` of given action using frozen encoder and quantization layers. (please refer to Figure 2. in the paper https://arxiv.org/pdf/2403.03181)
|
| 829 |
+
state = einops.rearrange(state, "N T A -> N (T A)")
|
| 830 |
+
with torch.no_grad():
|
| 831 |
+
state_rep = self.encoder(state)
|
| 832 |
+
state_rep_shape = state_rep.shape[:-1]
|
| 833 |
+
state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1))
|
| 834 |
+
state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat)
|
| 835 |
+
state_vq = state_rep_flat.view(*state_rep_shape, -1)
|
| 836 |
+
vq_code = vq_code.view(*state_rep_shape, -1)
|
| 837 |
+
vq_loss_state = torch.sum(vq_loss_state)
|
| 838 |
+
return state_vq, vq_code
|
| 839 |
+
|
| 840 |
+
def vqvae_forward(self, state):
|
| 841 |
+
# This function passes the given data through Residual VQ with Encoder and Decoder. Please refer to section 3.2 in the paper https://arxiv.org/pdf/2403.03181).
|
| 842 |
+
state = einops.rearrange(state, "N T A -> N (T A)")
|
| 843 |
+
# We start with passing action (or action chunk) at:t+n through the encoder ϕ.
|
| 844 |
+
state_rep = self.encoder(state)
|
| 845 |
+
state_rep_shape = state_rep.shape[:-1]
|
| 846 |
+
state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1))
|
| 847 |
+
# The resulting latent embedding vector x = ϕ(at:t+n) is then mapped to an embedding vector in the codebook of the RVQ layers by the nearest neighbor look-up.
|
| 848 |
+
state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat)
|
| 849 |
+
state_vq = state_rep_flat.view(*state_rep_shape, -1)
|
| 850 |
+
vq_code = vq_code.view(*state_rep_shape, -1)
|
| 851 |
+
# since the RVQ has multiple layers, it adds the vectors in the axis of layers to provide a vector for that code combination.
|
| 852 |
+
vq_loss_state = torch.sum(vq_loss_state)
|
| 853 |
+
# Then, the discretized vector zq(x) is reconstructed as ψ(zq(x)) by passing through the decoder ψ.
|
| 854 |
+
dec_out = self.decoder(state_vq)
|
| 855 |
+
# Calculate L1 reconstruction loss
|
| 856 |
+
encoder_loss = (state - dec_out).abs().mean()
|
| 857 |
+
# add encoder reconstruction loss and commitment loss
|
| 858 |
+
rep_loss = encoder_loss + vq_loss_state * 5
|
| 859 |
+
|
| 860 |
+
metric = (
|
| 861 |
+
encoder_loss.clone().detach(),
|
| 862 |
+
vq_loss_state.clone().detach(),
|
| 863 |
+
vq_code,
|
| 864 |
+
rep_loss.item(),
|
| 865 |
+
)
|
| 866 |
+
return rep_loss, metric
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
class FocalLoss(nn.Module):
|
| 870 |
+
"""
|
| 871 |
+
From https://github.com/notmahi/miniBET/blob/main/behavior_transformer/bet.py
|
| 872 |
+
"""
|
| 873 |
+
|
| 874 |
+
def __init__(self, gamma: float = 0, size_average: bool = True):
|
| 875 |
+
super().__init__()
|
| 876 |
+
self.gamma = gamma
|
| 877 |
+
self.size_average = size_average
|
| 878 |
+
|
| 879 |
+
def forward(self, input, target):
|
| 880 |
+
if len(input.shape) == 3:
|
| 881 |
+
N, T, _ = input.shape
|
| 882 |
+
logpt = F.log_softmax(input, dim=-1)
|
| 883 |
+
logpt = logpt.gather(-1, target.view(N, T, 1)).view(N, T)
|
| 884 |
+
elif len(input.shape) == 2:
|
| 885 |
+
logpt = F.log_softmax(input, dim=-1)
|
| 886 |
+
logpt = logpt.gather(-1, target.view(-1, 1)).view(-1)
|
| 887 |
+
pt = logpt.exp()
|
| 888 |
+
|
| 889 |
+
loss = -1 * (1 - pt) ** self.gamma * logpt
|
| 890 |
+
if self.size_average:
|
| 891 |
+
return loss.mean()
|
| 892 |
+
else:
|
| 893 |
+
return loss.sum()
|
| 894 |
+
|
| 895 |
+
|
| 896 |
+
class MLP(torch.nn.Sequential):
|
| 897 |
+
def __init__(
|
| 898 |
+
self,
|
| 899 |
+
in_channels: int,
|
| 900 |
+
hidden_channels: List[int],
|
| 901 |
+
):
|
| 902 |
+
layers = []
|
| 903 |
+
in_dim = in_channels
|
| 904 |
+
for hidden_dim in hidden_channels[:-1]:
|
| 905 |
+
layers.append(torch.nn.Linear(in_dim, hidden_dim))
|
| 906 |
+
layers.append(torch.nn.ReLU())
|
| 907 |
+
in_dim = hidden_dim
|
| 908 |
+
|
| 909 |
+
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1]))
|
| 910 |
+
|
| 911 |
+
super().__init__(*layers)
|
lerobot/common/robot_devices/cameras/configs.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import abc
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
|
| 18 |
+
import draccus
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class CameraConfig(draccus.ChoiceRegistry, abc.ABC):
|
| 23 |
+
@property
|
| 24 |
+
def type(self) -> str:
|
| 25 |
+
return self.get_choice_name(self.__class__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@CameraConfig.register_subclass("opencv")
|
| 29 |
+
@dataclass
|
| 30 |
+
class OpenCVCameraConfig(CameraConfig):
|
| 31 |
+
"""
|
| 32 |
+
Example of tested options for Intel Real Sense D405:
|
| 33 |
+
|
| 34 |
+
```python
|
| 35 |
+
OpenCVCameraConfig(0, 30, 640, 480)
|
| 36 |
+
OpenCVCameraConfig(0, 60, 640, 480)
|
| 37 |
+
OpenCVCameraConfig(0, 90, 640, 480)
|
| 38 |
+
OpenCVCameraConfig(0, 30, 1280, 720)
|
| 39 |
+
```
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
camera_index: int
|
| 43 |
+
fps: int | None = None
|
| 44 |
+
width: int | None = None
|
| 45 |
+
height: int | None = None
|
| 46 |
+
color_mode: str = "rgb"
|
| 47 |
+
channels: int | None = None
|
| 48 |
+
rotation: int | None = None
|
| 49 |
+
mock: bool = False
|
| 50 |
+
|
| 51 |
+
def __post_init__(self):
|
| 52 |
+
if self.color_mode not in ["rgb", "bgr"]:
|
| 53 |
+
raise ValueError(
|
| 54 |
+
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
self.channels = 3
|
| 58 |
+
|
| 59 |
+
if self.rotation not in [-90, None, 90, 180]:
|
| 60 |
+
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@CameraConfig.register_subclass("intelrealsense")
|
| 64 |
+
@dataclass
|
| 65 |
+
class IntelRealSenseCameraConfig(CameraConfig):
|
| 66 |
+
"""
|
| 67 |
+
Example of tested options for Intel Real Sense D405:
|
| 68 |
+
|
| 69 |
+
```python
|
| 70 |
+
IntelRealSenseCameraConfig(128422271347, 30, 640, 480)
|
| 71 |
+
IntelRealSenseCameraConfig(128422271347, 60, 640, 480)
|
| 72 |
+
IntelRealSenseCameraConfig(128422271347, 90, 640, 480)
|
| 73 |
+
IntelRealSenseCameraConfig(128422271347, 30, 1280, 720)
|
| 74 |
+
IntelRealSenseCameraConfig(128422271347, 30, 640, 480, use_depth=True)
|
| 75 |
+
IntelRealSenseCameraConfig(128422271347, 30, 640, 480, rotation=90)
|
| 76 |
+
```
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
name: str | None = None
|
| 80 |
+
serial_number: int | None = None
|
| 81 |
+
fps: int | None = None
|
| 82 |
+
width: int | None = None
|
| 83 |
+
height: int | None = None
|
| 84 |
+
color_mode: str = "rgb"
|
| 85 |
+
channels: int | None = None
|
| 86 |
+
use_depth: bool = False
|
| 87 |
+
force_hardware_reset: bool = True
|
| 88 |
+
rotation: int | None = None
|
| 89 |
+
mock: bool = False
|
| 90 |
+
|
| 91 |
+
def __post_init__(self):
|
| 92 |
+
# bool is stronger than is None, since it works with empty strings
|
| 93 |
+
if bool(self.name) and bool(self.serial_number):
|
| 94 |
+
raise ValueError(
|
| 95 |
+
f"One of them must be set: name or serial_number, but {self.name=} and {self.serial_number=} provided."
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
if self.color_mode not in ["rgb", "bgr"]:
|
| 99 |
+
raise ValueError(
|
| 100 |
+
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
self.channels = 3
|
| 104 |
+
|
| 105 |
+
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
|
| 106 |
+
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
|
| 107 |
+
if at_least_one_is_not_none and at_least_one_is_none:
|
| 108 |
+
raise ValueError(
|
| 109 |
+
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
|
| 110 |
+
f"but {self.fps=}, {self.width=}, {self.height=} were provided."
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
if self.rotation not in [-90, None, 90, 180]:
|
| 114 |
+
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
lerobot/common/robot_devices/cameras/intelrealsense.py
ADDED
|
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
This file contains utilities for recording frames from Intel Realsense cameras.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import concurrent.futures
|
| 21 |
+
import logging
|
| 22 |
+
import math
|
| 23 |
+
import shutil
|
| 24 |
+
import threading
|
| 25 |
+
import time
|
| 26 |
+
import traceback
|
| 27 |
+
from collections import Counter
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
from threading import Thread
|
| 30 |
+
|
| 31 |
+
import numpy as np
|
| 32 |
+
from PIL import Image
|
| 33 |
+
|
| 34 |
+
from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig
|
| 35 |
+
from lerobot.common.robot_devices.utils import (
|
| 36 |
+
RobotDeviceAlreadyConnectedError,
|
| 37 |
+
RobotDeviceNotConnectedError,
|
| 38 |
+
busy_wait,
|
| 39 |
+
)
|
| 40 |
+
from lerobot.common.utils.utils import capture_timestamp_utc
|
| 41 |
+
|
| 42 |
+
SERIAL_NUMBER_INDEX = 1
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def find_cameras(raise_when_empty=True, mock=False) -> list[dict]:
|
| 46 |
+
"""
|
| 47 |
+
Find the names and the serial numbers of the Intel RealSense cameras
|
| 48 |
+
connected to the computer.
|
| 49 |
+
"""
|
| 50 |
+
if mock:
|
| 51 |
+
import tests.cameras.mock_pyrealsense2 as rs
|
| 52 |
+
else:
|
| 53 |
+
import pyrealsense2 as rs
|
| 54 |
+
|
| 55 |
+
cameras = []
|
| 56 |
+
for device in rs.context().query_devices():
|
| 57 |
+
serial_number = int(device.get_info(rs.camera_info(SERIAL_NUMBER_INDEX)))
|
| 58 |
+
name = device.get_info(rs.camera_info.name)
|
| 59 |
+
cameras.append(
|
| 60 |
+
{
|
| 61 |
+
"serial_number": serial_number,
|
| 62 |
+
"name": name,
|
| 63 |
+
}
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
if raise_when_empty and len(cameras) == 0:
|
| 67 |
+
raise OSError(
|
| 68 |
+
"Not a single camera was detected. Try re-plugging, or re-installing `librealsense` and its python wrapper `pyrealsense2`, or updating the firmware."
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
return cameras
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def save_image(img_array, serial_number, frame_index, images_dir):
|
| 75 |
+
try:
|
| 76 |
+
img = Image.fromarray(img_array)
|
| 77 |
+
path = images_dir / f"camera_{serial_number}_frame_{frame_index:06d}.png"
|
| 78 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 79 |
+
img.save(str(path), quality=100)
|
| 80 |
+
logging.info(f"Saved image: {path}")
|
| 81 |
+
except Exception as e:
|
| 82 |
+
logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def save_images_from_cameras(
|
| 86 |
+
images_dir: Path,
|
| 87 |
+
serial_numbers: list[int] | None = None,
|
| 88 |
+
fps=None,
|
| 89 |
+
width=None,
|
| 90 |
+
height=None,
|
| 91 |
+
record_time_s=2,
|
| 92 |
+
mock=False,
|
| 93 |
+
):
|
| 94 |
+
"""
|
| 95 |
+
Initializes all the cameras and saves images to the directory. Useful to visually identify the camera
|
| 96 |
+
associated to a given serial number.
|
| 97 |
+
"""
|
| 98 |
+
if serial_numbers is None or len(serial_numbers) == 0:
|
| 99 |
+
camera_infos = find_cameras(mock=mock)
|
| 100 |
+
serial_numbers = [cam["serial_number"] for cam in camera_infos]
|
| 101 |
+
|
| 102 |
+
if mock:
|
| 103 |
+
import tests.cameras.mock_cv2 as cv2
|
| 104 |
+
else:
|
| 105 |
+
import cv2
|
| 106 |
+
|
| 107 |
+
print("Connecting cameras")
|
| 108 |
+
cameras = []
|
| 109 |
+
for cam_sn in serial_numbers:
|
| 110 |
+
print(f"{cam_sn=}")
|
| 111 |
+
config = IntelRealSenseCameraConfig(
|
| 112 |
+
serial_number=cam_sn, fps=fps, width=width, height=height, mock=mock
|
| 113 |
+
)
|
| 114 |
+
camera = IntelRealSenseCamera(config)
|
| 115 |
+
camera.connect()
|
| 116 |
+
print(
|
| 117 |
+
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.capture_width}, height={camera.capture_height}, color_mode={camera.color_mode})"
|
| 118 |
+
)
|
| 119 |
+
cameras.append(camera)
|
| 120 |
+
|
| 121 |
+
images_dir = Path(images_dir)
|
| 122 |
+
if images_dir.exists():
|
| 123 |
+
shutil.rmtree(
|
| 124 |
+
images_dir,
|
| 125 |
+
)
|
| 126 |
+
images_dir.mkdir(parents=True, exist_ok=True)
|
| 127 |
+
|
| 128 |
+
print(f"Saving images to {images_dir}")
|
| 129 |
+
frame_index = 0
|
| 130 |
+
start_time = time.perf_counter()
|
| 131 |
+
try:
|
| 132 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
| 133 |
+
while True:
|
| 134 |
+
now = time.perf_counter()
|
| 135 |
+
|
| 136 |
+
for camera in cameras:
|
| 137 |
+
# If we use async_read when fps is None, the loop will go full speed, and we will end up
|
| 138 |
+
# saving the same images from the cameras multiple times until the RAM/disk is full.
|
| 139 |
+
image = camera.read() if fps is None else camera.async_read()
|
| 140 |
+
if image is None:
|
| 141 |
+
print("No Frame")
|
| 142 |
+
|
| 143 |
+
bgr_converted_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 144 |
+
|
| 145 |
+
executor.submit(
|
| 146 |
+
save_image,
|
| 147 |
+
bgr_converted_image,
|
| 148 |
+
camera.serial_number,
|
| 149 |
+
frame_index,
|
| 150 |
+
images_dir,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
if fps is not None:
|
| 154 |
+
dt_s = time.perf_counter() - now
|
| 155 |
+
busy_wait(1 / fps - dt_s)
|
| 156 |
+
|
| 157 |
+
if time.perf_counter() - start_time > record_time_s:
|
| 158 |
+
break
|
| 159 |
+
|
| 160 |
+
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
|
| 161 |
+
|
| 162 |
+
frame_index += 1
|
| 163 |
+
finally:
|
| 164 |
+
print(f"Images have been saved to {images_dir}")
|
| 165 |
+
for camera in cameras:
|
| 166 |
+
camera.disconnect()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class IntelRealSenseCamera:
|
| 170 |
+
"""
|
| 171 |
+
The IntelRealSenseCamera class is similar to OpenCVCamera class but adds additional features for Intel Real Sense cameras:
|
| 172 |
+
- is instantiated with the serial number of the camera - won't randomly change as it can be the case of OpenCVCamera for Linux,
|
| 173 |
+
- can also be instantiated with the camera's name — if it's unique — using IntelRealSenseCamera.init_from_name(),
|
| 174 |
+
- depth map can be returned.
|
| 175 |
+
|
| 176 |
+
To find the camera indices of your cameras, you can run our utility script that will save a few frames for each camera:
|
| 177 |
+
```bash
|
| 178 |
+
python lerobot/common/robot_devices/cameras/intelrealsense.py --images-dir outputs/images_from_intelrealsense_cameras
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
When an IntelRealSenseCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
|
| 182 |
+
of the given camera will be used.
|
| 183 |
+
|
| 184 |
+
Example of instantiating with a serial number:
|
| 185 |
+
```python
|
| 186 |
+
from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig
|
| 187 |
+
|
| 188 |
+
config = IntelRealSenseCameraConfig(serial_number=128422271347)
|
| 189 |
+
camera = IntelRealSenseCamera(config)
|
| 190 |
+
camera.connect()
|
| 191 |
+
color_image = camera.read()
|
| 192 |
+
# when done using the camera, consider disconnecting
|
| 193 |
+
camera.disconnect()
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
Example of instantiating with a name if it's unique:
|
| 197 |
+
```
|
| 198 |
+
config = IntelRealSenseCameraConfig(name="Intel RealSense D405")
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
Example of changing default fps, width, height and color_mode:
|
| 202 |
+
```python
|
| 203 |
+
config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=30, width=1280, height=720)
|
| 204 |
+
config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480)
|
| 205 |
+
config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480, color_mode="bgr")
|
| 206 |
+
# Note: might error out upon `camera.connect()` if these settings are not compatible with the camera
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
Example of returning depth:
|
| 210 |
+
```python
|
| 211 |
+
config = IntelRealSenseCameraConfig(serial_number=128422271347, use_depth=True)
|
| 212 |
+
camera = IntelRealSenseCamera(config)
|
| 213 |
+
camera.connect()
|
| 214 |
+
color_image, depth_map = camera.read()
|
| 215 |
+
```
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
def __init__(
|
| 219 |
+
self,
|
| 220 |
+
config: IntelRealSenseCameraConfig,
|
| 221 |
+
):
|
| 222 |
+
self.config = config
|
| 223 |
+
if config.name is not None:
|
| 224 |
+
self.serial_number = self.find_serial_number_from_name(config.name)
|
| 225 |
+
else:
|
| 226 |
+
self.serial_number = config.serial_number
|
| 227 |
+
|
| 228 |
+
# Store the raw (capture) resolution from the config.
|
| 229 |
+
self.capture_width = config.width
|
| 230 |
+
self.capture_height = config.height
|
| 231 |
+
|
| 232 |
+
# If rotated by ±90, swap width and height.
|
| 233 |
+
if config.rotation in [-90, 90]:
|
| 234 |
+
self.width = config.height
|
| 235 |
+
self.height = config.width
|
| 236 |
+
else:
|
| 237 |
+
self.width = config.width
|
| 238 |
+
self.height = config.height
|
| 239 |
+
|
| 240 |
+
self.fps = config.fps
|
| 241 |
+
self.channels = config.channels
|
| 242 |
+
self.color_mode = config.color_mode
|
| 243 |
+
self.use_depth = config.use_depth
|
| 244 |
+
self.force_hardware_reset = config.force_hardware_reset
|
| 245 |
+
self.mock = config.mock
|
| 246 |
+
|
| 247 |
+
self.camera = None
|
| 248 |
+
self.is_connected = False
|
| 249 |
+
self.thread = None
|
| 250 |
+
self.stop_event = None
|
| 251 |
+
self.color_image = None
|
| 252 |
+
self.depth_map = None
|
| 253 |
+
self.logs = {}
|
| 254 |
+
|
| 255 |
+
if self.mock:
|
| 256 |
+
import tests.cameras.mock_cv2 as cv2
|
| 257 |
+
else:
|
| 258 |
+
import cv2
|
| 259 |
+
|
| 260 |
+
self.rotation = None
|
| 261 |
+
if config.rotation == -90:
|
| 262 |
+
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
|
| 263 |
+
elif config.rotation == 90:
|
| 264 |
+
self.rotation = cv2.ROTATE_90_CLOCKWISE
|
| 265 |
+
elif config.rotation == 180:
|
| 266 |
+
self.rotation = cv2.ROTATE_180
|
| 267 |
+
|
| 268 |
+
def find_serial_number_from_name(self, name):
|
| 269 |
+
camera_infos = find_cameras()
|
| 270 |
+
camera_names = [cam["name"] for cam in camera_infos]
|
| 271 |
+
this_name_count = Counter(camera_names)[name]
|
| 272 |
+
if this_name_count > 1:
|
| 273 |
+
# TODO(aliberts): Test this with multiple identical cameras (Aloha)
|
| 274 |
+
raise ValueError(
|
| 275 |
+
f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them."
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos}
|
| 279 |
+
cam_sn = name_to_serial_dict[name]
|
| 280 |
+
|
| 281 |
+
return cam_sn
|
| 282 |
+
|
| 283 |
+
def connect(self):
|
| 284 |
+
if self.is_connected:
|
| 285 |
+
raise RobotDeviceAlreadyConnectedError(
|
| 286 |
+
f"IntelRealSenseCamera({self.serial_number}) is already connected."
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
if self.mock:
|
| 290 |
+
import tests.cameras.mock_pyrealsense2 as rs
|
| 291 |
+
else:
|
| 292 |
+
import pyrealsense2 as rs
|
| 293 |
+
|
| 294 |
+
config = rs.config()
|
| 295 |
+
config.enable_device(str(self.serial_number))
|
| 296 |
+
|
| 297 |
+
if self.fps and self.capture_width and self.capture_height:
|
| 298 |
+
# TODO(rcadene): can we set rgb8 directly?
|
| 299 |
+
config.enable_stream(
|
| 300 |
+
rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps
|
| 301 |
+
)
|
| 302 |
+
else:
|
| 303 |
+
config.enable_stream(rs.stream.color)
|
| 304 |
+
|
| 305 |
+
if self.use_depth:
|
| 306 |
+
if self.fps and self.capture_width and self.capture_height:
|
| 307 |
+
config.enable_stream(
|
| 308 |
+
rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps
|
| 309 |
+
)
|
| 310 |
+
else:
|
| 311 |
+
config.enable_stream(rs.stream.depth)
|
| 312 |
+
|
| 313 |
+
self.camera = rs.pipeline()
|
| 314 |
+
try:
|
| 315 |
+
profile = self.camera.start(config)
|
| 316 |
+
is_camera_open = True
|
| 317 |
+
except RuntimeError:
|
| 318 |
+
is_camera_open = False
|
| 319 |
+
traceback.print_exc()
|
| 320 |
+
|
| 321 |
+
# If the camera doesn't work, display the camera indices corresponding to
|
| 322 |
+
# valid cameras.
|
| 323 |
+
if not is_camera_open:
|
| 324 |
+
# Verify that the provided `serial_number` is valid before printing the traceback
|
| 325 |
+
camera_infos = find_cameras()
|
| 326 |
+
serial_numbers = [cam["serial_number"] for cam in camera_infos]
|
| 327 |
+
if self.serial_number not in serial_numbers:
|
| 328 |
+
raise ValueError(
|
| 329 |
+
f"`serial_number` is expected to be one of these available cameras {serial_numbers}, but {self.serial_number} is provided instead. "
|
| 330 |
+
"To find the serial number you should use, run `python lerobot/common/robot_devices/cameras/intelrealsense.py`."
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
raise OSError(f"Can't access IntelRealSenseCamera({self.serial_number}).")
|
| 334 |
+
|
| 335 |
+
color_stream = profile.get_stream(rs.stream.color)
|
| 336 |
+
color_profile = color_stream.as_video_stream_profile()
|
| 337 |
+
actual_fps = color_profile.fps()
|
| 338 |
+
actual_width = color_profile.width()
|
| 339 |
+
actual_height = color_profile.height()
|
| 340 |
+
|
| 341 |
+
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
|
| 342 |
+
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
| 343 |
+
# Using `OSError` since it's a broad that encompasses issues related to device communication
|
| 344 |
+
raise OSError(
|
| 345 |
+
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
|
| 346 |
+
)
|
| 347 |
+
if self.capture_width is not None and self.capture_width != actual_width:
|
| 348 |
+
raise OSError(
|
| 349 |
+
f"Can't set {self.capture_width=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_width}."
|
| 350 |
+
)
|
| 351 |
+
if self.capture_height is not None and self.capture_height != actual_height:
|
| 352 |
+
raise OSError(
|
| 353 |
+
f"Can't set {self.capture_height=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_height}."
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
self.fps = round(actual_fps)
|
| 357 |
+
self.capture_width = round(actual_width)
|
| 358 |
+
self.capture_height = round(actual_height)
|
| 359 |
+
|
| 360 |
+
self.is_connected = True
|
| 361 |
+
|
| 362 |
+
def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
|
| 363 |
+
"""Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3)
|
| 364 |
+
of type `np.uint8`, contrarily to the pytorch format which is float channel first.
|
| 365 |
+
|
| 366 |
+
When `use_depth=True`, returns a tuple `(color_image, depth_map)` with a depth map in the format
|
| 367 |
+
height x width (e.g. 480 x 640) of type np.uint16.
|
| 368 |
+
|
| 369 |
+
Note: Reading a frame is done every `camera.fps` times per second, and it is blocking.
|
| 370 |
+
If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`.
|
| 371 |
+
"""
|
| 372 |
+
if not self.is_connected:
|
| 373 |
+
raise RobotDeviceNotConnectedError(
|
| 374 |
+
f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
if self.mock:
|
| 378 |
+
import tests.cameras.mock_cv2 as cv2
|
| 379 |
+
else:
|
| 380 |
+
import cv2
|
| 381 |
+
|
| 382 |
+
start_time = time.perf_counter()
|
| 383 |
+
|
| 384 |
+
frame = self.camera.wait_for_frames(timeout_ms=5000)
|
| 385 |
+
|
| 386 |
+
color_frame = frame.get_color_frame()
|
| 387 |
+
|
| 388 |
+
if not color_frame:
|
| 389 |
+
raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).")
|
| 390 |
+
|
| 391 |
+
color_image = np.asanyarray(color_frame.get_data())
|
| 392 |
+
|
| 393 |
+
requested_color_mode = self.color_mode if temporary_color is None else temporary_color
|
| 394 |
+
if requested_color_mode not in ["rgb", "bgr"]:
|
| 395 |
+
raise ValueError(
|
| 396 |
+
f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided."
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
# IntelRealSense uses RGB format as default (red, green, blue).
|
| 400 |
+
if requested_color_mode == "bgr":
|
| 401 |
+
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
|
| 402 |
+
|
| 403 |
+
h, w, _ = color_image.shape
|
| 404 |
+
if h != self.capture_height or w != self.capture_width:
|
| 405 |
+
raise OSError(
|
| 406 |
+
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
if self.rotation is not None:
|
| 410 |
+
color_image = cv2.rotate(color_image, self.rotation)
|
| 411 |
+
|
| 412 |
+
# log the number of seconds it took to read the image
|
| 413 |
+
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
|
| 414 |
+
|
| 415 |
+
# log the utc time at which the image was received
|
| 416 |
+
self.logs["timestamp_utc"] = capture_timestamp_utc()
|
| 417 |
+
|
| 418 |
+
if self.use_depth:
|
| 419 |
+
depth_frame = frame.get_depth_frame()
|
| 420 |
+
if not depth_frame:
|
| 421 |
+
raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).")
|
| 422 |
+
|
| 423 |
+
depth_map = np.asanyarray(depth_frame.get_data())
|
| 424 |
+
|
| 425 |
+
h, w = depth_map.shape
|
| 426 |
+
if h != self.capture_height or w != self.capture_width:
|
| 427 |
+
raise OSError(
|
| 428 |
+
f"Can't capture depth map with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
if self.rotation is not None:
|
| 432 |
+
depth_map = cv2.rotate(depth_map, self.rotation)
|
| 433 |
+
|
| 434 |
+
return color_image, depth_map
|
| 435 |
+
else:
|
| 436 |
+
return color_image
|
| 437 |
+
|
| 438 |
+
def read_loop(self):
|
| 439 |
+
while not self.stop_event.is_set():
|
| 440 |
+
if self.use_depth:
|
| 441 |
+
self.color_image, self.depth_map = self.read()
|
| 442 |
+
else:
|
| 443 |
+
self.color_image = self.read()
|
| 444 |
+
|
| 445 |
+
def async_read(self):
|
| 446 |
+
"""Access the latest color image"""
|
| 447 |
+
if not self.is_connected:
|
| 448 |
+
raise RobotDeviceNotConnectedError(
|
| 449 |
+
f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
if self.thread is None:
|
| 453 |
+
self.stop_event = threading.Event()
|
| 454 |
+
self.thread = Thread(target=self.read_loop, args=())
|
| 455 |
+
self.thread.daemon = True
|
| 456 |
+
self.thread.start()
|
| 457 |
+
|
| 458 |
+
num_tries = 0
|
| 459 |
+
while self.color_image is None:
|
| 460 |
+
# TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here
|
| 461 |
+
num_tries += 1
|
| 462 |
+
time.sleep(1 / self.fps)
|
| 463 |
+
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
|
| 464 |
+
raise Exception(
|
| 465 |
+
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
if self.use_depth:
|
| 469 |
+
return self.color_image, self.depth_map
|
| 470 |
+
else:
|
| 471 |
+
return self.color_image
|
| 472 |
+
|
| 473 |
+
def disconnect(self):
|
| 474 |
+
if not self.is_connected:
|
| 475 |
+
raise RobotDeviceNotConnectedError(
|
| 476 |
+
f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
if self.thread is not None and self.thread.is_alive():
|
| 480 |
+
# wait for the thread to finish
|
| 481 |
+
self.stop_event.set()
|
| 482 |
+
self.thread.join()
|
| 483 |
+
self.thread = None
|
| 484 |
+
self.stop_event = None
|
| 485 |
+
|
| 486 |
+
self.camera.stop()
|
| 487 |
+
self.camera = None
|
| 488 |
+
|
| 489 |
+
self.is_connected = False
|
| 490 |
+
|
| 491 |
+
def __del__(self):
|
| 492 |
+
if getattr(self, "is_connected", False):
|
| 493 |
+
self.disconnect()
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
if __name__ == "__main__":
|
| 497 |
+
parser = argparse.ArgumentParser(
|
| 498 |
+
description="Save a few frames using `IntelRealSenseCamera` for all cameras connected to the computer, or a selected subset."
|
| 499 |
+
)
|
| 500 |
+
parser.add_argument(
|
| 501 |
+
"--serial-numbers",
|
| 502 |
+
type=int,
|
| 503 |
+
nargs="*",
|
| 504 |
+
default=None,
|
| 505 |
+
help="List of serial numbers used to instantiate the `IntelRealSenseCamera`. If not provided, find and use all available camera indices.",
|
| 506 |
+
)
|
| 507 |
+
parser.add_argument(
|
| 508 |
+
"--fps",
|
| 509 |
+
type=int,
|
| 510 |
+
default=30,
|
| 511 |
+
help="Set the number of frames recorded per seconds for all cameras. If not provided, use the default fps of each camera.",
|
| 512 |
+
)
|
| 513 |
+
parser.add_argument(
|
| 514 |
+
"--width",
|
| 515 |
+
type=str,
|
| 516 |
+
default=640,
|
| 517 |
+
help="Set the width for all cameras. If not provided, use the default width of each camera.",
|
| 518 |
+
)
|
| 519 |
+
parser.add_argument(
|
| 520 |
+
"--height",
|
| 521 |
+
type=str,
|
| 522 |
+
default=480,
|
| 523 |
+
help="Set the height for all cameras. If not provided, use the default height of each camera.",
|
| 524 |
+
)
|
| 525 |
+
parser.add_argument(
|
| 526 |
+
"--images-dir",
|
| 527 |
+
type=Path,
|
| 528 |
+
default="outputs/images_from_intelrealsense_cameras",
|
| 529 |
+
help="Set directory to save a few frames for each camera.",
|
| 530 |
+
)
|
| 531 |
+
parser.add_argument(
|
| 532 |
+
"--record-time-s",
|
| 533 |
+
type=float,
|
| 534 |
+
default=2.0,
|
| 535 |
+
help="Set the number of seconds used to record the frames. By default, 2 seconds.",
|
| 536 |
+
)
|
| 537 |
+
args = parser.parse_args()
|
| 538 |
+
save_images_from_cameras(**vars(args))
|
lerobot/common/robot_devices/cameras/opencv.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import concurrent.futures
|
| 21 |
+
import math
|
| 22 |
+
import platform
|
| 23 |
+
import shutil
|
| 24 |
+
import threading
|
| 25 |
+
import time
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from threading import Thread
|
| 28 |
+
|
| 29 |
+
import numpy as np
|
| 30 |
+
from PIL import Image
|
| 31 |
+
|
| 32 |
+
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
|
| 33 |
+
from lerobot.common.robot_devices.utils import (
|
| 34 |
+
RobotDeviceAlreadyConnectedError,
|
| 35 |
+
RobotDeviceNotConnectedError,
|
| 36 |
+
busy_wait,
|
| 37 |
+
)
|
| 38 |
+
from lerobot.common.utils.utils import capture_timestamp_utc
|
| 39 |
+
|
| 40 |
+
# The maximum opencv device index depends on your operating system. For instance,
|
| 41 |
+
# if you have 3 cameras, they should be associated to index 0, 1, and 2. This is the case
|
| 42 |
+
# on MacOS. However, on Ubuntu, the indices are different like 6, 16, 23.
|
| 43 |
+
# When you change the USB port or reboot the computer, the operating system might
|
| 44 |
+
# treat the same cameras as new devices. Thus we select a higher bound to search indices.
|
| 45 |
+
MAX_OPENCV_INDEX = 60
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
|
| 49 |
+
cameras = []
|
| 50 |
+
if platform.system() == "Linux":
|
| 51 |
+
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
|
| 52 |
+
possible_ports = [str(port) for port in Path("/dev").glob("video*")]
|
| 53 |
+
ports = _find_cameras(possible_ports, mock=mock)
|
| 54 |
+
for port in ports:
|
| 55 |
+
cameras.append(
|
| 56 |
+
{
|
| 57 |
+
"port": port,
|
| 58 |
+
"index": int(port.removeprefix("/dev/video")),
|
| 59 |
+
}
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
print(
|
| 63 |
+
"Mac or Windows detected. Finding available camera indices through "
|
| 64 |
+
f"scanning all indices from 0 to {MAX_OPENCV_INDEX}"
|
| 65 |
+
)
|
| 66 |
+
possible_indices = range(max_index_search_range)
|
| 67 |
+
indices = _find_cameras(possible_indices, mock=mock)
|
| 68 |
+
for index in indices:
|
| 69 |
+
cameras.append(
|
| 70 |
+
{
|
| 71 |
+
"port": None,
|
| 72 |
+
"index": index,
|
| 73 |
+
}
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
return cameras
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _find_cameras(
|
| 80 |
+
possible_camera_ids: list[int | str], raise_when_empty=False, mock=False
|
| 81 |
+
) -> list[int | str]:
|
| 82 |
+
if mock:
|
| 83 |
+
import tests.cameras.mock_cv2 as cv2
|
| 84 |
+
else:
|
| 85 |
+
import cv2
|
| 86 |
+
|
| 87 |
+
camera_ids = []
|
| 88 |
+
for camera_idx in possible_camera_ids:
|
| 89 |
+
camera = cv2.VideoCapture(camera_idx)
|
| 90 |
+
is_open = camera.isOpened()
|
| 91 |
+
camera.release()
|
| 92 |
+
|
| 93 |
+
if is_open:
|
| 94 |
+
print(f"Camera found at index {camera_idx}")
|
| 95 |
+
camera_ids.append(camera_idx)
|
| 96 |
+
|
| 97 |
+
if raise_when_empty and len(camera_ids) == 0:
|
| 98 |
+
raise OSError(
|
| 99 |
+
"Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, "
|
| 100 |
+
"or your camera driver, or make sure your camera is compatible with opencv2."
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
return camera_ids
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def is_valid_unix_path(path: str) -> bool:
|
| 107 |
+
"""Note: if 'path' points to a symlink, this will return True only if the target exists"""
|
| 108 |
+
p = Path(path)
|
| 109 |
+
return p.is_absolute() and p.exists()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def get_camera_index_from_unix_port(port: Path) -> int:
|
| 113 |
+
return int(str(port.resolve()).removeprefix("/dev/video"))
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def save_image(img_array, camera_index, frame_index, images_dir):
|
| 117 |
+
img = Image.fromarray(img_array)
|
| 118 |
+
path = images_dir / f"camera_{camera_index:02d}_frame_{frame_index:06d}.png"
|
| 119 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 120 |
+
img.save(str(path), quality=100)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def save_images_from_cameras(
|
| 124 |
+
images_dir: Path,
|
| 125 |
+
camera_ids: list | None = None,
|
| 126 |
+
fps=None,
|
| 127 |
+
width=None,
|
| 128 |
+
height=None,
|
| 129 |
+
record_time_s=2,
|
| 130 |
+
mock=False,
|
| 131 |
+
):
|
| 132 |
+
"""
|
| 133 |
+
Initializes all the cameras and saves images to the directory. Useful to visually identify the camera
|
| 134 |
+
associated to a given camera index.
|
| 135 |
+
"""
|
| 136 |
+
if camera_ids is None or len(camera_ids) == 0:
|
| 137 |
+
camera_infos = find_cameras(mock=mock)
|
| 138 |
+
camera_ids = [cam["index"] for cam in camera_infos]
|
| 139 |
+
|
| 140 |
+
print("Connecting cameras")
|
| 141 |
+
cameras = []
|
| 142 |
+
for cam_idx in camera_ids:
|
| 143 |
+
config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock)
|
| 144 |
+
camera = OpenCVCamera(config)
|
| 145 |
+
camera.connect()
|
| 146 |
+
print(
|
| 147 |
+
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.capture_width}, "
|
| 148 |
+
f"height={camera.capture_height}, color_mode={camera.color_mode})"
|
| 149 |
+
)
|
| 150 |
+
cameras.append(camera)
|
| 151 |
+
|
| 152 |
+
images_dir = Path(images_dir)
|
| 153 |
+
if images_dir.exists():
|
| 154 |
+
shutil.rmtree(
|
| 155 |
+
images_dir,
|
| 156 |
+
)
|
| 157 |
+
images_dir.mkdir(parents=True, exist_ok=True)
|
| 158 |
+
|
| 159 |
+
print(f"Saving images to {images_dir}")
|
| 160 |
+
frame_index = 0
|
| 161 |
+
start_time = time.perf_counter()
|
| 162 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
| 163 |
+
while True:
|
| 164 |
+
now = time.perf_counter()
|
| 165 |
+
|
| 166 |
+
for camera in cameras:
|
| 167 |
+
# If we use async_read when fps is None, the loop will go full speed, and we will endup
|
| 168 |
+
# saving the same images from the cameras multiple times until the RAM/disk is full.
|
| 169 |
+
image = camera.read() if fps is None else camera.async_read()
|
| 170 |
+
|
| 171 |
+
executor.submit(
|
| 172 |
+
save_image,
|
| 173 |
+
image,
|
| 174 |
+
camera.camera_index,
|
| 175 |
+
frame_index,
|
| 176 |
+
images_dir,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
if fps is not None:
|
| 180 |
+
dt_s = time.perf_counter() - now
|
| 181 |
+
busy_wait(1 / fps - dt_s)
|
| 182 |
+
|
| 183 |
+
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
|
| 184 |
+
|
| 185 |
+
if time.perf_counter() - start_time > record_time_s:
|
| 186 |
+
break
|
| 187 |
+
|
| 188 |
+
frame_index += 1
|
| 189 |
+
|
| 190 |
+
print(f"Images have been saved to {images_dir}")
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class OpenCVCamera:
|
| 194 |
+
"""
|
| 195 |
+
The OpenCVCamera class allows to efficiently record images from cameras. It relies on opencv2 to communicate
|
| 196 |
+
with the cameras. Most cameras are compatible. For more info, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
|
| 197 |
+
|
| 198 |
+
An OpenCVCamera instance requires a camera index (e.g. `OpenCVCamera(camera_index=0)`). When you only have one camera
|
| 199 |
+
like a webcam of a laptop, the camera index is expected to be 0, but it might also be very different, and the camera index
|
| 200 |
+
might change if you reboot your computer or re-plug your camera. This behavior depends on your operation system.
|
| 201 |
+
|
| 202 |
+
To find the camera indices of your cameras, you can run our utility script that will be save a few frames for each camera:
|
| 203 |
+
```bash
|
| 204 |
+
python lerobot/common/robot_devices/cameras/opencv.py --images-dir outputs/images_from_opencv_cameras
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
When an OpenCVCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
|
| 208 |
+
of the given camera will be used.
|
| 209 |
+
|
| 210 |
+
Example of usage:
|
| 211 |
+
```python
|
| 212 |
+
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
|
| 213 |
+
|
| 214 |
+
config = OpenCVCameraConfig(camera_index=0)
|
| 215 |
+
camera = OpenCVCamera(config)
|
| 216 |
+
camera.connect()
|
| 217 |
+
color_image = camera.read()
|
| 218 |
+
# when done using the camera, consider disconnecting
|
| 219 |
+
camera.disconnect()
|
| 220 |
+
```
|
| 221 |
+
|
| 222 |
+
Example of changing default fps, width, height and color_mode:
|
| 223 |
+
```python
|
| 224 |
+
config = OpenCVCameraConfig(camera_index=0, fps=30, width=1280, height=720)
|
| 225 |
+
config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480)
|
| 226 |
+
config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480, color_mode="bgr")
|
| 227 |
+
# Note: might error out open `camera.connect()` if these settings are not compatible with the camera
|
| 228 |
+
```
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
def __init__(self, config: OpenCVCameraConfig):
|
| 232 |
+
self.config = config
|
| 233 |
+
self.camera_index = config.camera_index
|
| 234 |
+
self.port = None
|
| 235 |
+
|
| 236 |
+
# Linux uses ports for connecting to cameras
|
| 237 |
+
if platform.system() == "Linux":
|
| 238 |
+
if isinstance(self.camera_index, int):
|
| 239 |
+
self.port = Path(f"/dev/video{self.camera_index}")
|
| 240 |
+
elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index):
|
| 241 |
+
self.port = Path(self.camera_index)
|
| 242 |
+
# Retrieve the camera index from a potentially symlinked path
|
| 243 |
+
self.camera_index = get_camera_index_from_unix_port(self.port)
|
| 244 |
+
else:
|
| 245 |
+
raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
|
| 246 |
+
|
| 247 |
+
# Store the raw (capture) resolution from the config.
|
| 248 |
+
self.capture_width = config.width
|
| 249 |
+
self.capture_height = config.height
|
| 250 |
+
|
| 251 |
+
# If rotated by ±90, swap width and height.
|
| 252 |
+
if config.rotation in [-90, 90]:
|
| 253 |
+
self.width = config.height
|
| 254 |
+
self.height = config.width
|
| 255 |
+
else:
|
| 256 |
+
self.width = config.width
|
| 257 |
+
self.height = config.height
|
| 258 |
+
|
| 259 |
+
self.fps = config.fps
|
| 260 |
+
self.channels = config.channels
|
| 261 |
+
self.color_mode = config.color_mode
|
| 262 |
+
self.mock = config.mock
|
| 263 |
+
|
| 264 |
+
self.camera = None
|
| 265 |
+
self.is_connected = False
|
| 266 |
+
self.thread = None
|
| 267 |
+
self.stop_event = None
|
| 268 |
+
self.color_image = None
|
| 269 |
+
self.logs = {}
|
| 270 |
+
|
| 271 |
+
if self.mock:
|
| 272 |
+
import tests.cameras.mock_cv2 as cv2
|
| 273 |
+
else:
|
| 274 |
+
import cv2
|
| 275 |
+
|
| 276 |
+
self.rotation = None
|
| 277 |
+
if config.rotation == -90:
|
| 278 |
+
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
|
| 279 |
+
elif config.rotation == 90:
|
| 280 |
+
self.rotation = cv2.ROTATE_90_CLOCKWISE
|
| 281 |
+
elif config.rotation == 180:
|
| 282 |
+
self.rotation = cv2.ROTATE_180
|
| 283 |
+
|
| 284 |
+
def connect(self):
|
| 285 |
+
if self.is_connected:
|
| 286 |
+
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
|
| 287 |
+
|
| 288 |
+
if self.mock:
|
| 289 |
+
import tests.cameras.mock_cv2 as cv2
|
| 290 |
+
else:
|
| 291 |
+
import cv2
|
| 292 |
+
|
| 293 |
+
# Use 1 thread to avoid blocking the main thread. Especially useful during data collection
|
| 294 |
+
# when other threads are used to save the images.
|
| 295 |
+
cv2.setNumThreads(1)
|
| 296 |
+
|
| 297 |
+
backend = (
|
| 298 |
+
cv2.CAP_V4L2
|
| 299 |
+
if platform.system() == "Linux"
|
| 300 |
+
else cv2.CAP_DSHOW
|
| 301 |
+
if platform.system() == "Windows"
|
| 302 |
+
else cv2.CAP_AVFOUNDATION
|
| 303 |
+
if platform.system() == "Darwin"
|
| 304 |
+
else cv2.CAP_ANY
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
|
| 308 |
+
# First create a temporary camera trying to access `camera_index`,
|
| 309 |
+
# and verify it is a valid camera by calling `isOpened`.
|
| 310 |
+
tmp_camera = cv2.VideoCapture(camera_idx, backend)
|
| 311 |
+
is_camera_open = tmp_camera.isOpened()
|
| 312 |
+
# Release camera to make it accessible for `find_camera_indices`
|
| 313 |
+
tmp_camera.release()
|
| 314 |
+
del tmp_camera
|
| 315 |
+
|
| 316 |
+
# If the camera doesn't work, display the camera indices corresponding to
|
| 317 |
+
# valid cameras.
|
| 318 |
+
if not is_camera_open:
|
| 319 |
+
# Verify that the provided `camera_index` is valid before printing the traceback
|
| 320 |
+
cameras_info = find_cameras()
|
| 321 |
+
available_cam_ids = [cam["index"] for cam in cameras_info]
|
| 322 |
+
if self.camera_index not in available_cam_ids:
|
| 323 |
+
raise ValueError(
|
| 324 |
+
f"`camera_index` is expected to be one of these available cameras {available_cam_ids}, but {self.camera_index} is provided instead. "
|
| 325 |
+
"To find the camera index you should use, run `python lerobot/common/robot_devices/cameras/opencv.py`."
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
raise OSError(f"Can't access OpenCVCamera({camera_idx}).")
|
| 329 |
+
|
| 330 |
+
# Secondly, create the camera that will be used downstream.
|
| 331 |
+
# Note: For some unknown reason, calling `isOpened` blocks the camera which then
|
| 332 |
+
# needs to be re-created.
|
| 333 |
+
self.camera = cv2.VideoCapture(camera_idx, backend)
|
| 334 |
+
|
| 335 |
+
if self.fps is not None:
|
| 336 |
+
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
|
| 337 |
+
if self.capture_width is not None:
|
| 338 |
+
self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.capture_width)
|
| 339 |
+
if self.capture_height is not None:
|
| 340 |
+
self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.capture_height)
|
| 341 |
+
|
| 342 |
+
actual_fps = self.camera.get(cv2.CAP_PROP_FPS)
|
| 343 |
+
actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH)
|
| 344 |
+
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
| 345 |
+
|
| 346 |
+
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
|
| 347 |
+
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
| 348 |
+
# Using `OSError` since it's a broad that encompasses issues related to device communication
|
| 349 |
+
raise OSError(
|
| 350 |
+
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
|
| 351 |
+
)
|
| 352 |
+
if self.capture_width is not None and not math.isclose(
|
| 353 |
+
self.capture_width, actual_width, rel_tol=1e-3
|
| 354 |
+
):
|
| 355 |
+
raise OSError(
|
| 356 |
+
f"Can't set {self.capture_width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
|
| 357 |
+
)
|
| 358 |
+
if self.capture_height is not None and not math.isclose(
|
| 359 |
+
self.capture_height, actual_height, rel_tol=1e-3
|
| 360 |
+
):
|
| 361 |
+
raise OSError(
|
| 362 |
+
f"Can't set {self.capture_height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
self.fps = round(actual_fps)
|
| 366 |
+
self.capture_width = round(actual_width)
|
| 367 |
+
self.capture_height = round(actual_height)
|
| 368 |
+
self.is_connected = True
|
| 369 |
+
|
| 370 |
+
def read(self, temporary_color_mode: str | None = None) -> np.ndarray:
|
| 371 |
+
"""Read a frame from the camera returned in the format (height, width, channels)
|
| 372 |
+
(e.g. 480 x 640 x 3), contrarily to the pytorch format which is channel first.
|
| 373 |
+
|
| 374 |
+
Note: Reading a frame is done every `camera.fps` times per second, and it is blocking.
|
| 375 |
+
If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`.
|
| 376 |
+
"""
|
| 377 |
+
if not self.is_connected:
|
| 378 |
+
raise RobotDeviceNotConnectedError(
|
| 379 |
+
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
start_time = time.perf_counter()
|
| 383 |
+
|
| 384 |
+
ret, color_image = self.camera.read()
|
| 385 |
+
|
| 386 |
+
if not ret:
|
| 387 |
+
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
|
| 388 |
+
|
| 389 |
+
requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode
|
| 390 |
+
|
| 391 |
+
if requested_color_mode not in ["rgb", "bgr"]:
|
| 392 |
+
raise ValueError(
|
| 393 |
+
f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided."
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
# OpenCV uses BGR format as default (blue, green, red) for all operations, including displaying images.
|
| 397 |
+
# However, Deep Learning framework such as LeRobot uses RGB format as default to train neural networks,
|
| 398 |
+
# so we convert the image color from BGR to RGB.
|
| 399 |
+
if requested_color_mode == "rgb":
|
| 400 |
+
if self.mock:
|
| 401 |
+
import tests.cameras.mock_cv2 as cv2
|
| 402 |
+
else:
|
| 403 |
+
import cv2
|
| 404 |
+
|
| 405 |
+
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
|
| 406 |
+
|
| 407 |
+
h, w, _ = color_image.shape
|
| 408 |
+
if h != self.capture_height or w != self.capture_width:
|
| 409 |
+
raise OSError(
|
| 410 |
+
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
if self.rotation is not None:
|
| 414 |
+
color_image = cv2.rotate(color_image, self.rotation)
|
| 415 |
+
|
| 416 |
+
# log the number of seconds it took to read the image
|
| 417 |
+
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
|
| 418 |
+
|
| 419 |
+
# log the utc time at which the image was received
|
| 420 |
+
self.logs["timestamp_utc"] = capture_timestamp_utc()
|
| 421 |
+
|
| 422 |
+
self.color_image = color_image
|
| 423 |
+
|
| 424 |
+
return color_image
|
| 425 |
+
|
| 426 |
+
def read_loop(self):
|
| 427 |
+
while not self.stop_event.is_set():
|
| 428 |
+
try:
|
| 429 |
+
self.color_image = self.read()
|
| 430 |
+
except Exception as e:
|
| 431 |
+
print(f"Error reading in thread: {e}")
|
| 432 |
+
|
| 433 |
+
def async_read(self):
|
| 434 |
+
if not self.is_connected:
|
| 435 |
+
raise RobotDeviceNotConnectedError(
|
| 436 |
+
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
if self.thread is None:
|
| 440 |
+
self.stop_event = threading.Event()
|
| 441 |
+
self.thread = Thread(target=self.read_loop, args=())
|
| 442 |
+
self.thread.daemon = True
|
| 443 |
+
self.thread.start()
|
| 444 |
+
|
| 445 |
+
num_tries = 0
|
| 446 |
+
while True:
|
| 447 |
+
if self.color_image is not None:
|
| 448 |
+
return self.color_image
|
| 449 |
+
|
| 450 |
+
time.sleep(1 / self.fps)
|
| 451 |
+
num_tries += 1
|
| 452 |
+
if num_tries > self.fps * 2:
|
| 453 |
+
raise TimeoutError("Timed out waiting for async_read() to start.")
|
| 454 |
+
|
| 455 |
+
def disconnect(self):
|
| 456 |
+
if not self.is_connected:
|
| 457 |
+
raise RobotDeviceNotConnectedError(
|
| 458 |
+
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
if self.thread is not None:
|
| 462 |
+
self.stop_event.set()
|
| 463 |
+
self.thread.join() # wait for the thread to finish
|
| 464 |
+
self.thread = None
|
| 465 |
+
self.stop_event = None
|
| 466 |
+
|
| 467 |
+
self.camera.release()
|
| 468 |
+
self.camera = None
|
| 469 |
+
self.is_connected = False
|
| 470 |
+
|
| 471 |
+
def __del__(self):
|
| 472 |
+
if getattr(self, "is_connected", False):
|
| 473 |
+
self.disconnect()
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
if __name__ == "__main__":
|
| 477 |
+
parser = argparse.ArgumentParser(
|
| 478 |
+
description="Save a few frames using `OpenCVCamera` for all cameras connected to the computer, or a selected subset."
|
| 479 |
+
)
|
| 480 |
+
parser.add_argument(
|
| 481 |
+
"--camera-ids",
|
| 482 |
+
type=int,
|
| 483 |
+
nargs="*",
|
| 484 |
+
default=None,
|
| 485 |
+
help="List of camera indices used to instantiate the `OpenCVCamera`. If not provided, find and use all available camera indices.",
|
| 486 |
+
)
|
| 487 |
+
parser.add_argument(
|
| 488 |
+
"--fps",
|
| 489 |
+
type=int,
|
| 490 |
+
default=None,
|
| 491 |
+
help="Set the number of frames recorded per seconds for all cameras. If not provided, use the default fps of each camera.",
|
| 492 |
+
)
|
| 493 |
+
parser.add_argument(
|
| 494 |
+
"--width",
|
| 495 |
+
type=str,
|
| 496 |
+
default=None,
|
| 497 |
+
help="Set the width for all cameras. If not provided, use the default width of each camera.",
|
| 498 |
+
)
|
| 499 |
+
parser.add_argument(
|
| 500 |
+
"--height",
|
| 501 |
+
type=str,
|
| 502 |
+
default=None,
|
| 503 |
+
help="Set the height for all cameras. If not provided, use the default height of each camera.",
|
| 504 |
+
)
|
| 505 |
+
parser.add_argument(
|
| 506 |
+
"--images-dir",
|
| 507 |
+
type=Path,
|
| 508 |
+
default="outputs/images_from_opencv_cameras",
|
| 509 |
+
help="Set directory to save a few frames for each camera.",
|
| 510 |
+
)
|
| 511 |
+
parser.add_argument(
|
| 512 |
+
"--record-time-s",
|
| 513 |
+
type=float,
|
| 514 |
+
default=4.0,
|
| 515 |
+
help="Set the number of seconds used to record the frames. By default, 2 seconds.",
|
| 516 |
+
)
|
| 517 |
+
args = parser.parse_args()
|
| 518 |
+
save_images_from_cameras(**vars(args))
|
lerobot/common/robot_devices/cameras/utils.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Protocol
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
from lerobot.common.robot_devices.cameras.configs import (
|
| 20 |
+
CameraConfig,
|
| 21 |
+
IntelRealSenseCameraConfig,
|
| 22 |
+
OpenCVCameraConfig,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Defines a camera type
|
| 27 |
+
class Camera(Protocol):
|
| 28 |
+
def connect(self): ...
|
| 29 |
+
def read(self, temporary_color: str | None = None) -> np.ndarray: ...
|
| 30 |
+
def async_read(self) -> np.ndarray: ...
|
| 31 |
+
def disconnect(self): ...
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[Camera]:
|
| 35 |
+
cameras = {}
|
| 36 |
+
|
| 37 |
+
for key, cfg in camera_configs.items():
|
| 38 |
+
if cfg.type == "opencv":
|
| 39 |
+
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
| 40 |
+
|
| 41 |
+
cameras[key] = OpenCVCamera(cfg)
|
| 42 |
+
|
| 43 |
+
elif cfg.type == "intelrealsense":
|
| 44 |
+
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
|
| 45 |
+
|
| 46 |
+
cameras[key] = IntelRealSenseCamera(cfg)
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError(f"The camera type '{cfg.type}' is not valid.")
|
| 49 |
+
|
| 50 |
+
return cameras
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def make_camera(camera_type, **kwargs) -> Camera:
|
| 54 |
+
if camera_type == "opencv":
|
| 55 |
+
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
| 56 |
+
|
| 57 |
+
config = OpenCVCameraConfig(**kwargs)
|
| 58 |
+
return OpenCVCamera(config)
|
| 59 |
+
|
| 60 |
+
elif camera_type == "intelrealsense":
|
| 61 |
+
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
|
| 62 |
+
|
| 63 |
+
config = IntelRealSenseCameraConfig(**kwargs)
|
| 64 |
+
return IntelRealSenseCamera(config)
|
| 65 |
+
|
| 66 |
+
else:
|
| 67 |
+
raise ValueError(f"The camera type '{camera_type}' is not valid.")
|
lerobot/common/robot_devices/control_configs.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import draccus
|
| 19 |
+
|
| 20 |
+
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
| 21 |
+
from lerobot.configs import parser
|
| 22 |
+
from lerobot.configs.policies import PreTrainedConfig
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class ControlConfig(draccus.ChoiceRegistry):
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@ControlConfig.register_subclass("calibrate")
|
| 31 |
+
@dataclass
|
| 32 |
+
class CalibrateControlConfig(ControlConfig):
|
| 33 |
+
# List of arms to calibrate (e.g. `--arms='["left_follower","right_follower"]' left_leader`)
|
| 34 |
+
arms: list[str] | None = None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@ControlConfig.register_subclass("teleoperate")
|
| 38 |
+
@dataclass
|
| 39 |
+
class TeleoperateControlConfig(ControlConfig):
|
| 40 |
+
# Limit the maximum frames per second. By default, no limit.
|
| 41 |
+
fps: int | None = None
|
| 42 |
+
teleop_time_s: float | None = None
|
| 43 |
+
# Display all cameras on screen
|
| 44 |
+
display_cameras: bool = True
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@ControlConfig.register_subclass("record")
|
| 48 |
+
@dataclass
|
| 49 |
+
class RecordControlConfig(ControlConfig):
|
| 50 |
+
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
| 51 |
+
repo_id: str
|
| 52 |
+
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
| 53 |
+
single_task: str
|
| 54 |
+
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
| 55 |
+
root: str | Path | None = None
|
| 56 |
+
policy: PreTrainedConfig | None = None
|
| 57 |
+
# Limit the frames per second. By default, uses the policy fps.
|
| 58 |
+
fps: int | None = None
|
| 59 |
+
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
|
| 60 |
+
warmup_time_s: int | float = 10
|
| 61 |
+
# Number of seconds for data recording for each episode.
|
| 62 |
+
episode_time_s: int | float = 60
|
| 63 |
+
# Number of seconds for resetting the environment after each episode.
|
| 64 |
+
reset_time_s: int | float = 60
|
| 65 |
+
# Number of episodes to record.
|
| 66 |
+
num_episodes: int = 50
|
| 67 |
+
# Encode frames in the dataset into video
|
| 68 |
+
video: bool = True
|
| 69 |
+
# Upload dataset to Hugging Face hub.
|
| 70 |
+
push_to_hub: bool = True
|
| 71 |
+
# Upload on private repository on the Hugging Face hub.
|
| 72 |
+
private: bool = False
|
| 73 |
+
# Add tags to your dataset on the hub.
|
| 74 |
+
tags: list[str] | None = None
|
| 75 |
+
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
| 76 |
+
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
|
| 77 |
+
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
|
| 78 |
+
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
|
| 79 |
+
num_image_writer_processes: int = 0
|
| 80 |
+
# Number of threads writing the frames as png images on disk, per camera.
|
| 81 |
+
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
|
| 82 |
+
# Not enough threads might cause low camera fps.
|
| 83 |
+
num_image_writer_threads_per_camera: int = 4
|
| 84 |
+
# Display all cameras on screen
|
| 85 |
+
display_cameras: bool = True
|
| 86 |
+
# Use vocal synthesis to read events.
|
| 87 |
+
play_sounds: bool = True
|
| 88 |
+
# Resume recording on an existing dataset.
|
| 89 |
+
resume: bool = False
|
| 90 |
+
|
| 91 |
+
def __post_init__(self):
|
| 92 |
+
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
| 93 |
+
policy_path = parser.get_path_arg("control.policy")
|
| 94 |
+
if policy_path:
|
| 95 |
+
cli_overrides = parser.get_cli_overrides("control.policy")
|
| 96 |
+
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
| 97 |
+
self.policy.pretrained_path = policy_path
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@ControlConfig.register_subclass("replay")
|
| 101 |
+
@dataclass
|
| 102 |
+
class ReplayControlConfig(ControlConfig):
|
| 103 |
+
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
| 104 |
+
repo_id: str
|
| 105 |
+
# Index of the episode to replay.
|
| 106 |
+
episode: int
|
| 107 |
+
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
| 108 |
+
root: str | Path | None = None
|
| 109 |
+
# Limit the frames per second. By default, uses the dataset fps.
|
| 110 |
+
fps: int | None = None
|
| 111 |
+
# Use vocal synthesis to read events.
|
| 112 |
+
play_sounds: bool = True
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@ControlConfig.register_subclass("remote_robot")
|
| 116 |
+
@dataclass
|
| 117 |
+
class RemoteRobotConfig(ControlConfig):
|
| 118 |
+
log_interval: int = 100
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@dataclass
|
| 122 |
+
class ControlPipelineConfig:
|
| 123 |
+
robot: RobotConfig
|
| 124 |
+
control: ControlConfig
|
| 125 |
+
|
| 126 |
+
@classmethod
|
| 127 |
+
def __get_path_fields__(cls) -> list[str]:
|
| 128 |
+
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
| 129 |
+
return ["control.policy"]
|
lerobot/common/robot_devices/control_utils.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
########################################################################################
|
| 16 |
+
# Utilities
|
| 17 |
+
########################################################################################
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
import logging
|
| 21 |
+
import time
|
| 22 |
+
import traceback
|
| 23 |
+
from contextlib import nullcontext
|
| 24 |
+
from copy import copy
|
| 25 |
+
from functools import cache
|
| 26 |
+
|
| 27 |
+
import cv2
|
| 28 |
+
import torch
|
| 29 |
+
from deepdiff import DeepDiff
|
| 30 |
+
from termcolor import colored
|
| 31 |
+
|
| 32 |
+
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
| 33 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
| 34 |
+
from lerobot.common.datasets.utils import get_features_from_robot
|
| 35 |
+
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
| 36 |
+
from lerobot.common.robot_devices.robots.utils import Robot
|
| 37 |
+
from lerobot.common.robot_devices.utils import busy_wait
|
| 38 |
+
from lerobot.common.utils.utils import get_safe_torch_device, has_method
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
| 42 |
+
log_items = []
|
| 43 |
+
if episode_index is not None:
|
| 44 |
+
log_items.append(f"ep:{episode_index}")
|
| 45 |
+
if frame_index is not None:
|
| 46 |
+
log_items.append(f"frame:{frame_index}")
|
| 47 |
+
|
| 48 |
+
def log_dt(shortname, dt_val_s):
|
| 49 |
+
nonlocal log_items, fps
|
| 50 |
+
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
|
| 51 |
+
if fps is not None:
|
| 52 |
+
actual_fps = 1 / dt_val_s
|
| 53 |
+
if actual_fps < fps - 1:
|
| 54 |
+
info_str = colored(info_str, "yellow")
|
| 55 |
+
log_items.append(info_str)
|
| 56 |
+
|
| 57 |
+
# total step time displayed in milliseconds and its frequency
|
| 58 |
+
log_dt("dt", dt_s)
|
| 59 |
+
|
| 60 |
+
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
|
| 61 |
+
if not robot.robot_type.startswith("stretch"):
|
| 62 |
+
for name in robot.leader_arms:
|
| 63 |
+
key = f"read_leader_{name}_pos_dt_s"
|
| 64 |
+
if key in robot.logs:
|
| 65 |
+
log_dt("dtRlead", robot.logs[key])
|
| 66 |
+
|
| 67 |
+
for name in robot.follower_arms:
|
| 68 |
+
key = f"write_follower_{name}_goal_pos_dt_s"
|
| 69 |
+
if key in robot.logs:
|
| 70 |
+
log_dt("dtWfoll", robot.logs[key])
|
| 71 |
+
|
| 72 |
+
key = f"read_follower_{name}_pos_dt_s"
|
| 73 |
+
if key in robot.logs:
|
| 74 |
+
log_dt("dtRfoll", robot.logs[key])
|
| 75 |
+
|
| 76 |
+
for name in robot.cameras:
|
| 77 |
+
key = f"read_camera_{name}_dt_s"
|
| 78 |
+
if key in robot.logs:
|
| 79 |
+
log_dt(f"dtR{name}", robot.logs[key])
|
| 80 |
+
|
| 81 |
+
info_str = " ".join(log_items)
|
| 82 |
+
logging.info(info_str)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@cache
|
| 86 |
+
def is_headless():
|
| 87 |
+
"""Detects if python is running without a monitor."""
|
| 88 |
+
try:
|
| 89 |
+
import pynput # noqa
|
| 90 |
+
|
| 91 |
+
return False
|
| 92 |
+
except Exception:
|
| 93 |
+
print(
|
| 94 |
+
"Error trying to import pynput. Switching to headless mode. "
|
| 95 |
+
"As a result, the video stream from the cameras won't be shown, "
|
| 96 |
+
"and you won't be able to change the control flow with keyboards. "
|
| 97 |
+
"For more info, see traceback below.\n"
|
| 98 |
+
)
|
| 99 |
+
traceback.print_exc()
|
| 100 |
+
print()
|
| 101 |
+
return True
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def predict_action(observation, policy, device, use_amp):
|
| 105 |
+
observation = copy(observation)
|
| 106 |
+
with (
|
| 107 |
+
torch.inference_mode(),
|
| 108 |
+
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
| 109 |
+
):
|
| 110 |
+
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
| 111 |
+
for name in observation:
|
| 112 |
+
if "image" in name:
|
| 113 |
+
observation[name] = observation[name].type(torch.float32) / 255
|
| 114 |
+
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
| 115 |
+
observation[name] = observation[name].unsqueeze(0)
|
| 116 |
+
observation[name] = observation[name].to(device)
|
| 117 |
+
|
| 118 |
+
# Compute the next action with the policy
|
| 119 |
+
# based on the current observation
|
| 120 |
+
action = policy.select_action(observation)
|
| 121 |
+
|
| 122 |
+
# Remove batch dimension
|
| 123 |
+
action = action.squeeze(0)
|
| 124 |
+
|
| 125 |
+
# Move to cpu, if not already the case
|
| 126 |
+
action = action.to("cpu")
|
| 127 |
+
|
| 128 |
+
return action
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def init_keyboard_listener():
|
| 132 |
+
# Allow to exit early while recording an episode or resetting the environment,
|
| 133 |
+
# by tapping the right arrow key '->'. This might require a sudo permission
|
| 134 |
+
# to allow your terminal to monitor keyboard events.
|
| 135 |
+
events = {}
|
| 136 |
+
events["exit_early"] = False
|
| 137 |
+
events["rerecord_episode"] = False
|
| 138 |
+
events["stop_recording"] = False
|
| 139 |
+
|
| 140 |
+
if is_headless():
|
| 141 |
+
logging.warning(
|
| 142 |
+
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
| 143 |
+
)
|
| 144 |
+
listener = None
|
| 145 |
+
return listener, events
|
| 146 |
+
|
| 147 |
+
# Only import pynput if not in a headless environment
|
| 148 |
+
from pynput import keyboard
|
| 149 |
+
|
| 150 |
+
def on_press(key):
|
| 151 |
+
try:
|
| 152 |
+
if key == keyboard.Key.right:
|
| 153 |
+
print("Right arrow key pressed. Exiting loop...")
|
| 154 |
+
events["exit_early"] = True
|
| 155 |
+
elif key == keyboard.Key.left:
|
| 156 |
+
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
| 157 |
+
events["rerecord_episode"] = True
|
| 158 |
+
events["exit_early"] = True
|
| 159 |
+
elif key == keyboard.Key.esc:
|
| 160 |
+
print("Escape key pressed. Stopping data recording...")
|
| 161 |
+
events["stop_recording"] = True
|
| 162 |
+
events["exit_early"] = True
|
| 163 |
+
except Exception as e:
|
| 164 |
+
print(f"Error handling key press: {e}")
|
| 165 |
+
|
| 166 |
+
listener = keyboard.Listener(on_press=on_press)
|
| 167 |
+
listener.start()
|
| 168 |
+
|
| 169 |
+
return listener, events
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def warmup_record(
|
| 173 |
+
robot,
|
| 174 |
+
events,
|
| 175 |
+
enable_teleoperation,
|
| 176 |
+
warmup_time_s,
|
| 177 |
+
display_cameras,
|
| 178 |
+
fps,
|
| 179 |
+
):
|
| 180 |
+
control_loop(
|
| 181 |
+
robot=robot,
|
| 182 |
+
control_time_s=warmup_time_s,
|
| 183 |
+
display_cameras=display_cameras,
|
| 184 |
+
events=events,
|
| 185 |
+
fps=fps,
|
| 186 |
+
teleoperate=enable_teleoperation,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def record_episode(
|
| 191 |
+
robot,
|
| 192 |
+
dataset,
|
| 193 |
+
events,
|
| 194 |
+
episode_time_s,
|
| 195 |
+
display_cameras,
|
| 196 |
+
policy,
|
| 197 |
+
fps,
|
| 198 |
+
single_task,
|
| 199 |
+
):
|
| 200 |
+
control_loop(
|
| 201 |
+
robot=robot,
|
| 202 |
+
control_time_s=episode_time_s,
|
| 203 |
+
display_cameras=display_cameras,
|
| 204 |
+
dataset=dataset,
|
| 205 |
+
events=events,
|
| 206 |
+
policy=policy,
|
| 207 |
+
fps=fps,
|
| 208 |
+
teleoperate=policy is None,
|
| 209 |
+
single_task=single_task,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
@safe_stop_image_writer
|
| 214 |
+
def control_loop(
|
| 215 |
+
robot,
|
| 216 |
+
control_time_s=None,
|
| 217 |
+
teleoperate=False,
|
| 218 |
+
display_cameras=False,
|
| 219 |
+
dataset: LeRobotDataset | None = None,
|
| 220 |
+
events=None,
|
| 221 |
+
policy: PreTrainedPolicy = None,
|
| 222 |
+
fps: int | None = None,
|
| 223 |
+
single_task: str | None = None,
|
| 224 |
+
):
|
| 225 |
+
# TODO(rcadene): Add option to record logs
|
| 226 |
+
if not robot.is_connected:
|
| 227 |
+
robot.connect()
|
| 228 |
+
|
| 229 |
+
if events is None:
|
| 230 |
+
events = {"exit_early": False}
|
| 231 |
+
|
| 232 |
+
if control_time_s is None:
|
| 233 |
+
control_time_s = float("inf")
|
| 234 |
+
|
| 235 |
+
if teleoperate and policy is not None:
|
| 236 |
+
raise ValueError("When `teleoperate` is True, `policy` should be None.")
|
| 237 |
+
|
| 238 |
+
if dataset is not None and single_task is None:
|
| 239 |
+
raise ValueError("You need to provide a task as argument in `single_task`.")
|
| 240 |
+
|
| 241 |
+
if dataset is not None and fps is not None and dataset.fps != fps:
|
| 242 |
+
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
| 243 |
+
|
| 244 |
+
timestamp = 0
|
| 245 |
+
start_episode_t = time.perf_counter()
|
| 246 |
+
while timestamp < control_time_s:
|
| 247 |
+
start_loop_t = time.perf_counter()
|
| 248 |
+
|
| 249 |
+
if teleoperate:
|
| 250 |
+
observation, action = robot.teleop_step(record_data=True)
|
| 251 |
+
else:
|
| 252 |
+
observation = robot.capture_observation()
|
| 253 |
+
|
| 254 |
+
if policy is not None:
|
| 255 |
+
pred_action = predict_action(
|
| 256 |
+
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
| 257 |
+
)
|
| 258 |
+
# Action can eventually be clipped using `max_relative_target`,
|
| 259 |
+
# so action actually sent is saved in the dataset.
|
| 260 |
+
action = robot.send_action(pred_action)
|
| 261 |
+
action = {"action": action}
|
| 262 |
+
|
| 263 |
+
if dataset is not None:
|
| 264 |
+
frame = {**observation, **action, "task": single_task}
|
| 265 |
+
dataset.add_frame(frame)
|
| 266 |
+
|
| 267 |
+
if display_cameras and not is_headless():
|
| 268 |
+
image_keys = [key for key in observation if "image" in key]
|
| 269 |
+
for key in image_keys:
|
| 270 |
+
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
| 271 |
+
cv2.waitKey(1)
|
| 272 |
+
|
| 273 |
+
if fps is not None:
|
| 274 |
+
dt_s = time.perf_counter() - start_loop_t
|
| 275 |
+
busy_wait(1 / fps - dt_s)
|
| 276 |
+
|
| 277 |
+
dt_s = time.perf_counter() - start_loop_t
|
| 278 |
+
log_control_info(robot, dt_s, fps=fps)
|
| 279 |
+
|
| 280 |
+
timestamp = time.perf_counter() - start_episode_t
|
| 281 |
+
if events["exit_early"]:
|
| 282 |
+
events["exit_early"] = False
|
| 283 |
+
break
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def reset_environment(robot, events, reset_time_s, fps):
|
| 287 |
+
# TODO(rcadene): refactor warmup_record and reset_environment
|
| 288 |
+
if has_method(robot, "teleop_safety_stop"):
|
| 289 |
+
robot.teleop_safety_stop()
|
| 290 |
+
|
| 291 |
+
control_loop(
|
| 292 |
+
robot=robot,
|
| 293 |
+
control_time_s=reset_time_s,
|
| 294 |
+
events=events,
|
| 295 |
+
fps=fps,
|
| 296 |
+
teleoperate=True,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def stop_recording(robot, listener, display_cameras):
|
| 301 |
+
robot.disconnect()
|
| 302 |
+
|
| 303 |
+
if not is_headless():
|
| 304 |
+
if listener is not None:
|
| 305 |
+
listener.stop()
|
| 306 |
+
|
| 307 |
+
if display_cameras:
|
| 308 |
+
cv2.destroyAllWindows()
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def sanity_check_dataset_name(repo_id, policy_cfg):
|
| 312 |
+
_, dataset_name = repo_id.split("/")
|
| 313 |
+
# either repo_id doesnt start with "eval_" and there is no policy
|
| 314 |
+
# or repo_id starts with "eval_" and there is a policy
|
| 315 |
+
|
| 316 |
+
# Check if dataset_name starts with "eval_" but policy is missing
|
| 317 |
+
if dataset_name.startswith("eval_") and policy_cfg is None:
|
| 318 |
+
raise ValueError(
|
| 319 |
+
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# Check if dataset_name does not start with "eval_" but policy is provided
|
| 323 |
+
if not dataset_name.startswith("eval_") and policy_cfg is not None:
|
| 324 |
+
raise ValueError(
|
| 325 |
+
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy_cfg.type})."
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def sanity_check_dataset_robot_compatibility(
|
| 330 |
+
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
|
| 331 |
+
) -> None:
|
| 332 |
+
fields = [
|
| 333 |
+
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
| 334 |
+
("fps", dataset.fps, fps),
|
| 335 |
+
("features", dataset.features, get_features_from_robot(robot, use_videos)),
|
| 336 |
+
]
|
| 337 |
+
|
| 338 |
+
mismatches = []
|
| 339 |
+
for field, dataset_value, present_value in fields:
|
| 340 |
+
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
|
| 341 |
+
if diff:
|
| 342 |
+
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
|
| 343 |
+
|
| 344 |
+
if mismatches:
|
| 345 |
+
raise ValueError(
|
| 346 |
+
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
| 347 |
+
)
|
lerobot/common/robot_devices/motors/configs.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import abc
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
|
| 18 |
+
import draccus
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class MotorsBusConfig(draccus.ChoiceRegistry, abc.ABC):
|
| 23 |
+
@property
|
| 24 |
+
def type(self) -> str:
|
| 25 |
+
return self.get_choice_name(self.__class__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@MotorsBusConfig.register_subclass("dynamixel")
|
| 29 |
+
@dataclass
|
| 30 |
+
class DynamixelMotorsBusConfig(MotorsBusConfig):
|
| 31 |
+
port: str
|
| 32 |
+
motors: dict[str, tuple[int, str]]
|
| 33 |
+
mock: bool = False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@MotorsBusConfig.register_subclass("feetech")
|
| 37 |
+
@dataclass
|
| 38 |
+
class FeetechMotorsBusConfig(MotorsBusConfig):
|
| 39 |
+
port: str
|
| 40 |
+
motors: dict[str, tuple[int, str]]
|
| 41 |
+
mock: bool = False
|
lerobot/common/robot_devices/motors/dynamixel.py
ADDED
|
@@ -0,0 +1,873 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import enum
|
| 16 |
+
import logging
|
| 17 |
+
import math
|
| 18 |
+
import time
|
| 19 |
+
import traceback
|
| 20 |
+
from copy import deepcopy
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import tqdm
|
| 24 |
+
|
| 25 |
+
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
|
| 26 |
+
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
| 27 |
+
from lerobot.common.utils.utils import capture_timestamp_utc
|
| 28 |
+
|
| 29 |
+
PROTOCOL_VERSION = 2.0
|
| 30 |
+
BAUDRATE = 1_000_000
|
| 31 |
+
TIMEOUT_MS = 1000
|
| 32 |
+
|
| 33 |
+
MAX_ID_RANGE = 252
|
| 34 |
+
|
| 35 |
+
# The following bounds define the lower and upper joints range (after calibration).
|
| 36 |
+
# For joints in degree (i.e. revolute joints), their nominal range is [-180, 180] degrees
|
| 37 |
+
# which corresponds to a half rotation on the left and half rotation on the right.
|
| 38 |
+
# Some joints might require higher range, so we allow up to [-270, 270] degrees until
|
| 39 |
+
# an error is raised.
|
| 40 |
+
LOWER_BOUND_DEGREE = -270
|
| 41 |
+
UPPER_BOUND_DEGREE = 270
|
| 42 |
+
# For joints in percentage (i.e. joints that move linearly like the prismatic joint of a gripper),
|
| 43 |
+
# their nominal range is [0, 100] %. For instance, for Aloha gripper, 0% is fully
|
| 44 |
+
# closed, and 100% is fully open. To account for slight calibration issue, we allow up to
|
| 45 |
+
# [-10, 110] until an error is raised.
|
| 46 |
+
LOWER_BOUND_LINEAR = -10
|
| 47 |
+
UPPER_BOUND_LINEAR = 110
|
| 48 |
+
|
| 49 |
+
HALF_TURN_DEGREE = 180
|
| 50 |
+
|
| 51 |
+
# https://emanual.robotis.com/docs/en/dxl/x/xl330-m077
|
| 52 |
+
# https://emanual.robotis.com/docs/en/dxl/x/xl330-m288
|
| 53 |
+
# https://emanual.robotis.com/docs/en/dxl/x/xl430-w250
|
| 54 |
+
# https://emanual.robotis.com/docs/en/dxl/x/xm430-w350
|
| 55 |
+
# https://emanual.robotis.com/docs/en/dxl/x/xm540-w270
|
| 56 |
+
# https://emanual.robotis.com/docs/en/dxl/x/xc430-w150
|
| 57 |
+
|
| 58 |
+
# data_name: (address, size_byte)
|
| 59 |
+
X_SERIES_CONTROL_TABLE = {
|
| 60 |
+
"Model_Number": (0, 2),
|
| 61 |
+
"Model_Information": (2, 4),
|
| 62 |
+
"Firmware_Version": (6, 1),
|
| 63 |
+
"ID": (7, 1),
|
| 64 |
+
"Baud_Rate": (8, 1),
|
| 65 |
+
"Return_Delay_Time": (9, 1),
|
| 66 |
+
"Drive_Mode": (10, 1),
|
| 67 |
+
"Operating_Mode": (11, 1),
|
| 68 |
+
"Secondary_ID": (12, 1),
|
| 69 |
+
"Protocol_Type": (13, 1),
|
| 70 |
+
"Homing_Offset": (20, 4),
|
| 71 |
+
"Moving_Threshold": (24, 4),
|
| 72 |
+
"Temperature_Limit": (31, 1),
|
| 73 |
+
"Max_Voltage_Limit": (32, 2),
|
| 74 |
+
"Min_Voltage_Limit": (34, 2),
|
| 75 |
+
"PWM_Limit": (36, 2),
|
| 76 |
+
"Current_Limit": (38, 2),
|
| 77 |
+
"Acceleration_Limit": (40, 4),
|
| 78 |
+
"Velocity_Limit": (44, 4),
|
| 79 |
+
"Max_Position_Limit": (48, 4),
|
| 80 |
+
"Min_Position_Limit": (52, 4),
|
| 81 |
+
"Shutdown": (63, 1),
|
| 82 |
+
"Torque_Enable": (64, 1),
|
| 83 |
+
"LED": (65, 1),
|
| 84 |
+
"Status_Return_Level": (68, 1),
|
| 85 |
+
"Registered_Instruction": (69, 1),
|
| 86 |
+
"Hardware_Error_Status": (70, 1),
|
| 87 |
+
"Velocity_I_Gain": (76, 2),
|
| 88 |
+
"Velocity_P_Gain": (78, 2),
|
| 89 |
+
"Position_D_Gain": (80, 2),
|
| 90 |
+
"Position_I_Gain": (82, 2),
|
| 91 |
+
"Position_P_Gain": (84, 2),
|
| 92 |
+
"Feedforward_2nd_Gain": (88, 2),
|
| 93 |
+
"Feedforward_1st_Gain": (90, 2),
|
| 94 |
+
"Bus_Watchdog": (98, 1),
|
| 95 |
+
"Goal_PWM": (100, 2),
|
| 96 |
+
"Goal_Current": (102, 2),
|
| 97 |
+
"Goal_Velocity": (104, 4),
|
| 98 |
+
"Profile_Acceleration": (108, 4),
|
| 99 |
+
"Profile_Velocity": (112, 4),
|
| 100 |
+
"Goal_Position": (116, 4),
|
| 101 |
+
"Realtime_Tick": (120, 2),
|
| 102 |
+
"Moving": (122, 1),
|
| 103 |
+
"Moving_Status": (123, 1),
|
| 104 |
+
"Present_PWM": (124, 2),
|
| 105 |
+
"Present_Current": (126, 2),
|
| 106 |
+
"Present_Velocity": (128, 4),
|
| 107 |
+
"Present_Position": (132, 4),
|
| 108 |
+
"Velocity_Trajectory": (136, 4),
|
| 109 |
+
"Position_Trajectory": (140, 4),
|
| 110 |
+
"Present_Input_Voltage": (144, 2),
|
| 111 |
+
"Present_Temperature": (146, 1),
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
X_SERIES_BAUDRATE_TABLE = {
|
| 115 |
+
0: 9_600,
|
| 116 |
+
1: 57_600,
|
| 117 |
+
2: 115_200,
|
| 118 |
+
3: 1_000_000,
|
| 119 |
+
4: 2_000_000,
|
| 120 |
+
5: 3_000_000,
|
| 121 |
+
6: 4_000_000,
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"]
|
| 125 |
+
CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
|
| 126 |
+
|
| 127 |
+
MODEL_CONTROL_TABLE = {
|
| 128 |
+
"x_series": X_SERIES_CONTROL_TABLE,
|
| 129 |
+
"xl330-m077": X_SERIES_CONTROL_TABLE,
|
| 130 |
+
"xl330-m288": X_SERIES_CONTROL_TABLE,
|
| 131 |
+
"xl430-w250": X_SERIES_CONTROL_TABLE,
|
| 132 |
+
"xm430-w350": X_SERIES_CONTROL_TABLE,
|
| 133 |
+
"xm540-w270": X_SERIES_CONTROL_TABLE,
|
| 134 |
+
"xc430-w150": X_SERIES_CONTROL_TABLE,
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
MODEL_RESOLUTION = {
|
| 138 |
+
"x_series": 4096,
|
| 139 |
+
"xl330-m077": 4096,
|
| 140 |
+
"xl330-m288": 4096,
|
| 141 |
+
"xl430-w250": 4096,
|
| 142 |
+
"xm430-w350": 4096,
|
| 143 |
+
"xm540-w270": 4096,
|
| 144 |
+
"xc430-w150": 4096,
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
MODEL_BAUDRATE_TABLE = {
|
| 148 |
+
"x_series": X_SERIES_BAUDRATE_TABLE,
|
| 149 |
+
"xl330-m077": X_SERIES_BAUDRATE_TABLE,
|
| 150 |
+
"xl330-m288": X_SERIES_BAUDRATE_TABLE,
|
| 151 |
+
"xl430-w250": X_SERIES_BAUDRATE_TABLE,
|
| 152 |
+
"xm430-w350": X_SERIES_BAUDRATE_TABLE,
|
| 153 |
+
"xm540-w270": X_SERIES_BAUDRATE_TABLE,
|
| 154 |
+
"xc430-w150": X_SERIES_BAUDRATE_TABLE,
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
NUM_READ_RETRY = 10
|
| 158 |
+
NUM_WRITE_RETRY = 10
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
|
| 162 |
+
"""This function converts the degree range to the step range for indicating motors rotation.
|
| 163 |
+
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
|
| 164 |
+
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
|
| 165 |
+
"""
|
| 166 |
+
resolutions = [MODEL_RESOLUTION[model] for model in models]
|
| 167 |
+
steps = degrees / 180 * np.array(resolutions) / 2
|
| 168 |
+
steps = steps.astype(int)
|
| 169 |
+
return steps
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def convert_to_bytes(value, bytes, mock=False):
|
| 173 |
+
if mock:
|
| 174 |
+
return value
|
| 175 |
+
|
| 176 |
+
import dynamixel_sdk as dxl
|
| 177 |
+
|
| 178 |
+
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
| 179 |
+
# already handles it for us.
|
| 180 |
+
if bytes == 1:
|
| 181 |
+
data = [
|
| 182 |
+
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
| 183 |
+
]
|
| 184 |
+
elif bytes == 2:
|
| 185 |
+
data = [
|
| 186 |
+
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
| 187 |
+
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
|
| 188 |
+
]
|
| 189 |
+
elif bytes == 4:
|
| 190 |
+
data = [
|
| 191 |
+
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
| 192 |
+
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
|
| 193 |
+
dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)),
|
| 194 |
+
dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)),
|
| 195 |
+
]
|
| 196 |
+
else:
|
| 197 |
+
raise NotImplementedError(
|
| 198 |
+
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
| 199 |
+
f"{bytes} is provided instead."
|
| 200 |
+
)
|
| 201 |
+
return data
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def get_group_sync_key(data_name, motor_names):
|
| 205 |
+
group_key = f"{data_name}_" + "_".join(motor_names)
|
| 206 |
+
return group_key
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def get_result_name(fn_name, data_name, motor_names):
|
| 210 |
+
group_key = get_group_sync_key(data_name, motor_names)
|
| 211 |
+
rslt_name = f"{fn_name}_{group_key}"
|
| 212 |
+
return rslt_name
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def get_queue_name(fn_name, data_name, motor_names):
|
| 216 |
+
group_key = get_group_sync_key(data_name, motor_names)
|
| 217 |
+
queue_name = f"{fn_name}_{group_key}"
|
| 218 |
+
return queue_name
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def get_log_name(var_name, fn_name, data_name, motor_names):
|
| 222 |
+
group_key = get_group_sync_key(data_name, motor_names)
|
| 223 |
+
log_name = f"{var_name}_{fn_name}_{group_key}"
|
| 224 |
+
return log_name
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def assert_same_address(model_ctrl_table, motor_models, data_name):
|
| 228 |
+
all_addr = []
|
| 229 |
+
all_bytes = []
|
| 230 |
+
for model in motor_models:
|
| 231 |
+
addr, bytes = model_ctrl_table[model][data_name]
|
| 232 |
+
all_addr.append(addr)
|
| 233 |
+
all_bytes.append(bytes)
|
| 234 |
+
|
| 235 |
+
if len(set(all_addr)) != 1:
|
| 236 |
+
raise NotImplementedError(
|
| 237 |
+
f"At least two motor models use a different address for `data_name`='{data_name}' ({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer."
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
if len(set(all_bytes)) != 1:
|
| 241 |
+
raise NotImplementedError(
|
| 242 |
+
f"At least two motor models use a different bytes representation for `data_name`='{data_name}' ({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer."
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class TorqueMode(enum.Enum):
|
| 247 |
+
ENABLED = 1
|
| 248 |
+
DISABLED = 0
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class DriveMode(enum.Enum):
|
| 252 |
+
NON_INVERTED = 0
|
| 253 |
+
INVERTED = 1
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class CalibrationMode(enum.Enum):
|
| 257 |
+
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
| 258 |
+
DEGREE = 0
|
| 259 |
+
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
|
| 260 |
+
LINEAR = 1
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class JointOutOfRangeError(Exception):
|
| 264 |
+
def __init__(self, message="Joint is out of range"):
|
| 265 |
+
self.message = message
|
| 266 |
+
super().__init__(self.message)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class DynamixelMotorsBus:
|
| 270 |
+
"""
|
| 271 |
+
The DynamixelMotorsBus class allows to efficiently read and write to the attached motors. It relies on
|
| 272 |
+
the python dynamixel sdk to communicate with the motors. For more info, see the [Dynamixel SDK Documentation](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20).
|
| 273 |
+
|
| 274 |
+
A DynamixelMotorsBus instance requires a port (e.g. `DynamixelMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
|
| 275 |
+
To find the port, you can run our utility script:
|
| 276 |
+
```bash
|
| 277 |
+
python lerobot/scripts/find_motors_bus_port.py
|
| 278 |
+
>>> Finding all available ports for the MotorBus.
|
| 279 |
+
>>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
| 280 |
+
>>> Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
| 281 |
+
>>> The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751.
|
| 282 |
+
>>> Reconnect the usb cable.
|
| 283 |
+
```
|
| 284 |
+
|
| 285 |
+
Example of usage for 1 motor connected to the bus:
|
| 286 |
+
```python
|
| 287 |
+
motor_name = "gripper"
|
| 288 |
+
motor_index = 6
|
| 289 |
+
motor_model = "xl330-m288"
|
| 290 |
+
|
| 291 |
+
config = DynamixelMotorsBusConfig(
|
| 292 |
+
port="/dev/tty.usbmodem575E0031751",
|
| 293 |
+
motors={motor_name: (motor_index, motor_model)},
|
| 294 |
+
)
|
| 295 |
+
motors_bus = DynamixelMotorsBus(config)
|
| 296 |
+
motors_bus.connect()
|
| 297 |
+
|
| 298 |
+
position = motors_bus.read("Present_Position")
|
| 299 |
+
|
| 300 |
+
# move from a few motor steps as an example
|
| 301 |
+
few_steps = 30
|
| 302 |
+
motors_bus.write("Goal_Position", position + few_steps)
|
| 303 |
+
|
| 304 |
+
# when done, consider disconnecting
|
| 305 |
+
motors_bus.disconnect()
|
| 306 |
+
```
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
def __init__(
|
| 310 |
+
self,
|
| 311 |
+
config: DynamixelMotorsBusConfig,
|
| 312 |
+
):
|
| 313 |
+
self.port = config.port
|
| 314 |
+
self.motors = config.motors
|
| 315 |
+
self.mock = config.mock
|
| 316 |
+
|
| 317 |
+
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
|
| 318 |
+
self.model_resolution = deepcopy(MODEL_RESOLUTION)
|
| 319 |
+
|
| 320 |
+
self.port_handler = None
|
| 321 |
+
self.packet_handler = None
|
| 322 |
+
self.calibration = None
|
| 323 |
+
self.is_connected = False
|
| 324 |
+
self.group_readers = {}
|
| 325 |
+
self.group_writers = {}
|
| 326 |
+
self.logs = {}
|
| 327 |
+
|
| 328 |
+
def connect(self):
|
| 329 |
+
if self.is_connected:
|
| 330 |
+
raise RobotDeviceAlreadyConnectedError(
|
| 331 |
+
f"DynamixelMotorsBus({self.port}) is already connected. Do not call `motors_bus.connect()` twice."
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
if self.mock:
|
| 335 |
+
import tests.motors.mock_dynamixel_sdk as dxl
|
| 336 |
+
else:
|
| 337 |
+
import dynamixel_sdk as dxl
|
| 338 |
+
|
| 339 |
+
self.port_handler = dxl.PortHandler(self.port)
|
| 340 |
+
self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION)
|
| 341 |
+
|
| 342 |
+
try:
|
| 343 |
+
if not self.port_handler.openPort():
|
| 344 |
+
raise OSError(f"Failed to open port '{self.port}'.")
|
| 345 |
+
except Exception:
|
| 346 |
+
traceback.print_exc()
|
| 347 |
+
print(
|
| 348 |
+
"\nTry running `python lerobot/scripts/find_motors_bus_port.py` to make sure you are using the correct port.\n"
|
| 349 |
+
)
|
| 350 |
+
raise
|
| 351 |
+
|
| 352 |
+
# Allow to read and write
|
| 353 |
+
self.is_connected = True
|
| 354 |
+
|
| 355 |
+
self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
|
| 356 |
+
|
| 357 |
+
def reconnect(self):
|
| 358 |
+
if self.mock:
|
| 359 |
+
import tests.motors.mock_dynamixel_sdk as dxl
|
| 360 |
+
else:
|
| 361 |
+
import dynamixel_sdk as dxl
|
| 362 |
+
|
| 363 |
+
self.port_handler = dxl.PortHandler(self.port)
|
| 364 |
+
self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION)
|
| 365 |
+
|
| 366 |
+
if not self.port_handler.openPort():
|
| 367 |
+
raise OSError(f"Failed to open port '{self.port}'.")
|
| 368 |
+
|
| 369 |
+
self.is_connected = True
|
| 370 |
+
|
| 371 |
+
def are_motors_configured(self):
|
| 372 |
+
# Only check the motor indices and not baudrate, since if the motor baudrates are incorrect,
|
| 373 |
+
# a ConnectionError will be raised anyway.
|
| 374 |
+
try:
|
| 375 |
+
return (self.motor_indices == self.read("ID")).all()
|
| 376 |
+
except ConnectionError as e:
|
| 377 |
+
print(e)
|
| 378 |
+
return False
|
| 379 |
+
|
| 380 |
+
def find_motor_indices(self, possible_ids=None, num_retry=2):
|
| 381 |
+
if possible_ids is None:
|
| 382 |
+
possible_ids = range(MAX_ID_RANGE)
|
| 383 |
+
|
| 384 |
+
indices = []
|
| 385 |
+
for idx in tqdm.tqdm(possible_ids):
|
| 386 |
+
try:
|
| 387 |
+
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
|
| 388 |
+
except ConnectionError:
|
| 389 |
+
continue
|
| 390 |
+
|
| 391 |
+
if idx != present_idx:
|
| 392 |
+
# sanity check
|
| 393 |
+
raise OSError(
|
| 394 |
+
"Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged."
|
| 395 |
+
)
|
| 396 |
+
indices.append(idx)
|
| 397 |
+
|
| 398 |
+
return indices
|
| 399 |
+
|
| 400 |
+
def set_bus_baudrate(self, baudrate):
|
| 401 |
+
present_bus_baudrate = self.port_handler.getBaudRate()
|
| 402 |
+
if present_bus_baudrate != baudrate:
|
| 403 |
+
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
|
| 404 |
+
self.port_handler.setBaudRate(baudrate)
|
| 405 |
+
|
| 406 |
+
if self.port_handler.getBaudRate() != baudrate:
|
| 407 |
+
raise OSError("Failed to write bus baud rate.")
|
| 408 |
+
|
| 409 |
+
@property
|
| 410 |
+
def motor_names(self) -> list[str]:
|
| 411 |
+
return list(self.motors.keys())
|
| 412 |
+
|
| 413 |
+
@property
|
| 414 |
+
def motor_models(self) -> list[str]:
|
| 415 |
+
return [model for _, model in self.motors.values()]
|
| 416 |
+
|
| 417 |
+
@property
|
| 418 |
+
def motor_indices(self) -> list[int]:
|
| 419 |
+
return [idx for idx, _ in self.motors.values()]
|
| 420 |
+
|
| 421 |
+
def set_calibration(self, calibration: dict[str, list]):
|
| 422 |
+
self.calibration = calibration
|
| 423 |
+
|
| 424 |
+
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
|
| 425 |
+
"""This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct.
|
| 426 |
+
|
| 427 |
+
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
|
| 428 |
+
"""
|
| 429 |
+
try:
|
| 430 |
+
values = self.apply_calibration(values, motor_names)
|
| 431 |
+
except JointOutOfRangeError as e:
|
| 432 |
+
print(e)
|
| 433 |
+
self.autocorrect_calibration(values, motor_names)
|
| 434 |
+
values = self.apply_calibration(values, motor_names)
|
| 435 |
+
return values
|
| 436 |
+
|
| 437 |
+
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
| 438 |
+
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
|
| 439 |
+
a "zero position" at 0 degree.
|
| 440 |
+
|
| 441 |
+
Note: We say "nominal degree range" since the motors can take values outside this range. For instance, 190 degrees, if the motor
|
| 442 |
+
rotate more than a half a turn from the zero position. However, most motors can't rotate more than 180 degrees and will stay in this range.
|
| 443 |
+
|
| 444 |
+
Joints values are original in [0, 2**32[ (unsigned int32). Each motor are expected to complete a full rotation
|
| 445 |
+
when given a goal position that is + or - their resolution. For instance, dynamixel xl330-m077 have a resolution of 4096, and
|
| 446 |
+
at any position in their original range, let's say the position 56734, they complete a full rotation clockwise by moving to 60830,
|
| 447 |
+
or anticlockwise by moving to 52638. The position in the original range is arbitrary and might change a lot between each motor.
|
| 448 |
+
To harmonize between motors of the same model, different robots, or even models of different brands, we propose to work
|
| 449 |
+
in the centered nominal degree range ]-180, 180[.
|
| 450 |
+
"""
|
| 451 |
+
if motor_names is None:
|
| 452 |
+
motor_names = self.motor_names
|
| 453 |
+
|
| 454 |
+
# Convert from unsigned int32 original range [0, 2**32] to signed float32 range
|
| 455 |
+
values = values.astype(np.float32)
|
| 456 |
+
|
| 457 |
+
for i, name in enumerate(motor_names):
|
| 458 |
+
calib_idx = self.calibration["motor_names"].index(name)
|
| 459 |
+
calib_mode = self.calibration["calib_mode"][calib_idx]
|
| 460 |
+
|
| 461 |
+
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
| 462 |
+
drive_mode = self.calibration["drive_mode"][calib_idx]
|
| 463 |
+
homing_offset = self.calibration["homing_offset"][calib_idx]
|
| 464 |
+
_, model = self.motors[name]
|
| 465 |
+
resolution = self.model_resolution[model]
|
| 466 |
+
|
| 467 |
+
# Update direction of rotation of the motor to match between leader and follower.
|
| 468 |
+
# In fact, the motor of the leader for a given joint can be assembled in an
|
| 469 |
+
# opposite direction in term of rotation than the motor of the follower on the same joint.
|
| 470 |
+
if drive_mode:
|
| 471 |
+
values[i] *= -1
|
| 472 |
+
|
| 473 |
+
# Convert from range [-2**31, 2**31] to
|
| 474 |
+
# nominal range [-resolution//2, resolution//2] (e.g. [-2048, 2048])
|
| 475 |
+
values[i] += homing_offset
|
| 476 |
+
|
| 477 |
+
# Convert from range [-resolution//2, resolution//2] to
|
| 478 |
+
# universal float32 centered degree range [-180, 180]
|
| 479 |
+
# (e.g. 2048 / (4096 // 2) * 180 = 180)
|
| 480 |
+
values[i] = values[i] / (resolution // 2) * HALF_TURN_DEGREE
|
| 481 |
+
|
| 482 |
+
if (values[i] < LOWER_BOUND_DEGREE) or (values[i] > UPPER_BOUND_DEGREE):
|
| 483 |
+
raise JointOutOfRangeError(
|
| 484 |
+
f"Wrong motor position range detected for {name}. "
|
| 485 |
+
f"Expected to be in nominal range of [-{HALF_TURN_DEGREE}, {HALF_TURN_DEGREE}] degrees (a full rotation), "
|
| 486 |
+
f"with a maximum range of [{LOWER_BOUND_DEGREE}, {UPPER_BOUND_DEGREE}] degrees to account for joints that can rotate a bit more, "
|
| 487 |
+
f"but present value is {values[i]} degree. "
|
| 488 |
+
"This might be due to a cable connection issue creating an artificial 360 degrees jump in motor values. "
|
| 489 |
+
"You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
| 493 |
+
start_pos = self.calibration["start_pos"][calib_idx]
|
| 494 |
+
end_pos = self.calibration["end_pos"][calib_idx]
|
| 495 |
+
|
| 496 |
+
# Rescale the present position to a nominal range [0, 100] %,
|
| 497 |
+
# useful for joints with linear motions like Aloha gripper
|
| 498 |
+
values[i] = (values[i] - start_pos) / (end_pos - start_pos) * 100
|
| 499 |
+
|
| 500 |
+
if (values[i] < LOWER_BOUND_LINEAR) or (values[i] > UPPER_BOUND_LINEAR):
|
| 501 |
+
raise JointOutOfRangeError(
|
| 502 |
+
f"Wrong motor position range detected for {name}. "
|
| 503 |
+
f"Expected to be in nominal range of [0, 100] % (a full linear translation), "
|
| 504 |
+
f"with a maximum range of [{LOWER_BOUND_LINEAR}, {UPPER_BOUND_LINEAR}] % to account for some imprecision during calibration, "
|
| 505 |
+
f"but present value is {values[i]} %. "
|
| 506 |
+
"This might be due to a cable connection issue creating an artificial jump in motor values. "
|
| 507 |
+
"You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
return values
|
| 511 |
+
|
| 512 |
+
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
| 513 |
+
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
|
| 514 |
+
|
| 515 |
+
Some motors might have values outside of expected maximum bounds after calibration.
|
| 516 |
+
For instance, for a joint in degree, its value can be outside [-270, 270] degrees, which is totally unexpected given
|
| 517 |
+
a nominal range of [-180, 180] degrees, which represents half a turn to the left or right starting from zero position.
|
| 518 |
+
|
| 519 |
+
Known issues:
|
| 520 |
+
#1: Motor value randomly shifts of a full turn, caused by hardware/connection errors.
|
| 521 |
+
#2: Motor internal homing offset is shifted by a full turn, caused by using default calibration (e.g Aloha).
|
| 522 |
+
#3: motor internal homing offset is shifted by less or more than a full turn, caused by using default calibration
|
| 523 |
+
or by human error during manual calibration.
|
| 524 |
+
|
| 525 |
+
Issues #1 and #2 can be solved by shifting the calibration homing offset by a full turn.
|
| 526 |
+
Issue #3 will be visually detected by user and potentially captured by the safety feature `max_relative_target`,
|
| 527 |
+
that will slow down the motor, raise an error asking to recalibrate. Manual recalibrating will solve the issue.
|
| 528 |
+
|
| 529 |
+
Note: A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
| 530 |
+
"""
|
| 531 |
+
if motor_names is None:
|
| 532 |
+
motor_names = self.motor_names
|
| 533 |
+
|
| 534 |
+
# Convert from unsigned int32 original range [0, 2**32] to signed float32 range
|
| 535 |
+
values = values.astype(np.float32)
|
| 536 |
+
|
| 537 |
+
for i, name in enumerate(motor_names):
|
| 538 |
+
calib_idx = self.calibration["motor_names"].index(name)
|
| 539 |
+
calib_mode = self.calibration["calib_mode"][calib_idx]
|
| 540 |
+
|
| 541 |
+
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
| 542 |
+
drive_mode = self.calibration["drive_mode"][calib_idx]
|
| 543 |
+
homing_offset = self.calibration["homing_offset"][calib_idx]
|
| 544 |
+
_, model = self.motors[name]
|
| 545 |
+
resolution = self.model_resolution[model]
|
| 546 |
+
|
| 547 |
+
# Update direction of rotation of the motor to match between leader and follower.
|
| 548 |
+
# In fact, the motor of the leader for a given joint can be assembled in an
|
| 549 |
+
# opposite direction in term of rotation than the motor of the follower on the same joint.
|
| 550 |
+
if drive_mode:
|
| 551 |
+
values[i] *= -1
|
| 552 |
+
|
| 553 |
+
# Convert from initial range to range [-180, 180] degrees
|
| 554 |
+
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
| 555 |
+
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
|
| 556 |
+
|
| 557 |
+
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
|
| 558 |
+
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
|
| 559 |
+
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
|
| 560 |
+
# (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution
|
| 561 |
+
low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution
|
| 562 |
+
upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution
|
| 563 |
+
|
| 564 |
+
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
| 565 |
+
start_pos = self.calibration["start_pos"][calib_idx]
|
| 566 |
+
end_pos = self.calibration["end_pos"][calib_idx]
|
| 567 |
+
|
| 568 |
+
# Convert from initial range to range [0, 100] in %
|
| 569 |
+
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
|
| 570 |
+
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
|
| 571 |
+
|
| 572 |
+
# Solve this inequality to find the factor to shift the range into [0, 100] %
|
| 573 |
+
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
|
| 574 |
+
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100
|
| 575 |
+
# 0 <= (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100 <= 100
|
| 576 |
+
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
|
| 577 |
+
low_factor = (start_pos - values[i]) / resolution
|
| 578 |
+
upp_factor = (end_pos - values[i]) / resolution
|
| 579 |
+
|
| 580 |
+
if not in_range:
|
| 581 |
+
# Get first integer between the two bounds
|
| 582 |
+
if low_factor < upp_factor:
|
| 583 |
+
factor = math.ceil(low_factor)
|
| 584 |
+
|
| 585 |
+
if factor > upp_factor:
|
| 586 |
+
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
| 587 |
+
else:
|
| 588 |
+
factor = math.ceil(upp_factor)
|
| 589 |
+
|
| 590 |
+
if factor > low_factor:
|
| 591 |
+
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
| 592 |
+
|
| 593 |
+
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
| 594 |
+
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
| 595 |
+
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
| 596 |
+
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
| 597 |
+
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
| 598 |
+
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
| 599 |
+
|
| 600 |
+
logging.warning(
|
| 601 |
+
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
| 602 |
+
f"from '{out_of_range_str}' to '{in_range_str}'."
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
| 606 |
+
self.calibration["homing_offset"][calib_idx] += resolution * factor
|
| 607 |
+
|
| 608 |
+
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
| 609 |
+
"""Inverse of `apply_calibration`."""
|
| 610 |
+
if motor_names is None:
|
| 611 |
+
motor_names = self.motor_names
|
| 612 |
+
|
| 613 |
+
for i, name in enumerate(motor_names):
|
| 614 |
+
calib_idx = self.calibration["motor_names"].index(name)
|
| 615 |
+
calib_mode = self.calibration["calib_mode"][calib_idx]
|
| 616 |
+
|
| 617 |
+
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
| 618 |
+
drive_mode = self.calibration["drive_mode"][calib_idx]
|
| 619 |
+
homing_offset = self.calibration["homing_offset"][calib_idx]
|
| 620 |
+
_, model = self.motors[name]
|
| 621 |
+
resolution = self.model_resolution[model]
|
| 622 |
+
|
| 623 |
+
# Convert from nominal 0-centered degree range [-180, 180] to
|
| 624 |
+
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
|
| 625 |
+
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
|
| 626 |
+
|
| 627 |
+
# Subtract the homing offsets to come back to actual motor range of values
|
| 628 |
+
# which can be arbitrary.
|
| 629 |
+
values[i] -= homing_offset
|
| 630 |
+
|
| 631 |
+
# Remove drive mode, which is the rotation direction of the motor, to come back to
|
| 632 |
+
# actual motor rotation direction which can be arbitrary.
|
| 633 |
+
if drive_mode:
|
| 634 |
+
values[i] *= -1
|
| 635 |
+
|
| 636 |
+
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
| 637 |
+
start_pos = self.calibration["start_pos"][calib_idx]
|
| 638 |
+
end_pos = self.calibration["end_pos"][calib_idx]
|
| 639 |
+
|
| 640 |
+
# Convert from nominal lnear range of [0, 100] % to
|
| 641 |
+
# actual motor range of values which can be arbitrary.
|
| 642 |
+
values[i] = values[i] / 100 * (end_pos - start_pos) + start_pos
|
| 643 |
+
|
| 644 |
+
values = np.round(values).astype(np.int32)
|
| 645 |
+
return values
|
| 646 |
+
|
| 647 |
+
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
| 648 |
+
if self.mock:
|
| 649 |
+
import tests.motors.mock_dynamixel_sdk as dxl
|
| 650 |
+
else:
|
| 651 |
+
import dynamixel_sdk as dxl
|
| 652 |
+
|
| 653 |
+
return_list = True
|
| 654 |
+
if not isinstance(motor_ids, list):
|
| 655 |
+
return_list = False
|
| 656 |
+
motor_ids = [motor_ids]
|
| 657 |
+
|
| 658 |
+
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
|
| 659 |
+
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
| 660 |
+
group = dxl.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
|
| 661 |
+
for idx in motor_ids:
|
| 662 |
+
group.addParam(idx)
|
| 663 |
+
|
| 664 |
+
for _ in range(num_retry):
|
| 665 |
+
comm = group.txRxPacket()
|
| 666 |
+
if comm == dxl.COMM_SUCCESS:
|
| 667 |
+
break
|
| 668 |
+
|
| 669 |
+
if comm != dxl.COMM_SUCCESS:
|
| 670 |
+
raise ConnectionError(
|
| 671 |
+
f"Read failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
|
| 672 |
+
f"{self.packet_handler.getTxRxResult(comm)}"
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
values = []
|
| 676 |
+
for idx in motor_ids:
|
| 677 |
+
value = group.getData(idx, addr, bytes)
|
| 678 |
+
values.append(value)
|
| 679 |
+
|
| 680 |
+
if return_list:
|
| 681 |
+
return values
|
| 682 |
+
else:
|
| 683 |
+
return values[0]
|
| 684 |
+
|
| 685 |
+
def read(self, data_name, motor_names: str | list[str] | None = None):
|
| 686 |
+
if not self.is_connected:
|
| 687 |
+
raise RobotDeviceNotConnectedError(
|
| 688 |
+
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
start_time = time.perf_counter()
|
| 692 |
+
|
| 693 |
+
if self.mock:
|
| 694 |
+
import tests.motors.mock_dynamixel_sdk as dxl
|
| 695 |
+
else:
|
| 696 |
+
import dynamixel_sdk as dxl
|
| 697 |
+
|
| 698 |
+
if motor_names is None:
|
| 699 |
+
motor_names = self.motor_names
|
| 700 |
+
|
| 701 |
+
if isinstance(motor_names, str):
|
| 702 |
+
motor_names = [motor_names]
|
| 703 |
+
|
| 704 |
+
motor_ids = []
|
| 705 |
+
models = []
|
| 706 |
+
for name in motor_names:
|
| 707 |
+
motor_idx, model = self.motors[name]
|
| 708 |
+
motor_ids.append(motor_idx)
|
| 709 |
+
models.append(model)
|
| 710 |
+
|
| 711 |
+
assert_same_address(self.model_ctrl_table, models, data_name)
|
| 712 |
+
addr, bytes = self.model_ctrl_table[model][data_name]
|
| 713 |
+
group_key = get_group_sync_key(data_name, motor_names)
|
| 714 |
+
|
| 715 |
+
if data_name not in self.group_readers:
|
| 716 |
+
# create new group reader
|
| 717 |
+
self.group_readers[group_key] = dxl.GroupSyncRead(
|
| 718 |
+
self.port_handler, self.packet_handler, addr, bytes
|
| 719 |
+
)
|
| 720 |
+
for idx in motor_ids:
|
| 721 |
+
self.group_readers[group_key].addParam(idx)
|
| 722 |
+
|
| 723 |
+
for _ in range(NUM_READ_RETRY):
|
| 724 |
+
comm = self.group_readers[group_key].txRxPacket()
|
| 725 |
+
if comm == dxl.COMM_SUCCESS:
|
| 726 |
+
break
|
| 727 |
+
|
| 728 |
+
if comm != dxl.COMM_SUCCESS:
|
| 729 |
+
raise ConnectionError(
|
| 730 |
+
f"Read failed due to communication error on port {self.port} for group_key {group_key}: "
|
| 731 |
+
f"{self.packet_handler.getTxRxResult(comm)}"
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
values = []
|
| 735 |
+
for idx in motor_ids:
|
| 736 |
+
value = self.group_readers[group_key].getData(idx, addr, bytes)
|
| 737 |
+
values.append(value)
|
| 738 |
+
|
| 739 |
+
values = np.array(values)
|
| 740 |
+
|
| 741 |
+
# Convert to signed int to use range [-2048, 2048] for our motor positions.
|
| 742 |
+
if data_name in CONVERT_UINT32_TO_INT32_REQUIRED:
|
| 743 |
+
values = values.astype(np.int32)
|
| 744 |
+
|
| 745 |
+
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
|
| 746 |
+
values = self.apply_calibration_autocorrect(values, motor_names)
|
| 747 |
+
|
| 748 |
+
# log the number of seconds it took to read the data from the motors
|
| 749 |
+
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
|
| 750 |
+
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
| 751 |
+
|
| 752 |
+
# log the utc time at which the data was received
|
| 753 |
+
ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names)
|
| 754 |
+
self.logs[ts_utc_name] = capture_timestamp_utc()
|
| 755 |
+
|
| 756 |
+
return values
|
| 757 |
+
|
| 758 |
+
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
| 759 |
+
if self.mock:
|
| 760 |
+
import tests.motors.mock_dynamixel_sdk as dxl
|
| 761 |
+
else:
|
| 762 |
+
import dynamixel_sdk as dxl
|
| 763 |
+
|
| 764 |
+
if not isinstance(motor_ids, list):
|
| 765 |
+
motor_ids = [motor_ids]
|
| 766 |
+
if not isinstance(values, list):
|
| 767 |
+
values = [values]
|
| 768 |
+
|
| 769 |
+
assert_same_address(self.model_ctrl_table, motor_models, data_name)
|
| 770 |
+
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
| 771 |
+
group = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
|
| 772 |
+
for idx, value in zip(motor_ids, values, strict=True):
|
| 773 |
+
data = convert_to_bytes(value, bytes, self.mock)
|
| 774 |
+
group.addParam(idx, data)
|
| 775 |
+
|
| 776 |
+
for _ in range(num_retry):
|
| 777 |
+
comm = group.txPacket()
|
| 778 |
+
if comm == dxl.COMM_SUCCESS:
|
| 779 |
+
break
|
| 780 |
+
|
| 781 |
+
if comm != dxl.COMM_SUCCESS:
|
| 782 |
+
raise ConnectionError(
|
| 783 |
+
f"Write failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
|
| 784 |
+
f"{self.packet_handler.getTxRxResult(comm)}"
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
|
| 788 |
+
if not self.is_connected:
|
| 789 |
+
raise RobotDeviceNotConnectedError(
|
| 790 |
+
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
start_time = time.perf_counter()
|
| 794 |
+
|
| 795 |
+
if self.mock:
|
| 796 |
+
import tests.motors.mock_dynamixel_sdk as dxl
|
| 797 |
+
else:
|
| 798 |
+
import dynamixel_sdk as dxl
|
| 799 |
+
|
| 800 |
+
if motor_names is None:
|
| 801 |
+
motor_names = self.motor_names
|
| 802 |
+
|
| 803 |
+
if isinstance(motor_names, str):
|
| 804 |
+
motor_names = [motor_names]
|
| 805 |
+
|
| 806 |
+
if isinstance(values, (int, float, np.integer)):
|
| 807 |
+
values = [int(values)] * len(motor_names)
|
| 808 |
+
|
| 809 |
+
values = np.array(values)
|
| 810 |
+
|
| 811 |
+
motor_ids = []
|
| 812 |
+
models = []
|
| 813 |
+
for name in motor_names:
|
| 814 |
+
motor_idx, model = self.motors[name]
|
| 815 |
+
motor_ids.append(motor_idx)
|
| 816 |
+
models.append(model)
|
| 817 |
+
|
| 818 |
+
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
|
| 819 |
+
values = self.revert_calibration(values, motor_names)
|
| 820 |
+
|
| 821 |
+
values = values.tolist()
|
| 822 |
+
|
| 823 |
+
assert_same_address(self.model_ctrl_table, models, data_name)
|
| 824 |
+
addr, bytes = self.model_ctrl_table[model][data_name]
|
| 825 |
+
group_key = get_group_sync_key(data_name, motor_names)
|
| 826 |
+
|
| 827 |
+
init_group = data_name not in self.group_readers
|
| 828 |
+
if init_group:
|
| 829 |
+
self.group_writers[group_key] = dxl.GroupSyncWrite(
|
| 830 |
+
self.port_handler, self.packet_handler, addr, bytes
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
for idx, value in zip(motor_ids, values, strict=True):
|
| 834 |
+
data = convert_to_bytes(value, bytes, self.mock)
|
| 835 |
+
if init_group:
|
| 836 |
+
self.group_writers[group_key].addParam(idx, data)
|
| 837 |
+
else:
|
| 838 |
+
self.group_writers[group_key].changeParam(idx, data)
|
| 839 |
+
|
| 840 |
+
comm = self.group_writers[group_key].txPacket()
|
| 841 |
+
if comm != dxl.COMM_SUCCESS:
|
| 842 |
+
raise ConnectionError(
|
| 843 |
+
f"Write failed due to communication error on port {self.port} for group_key {group_key}: "
|
| 844 |
+
f"{self.packet_handler.getTxRxResult(comm)}"
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
# log the number of seconds it took to write the data to the motors
|
| 848 |
+
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
|
| 849 |
+
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
| 850 |
+
|
| 851 |
+
# TODO(rcadene): should we log the time before sending the write command?
|
| 852 |
+
# log the utc time when the write has been completed
|
| 853 |
+
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names)
|
| 854 |
+
self.logs[ts_utc_name] = capture_timestamp_utc()
|
| 855 |
+
|
| 856 |
+
def disconnect(self):
|
| 857 |
+
if not self.is_connected:
|
| 858 |
+
raise RobotDeviceNotConnectedError(
|
| 859 |
+
f"DynamixelMotorsBus({self.port}) is not connected. Try running `motors_bus.connect()` first."
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
if self.port_handler is not None:
|
| 863 |
+
self.port_handler.closePort()
|
| 864 |
+
self.port_handler = None
|
| 865 |
+
|
| 866 |
+
self.packet_handler = None
|
| 867 |
+
self.group_readers = {}
|
| 868 |
+
self.group_writers = {}
|
| 869 |
+
self.is_connected = False
|
| 870 |
+
|
| 871 |
+
def __del__(self):
|
| 872 |
+
if getattr(self, "is_connected", False):
|
| 873 |
+
self.disconnect()
|
lerobot/common/robot_devices/motors/feetech.py
ADDED
|
@@ -0,0 +1,898 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import enum
|
| 16 |
+
import logging
|
| 17 |
+
import math
|
| 18 |
+
import time
|
| 19 |
+
import traceback
|
| 20 |
+
from copy import deepcopy
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import tqdm
|
| 24 |
+
|
| 25 |
+
from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig
|
| 26 |
+
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
| 27 |
+
from lerobot.common.utils.utils import capture_timestamp_utc
|
| 28 |
+
|
| 29 |
+
PROTOCOL_VERSION = 0
|
| 30 |
+
BAUDRATE = 1_000_000
|
| 31 |
+
TIMEOUT_MS = 1000
|
| 32 |
+
|
| 33 |
+
MAX_ID_RANGE = 252
|
| 34 |
+
|
| 35 |
+
# The following bounds define the lower and upper joints range (after calibration).
|
| 36 |
+
# For joints in degree (i.e. revolute joints), their nominal range is [-180, 180] degrees
|
| 37 |
+
# which corresponds to a half rotation on the left and half rotation on the right.
|
| 38 |
+
# Some joints might require higher range, so we allow up to [-270, 270] degrees until
|
| 39 |
+
# an error is raised.
|
| 40 |
+
LOWER_BOUND_DEGREE = -270
|
| 41 |
+
UPPER_BOUND_DEGREE = 270
|
| 42 |
+
# For joints in percentage (i.e. joints that move linearly like the prismatic joint of a gripper),
|
| 43 |
+
# their nominal range is [0, 100] %. For instance, for Aloha gripper, 0% is fully
|
| 44 |
+
# closed, and 100% is fully open. To account for slight calibration issue, we allow up to
|
| 45 |
+
# [-10, 110] until an error is raised.
|
| 46 |
+
LOWER_BOUND_LINEAR = -10
|
| 47 |
+
UPPER_BOUND_LINEAR = 110
|
| 48 |
+
|
| 49 |
+
HALF_TURN_DEGREE = 180
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# See this link for STS3215 Memory Table:
|
| 53 |
+
# https://docs.google.com/spreadsheets/d/1GVs7W1VS1PqdhA1nW-abeyAHhTUxKUdR/edit?usp=sharing&ouid=116566590112741600240&rtpof=true&sd=true
|
| 54 |
+
# data_name: (address, size_byte)
|
| 55 |
+
SCS_SERIES_CONTROL_TABLE = {
|
| 56 |
+
"Model": (3, 2),
|
| 57 |
+
"ID": (5, 1),
|
| 58 |
+
"Baud_Rate": (6, 1),
|
| 59 |
+
"Return_Delay": (7, 1),
|
| 60 |
+
"Response_Status_Level": (8, 1),
|
| 61 |
+
"Min_Angle_Limit": (9, 2),
|
| 62 |
+
"Max_Angle_Limit": (11, 2),
|
| 63 |
+
"Max_Temperature_Limit": (13, 1),
|
| 64 |
+
"Max_Voltage_Limit": (14, 1),
|
| 65 |
+
"Min_Voltage_Limit": (15, 1),
|
| 66 |
+
"Max_Torque_Limit": (16, 2),
|
| 67 |
+
"Phase": (18, 1),
|
| 68 |
+
"Unloading_Condition": (19, 1),
|
| 69 |
+
"LED_Alarm_Condition": (20, 1),
|
| 70 |
+
"P_Coefficient": (21, 1),
|
| 71 |
+
"D_Coefficient": (22, 1),
|
| 72 |
+
"I_Coefficient": (23, 1),
|
| 73 |
+
"Minimum_Startup_Force": (24, 2),
|
| 74 |
+
"CW_Dead_Zone": (26, 1),
|
| 75 |
+
"CCW_Dead_Zone": (27, 1),
|
| 76 |
+
"Protection_Current": (28, 2),
|
| 77 |
+
"Angular_Resolution": (30, 1),
|
| 78 |
+
"Offset": (31, 2),
|
| 79 |
+
"Mode": (33, 1),
|
| 80 |
+
"Protective_Torque": (34, 1),
|
| 81 |
+
"Protection_Time": (35, 1),
|
| 82 |
+
"Overload_Torque": (36, 1),
|
| 83 |
+
"Speed_closed_loop_P_proportional_coefficient": (37, 1),
|
| 84 |
+
"Over_Current_Protection_Time": (38, 1),
|
| 85 |
+
"Velocity_closed_loop_I_integral_coefficient": (39, 1),
|
| 86 |
+
"Torque_Enable": (40, 1),
|
| 87 |
+
"Acceleration": (41, 1),
|
| 88 |
+
"Goal_Position": (42, 2),
|
| 89 |
+
"Goal_Time": (44, 2),
|
| 90 |
+
"Goal_Speed": (46, 2),
|
| 91 |
+
"Torque_Limit": (48, 2),
|
| 92 |
+
"Lock": (55, 1),
|
| 93 |
+
"Present_Position": (56, 2),
|
| 94 |
+
"Present_Speed": (58, 2),
|
| 95 |
+
"Present_Load": (60, 2),
|
| 96 |
+
"Present_Voltage": (62, 1),
|
| 97 |
+
"Present_Temperature": (63, 1),
|
| 98 |
+
"Status": (65, 1),
|
| 99 |
+
"Moving": (66, 1),
|
| 100 |
+
"Present_Current": (69, 2),
|
| 101 |
+
# Not in the Memory Table
|
| 102 |
+
"Maximum_Acceleration": (85, 2),
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
SCS_SERIES_BAUDRATE_TABLE = {
|
| 106 |
+
0: 1_000_000,
|
| 107 |
+
1: 500_000,
|
| 108 |
+
2: 250_000,
|
| 109 |
+
3: 128_000,
|
| 110 |
+
4: 115_200,
|
| 111 |
+
5: 57_600,
|
| 112 |
+
6: 38_400,
|
| 113 |
+
7: 19_200,
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"]
|
| 117 |
+
CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
MODEL_CONTROL_TABLE = {
|
| 121 |
+
"scs_series": SCS_SERIES_CONTROL_TABLE,
|
| 122 |
+
"sts3215": SCS_SERIES_CONTROL_TABLE,
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
MODEL_RESOLUTION = {
|
| 126 |
+
"scs_series": 4096,
|
| 127 |
+
"sts3215": 4096,
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
MODEL_BAUDRATE_TABLE = {
|
| 131 |
+
"scs_series": SCS_SERIES_BAUDRATE_TABLE,
|
| 132 |
+
"sts3215": SCS_SERIES_BAUDRATE_TABLE,
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
# High number of retries is needed for feetech compared to dynamixel motors.
|
| 136 |
+
NUM_READ_RETRY = 20
|
| 137 |
+
NUM_WRITE_RETRY = 20
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
|
| 141 |
+
"""This function converts the degree range to the step range for indicating motors rotation.
|
| 142 |
+
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
|
| 143 |
+
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
|
| 144 |
+
"""
|
| 145 |
+
resolutions = [MODEL_RESOLUTION[model] for model in models]
|
| 146 |
+
steps = degrees / 180 * np.array(resolutions) / 2
|
| 147 |
+
steps = steps.astype(int)
|
| 148 |
+
return steps
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def convert_to_bytes(value, bytes, mock=False):
|
| 152 |
+
if mock:
|
| 153 |
+
return value
|
| 154 |
+
|
| 155 |
+
import scservo_sdk as scs
|
| 156 |
+
|
| 157 |
+
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
| 158 |
+
# already handles it for us.
|
| 159 |
+
if bytes == 1:
|
| 160 |
+
data = [
|
| 161 |
+
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
| 162 |
+
]
|
| 163 |
+
elif bytes == 2:
|
| 164 |
+
data = [
|
| 165 |
+
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
| 166 |
+
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
|
| 167 |
+
]
|
| 168 |
+
elif bytes == 4:
|
| 169 |
+
data = [
|
| 170 |
+
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
| 171 |
+
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
|
| 172 |
+
scs.SCS_LOBYTE(scs.SCS_HIWORD(value)),
|
| 173 |
+
scs.SCS_HIBYTE(scs.SCS_HIWORD(value)),
|
| 174 |
+
]
|
| 175 |
+
else:
|
| 176 |
+
raise NotImplementedError(
|
| 177 |
+
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
| 178 |
+
f"{bytes} is provided instead."
|
| 179 |
+
)
|
| 180 |
+
return data
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def get_group_sync_key(data_name, motor_names):
|
| 184 |
+
group_key = f"{data_name}_" + "_".join(motor_names)
|
| 185 |
+
return group_key
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def get_result_name(fn_name, data_name, motor_names):
|
| 189 |
+
group_key = get_group_sync_key(data_name, motor_names)
|
| 190 |
+
rslt_name = f"{fn_name}_{group_key}"
|
| 191 |
+
return rslt_name
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def get_queue_name(fn_name, data_name, motor_names):
|
| 195 |
+
group_key = get_group_sync_key(data_name, motor_names)
|
| 196 |
+
queue_name = f"{fn_name}_{group_key}"
|
| 197 |
+
return queue_name
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def get_log_name(var_name, fn_name, data_name, motor_names):
|
| 201 |
+
group_key = get_group_sync_key(data_name, motor_names)
|
| 202 |
+
log_name = f"{var_name}_{fn_name}_{group_key}"
|
| 203 |
+
return log_name
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def assert_same_address(model_ctrl_table, motor_models, data_name):
|
| 207 |
+
all_addr = []
|
| 208 |
+
all_bytes = []
|
| 209 |
+
for model in motor_models:
|
| 210 |
+
addr, bytes = model_ctrl_table[model][data_name]
|
| 211 |
+
all_addr.append(addr)
|
| 212 |
+
all_bytes.append(bytes)
|
| 213 |
+
|
| 214 |
+
if len(set(all_addr)) != 1:
|
| 215 |
+
raise NotImplementedError(
|
| 216 |
+
f"At least two motor models use a different address for `data_name`='{data_name}' ({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer."
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
if len(set(all_bytes)) != 1:
|
| 220 |
+
raise NotImplementedError(
|
| 221 |
+
f"At least two motor models use a different bytes representation for `data_name`='{data_name}' ({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer."
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class TorqueMode(enum.Enum):
|
| 226 |
+
ENABLED = 1
|
| 227 |
+
DISABLED = 0
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class DriveMode(enum.Enum):
|
| 231 |
+
NON_INVERTED = 0
|
| 232 |
+
INVERTED = 1
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class CalibrationMode(enum.Enum):
|
| 236 |
+
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
| 237 |
+
DEGREE = 0
|
| 238 |
+
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
|
| 239 |
+
LINEAR = 1
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class JointOutOfRangeError(Exception):
|
| 243 |
+
def __init__(self, message="Joint is out of range"):
|
| 244 |
+
self.message = message
|
| 245 |
+
super().__init__(self.message)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class FeetechMotorsBus:
|
| 249 |
+
"""
|
| 250 |
+
The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on
|
| 251 |
+
the python feetech sdk to communicate with the motors. For more info, see the [feetech SDK Documentation](https://emanual.robotis.com/docs/en/software/feetech/feetech_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20).
|
| 252 |
+
|
| 253 |
+
A FeetechMotorsBus instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
|
| 254 |
+
To find the port, you can run our utility script:
|
| 255 |
+
```bash
|
| 256 |
+
python lerobot/scripts/find_motors_bus_port.py
|
| 257 |
+
>>> Finding all available ports for the MotorsBus.
|
| 258 |
+
>>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
| 259 |
+
>>> Remove the usb cable from your FeetechMotorsBus and press Enter when done.
|
| 260 |
+
>>> The port of this FeetechMotorsBus is /dev/tty.usbmodem575E0031751.
|
| 261 |
+
>>> Reconnect the usb cable.
|
| 262 |
+
```
|
| 263 |
+
|
| 264 |
+
Example of usage for 1 motor connected to the bus:
|
| 265 |
+
```python
|
| 266 |
+
motor_name = "gripper"
|
| 267 |
+
motor_index = 6
|
| 268 |
+
motor_model = "sts3215"
|
| 269 |
+
|
| 270 |
+
config = FeetechMotorsBusConfig(
|
| 271 |
+
port="/dev/tty.usbmodem575E0031751",
|
| 272 |
+
motors={motor_name: (motor_index, motor_model)},
|
| 273 |
+
)
|
| 274 |
+
motors_bus = FeetechMotorsBus(config)
|
| 275 |
+
motors_bus.connect()
|
| 276 |
+
|
| 277 |
+
position = motors_bus.read("Present_Position")
|
| 278 |
+
|
| 279 |
+
# move from a few motor steps as an example
|
| 280 |
+
few_steps = 30
|
| 281 |
+
motors_bus.write("Goal_Position", position + few_steps)
|
| 282 |
+
|
| 283 |
+
# when done, consider disconnecting
|
| 284 |
+
motors_bus.disconnect()
|
| 285 |
+
```
|
| 286 |
+
"""
|
| 287 |
+
|
| 288 |
+
def __init__(
|
| 289 |
+
self,
|
| 290 |
+
config: FeetechMotorsBusConfig,
|
| 291 |
+
):
|
| 292 |
+
self.port = config.port
|
| 293 |
+
self.motors = config.motors
|
| 294 |
+
self.mock = config.mock
|
| 295 |
+
|
| 296 |
+
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
|
| 297 |
+
self.model_resolution = deepcopy(MODEL_RESOLUTION)
|
| 298 |
+
|
| 299 |
+
self.port_handler = None
|
| 300 |
+
self.packet_handler = None
|
| 301 |
+
self.calibration = None
|
| 302 |
+
self.is_connected = False
|
| 303 |
+
self.group_readers = {}
|
| 304 |
+
self.group_writers = {}
|
| 305 |
+
self.logs = {}
|
| 306 |
+
|
| 307 |
+
self.track_positions = {}
|
| 308 |
+
|
| 309 |
+
def connect(self):
|
| 310 |
+
if self.is_connected:
|
| 311 |
+
raise RobotDeviceAlreadyConnectedError(
|
| 312 |
+
f"FeetechMotorsBus({self.port}) is already connected. Do not call `motors_bus.connect()` twice."
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
if self.mock:
|
| 316 |
+
import tests.motors.mock_scservo_sdk as scs
|
| 317 |
+
else:
|
| 318 |
+
import scservo_sdk as scs
|
| 319 |
+
|
| 320 |
+
self.port_handler = scs.PortHandler(self.port)
|
| 321 |
+
self.packet_handler = scs.PacketHandler(PROTOCOL_VERSION)
|
| 322 |
+
|
| 323 |
+
try:
|
| 324 |
+
if not self.port_handler.openPort():
|
| 325 |
+
raise OSError(f"Failed to open port '{self.port}'.")
|
| 326 |
+
except Exception:
|
| 327 |
+
traceback.print_exc()
|
| 328 |
+
print(
|
| 329 |
+
"\nTry running `python lerobot/scripts/find_motors_bus_port.py` to make sure you are using the correct port.\n"
|
| 330 |
+
)
|
| 331 |
+
raise
|
| 332 |
+
|
| 333 |
+
# Allow to read and write
|
| 334 |
+
self.is_connected = True
|
| 335 |
+
|
| 336 |
+
self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
|
| 337 |
+
|
| 338 |
+
def reconnect(self):
|
| 339 |
+
if self.mock:
|
| 340 |
+
import tests.motors.mock_scservo_sdk as scs
|
| 341 |
+
else:
|
| 342 |
+
import scservo_sdk as scs
|
| 343 |
+
|
| 344 |
+
self.port_handler = scs.PortHandler(self.port)
|
| 345 |
+
self.packet_handler = scs.PacketHandler(PROTOCOL_VERSION)
|
| 346 |
+
|
| 347 |
+
if not self.port_handler.openPort():
|
| 348 |
+
raise OSError(f"Failed to open port '{self.port}'.")
|
| 349 |
+
|
| 350 |
+
self.is_connected = True
|
| 351 |
+
|
| 352 |
+
def are_motors_configured(self):
|
| 353 |
+
# Only check the motor indices and not baudrate, since if the motor baudrates are incorrect,
|
| 354 |
+
# a ConnectionError will be raised anyway.
|
| 355 |
+
try:
|
| 356 |
+
return (self.motor_indices == self.read("ID")).all()
|
| 357 |
+
except ConnectionError as e:
|
| 358 |
+
print(e)
|
| 359 |
+
return False
|
| 360 |
+
|
| 361 |
+
def find_motor_indices(self, possible_ids=None, num_retry=2):
|
| 362 |
+
if possible_ids is None:
|
| 363 |
+
possible_ids = range(MAX_ID_RANGE)
|
| 364 |
+
|
| 365 |
+
indices = []
|
| 366 |
+
for idx in tqdm.tqdm(possible_ids):
|
| 367 |
+
try:
|
| 368 |
+
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
|
| 369 |
+
except ConnectionError:
|
| 370 |
+
continue
|
| 371 |
+
|
| 372 |
+
if idx != present_idx:
|
| 373 |
+
# sanity check
|
| 374 |
+
raise OSError(
|
| 375 |
+
"Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged."
|
| 376 |
+
)
|
| 377 |
+
indices.append(idx)
|
| 378 |
+
|
| 379 |
+
return indices
|
| 380 |
+
|
| 381 |
+
def set_bus_baudrate(self, baudrate):
|
| 382 |
+
present_bus_baudrate = self.port_handler.getBaudRate()
|
| 383 |
+
if present_bus_baudrate != baudrate:
|
| 384 |
+
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
|
| 385 |
+
self.port_handler.setBaudRate(baudrate)
|
| 386 |
+
|
| 387 |
+
if self.port_handler.getBaudRate() != baudrate:
|
| 388 |
+
raise OSError("Failed to write bus baud rate.")
|
| 389 |
+
|
| 390 |
+
@property
|
| 391 |
+
def motor_names(self) -> list[str]:
|
| 392 |
+
return list(self.motors.keys())
|
| 393 |
+
|
| 394 |
+
@property
|
| 395 |
+
def motor_models(self) -> list[str]:
|
| 396 |
+
return [model for _, model in self.motors.values()]
|
| 397 |
+
|
| 398 |
+
@property
|
| 399 |
+
def motor_indices(self) -> list[int]:
|
| 400 |
+
return [idx for idx, _ in self.motors.values()]
|
| 401 |
+
|
| 402 |
+
def set_calibration(self, calibration: dict[str, list]):
|
| 403 |
+
self.calibration = calibration
|
| 404 |
+
|
| 405 |
+
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
|
| 406 |
+
"""This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct.
|
| 407 |
+
|
| 408 |
+
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
|
| 409 |
+
"""
|
| 410 |
+
try:
|
| 411 |
+
values = self.apply_calibration(values, motor_names)
|
| 412 |
+
except JointOutOfRangeError as e:
|
| 413 |
+
print(e)
|
| 414 |
+
self.autocorrect_calibration(values, motor_names)
|
| 415 |
+
values = self.apply_calibration(values, motor_names)
|
| 416 |
+
return values
|
| 417 |
+
|
| 418 |
+
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
| 419 |
+
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
|
| 420 |
+
a "zero position" at 0 degree.
|
| 421 |
+
|
| 422 |
+
Note: We say "nominal degree range" since the motors can take values outside this range. For instance, 190 degrees, if the motor
|
| 423 |
+
rotate more than a half a turn from the zero position. However, most motors can't rotate more than 180 degrees and will stay in this range.
|
| 424 |
+
|
| 425 |
+
Joints values are original in [0, 2**32[ (unsigned int32). Each motor are expected to complete a full rotation
|
| 426 |
+
when given a goal position that is + or - their resolution. For instance, feetech xl330-m077 have a resolution of 4096, and
|
| 427 |
+
at any position in their original range, let's say the position 56734, they complete a full rotation clockwise by moving to 60830,
|
| 428 |
+
or anticlockwise by moving to 52638. The position in the original range is arbitrary and might change a lot between each motor.
|
| 429 |
+
To harmonize between motors of the same model, different robots, or even models of different brands, we propose to work
|
| 430 |
+
in the centered nominal degree range ]-180, 180[.
|
| 431 |
+
"""
|
| 432 |
+
if motor_names is None:
|
| 433 |
+
motor_names = self.motor_names
|
| 434 |
+
|
| 435 |
+
# Convert from unsigned int32 original range [0, 2**32] to signed float32 range
|
| 436 |
+
values = values.astype(np.float32)
|
| 437 |
+
|
| 438 |
+
for i, name in enumerate(motor_names):
|
| 439 |
+
calib_idx = self.calibration["motor_names"].index(name)
|
| 440 |
+
calib_mode = self.calibration["calib_mode"][calib_idx]
|
| 441 |
+
|
| 442 |
+
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
| 443 |
+
drive_mode = self.calibration["drive_mode"][calib_idx]
|
| 444 |
+
homing_offset = self.calibration["homing_offset"][calib_idx]
|
| 445 |
+
_, model = self.motors[name]
|
| 446 |
+
resolution = self.model_resolution[model]
|
| 447 |
+
|
| 448 |
+
# Update direction of rotation of the motor to match between leader and follower.
|
| 449 |
+
# In fact, the motor of the leader for a given joint can be assembled in an
|
| 450 |
+
# opposite direction in term of rotation than the motor of the follower on the same joint.
|
| 451 |
+
if drive_mode:
|
| 452 |
+
values[i] *= -1
|
| 453 |
+
|
| 454 |
+
# Convert from range [-2**31, 2**31[ to
|
| 455 |
+
# nominal range ]-resolution, resolution[ (e.g. ]-2048, 2048[)
|
| 456 |
+
values[i] += homing_offset
|
| 457 |
+
|
| 458 |
+
# Convert from range ]-resolution, resolution[ to
|
| 459 |
+
# universal float32 centered degree range ]-180, 180[
|
| 460 |
+
values[i] = values[i] / (resolution // 2) * HALF_TURN_DEGREE
|
| 461 |
+
|
| 462 |
+
if (values[i] < LOWER_BOUND_DEGREE) or (values[i] > UPPER_BOUND_DEGREE):
|
| 463 |
+
raise JointOutOfRangeError(
|
| 464 |
+
f"Wrong motor position range detected for {name}. "
|
| 465 |
+
f"Expected to be in nominal range of [-{HALF_TURN_DEGREE}, {HALF_TURN_DEGREE}] degrees (a full rotation), "
|
| 466 |
+
f"with a maximum range of [{LOWER_BOUND_DEGREE}, {UPPER_BOUND_DEGREE}] degrees to account for joints that can rotate a bit more, "
|
| 467 |
+
f"but present value is {values[i]} degree. "
|
| 468 |
+
"This might be due to a cable connection issue creating an artificial 360 degrees jump in motor values. "
|
| 469 |
+
"You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
| 473 |
+
start_pos = self.calibration["start_pos"][calib_idx]
|
| 474 |
+
end_pos = self.calibration["end_pos"][calib_idx]
|
| 475 |
+
|
| 476 |
+
# Rescale the present position to a nominal range [0, 100] %,
|
| 477 |
+
# useful for joints with linear motions like Aloha gripper
|
| 478 |
+
values[i] = (values[i] - start_pos) / (end_pos - start_pos) * 100
|
| 479 |
+
|
| 480 |
+
if (values[i] < LOWER_BOUND_LINEAR) or (values[i] > UPPER_BOUND_LINEAR):
|
| 481 |
+
raise JointOutOfRangeError(
|
| 482 |
+
f"Wrong motor position range detected for {name}. "
|
| 483 |
+
f"Expected to be in nominal range of [0, 100] % (a full linear translation), "
|
| 484 |
+
f"with a maximum range of [{LOWER_BOUND_LINEAR}, {UPPER_BOUND_LINEAR}] % to account for some imprecision during calibration, "
|
| 485 |
+
f"but present value is {values[i]} %. "
|
| 486 |
+
"This might be due to a cable connection issue creating an artificial jump in motor values. "
|
| 487 |
+
"You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
return values
|
| 491 |
+
|
| 492 |
+
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
| 493 |
+
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
|
| 494 |
+
|
| 495 |
+
Some motors might have values outside of expected maximum bounds after calibration.
|
| 496 |
+
For instance, for a joint in degree, its value can be outside [-270, 270] degrees, which is totally unexpected given
|
| 497 |
+
a nominal range of [-180, 180] degrees, which represents half a turn to the left or right starting from zero position.
|
| 498 |
+
|
| 499 |
+
Known issues:
|
| 500 |
+
#1: Motor value randomly shifts of a full turn, caused by hardware/connection errors.
|
| 501 |
+
#2: Motor internal homing offset is shifted of a full turn, caused by using default calibration (e.g Aloha).
|
| 502 |
+
#3: motor internal homing offset is shifted of less or more than a full turn, caused by using default calibration
|
| 503 |
+
or by human error during manual calibration.
|
| 504 |
+
|
| 505 |
+
Issues #1 and #2 can be solved by shifting the calibration homing offset by a full turn.
|
| 506 |
+
Issue #3 will be visually detected by user and potentially captured by the safety feature `max_relative_target`,
|
| 507 |
+
that will slow down the motor, raise an error asking to recalibrate. Manual recalibrating will solve the issue.
|
| 508 |
+
|
| 509 |
+
Note: A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
| 510 |
+
"""
|
| 511 |
+
if motor_names is None:
|
| 512 |
+
motor_names = self.motor_names
|
| 513 |
+
|
| 514 |
+
# Convert from unsigned int32 original range [0, 2**32] to signed float32 range
|
| 515 |
+
values = values.astype(np.float32)
|
| 516 |
+
|
| 517 |
+
for i, name in enumerate(motor_names):
|
| 518 |
+
calib_idx = self.calibration["motor_names"].index(name)
|
| 519 |
+
calib_mode = self.calibration["calib_mode"][calib_idx]
|
| 520 |
+
|
| 521 |
+
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
| 522 |
+
drive_mode = self.calibration["drive_mode"][calib_idx]
|
| 523 |
+
homing_offset = self.calibration["homing_offset"][calib_idx]
|
| 524 |
+
_, model = self.motors[name]
|
| 525 |
+
resolution = self.model_resolution[model]
|
| 526 |
+
|
| 527 |
+
if drive_mode:
|
| 528 |
+
values[i] *= -1
|
| 529 |
+
|
| 530 |
+
# Convert from initial range to range [-180, 180] degrees
|
| 531 |
+
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
| 532 |
+
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
|
| 533 |
+
|
| 534 |
+
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
|
| 535 |
+
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
|
| 536 |
+
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
|
| 537 |
+
# (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution
|
| 538 |
+
low_factor = (
|
| 539 |
+
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
|
| 540 |
+
) / resolution
|
| 541 |
+
upp_factor = (
|
| 542 |
+
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
|
| 543 |
+
) / resolution
|
| 544 |
+
|
| 545 |
+
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
| 546 |
+
start_pos = self.calibration["start_pos"][calib_idx]
|
| 547 |
+
end_pos = self.calibration["end_pos"][calib_idx]
|
| 548 |
+
|
| 549 |
+
# Convert from initial range to range [0, 100] in %
|
| 550 |
+
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
|
| 551 |
+
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
|
| 552 |
+
|
| 553 |
+
# Solve this inequality to find the factor to shift the range into [0, 100] %
|
| 554 |
+
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
|
| 555 |
+
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100
|
| 556 |
+
# 0 <= (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100 <= 100
|
| 557 |
+
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
|
| 558 |
+
low_factor = (start_pos - values[i]) / resolution
|
| 559 |
+
upp_factor = (end_pos - values[i]) / resolution
|
| 560 |
+
|
| 561 |
+
if not in_range:
|
| 562 |
+
# Get first integer between the two bounds
|
| 563 |
+
if low_factor < upp_factor:
|
| 564 |
+
factor = math.ceil(low_factor)
|
| 565 |
+
|
| 566 |
+
if factor > upp_factor:
|
| 567 |
+
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
| 568 |
+
else:
|
| 569 |
+
factor = math.ceil(upp_factor)
|
| 570 |
+
|
| 571 |
+
if factor > low_factor:
|
| 572 |
+
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
| 573 |
+
|
| 574 |
+
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
| 575 |
+
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
| 576 |
+
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
| 577 |
+
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
| 578 |
+
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
| 579 |
+
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
| 580 |
+
|
| 581 |
+
logging.warning(
|
| 582 |
+
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
| 583 |
+
f"from '{out_of_range_str}' to '{in_range_str}'."
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
| 587 |
+
self.calibration["homing_offset"][calib_idx] += resolution * factor
|
| 588 |
+
|
| 589 |
+
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
| 590 |
+
"""Inverse of `apply_calibration`."""
|
| 591 |
+
if motor_names is None:
|
| 592 |
+
motor_names = self.motor_names
|
| 593 |
+
|
| 594 |
+
for i, name in enumerate(motor_names):
|
| 595 |
+
calib_idx = self.calibration["motor_names"].index(name)
|
| 596 |
+
calib_mode = self.calibration["calib_mode"][calib_idx]
|
| 597 |
+
|
| 598 |
+
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
| 599 |
+
drive_mode = self.calibration["drive_mode"][calib_idx]
|
| 600 |
+
homing_offset = self.calibration["homing_offset"][calib_idx]
|
| 601 |
+
_, model = self.motors[name]
|
| 602 |
+
resolution = self.model_resolution[model]
|
| 603 |
+
|
| 604 |
+
# Convert from nominal 0-centered degree range [-180, 180] to
|
| 605 |
+
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
|
| 606 |
+
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
|
| 607 |
+
|
| 608 |
+
# Subtract the homing offsets to come back to actual motor range of values
|
| 609 |
+
# which can be arbitrary.
|
| 610 |
+
values[i] -= homing_offset
|
| 611 |
+
|
| 612 |
+
# Remove drive mode, which is the rotation direction of the motor, to come back to
|
| 613 |
+
# actual motor rotation direction which can be arbitrary.
|
| 614 |
+
if drive_mode:
|
| 615 |
+
values[i] *= -1
|
| 616 |
+
|
| 617 |
+
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
| 618 |
+
start_pos = self.calibration["start_pos"][calib_idx]
|
| 619 |
+
end_pos = self.calibration["end_pos"][calib_idx]
|
| 620 |
+
|
| 621 |
+
# Convert from nominal lnear range of [0, 100] % to
|
| 622 |
+
# actual motor range of values which can be arbitrary.
|
| 623 |
+
values[i] = values[i] / 100 * (end_pos - start_pos) + start_pos
|
| 624 |
+
|
| 625 |
+
values = np.round(values).astype(np.int32)
|
| 626 |
+
return values
|
| 627 |
+
|
| 628 |
+
def avoid_rotation_reset(self, values, motor_names, data_name):
|
| 629 |
+
if data_name not in self.track_positions:
|
| 630 |
+
self.track_positions[data_name] = {
|
| 631 |
+
"prev": [None] * len(self.motor_names),
|
| 632 |
+
# Assume False at initialization
|
| 633 |
+
"below_zero": [False] * len(self.motor_names),
|
| 634 |
+
"above_max": [False] * len(self.motor_names),
|
| 635 |
+
}
|
| 636 |
+
|
| 637 |
+
track = self.track_positions[data_name]
|
| 638 |
+
|
| 639 |
+
if motor_names is None:
|
| 640 |
+
motor_names = self.motor_names
|
| 641 |
+
|
| 642 |
+
for i, name in enumerate(motor_names):
|
| 643 |
+
idx = self.motor_names.index(name)
|
| 644 |
+
|
| 645 |
+
if track["prev"][idx] is None:
|
| 646 |
+
track["prev"][idx] = values[i]
|
| 647 |
+
continue
|
| 648 |
+
|
| 649 |
+
# Detect a full rotation occurred
|
| 650 |
+
if abs(track["prev"][idx] - values[i]) > 2048:
|
| 651 |
+
# Position went below 0 and got reset to 4095
|
| 652 |
+
if track["prev"][idx] < values[i]:
|
| 653 |
+
# So we set negative value by adding a full rotation
|
| 654 |
+
values[i] -= 4096
|
| 655 |
+
|
| 656 |
+
# Position went above 4095 and got reset to 0
|
| 657 |
+
elif track["prev"][idx] > values[i]:
|
| 658 |
+
# So we add a full rotation
|
| 659 |
+
values[i] += 4096
|
| 660 |
+
|
| 661 |
+
track["prev"][idx] = values[i]
|
| 662 |
+
|
| 663 |
+
return values
|
| 664 |
+
|
| 665 |
+
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
| 666 |
+
if self.mock:
|
| 667 |
+
import tests.motors.mock_scservo_sdk as scs
|
| 668 |
+
else:
|
| 669 |
+
import scservo_sdk as scs
|
| 670 |
+
|
| 671 |
+
return_list = True
|
| 672 |
+
if not isinstance(motor_ids, list):
|
| 673 |
+
return_list = False
|
| 674 |
+
motor_ids = [motor_ids]
|
| 675 |
+
|
| 676 |
+
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
|
| 677 |
+
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
| 678 |
+
group = scs.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
|
| 679 |
+
for idx in motor_ids:
|
| 680 |
+
group.addParam(idx)
|
| 681 |
+
|
| 682 |
+
for _ in range(num_retry):
|
| 683 |
+
comm = group.txRxPacket()
|
| 684 |
+
if comm == scs.COMM_SUCCESS:
|
| 685 |
+
break
|
| 686 |
+
|
| 687 |
+
if comm != scs.COMM_SUCCESS:
|
| 688 |
+
raise ConnectionError(
|
| 689 |
+
f"Read failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
|
| 690 |
+
f"{self.packet_handler.getTxRxResult(comm)}"
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
values = []
|
| 694 |
+
for idx in motor_ids:
|
| 695 |
+
value = group.getData(idx, addr, bytes)
|
| 696 |
+
values.append(value)
|
| 697 |
+
|
| 698 |
+
if return_list:
|
| 699 |
+
return values
|
| 700 |
+
else:
|
| 701 |
+
return values[0]
|
| 702 |
+
|
| 703 |
+
def read(self, data_name, motor_names: str | list[str] | None = None):
|
| 704 |
+
if self.mock:
|
| 705 |
+
import tests.motors.mock_scservo_sdk as scs
|
| 706 |
+
else:
|
| 707 |
+
import scservo_sdk as scs
|
| 708 |
+
|
| 709 |
+
if not self.is_connected:
|
| 710 |
+
raise RobotDeviceNotConnectedError(
|
| 711 |
+
f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
start_time = time.perf_counter()
|
| 715 |
+
|
| 716 |
+
if motor_names is None:
|
| 717 |
+
motor_names = self.motor_names
|
| 718 |
+
|
| 719 |
+
if isinstance(motor_names, str):
|
| 720 |
+
motor_names = [motor_names]
|
| 721 |
+
|
| 722 |
+
motor_ids = []
|
| 723 |
+
models = []
|
| 724 |
+
for name in motor_names:
|
| 725 |
+
motor_idx, model = self.motors[name]
|
| 726 |
+
motor_ids.append(motor_idx)
|
| 727 |
+
models.append(model)
|
| 728 |
+
|
| 729 |
+
assert_same_address(self.model_ctrl_table, models, data_name)
|
| 730 |
+
addr, bytes = self.model_ctrl_table[model][data_name]
|
| 731 |
+
group_key = get_group_sync_key(data_name, motor_names)
|
| 732 |
+
|
| 733 |
+
if data_name not in self.group_readers:
|
| 734 |
+
# Very Important to flush the buffer!
|
| 735 |
+
self.port_handler.ser.reset_output_buffer()
|
| 736 |
+
self.port_handler.ser.reset_input_buffer()
|
| 737 |
+
|
| 738 |
+
# create new group reader
|
| 739 |
+
self.group_readers[group_key] = scs.GroupSyncRead(
|
| 740 |
+
self.port_handler, self.packet_handler, addr, bytes
|
| 741 |
+
)
|
| 742 |
+
for idx in motor_ids:
|
| 743 |
+
self.group_readers[group_key].addParam(idx)
|
| 744 |
+
|
| 745 |
+
for _ in range(NUM_READ_RETRY):
|
| 746 |
+
comm = self.group_readers[group_key].txRxPacket()
|
| 747 |
+
if comm == scs.COMM_SUCCESS:
|
| 748 |
+
break
|
| 749 |
+
|
| 750 |
+
if comm != scs.COMM_SUCCESS:
|
| 751 |
+
raise ConnectionError(
|
| 752 |
+
f"Read failed due to communication error on port {self.port} for group_key {group_key}: "
|
| 753 |
+
f"{self.packet_handler.getTxRxResult(comm)}"
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
values = []
|
| 757 |
+
for idx in motor_ids:
|
| 758 |
+
value = self.group_readers[group_key].getData(idx, addr, bytes)
|
| 759 |
+
values.append(value)
|
| 760 |
+
|
| 761 |
+
values = np.array(values)
|
| 762 |
+
|
| 763 |
+
# Convert to signed int to use range [-2048, 2048] for our motor positions.
|
| 764 |
+
if data_name in CONVERT_UINT32_TO_INT32_REQUIRED:
|
| 765 |
+
values = values.astype(np.int32)
|
| 766 |
+
|
| 767 |
+
if data_name in CALIBRATION_REQUIRED:
|
| 768 |
+
values = self.avoid_rotation_reset(values, motor_names, data_name)
|
| 769 |
+
|
| 770 |
+
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
|
| 771 |
+
values = self.apply_calibration_autocorrect(values, motor_names)
|
| 772 |
+
|
| 773 |
+
# log the number of seconds it took to read the data from the motors
|
| 774 |
+
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
|
| 775 |
+
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
| 776 |
+
|
| 777 |
+
# log the utc time at which the data was received
|
| 778 |
+
ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names)
|
| 779 |
+
self.logs[ts_utc_name] = capture_timestamp_utc()
|
| 780 |
+
|
| 781 |
+
return values
|
| 782 |
+
|
| 783 |
+
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
| 784 |
+
if self.mock:
|
| 785 |
+
import tests.motors.mock_scservo_sdk as scs
|
| 786 |
+
else:
|
| 787 |
+
import scservo_sdk as scs
|
| 788 |
+
|
| 789 |
+
if not isinstance(motor_ids, list):
|
| 790 |
+
motor_ids = [motor_ids]
|
| 791 |
+
if not isinstance(values, list):
|
| 792 |
+
values = [values]
|
| 793 |
+
|
| 794 |
+
assert_same_address(self.model_ctrl_table, motor_models, data_name)
|
| 795 |
+
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
| 796 |
+
group = scs.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
|
| 797 |
+
for idx, value in zip(motor_ids, values, strict=True):
|
| 798 |
+
data = convert_to_bytes(value, bytes, self.mock)
|
| 799 |
+
group.addParam(idx, data)
|
| 800 |
+
|
| 801 |
+
for _ in range(num_retry):
|
| 802 |
+
comm = group.txPacket()
|
| 803 |
+
if comm == scs.COMM_SUCCESS:
|
| 804 |
+
break
|
| 805 |
+
|
| 806 |
+
if comm != scs.COMM_SUCCESS:
|
| 807 |
+
raise ConnectionError(
|
| 808 |
+
f"Write failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
|
| 809 |
+
f"{self.packet_handler.getTxRxResult(comm)}"
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
|
| 813 |
+
if not self.is_connected:
|
| 814 |
+
raise RobotDeviceNotConnectedError(
|
| 815 |
+
f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
start_time = time.perf_counter()
|
| 819 |
+
|
| 820 |
+
if self.mock:
|
| 821 |
+
import tests.motors.mock_scservo_sdk as scs
|
| 822 |
+
else:
|
| 823 |
+
import scservo_sdk as scs
|
| 824 |
+
|
| 825 |
+
if motor_names is None:
|
| 826 |
+
motor_names = self.motor_names
|
| 827 |
+
|
| 828 |
+
if isinstance(motor_names, str):
|
| 829 |
+
motor_names = [motor_names]
|
| 830 |
+
|
| 831 |
+
if isinstance(values, (int, float, np.integer)):
|
| 832 |
+
values = [int(values)] * len(motor_names)
|
| 833 |
+
|
| 834 |
+
values = np.array(values)
|
| 835 |
+
|
| 836 |
+
motor_ids = []
|
| 837 |
+
models = []
|
| 838 |
+
for name in motor_names:
|
| 839 |
+
motor_idx, model = self.motors[name]
|
| 840 |
+
motor_ids.append(motor_idx)
|
| 841 |
+
models.append(model)
|
| 842 |
+
|
| 843 |
+
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
|
| 844 |
+
values = self.revert_calibration(values, motor_names)
|
| 845 |
+
|
| 846 |
+
values = values.tolist()
|
| 847 |
+
|
| 848 |
+
assert_same_address(self.model_ctrl_table, models, data_name)
|
| 849 |
+
addr, bytes = self.model_ctrl_table[model][data_name]
|
| 850 |
+
group_key = get_group_sync_key(data_name, motor_names)
|
| 851 |
+
|
| 852 |
+
init_group = data_name not in self.group_readers
|
| 853 |
+
if init_group:
|
| 854 |
+
self.group_writers[group_key] = scs.GroupSyncWrite(
|
| 855 |
+
self.port_handler, self.packet_handler, addr, bytes
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
for idx, value in zip(motor_ids, values, strict=True):
|
| 859 |
+
data = convert_to_bytes(value, bytes, self.mock)
|
| 860 |
+
if init_group:
|
| 861 |
+
self.group_writers[group_key].addParam(idx, data)
|
| 862 |
+
else:
|
| 863 |
+
self.group_writers[group_key].changeParam(idx, data)
|
| 864 |
+
|
| 865 |
+
comm = self.group_writers[group_key].txPacket()
|
| 866 |
+
if comm != scs.COMM_SUCCESS:
|
| 867 |
+
raise ConnectionError(
|
| 868 |
+
f"Write failed due to communication error on port {self.port} for group_key {group_key}: "
|
| 869 |
+
f"{self.packet_handler.getTxRxResult(comm)}"
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
# log the number of seconds it took to write the data to the motors
|
| 873 |
+
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
|
| 874 |
+
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
| 875 |
+
|
| 876 |
+
# TODO(rcadene): should we log the time before sending the write command?
|
| 877 |
+
# log the utc time when the write has been completed
|
| 878 |
+
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names)
|
| 879 |
+
self.logs[ts_utc_name] = capture_timestamp_utc()
|
| 880 |
+
|
| 881 |
+
def disconnect(self):
|
| 882 |
+
if not self.is_connected:
|
| 883 |
+
raise RobotDeviceNotConnectedError(
|
| 884 |
+
f"FeetechMotorsBus({self.port}) is not connected. Try running `motors_bus.connect()` first."
|
| 885 |
+
)
|
| 886 |
+
|
| 887 |
+
if self.port_handler is not None:
|
| 888 |
+
self.port_handler.closePort()
|
| 889 |
+
self.port_handler = None
|
| 890 |
+
|
| 891 |
+
self.packet_handler = None
|
| 892 |
+
self.group_readers = {}
|
| 893 |
+
self.group_writers = {}
|
| 894 |
+
self.is_connected = False
|
| 895 |
+
|
| 896 |
+
def __del__(self):
|
| 897 |
+
if getattr(self, "is_connected", False):
|
| 898 |
+
self.disconnect()
|
lerobot/common/robot_devices/motors/utils.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Protocol
|
| 16 |
+
|
| 17 |
+
from lerobot.common.robot_devices.motors.configs import (
|
| 18 |
+
DynamixelMotorsBusConfig,
|
| 19 |
+
FeetechMotorsBusConfig,
|
| 20 |
+
MotorsBusConfig,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MotorsBus(Protocol):
|
| 25 |
+
def motor_names(self): ...
|
| 26 |
+
def set_calibration(self): ...
|
| 27 |
+
def apply_calibration(self): ...
|
| 28 |
+
def revert_calibration(self): ...
|
| 29 |
+
def read(self): ...
|
| 30 |
+
def write(self): ...
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig]) -> list[MotorsBus]:
|
| 34 |
+
motors_buses = {}
|
| 35 |
+
|
| 36 |
+
for key, cfg in motors_bus_configs.items():
|
| 37 |
+
if cfg.type == "dynamixel":
|
| 38 |
+
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
| 39 |
+
|
| 40 |
+
motors_buses[key] = DynamixelMotorsBus(cfg)
|
| 41 |
+
|
| 42 |
+
elif cfg.type == "feetech":
|
| 43 |
+
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
|
| 44 |
+
|
| 45 |
+
motors_buses[key] = FeetechMotorsBus(cfg)
|
| 46 |
+
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
| 49 |
+
|
| 50 |
+
return motors_buses
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
|
| 54 |
+
if motor_type == "dynamixel":
|
| 55 |
+
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
| 56 |
+
|
| 57 |
+
config = DynamixelMotorsBusConfig(**kwargs)
|
| 58 |
+
return DynamixelMotorsBus(config)
|
| 59 |
+
|
| 60 |
+
elif motor_type == "feetech":
|
| 61 |
+
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
|
| 62 |
+
|
| 63 |
+
config = FeetechMotorsBusConfig(**kwargs)
|
| 64 |
+
return FeetechMotorsBus(config)
|
| 65 |
+
|
| 66 |
+
else:
|
| 67 |
+
raise ValueError(f"The motor type '{motor_type}' is not valid.")
|
lerobot/common/robot_devices/robots/configs.py
ADDED
|
@@ -0,0 +1,613 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import abc
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from typing import Sequence
|
| 18 |
+
|
| 19 |
+
import draccus
|
| 20 |
+
|
| 21 |
+
from lerobot.common.robot_devices.cameras.configs import (
|
| 22 |
+
CameraConfig,
|
| 23 |
+
IntelRealSenseCameraConfig,
|
| 24 |
+
OpenCVCameraConfig,
|
| 25 |
+
)
|
| 26 |
+
from lerobot.common.robot_devices.motors.configs import (
|
| 27 |
+
DynamixelMotorsBusConfig,
|
| 28 |
+
FeetechMotorsBusConfig,
|
| 29 |
+
MotorsBusConfig,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
|
| 35 |
+
@property
|
| 36 |
+
def type(self) -> str:
|
| 37 |
+
return self.get_choice_name(self.__class__)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# TODO(rcadene, aliberts): remove ManipulatorRobotConfig abstraction
|
| 41 |
+
@dataclass
|
| 42 |
+
class ManipulatorRobotConfig(RobotConfig):
|
| 43 |
+
leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
|
| 44 |
+
follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
|
| 45 |
+
cameras: dict[str, CameraConfig] = field(default_factory=lambda: {})
|
| 46 |
+
|
| 47 |
+
# Optionally limit the magnitude of the relative positional target vector for safety purposes.
|
| 48 |
+
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length
|
| 49 |
+
# as the number of motors in your follower arms (assumes all follower arms have the same number of
|
| 50 |
+
# motors).
|
| 51 |
+
max_relative_target: list[float] | float | None = None
|
| 52 |
+
|
| 53 |
+
# Optionally set the leader arm in torque mode with the gripper motor set to this angle. This makes it
|
| 54 |
+
# possible to squeeze the gripper and have it spring back to an open position on its own. If None, the
|
| 55 |
+
# gripper is not put in torque mode.
|
| 56 |
+
gripper_open_degree: float | None = None
|
| 57 |
+
|
| 58 |
+
mock: bool = False
|
| 59 |
+
|
| 60 |
+
def __post_init__(self):
|
| 61 |
+
if self.mock:
|
| 62 |
+
for arm in self.leader_arms.values():
|
| 63 |
+
if not arm.mock:
|
| 64 |
+
arm.mock = True
|
| 65 |
+
for arm in self.follower_arms.values():
|
| 66 |
+
if not arm.mock:
|
| 67 |
+
arm.mock = True
|
| 68 |
+
for cam in self.cameras.values():
|
| 69 |
+
if not cam.mock:
|
| 70 |
+
cam.mock = True
|
| 71 |
+
|
| 72 |
+
if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
|
| 73 |
+
for name in self.follower_arms:
|
| 74 |
+
if len(self.follower_arms[name].motors) != len(self.max_relative_target):
|
| 75 |
+
raise ValueError(
|
| 76 |
+
f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has "
|
| 77 |
+
f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "
|
| 78 |
+
f"`max_relative_target` list has as many parameters as there are motors per arm. "
|
| 79 |
+
"Note: This feature does not yet work with robots where different follower arms have "
|
| 80 |
+
"different numbers of motors."
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@RobotConfig.register_subclass("aloha")
|
| 85 |
+
@dataclass
|
| 86 |
+
class AlohaRobotConfig(ManipulatorRobotConfig):
|
| 87 |
+
# Specific to Aloha, LeRobot comes with default calibration files. Assuming the motors have been
|
| 88 |
+
# properly assembled, no manual calibration step is expected. If you need to run manual calibration,
|
| 89 |
+
# simply update this path to ".cache/calibration/aloha"
|
| 90 |
+
calibration_dir: str = ".cache/calibration/aloha_default"
|
| 91 |
+
|
| 92 |
+
# /!\ FOR SAFETY, READ THIS /!\
|
| 93 |
+
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
| 94 |
+
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
| 95 |
+
# the number of motors in your follower arms.
|
| 96 |
+
# For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default.
|
| 97 |
+
# When you feel more confident with teleoperation or running the policy, you can extend
|
| 98 |
+
# this safety limit and even removing it by setting it to `null`.
|
| 99 |
+
# Also, everything is expected to work safely out-of-the-box, but we highly advise to
|
| 100 |
+
# first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml),
|
| 101 |
+
# then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully
|
| 102 |
+
max_relative_target: int | None = 5
|
| 103 |
+
|
| 104 |
+
leader_arms: dict[str, MotorsBusConfig] = field(
|
| 105 |
+
default_factory=lambda: {
|
| 106 |
+
"left": DynamixelMotorsBusConfig(
|
| 107 |
+
# window_x
|
| 108 |
+
port="/dev/ttyDXL_leader_left",
|
| 109 |
+
motors={
|
| 110 |
+
# name: (index, model)
|
| 111 |
+
"waist": [1, "xm430-w350"],
|
| 112 |
+
"shoulder": [2, "xm430-w350"],
|
| 113 |
+
"shoulder_shadow": [3, "xm430-w350"],
|
| 114 |
+
"elbow": [4, "xm430-w350"],
|
| 115 |
+
"elbow_shadow": [5, "xm430-w350"],
|
| 116 |
+
"forearm_roll": [6, "xm430-w350"],
|
| 117 |
+
"wrist_angle": [7, "xm430-w350"],
|
| 118 |
+
"wrist_rotate": [8, "xl430-w250"],
|
| 119 |
+
"gripper": [9, "xc430-w150"],
|
| 120 |
+
},
|
| 121 |
+
),
|
| 122 |
+
"right": DynamixelMotorsBusConfig(
|
| 123 |
+
# window_x
|
| 124 |
+
port="/dev/ttyDXL_leader_right",
|
| 125 |
+
motors={
|
| 126 |
+
# name: (index, model)
|
| 127 |
+
"waist": [1, "xm430-w350"],
|
| 128 |
+
"shoulder": [2, "xm430-w350"],
|
| 129 |
+
"shoulder_shadow": [3, "xm430-w350"],
|
| 130 |
+
"elbow": [4, "xm430-w350"],
|
| 131 |
+
"elbow_shadow": [5, "xm430-w350"],
|
| 132 |
+
"forearm_roll": [6, "xm430-w350"],
|
| 133 |
+
"wrist_angle": [7, "xm430-w350"],
|
| 134 |
+
"wrist_rotate": [8, "xl430-w250"],
|
| 135 |
+
"gripper": [9, "xc430-w150"],
|
| 136 |
+
},
|
| 137 |
+
),
|
| 138 |
+
}
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
follower_arms: dict[str, MotorsBusConfig] = field(
|
| 142 |
+
default_factory=lambda: {
|
| 143 |
+
"left": DynamixelMotorsBusConfig(
|
| 144 |
+
port="/dev/ttyDXL_follower_left",
|
| 145 |
+
motors={
|
| 146 |
+
# name: (index, model)
|
| 147 |
+
"waist": [1, "xm540-w270"],
|
| 148 |
+
"shoulder": [2, "xm540-w270"],
|
| 149 |
+
"shoulder_shadow": [3, "xm540-w270"],
|
| 150 |
+
"elbow": [4, "xm540-w270"],
|
| 151 |
+
"elbow_shadow": [5, "xm540-w270"],
|
| 152 |
+
"forearm_roll": [6, "xm540-w270"],
|
| 153 |
+
"wrist_angle": [7, "xm540-w270"],
|
| 154 |
+
"wrist_rotate": [8, "xm430-w350"],
|
| 155 |
+
"gripper": [9, "xm430-w350"],
|
| 156 |
+
},
|
| 157 |
+
),
|
| 158 |
+
"right": DynamixelMotorsBusConfig(
|
| 159 |
+
port="/dev/ttyDXL_follower_right",
|
| 160 |
+
motors={
|
| 161 |
+
# name: (index, model)
|
| 162 |
+
"waist": [1, "xm540-w270"],
|
| 163 |
+
"shoulder": [2, "xm540-w270"],
|
| 164 |
+
"shoulder_shadow": [3, "xm540-w270"],
|
| 165 |
+
"elbow": [4, "xm540-w270"],
|
| 166 |
+
"elbow_shadow": [5, "xm540-w270"],
|
| 167 |
+
"forearm_roll": [6, "xm540-w270"],
|
| 168 |
+
"wrist_angle": [7, "xm540-w270"],
|
| 169 |
+
"wrist_rotate": [8, "xm430-w350"],
|
| 170 |
+
"gripper": [9, "xm430-w350"],
|
| 171 |
+
},
|
| 172 |
+
),
|
| 173 |
+
}
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Troubleshooting: If one of your IntelRealSense cameras freeze during
|
| 177 |
+
# data recording due to bandwidth limit, you might need to plug the camera
|
| 178 |
+
# on another USB hub or PCIe card.
|
| 179 |
+
cameras: dict[str, CameraConfig] = field(
|
| 180 |
+
default_factory=lambda: {
|
| 181 |
+
"cam_high": IntelRealSenseCameraConfig(
|
| 182 |
+
serial_number=128422271347,
|
| 183 |
+
fps=30,
|
| 184 |
+
width=640,
|
| 185 |
+
height=480,
|
| 186 |
+
),
|
| 187 |
+
"cam_low": IntelRealSenseCameraConfig(
|
| 188 |
+
serial_number=130322270656,
|
| 189 |
+
fps=30,
|
| 190 |
+
width=640,
|
| 191 |
+
height=480,
|
| 192 |
+
),
|
| 193 |
+
"cam_left_wrist": IntelRealSenseCameraConfig(
|
| 194 |
+
serial_number=218622272670,
|
| 195 |
+
fps=30,
|
| 196 |
+
width=640,
|
| 197 |
+
height=480,
|
| 198 |
+
),
|
| 199 |
+
"cam_right_wrist": IntelRealSenseCameraConfig(
|
| 200 |
+
serial_number=130322272300,
|
| 201 |
+
fps=30,
|
| 202 |
+
width=640,
|
| 203 |
+
height=480,
|
| 204 |
+
),
|
| 205 |
+
}
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
mock: bool = False
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
@RobotConfig.register_subclass("koch")
|
| 212 |
+
@dataclass
|
| 213 |
+
class KochRobotConfig(ManipulatorRobotConfig):
|
| 214 |
+
calibration_dir: str = ".cache/calibration/koch"
|
| 215 |
+
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
| 216 |
+
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
| 217 |
+
# the number of motors in your follower arms.
|
| 218 |
+
max_relative_target: int | None = None
|
| 219 |
+
|
| 220 |
+
leader_arms: dict[str, MotorsBusConfig] = field(
|
| 221 |
+
default_factory=lambda: {
|
| 222 |
+
"main": DynamixelMotorsBusConfig(
|
| 223 |
+
port="/dev/tty.usbmodem585A0085511",
|
| 224 |
+
motors={
|
| 225 |
+
# name: (index, model)
|
| 226 |
+
"shoulder_pan": [1, "xl330-m077"],
|
| 227 |
+
"shoulder_lift": [2, "xl330-m077"],
|
| 228 |
+
"elbow_flex": [3, "xl330-m077"],
|
| 229 |
+
"wrist_flex": [4, "xl330-m077"],
|
| 230 |
+
"wrist_roll": [5, "xl330-m077"],
|
| 231 |
+
"gripper": [6, "xl330-m077"],
|
| 232 |
+
},
|
| 233 |
+
),
|
| 234 |
+
}
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
follower_arms: dict[str, MotorsBusConfig] = field(
|
| 238 |
+
default_factory=lambda: {
|
| 239 |
+
"main": DynamixelMotorsBusConfig(
|
| 240 |
+
port="/dev/tty.usbmodem585A0076891",
|
| 241 |
+
motors={
|
| 242 |
+
# name: (index, model)
|
| 243 |
+
"shoulder_pan": [1, "xl430-w250"],
|
| 244 |
+
"shoulder_lift": [2, "xl430-w250"],
|
| 245 |
+
"elbow_flex": [3, "xl330-m288"],
|
| 246 |
+
"wrist_flex": [4, "xl330-m288"],
|
| 247 |
+
"wrist_roll": [5, "xl330-m288"],
|
| 248 |
+
"gripper": [6, "xl330-m288"],
|
| 249 |
+
},
|
| 250 |
+
),
|
| 251 |
+
}
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
cameras: dict[str, CameraConfig] = field(
|
| 255 |
+
default_factory=lambda: {
|
| 256 |
+
"laptop": OpenCVCameraConfig(
|
| 257 |
+
camera_index=0,
|
| 258 |
+
fps=30,
|
| 259 |
+
width=640,
|
| 260 |
+
height=480,
|
| 261 |
+
),
|
| 262 |
+
"phone": OpenCVCameraConfig(
|
| 263 |
+
camera_index=1,
|
| 264 |
+
fps=30,
|
| 265 |
+
width=640,
|
| 266 |
+
height=480,
|
| 267 |
+
),
|
| 268 |
+
}
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# ~ Koch specific settings ~
|
| 272 |
+
# Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible
|
| 273 |
+
# to squeeze the gripper and have it spring back to an open position on its own.
|
| 274 |
+
gripper_open_degree: float = 35.156
|
| 275 |
+
|
| 276 |
+
mock: bool = False
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
@RobotConfig.register_subclass("koch_bimanual")
|
| 280 |
+
@dataclass
|
| 281 |
+
class KochBimanualRobotConfig(ManipulatorRobotConfig):
|
| 282 |
+
calibration_dir: str = ".cache/calibration/koch_bimanual"
|
| 283 |
+
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
| 284 |
+
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
| 285 |
+
# the number of motors in your follower arms.
|
| 286 |
+
max_relative_target: int | None = None
|
| 287 |
+
|
| 288 |
+
leader_arms: dict[str, MotorsBusConfig] = field(
|
| 289 |
+
default_factory=lambda: {
|
| 290 |
+
"left": DynamixelMotorsBusConfig(
|
| 291 |
+
port="/dev/tty.usbmodem585A0085511",
|
| 292 |
+
motors={
|
| 293 |
+
# name: (index, model)
|
| 294 |
+
"shoulder_pan": [1, "xl330-m077"],
|
| 295 |
+
"shoulder_lift": [2, "xl330-m077"],
|
| 296 |
+
"elbow_flex": [3, "xl330-m077"],
|
| 297 |
+
"wrist_flex": [4, "xl330-m077"],
|
| 298 |
+
"wrist_roll": [5, "xl330-m077"],
|
| 299 |
+
"gripper": [6, "xl330-m077"],
|
| 300 |
+
},
|
| 301 |
+
),
|
| 302 |
+
"right": DynamixelMotorsBusConfig(
|
| 303 |
+
port="/dev/tty.usbmodem575E0031751",
|
| 304 |
+
motors={
|
| 305 |
+
# name: (index, model)
|
| 306 |
+
"shoulder_pan": [1, "xl330-m077"],
|
| 307 |
+
"shoulder_lift": [2, "xl330-m077"],
|
| 308 |
+
"elbow_flex": [3, "xl330-m077"],
|
| 309 |
+
"wrist_flex": [4, "xl330-m077"],
|
| 310 |
+
"wrist_roll": [5, "xl330-m077"],
|
| 311 |
+
"gripper": [6, "xl330-m077"],
|
| 312 |
+
},
|
| 313 |
+
),
|
| 314 |
+
}
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
follower_arms: dict[str, MotorsBusConfig] = field(
|
| 318 |
+
default_factory=lambda: {
|
| 319 |
+
"left": DynamixelMotorsBusConfig(
|
| 320 |
+
port="/dev/tty.usbmodem585A0076891",
|
| 321 |
+
motors={
|
| 322 |
+
# name: (index, model)
|
| 323 |
+
"shoulder_pan": [1, "xl430-w250"],
|
| 324 |
+
"shoulder_lift": [2, "xl430-w250"],
|
| 325 |
+
"elbow_flex": [3, "xl330-m288"],
|
| 326 |
+
"wrist_flex": [4, "xl330-m288"],
|
| 327 |
+
"wrist_roll": [5, "xl330-m288"],
|
| 328 |
+
"gripper": [6, "xl330-m288"],
|
| 329 |
+
},
|
| 330 |
+
),
|
| 331 |
+
"right": DynamixelMotorsBusConfig(
|
| 332 |
+
port="/dev/tty.usbmodem575E0032081",
|
| 333 |
+
motors={
|
| 334 |
+
# name: (index, model)
|
| 335 |
+
"shoulder_pan": [1, "xl430-w250"],
|
| 336 |
+
"shoulder_lift": [2, "xl430-w250"],
|
| 337 |
+
"elbow_flex": [3, "xl330-m288"],
|
| 338 |
+
"wrist_flex": [4, "xl330-m288"],
|
| 339 |
+
"wrist_roll": [5, "xl330-m288"],
|
| 340 |
+
"gripper": [6, "xl330-m288"],
|
| 341 |
+
},
|
| 342 |
+
),
|
| 343 |
+
}
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
cameras: dict[str, CameraConfig] = field(
|
| 347 |
+
default_factory=lambda: {
|
| 348 |
+
"laptop": OpenCVCameraConfig(
|
| 349 |
+
camera_index=0,
|
| 350 |
+
fps=30,
|
| 351 |
+
width=640,
|
| 352 |
+
height=480,
|
| 353 |
+
),
|
| 354 |
+
"phone": OpenCVCameraConfig(
|
| 355 |
+
camera_index=1,
|
| 356 |
+
fps=30,
|
| 357 |
+
width=640,
|
| 358 |
+
height=480,
|
| 359 |
+
),
|
| 360 |
+
}
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
# ~ Koch specific settings ~
|
| 364 |
+
# Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible
|
| 365 |
+
# to squeeze the gripper and have it spring back to an open position on its own.
|
| 366 |
+
gripper_open_degree: float = 35.156
|
| 367 |
+
|
| 368 |
+
mock: bool = False
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
@RobotConfig.register_subclass("moss")
|
| 372 |
+
@dataclass
|
| 373 |
+
class MossRobotConfig(ManipulatorRobotConfig):
|
| 374 |
+
calibration_dir: str = ".cache/calibration/moss"
|
| 375 |
+
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
| 376 |
+
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
| 377 |
+
# the number of motors in your follower arms.
|
| 378 |
+
max_relative_target: int | None = None
|
| 379 |
+
|
| 380 |
+
leader_arms: dict[str, MotorsBusConfig] = field(
|
| 381 |
+
default_factory=lambda: {
|
| 382 |
+
"main": FeetechMotorsBusConfig(
|
| 383 |
+
port="/dev/tty.usbmodem58760431091",
|
| 384 |
+
motors={
|
| 385 |
+
# name: (index, model)
|
| 386 |
+
"shoulder_pan": [1, "sts3215"],
|
| 387 |
+
"shoulder_lift": [2, "sts3215"],
|
| 388 |
+
"elbow_flex": [3, "sts3215"],
|
| 389 |
+
"wrist_flex": [4, "sts3215"],
|
| 390 |
+
"wrist_roll": [5, "sts3215"],
|
| 391 |
+
"gripper": [6, "sts3215"],
|
| 392 |
+
},
|
| 393 |
+
),
|
| 394 |
+
}
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
follower_arms: dict[str, MotorsBusConfig] = field(
|
| 398 |
+
default_factory=lambda: {
|
| 399 |
+
"main": FeetechMotorsBusConfig(
|
| 400 |
+
port="/dev/tty.usbmodem585A0076891",
|
| 401 |
+
motors={
|
| 402 |
+
# name: (index, model)
|
| 403 |
+
"shoulder_pan": [1, "sts3215"],
|
| 404 |
+
"shoulder_lift": [2, "sts3215"],
|
| 405 |
+
"elbow_flex": [3, "sts3215"],
|
| 406 |
+
"wrist_flex": [4, "sts3215"],
|
| 407 |
+
"wrist_roll": [5, "sts3215"],
|
| 408 |
+
"gripper": [6, "sts3215"],
|
| 409 |
+
},
|
| 410 |
+
),
|
| 411 |
+
}
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
cameras: dict[str, CameraConfig] = field(
|
| 415 |
+
default_factory=lambda: {
|
| 416 |
+
"laptop": OpenCVCameraConfig(
|
| 417 |
+
camera_index=0,
|
| 418 |
+
fps=30,
|
| 419 |
+
width=640,
|
| 420 |
+
height=480,
|
| 421 |
+
),
|
| 422 |
+
"phone": OpenCVCameraConfig(
|
| 423 |
+
camera_index=1,
|
| 424 |
+
fps=30,
|
| 425 |
+
width=640,
|
| 426 |
+
height=480,
|
| 427 |
+
),
|
| 428 |
+
}
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
mock: bool = False
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
@RobotConfig.register_subclass("so100")
|
| 435 |
+
@dataclass
|
| 436 |
+
class So100RobotConfig(ManipulatorRobotConfig):
|
| 437 |
+
calibration_dir: str = ".cache/calibration/so100"
|
| 438 |
+
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
| 439 |
+
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
| 440 |
+
# the number of motors in your follower arms.
|
| 441 |
+
max_relative_target: int | None = None
|
| 442 |
+
|
| 443 |
+
leader_arms: dict[str, MotorsBusConfig] = field(
|
| 444 |
+
default_factory=lambda: {
|
| 445 |
+
"main": FeetechMotorsBusConfig(
|
| 446 |
+
port="/dev/tty.usbmodem58760431091",
|
| 447 |
+
motors={
|
| 448 |
+
# name: (index, model)
|
| 449 |
+
"shoulder_pan": [1, "sts3215"],
|
| 450 |
+
"shoulder_lift": [2, "sts3215"],
|
| 451 |
+
"elbow_flex": [3, "sts3215"],
|
| 452 |
+
"wrist_flex": [4, "sts3215"],
|
| 453 |
+
"wrist_roll": [5, "sts3215"],
|
| 454 |
+
"gripper": [6, "sts3215"],
|
| 455 |
+
},
|
| 456 |
+
),
|
| 457 |
+
}
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
follower_arms: dict[str, MotorsBusConfig] = field(
|
| 461 |
+
default_factory=lambda: {
|
| 462 |
+
"main": FeetechMotorsBusConfig(
|
| 463 |
+
port="/dev/tty.usbmodem585A0076891",
|
| 464 |
+
motors={
|
| 465 |
+
# name: (index, model)
|
| 466 |
+
"shoulder_pan": [1, "sts3215"],
|
| 467 |
+
"shoulder_lift": [2, "sts3215"],
|
| 468 |
+
"elbow_flex": [3, "sts3215"],
|
| 469 |
+
"wrist_flex": [4, "sts3215"],
|
| 470 |
+
"wrist_roll": [5, "sts3215"],
|
| 471 |
+
"gripper": [6, "sts3215"],
|
| 472 |
+
},
|
| 473 |
+
),
|
| 474 |
+
}
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
cameras: dict[str, CameraConfig] = field(
|
| 478 |
+
default_factory=lambda: {
|
| 479 |
+
"laptop": OpenCVCameraConfig(
|
| 480 |
+
camera_index=0,
|
| 481 |
+
fps=30,
|
| 482 |
+
width=640,
|
| 483 |
+
height=480,
|
| 484 |
+
),
|
| 485 |
+
"phone": OpenCVCameraConfig(
|
| 486 |
+
camera_index=1,
|
| 487 |
+
fps=30,
|
| 488 |
+
width=640,
|
| 489 |
+
height=480,
|
| 490 |
+
),
|
| 491 |
+
}
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
mock: bool = False
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
@RobotConfig.register_subclass("stretch")
|
| 498 |
+
@dataclass
|
| 499 |
+
class StretchRobotConfig(RobotConfig):
|
| 500 |
+
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
| 501 |
+
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
| 502 |
+
# the number of motors in your follower arms.
|
| 503 |
+
max_relative_target: int | None = None
|
| 504 |
+
|
| 505 |
+
cameras: dict[str, CameraConfig] = field(
|
| 506 |
+
default_factory=lambda: {
|
| 507 |
+
"navigation": OpenCVCameraConfig(
|
| 508 |
+
camera_index="/dev/hello-nav-head-camera",
|
| 509 |
+
fps=10,
|
| 510 |
+
width=1280,
|
| 511 |
+
height=720,
|
| 512 |
+
rotation=-90,
|
| 513 |
+
),
|
| 514 |
+
"head": IntelRealSenseCameraConfig(
|
| 515 |
+
name="Intel RealSense D435I",
|
| 516 |
+
fps=30,
|
| 517 |
+
width=640,
|
| 518 |
+
height=480,
|
| 519 |
+
rotation=90,
|
| 520 |
+
),
|
| 521 |
+
"wrist": IntelRealSenseCameraConfig(
|
| 522 |
+
name="Intel RealSense D405",
|
| 523 |
+
fps=30,
|
| 524 |
+
width=640,
|
| 525 |
+
height=480,
|
| 526 |
+
),
|
| 527 |
+
}
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
mock: bool = False
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
@RobotConfig.register_subclass("lekiwi")
|
| 534 |
+
@dataclass
|
| 535 |
+
class LeKiwiRobotConfig(RobotConfig):
|
| 536 |
+
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
| 537 |
+
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
| 538 |
+
# the number of motors in your follower arms.
|
| 539 |
+
max_relative_target: int | None = None
|
| 540 |
+
|
| 541 |
+
# Network Configuration
|
| 542 |
+
ip: str = "192.168.0.193"
|
| 543 |
+
port: int = 5555
|
| 544 |
+
video_port: int = 5556
|
| 545 |
+
|
| 546 |
+
cameras: dict[str, CameraConfig] = field(
|
| 547 |
+
default_factory=lambda: {
|
| 548 |
+
"front": OpenCVCameraConfig(
|
| 549 |
+
camera_index="/dev/video0", fps=30, width=640, height=480, rotation=90
|
| 550 |
+
),
|
| 551 |
+
"wrist": OpenCVCameraConfig(
|
| 552 |
+
camera_index="/dev/video2", fps=30, width=640, height=480, rotation=180
|
| 553 |
+
),
|
| 554 |
+
}
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
calibration_dir: str = ".cache/calibration/lekiwi"
|
| 558 |
+
|
| 559 |
+
leader_arms: dict[str, MotorsBusConfig] = field(
|
| 560 |
+
default_factory=lambda: {
|
| 561 |
+
"main": FeetechMotorsBusConfig(
|
| 562 |
+
port="/dev/tty.usbmodem585A0077581",
|
| 563 |
+
motors={
|
| 564 |
+
# name: (index, model)
|
| 565 |
+
"shoulder_pan": [1, "sts3215"],
|
| 566 |
+
"shoulder_lift": [2, "sts3215"],
|
| 567 |
+
"elbow_flex": [3, "sts3215"],
|
| 568 |
+
"wrist_flex": [4, "sts3215"],
|
| 569 |
+
"wrist_roll": [5, "sts3215"],
|
| 570 |
+
"gripper": [6, "sts3215"],
|
| 571 |
+
},
|
| 572 |
+
),
|
| 573 |
+
}
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
follower_arms: dict[str, MotorsBusConfig] = field(
|
| 577 |
+
default_factory=lambda: {
|
| 578 |
+
"main": FeetechMotorsBusConfig(
|
| 579 |
+
port="/dev/ttyACM0",
|
| 580 |
+
motors={
|
| 581 |
+
# name: (index, model)
|
| 582 |
+
"shoulder_pan": [1, "sts3215"],
|
| 583 |
+
"shoulder_lift": [2, "sts3215"],
|
| 584 |
+
"elbow_flex": [3, "sts3215"],
|
| 585 |
+
"wrist_flex": [4, "sts3215"],
|
| 586 |
+
"wrist_roll": [5, "sts3215"],
|
| 587 |
+
"gripper": [6, "sts3215"],
|
| 588 |
+
"left_wheel": (7, "sts3215"),
|
| 589 |
+
"back_wheel": (8, "sts3215"),
|
| 590 |
+
"right_wheel": (9, "sts3215"),
|
| 591 |
+
},
|
| 592 |
+
),
|
| 593 |
+
}
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
teleop_keys: dict[str, str] = field(
|
| 597 |
+
default_factory=lambda: {
|
| 598 |
+
# Movement
|
| 599 |
+
"forward": "w",
|
| 600 |
+
"backward": "s",
|
| 601 |
+
"left": "a",
|
| 602 |
+
"right": "d",
|
| 603 |
+
"rotate_left": "z",
|
| 604 |
+
"rotate_right": "x",
|
| 605 |
+
# Speed control
|
| 606 |
+
"speed_up": "r",
|
| 607 |
+
"speed_down": "f",
|
| 608 |
+
# quit teleop
|
| 609 |
+
"quit": "q",
|
| 610 |
+
}
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
mock: bool = False
|
lerobot/common/robot_devices/robots/dynamixel_calibration.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Logic to calibrate a robot arm built with dynamixel motors"""
|
| 16 |
+
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
from lerobot.common.robot_devices.motors.dynamixel import (
|
| 21 |
+
CalibrationMode,
|
| 22 |
+
TorqueMode,
|
| 23 |
+
convert_degrees_to_steps,
|
| 24 |
+
)
|
| 25 |
+
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
| 26 |
+
|
| 27 |
+
URL_TEMPLATE = (
|
| 28 |
+
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# The following positions are provided in nominal degree range ]-180, +180[
|
| 32 |
+
# For more info on these constants, see comments in the code where they get used.
|
| 33 |
+
ZERO_POSITION_DEGREE = 0
|
| 34 |
+
ROTATED_POSITION_DEGREE = 90
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def assert_drive_mode(drive_mode):
|
| 38 |
+
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
|
| 39 |
+
if not np.all(np.isin(drive_mode, [0, 1])):
|
| 40 |
+
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def apply_drive_mode(position, drive_mode):
|
| 44 |
+
assert_drive_mode(drive_mode)
|
| 45 |
+
# Convert `drive_mode` from [0, 1] with 0 indicates original rotation direction and 1 inverted,
|
| 46 |
+
# to [-1, 1] with 1 indicates original rotation direction and -1 inverted.
|
| 47 |
+
signed_drive_mode = -(drive_mode * 2 - 1)
|
| 48 |
+
position *= signed_drive_mode
|
| 49 |
+
return position
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def compute_nearest_rounded_position(position, models):
|
| 53 |
+
delta_turn = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, models)
|
| 54 |
+
nearest_pos = np.round(position.astype(float) / delta_turn) * delta_turn
|
| 55 |
+
return nearest_pos.astype(position.dtype)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
| 59 |
+
"""This function ensures that a neural network trained on data collected on a given robot
|
| 60 |
+
can work on another robot. For instance before calibration, setting a same goal position
|
| 61 |
+
for each motor of two different robots will get two very different positions. But after calibration,
|
| 62 |
+
the two robots will move to the same position.To this end, this function computes the homing offset
|
| 63 |
+
and the drive mode for each motor of a given robot.
|
| 64 |
+
|
| 65 |
+
Homing offset is used to shift the motor position to a ]-2048, +2048[ nominal range (when the motor uses 2048 steps
|
| 66 |
+
to complete a half a turn). This range is set around an arbitrary "zero position" corresponding to all motor positions
|
| 67 |
+
being 0. During the calibration process, you will need to manually move the robot to this "zero position".
|
| 68 |
+
|
| 69 |
+
Drive mode is used to invert the rotation direction of the motor. This is useful when some motors have been assembled
|
| 70 |
+
in the opposite orientation for some robots. During the calibration process, you will need to manually move the robot
|
| 71 |
+
to the "rotated position".
|
| 72 |
+
|
| 73 |
+
After calibration, the homing offsets and drive modes are stored in a cache.
|
| 74 |
+
|
| 75 |
+
Example of usage:
|
| 76 |
+
```python
|
| 77 |
+
run_arm_calibration(arm, "koch", "left", "follower")
|
| 78 |
+
```
|
| 79 |
+
"""
|
| 80 |
+
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
| 81 |
+
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
| 82 |
+
|
| 83 |
+
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
| 84 |
+
|
| 85 |
+
print("\nMove arm to zero position")
|
| 86 |
+
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
|
| 87 |
+
input("Press Enter to continue...")
|
| 88 |
+
|
| 89 |
+
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
|
| 90 |
+
# It is easy to identify and all motors are in a "quarter turn" position. Once calibration is done, this position will
|
| 91 |
+
# correspond to every motor angle being 0. If you set all 0 as Goal Position, the arm will move in this position.
|
| 92 |
+
zero_target_pos = convert_degrees_to_steps(ZERO_POSITION_DEGREE, arm.motor_models)
|
| 93 |
+
|
| 94 |
+
# Compute homing offset so that `present_position + homing_offset ~= target_position`.
|
| 95 |
+
zero_pos = arm.read("Present_Position")
|
| 96 |
+
zero_nearest_pos = compute_nearest_rounded_position(zero_pos, arm.motor_models)
|
| 97 |
+
homing_offset = zero_target_pos - zero_nearest_pos
|
| 98 |
+
|
| 99 |
+
# The rotated target position corresponds to a rotation of a quarter turn from the zero position.
|
| 100 |
+
# This allows to identify the rotation direction of each motor.
|
| 101 |
+
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
|
| 102 |
+
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
|
| 103 |
+
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
|
| 104 |
+
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
|
| 105 |
+
# of the previous motor in the kinetic chain.
|
| 106 |
+
print("\nMove arm to rotated target position")
|
| 107 |
+
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
|
| 108 |
+
input("Press Enter to continue...")
|
| 109 |
+
|
| 110 |
+
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
|
| 111 |
+
|
| 112 |
+
# Find drive mode by rotating each motor by a quarter of a turn.
|
| 113 |
+
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
|
| 114 |
+
rotated_pos = arm.read("Present_Position")
|
| 115 |
+
drive_mode = (rotated_pos < zero_pos).astype(np.int32)
|
| 116 |
+
|
| 117 |
+
# Re-compute homing offset to take into account drive mode
|
| 118 |
+
rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode)
|
| 119 |
+
rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models)
|
| 120 |
+
homing_offset = rotated_target_pos - rotated_nearest_pos
|
| 121 |
+
|
| 122 |
+
print("\nMove arm to rest position")
|
| 123 |
+
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
|
| 124 |
+
input("Press Enter to continue...")
|
| 125 |
+
print()
|
| 126 |
+
|
| 127 |
+
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
| 128 |
+
calib_mode = [CalibrationMode.DEGREE.name] * len(arm.motor_names)
|
| 129 |
+
|
| 130 |
+
# TODO(rcadene): make type of joints (DEGREE or LINEAR) configurable from yaml?
|
| 131 |
+
if robot_type in ["aloha"] and "gripper" in arm.motor_names:
|
| 132 |
+
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
|
| 133 |
+
calib_idx = arm.motor_names.index("gripper")
|
| 134 |
+
calib_mode[calib_idx] = CalibrationMode.LINEAR.name
|
| 135 |
+
|
| 136 |
+
calib_data = {
|
| 137 |
+
"homing_offset": homing_offset.tolist(),
|
| 138 |
+
"drive_mode": drive_mode.tolist(),
|
| 139 |
+
"start_pos": zero_pos.tolist(),
|
| 140 |
+
"end_pos": rotated_pos.tolist(),
|
| 141 |
+
"calib_mode": calib_mode,
|
| 142 |
+
"motor_names": arm.motor_names,
|
| 143 |
+
}
|
| 144 |
+
return calib_data
|
lerobot/common/robot_devices/robots/feetech_calibration.py
ADDED
|
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Logic to calibrate a robot arm built with feetech motors"""
|
| 16 |
+
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
|
| 17 |
+
|
| 18 |
+
import time
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
from lerobot.common.robot_devices.motors.feetech import (
|
| 23 |
+
CalibrationMode,
|
| 24 |
+
TorqueMode,
|
| 25 |
+
convert_degrees_to_steps,
|
| 26 |
+
)
|
| 27 |
+
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
| 28 |
+
|
| 29 |
+
URL_TEMPLATE = (
|
| 30 |
+
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# The following positions are provided in nominal degree range ]-180, +180[
|
| 34 |
+
# For more info on these constants, see comments in the code where they get used.
|
| 35 |
+
ZERO_POSITION_DEGREE = 0
|
| 36 |
+
ROTATED_POSITION_DEGREE = 90
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def assert_drive_mode(drive_mode):
|
| 40 |
+
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
|
| 41 |
+
if not np.all(np.isin(drive_mode, [0, 1])):
|
| 42 |
+
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def apply_drive_mode(position, drive_mode):
|
| 46 |
+
assert_drive_mode(drive_mode)
|
| 47 |
+
# Convert `drive_mode` from [0, 1] with 0 indicates original rotation direction and 1 inverted,
|
| 48 |
+
# to [-1, 1] with 1 indicates original rotation direction and -1 inverted.
|
| 49 |
+
signed_drive_mode = -(drive_mode * 2 - 1)
|
| 50 |
+
position *= signed_drive_mode
|
| 51 |
+
return position
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def move_until_block(arm, motor_name, positive_direction=True, while_move_hook=None):
|
| 55 |
+
count = 0
|
| 56 |
+
while True:
|
| 57 |
+
present_pos = arm.read("Present_Position", motor_name)
|
| 58 |
+
if positive_direction:
|
| 59 |
+
# Move +100 steps every time. Lower the steps to lower the speed at which the arm moves.
|
| 60 |
+
arm.write("Goal_Position", present_pos + 100, motor_name)
|
| 61 |
+
else:
|
| 62 |
+
arm.write("Goal_Position", present_pos - 100, motor_name)
|
| 63 |
+
|
| 64 |
+
if while_move_hook is not None:
|
| 65 |
+
while_move_hook()
|
| 66 |
+
|
| 67 |
+
present_pos = arm.read("Present_Position", motor_name).item()
|
| 68 |
+
present_speed = arm.read("Present_Speed", motor_name).item()
|
| 69 |
+
present_current = arm.read("Present_Current", motor_name).item()
|
| 70 |
+
# present_load = arm.read("Present_Load", motor_name).item()
|
| 71 |
+
# present_voltage = arm.read("Present_Voltage", motor_name).item()
|
| 72 |
+
# present_temperature = arm.read("Present_Temperature", motor_name).item()
|
| 73 |
+
|
| 74 |
+
# print(f"{present_pos=}")
|
| 75 |
+
# print(f"{present_speed=}")
|
| 76 |
+
# print(f"{present_current=}")
|
| 77 |
+
# print(f"{present_load=}")
|
| 78 |
+
# print(f"{present_voltage=}")
|
| 79 |
+
# print(f"{present_temperature=}")
|
| 80 |
+
|
| 81 |
+
if present_speed == 0 and present_current > 40:
|
| 82 |
+
count += 1
|
| 83 |
+
if count > 100 or present_current > 300:
|
| 84 |
+
return present_pos
|
| 85 |
+
else:
|
| 86 |
+
count = 0
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def move_to_calibrate(
|
| 90 |
+
arm,
|
| 91 |
+
motor_name,
|
| 92 |
+
invert_drive_mode=False,
|
| 93 |
+
positive_first=True,
|
| 94 |
+
in_between_move_hook=None,
|
| 95 |
+
while_move_hook=None,
|
| 96 |
+
):
|
| 97 |
+
initial_pos = arm.read("Present_Position", motor_name)
|
| 98 |
+
|
| 99 |
+
if positive_first:
|
| 100 |
+
p_present_pos = move_until_block(
|
| 101 |
+
arm, motor_name, positive_direction=True, while_move_hook=while_move_hook
|
| 102 |
+
)
|
| 103 |
+
else:
|
| 104 |
+
n_present_pos = move_until_block(
|
| 105 |
+
arm, motor_name, positive_direction=False, while_move_hook=while_move_hook
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
if in_between_move_hook is not None:
|
| 109 |
+
in_between_move_hook()
|
| 110 |
+
|
| 111 |
+
if positive_first:
|
| 112 |
+
n_present_pos = move_until_block(
|
| 113 |
+
arm, motor_name, positive_direction=False, while_move_hook=while_move_hook
|
| 114 |
+
)
|
| 115 |
+
else:
|
| 116 |
+
p_present_pos = move_until_block(
|
| 117 |
+
arm, motor_name, positive_direction=True, while_move_hook=while_move_hook
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
zero_pos = (n_present_pos + p_present_pos) / 2
|
| 121 |
+
|
| 122 |
+
calib_data = {
|
| 123 |
+
"initial_pos": initial_pos,
|
| 124 |
+
"homing_offset": zero_pos if invert_drive_mode else -zero_pos,
|
| 125 |
+
"invert_drive_mode": invert_drive_mode,
|
| 126 |
+
"drive_mode": -1 if invert_drive_mode else 0,
|
| 127 |
+
"zero_pos": zero_pos,
|
| 128 |
+
"start_pos": n_present_pos if invert_drive_mode else p_present_pos,
|
| 129 |
+
"end_pos": p_present_pos if invert_drive_mode else n_present_pos,
|
| 130 |
+
}
|
| 131 |
+
return calib_data
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def apply_offset(calib, offset):
|
| 135 |
+
calib["zero_pos"] += offset
|
| 136 |
+
if calib["drive_mode"]:
|
| 137 |
+
calib["homing_offset"] += offset
|
| 138 |
+
else:
|
| 139 |
+
calib["homing_offset"] -= offset
|
| 140 |
+
return calib
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
| 144 |
+
if robot_type == "so100":
|
| 145 |
+
return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type)
|
| 146 |
+
elif robot_type == "moss":
|
| 147 |
+
return run_arm_auto_calibration_moss(arm, robot_type, arm_name, arm_type)
|
| 148 |
+
else:
|
| 149 |
+
raise ValueError(robot_type)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
| 153 |
+
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
|
| 154 |
+
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
| 155 |
+
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
| 156 |
+
|
| 157 |
+
if not (robot_type == "so100" and arm_type == "follower"):
|
| 158 |
+
raise NotImplementedError("Auto calibration only supports the follower of so100 arms for now.")
|
| 159 |
+
|
| 160 |
+
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
| 161 |
+
|
| 162 |
+
print("\nMove arm to initial position")
|
| 163 |
+
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
|
| 164 |
+
input("Press Enter to continue...")
|
| 165 |
+
|
| 166 |
+
# Lower the acceleration of the motors (in [0,254])
|
| 167 |
+
initial_acceleration = arm.read("Acceleration")
|
| 168 |
+
arm.write("Lock", 0)
|
| 169 |
+
arm.write("Acceleration", 10)
|
| 170 |
+
time.sleep(1)
|
| 171 |
+
|
| 172 |
+
arm.write("Torque_Enable", TorqueMode.ENABLED.value)
|
| 173 |
+
|
| 174 |
+
print(f'{arm.read("Present_Position", "elbow_flex")=}')
|
| 175 |
+
|
| 176 |
+
calib = {}
|
| 177 |
+
|
| 178 |
+
init_wf_pos = arm.read("Present_Position", "wrist_flex")
|
| 179 |
+
init_sl_pos = arm.read("Present_Position", "shoulder_lift")
|
| 180 |
+
init_ef_pos = arm.read("Present_Position", "elbow_flex")
|
| 181 |
+
arm.write("Goal_Position", init_wf_pos - 800, "wrist_flex")
|
| 182 |
+
arm.write("Goal_Position", init_sl_pos + 150 + 1024, "shoulder_lift")
|
| 183 |
+
arm.write("Goal_Position", init_ef_pos - 2048, "elbow_flex")
|
| 184 |
+
time.sleep(2)
|
| 185 |
+
|
| 186 |
+
print("Calibrate shoulder_pan")
|
| 187 |
+
calib["shoulder_pan"] = move_to_calibrate(arm, "shoulder_pan")
|
| 188 |
+
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
|
| 189 |
+
time.sleep(1)
|
| 190 |
+
|
| 191 |
+
print("Calibrate gripper")
|
| 192 |
+
calib["gripper"] = move_to_calibrate(arm, "gripper", invert_drive_mode=True)
|
| 193 |
+
time.sleep(1)
|
| 194 |
+
|
| 195 |
+
print("Calibrate wrist_flex")
|
| 196 |
+
calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex")
|
| 197 |
+
calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=80)
|
| 198 |
+
|
| 199 |
+
def in_between_move_hook():
|
| 200 |
+
nonlocal arm, calib
|
| 201 |
+
time.sleep(2)
|
| 202 |
+
ef_pos = arm.read("Present_Position", "elbow_flex")
|
| 203 |
+
sl_pos = arm.read("Present_Position", "shoulder_lift")
|
| 204 |
+
arm.write("Goal_Position", ef_pos + 1024, "elbow_flex")
|
| 205 |
+
arm.write("Goal_Position", sl_pos - 1024, "shoulder_lift")
|
| 206 |
+
time.sleep(2)
|
| 207 |
+
|
| 208 |
+
print("Calibrate elbow_flex")
|
| 209 |
+
calib["elbow_flex"] = move_to_calibrate(
|
| 210 |
+
arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook
|
| 211 |
+
)
|
| 212 |
+
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
|
| 213 |
+
|
| 214 |
+
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
|
| 215 |
+
time.sleep(1)
|
| 216 |
+
|
| 217 |
+
def in_between_move_hook():
|
| 218 |
+
nonlocal arm, calib
|
| 219 |
+
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"], "elbow_flex")
|
| 220 |
+
|
| 221 |
+
print("Calibrate shoulder_lift")
|
| 222 |
+
calib["shoulder_lift"] = move_to_calibrate(
|
| 223 |
+
arm,
|
| 224 |
+
"shoulder_lift",
|
| 225 |
+
invert_drive_mode=True,
|
| 226 |
+
positive_first=False,
|
| 227 |
+
in_between_move_hook=in_between_move_hook,
|
| 228 |
+
)
|
| 229 |
+
# add an 30 steps as offset to align with body
|
| 230 |
+
calib["shoulder_lift"] = apply_offset(calib["shoulder_lift"], offset=1024 - 50)
|
| 231 |
+
|
| 232 |
+
def while_move_hook():
|
| 233 |
+
nonlocal arm, calib
|
| 234 |
+
positions = {
|
| 235 |
+
"shoulder_lift": round(calib["shoulder_lift"]["zero_pos"] - 1600),
|
| 236 |
+
"elbow_flex": round(calib["elbow_flex"]["zero_pos"] + 1700),
|
| 237 |
+
"wrist_flex": round(calib["wrist_flex"]["zero_pos"] + 800),
|
| 238 |
+
"gripper": round(calib["gripper"]["end_pos"]),
|
| 239 |
+
}
|
| 240 |
+
arm.write("Goal_Position", list(positions.values()), list(positions.keys()))
|
| 241 |
+
|
| 242 |
+
arm.write("Goal_Position", round(calib["shoulder_lift"]["zero_pos"] - 1600), "shoulder_lift")
|
| 243 |
+
time.sleep(2)
|
| 244 |
+
arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex")
|
| 245 |
+
time.sleep(2)
|
| 246 |
+
arm.write("Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex")
|
| 247 |
+
time.sleep(2)
|
| 248 |
+
arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper")
|
| 249 |
+
time.sleep(2)
|
| 250 |
+
|
| 251 |
+
print("Calibrate wrist_roll")
|
| 252 |
+
calib["wrist_roll"] = move_to_calibrate(
|
| 253 |
+
arm, "wrist_roll", invert_drive_mode=True, positive_first=False, while_move_hook=while_move_hook
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll")
|
| 257 |
+
time.sleep(1)
|
| 258 |
+
arm.write("Goal_Position", calib["gripper"]["start_pos"], "gripper")
|
| 259 |
+
time.sleep(1)
|
| 260 |
+
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex")
|
| 261 |
+
time.sleep(1)
|
| 262 |
+
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex")
|
| 263 |
+
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift")
|
| 264 |
+
time.sleep(1)
|
| 265 |
+
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
|
| 266 |
+
time.sleep(1)
|
| 267 |
+
|
| 268 |
+
calib_modes = []
|
| 269 |
+
for name in arm.motor_names:
|
| 270 |
+
if name == "gripper":
|
| 271 |
+
calib_modes.append(CalibrationMode.LINEAR.name)
|
| 272 |
+
else:
|
| 273 |
+
calib_modes.append(CalibrationMode.DEGREE.name)
|
| 274 |
+
|
| 275 |
+
calib_dict = {
|
| 276 |
+
"homing_offset": [calib[name]["homing_offset"] for name in arm.motor_names],
|
| 277 |
+
"drive_mode": [calib[name]["drive_mode"] for name in arm.motor_names],
|
| 278 |
+
"start_pos": [calib[name]["start_pos"] for name in arm.motor_names],
|
| 279 |
+
"end_pos": [calib[name]["end_pos"] for name in arm.motor_names],
|
| 280 |
+
"calib_mode": calib_modes,
|
| 281 |
+
"motor_names": arm.motor_names,
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
# Re-enable original accerlation
|
| 285 |
+
arm.write("Lock", 0)
|
| 286 |
+
arm.write("Acceleration", initial_acceleration)
|
| 287 |
+
time.sleep(1)
|
| 288 |
+
|
| 289 |
+
return calib_dict
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
| 293 |
+
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
|
| 294 |
+
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
| 295 |
+
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
| 296 |
+
|
| 297 |
+
if not (robot_type == "moss" and arm_type == "follower"):
|
| 298 |
+
raise NotImplementedError("Auto calibration only supports the follower of moss arms for now.")
|
| 299 |
+
|
| 300 |
+
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
| 301 |
+
|
| 302 |
+
print("\nMove arm to initial position")
|
| 303 |
+
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
|
| 304 |
+
input("Press Enter to continue...")
|
| 305 |
+
|
| 306 |
+
# Lower the acceleration of the motors (in [0,254])
|
| 307 |
+
initial_acceleration = arm.read("Acceleration")
|
| 308 |
+
arm.write("Lock", 0)
|
| 309 |
+
arm.write("Acceleration", 10)
|
| 310 |
+
time.sleep(1)
|
| 311 |
+
|
| 312 |
+
arm.write("Torque_Enable", TorqueMode.ENABLED.value)
|
| 313 |
+
|
| 314 |
+
sl_pos = arm.read("Present_Position", "shoulder_lift")
|
| 315 |
+
arm.write("Goal_Position", sl_pos - 1024 - 450, "shoulder_lift")
|
| 316 |
+
ef_pos = arm.read("Present_Position", "elbow_flex")
|
| 317 |
+
arm.write("Goal_Position", ef_pos + 1024 + 450, "elbow_flex")
|
| 318 |
+
time.sleep(2)
|
| 319 |
+
|
| 320 |
+
calib = {}
|
| 321 |
+
|
| 322 |
+
print("Calibrate shoulder_pan")
|
| 323 |
+
calib["shoulder_pan"] = move_to_calibrate(arm, "shoulder_pan")
|
| 324 |
+
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
|
| 325 |
+
time.sleep(1)
|
| 326 |
+
|
| 327 |
+
print("Calibrate gripper")
|
| 328 |
+
calib["gripper"] = move_to_calibrate(arm, "gripper", invert_drive_mode=True)
|
| 329 |
+
time.sleep(1)
|
| 330 |
+
|
| 331 |
+
print("Calibrate wrist_flex")
|
| 332 |
+
calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex", invert_drive_mode=True)
|
| 333 |
+
calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=-210 + 1024)
|
| 334 |
+
|
| 335 |
+
wr_pos = arm.read("Present_Position", "wrist_roll")
|
| 336 |
+
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
|
| 337 |
+
time.sleep(1)
|
| 338 |
+
arm.write("Goal_Position", wr_pos - 1024, "wrist_roll")
|
| 339 |
+
time.sleep(1)
|
| 340 |
+
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 2048, "wrist_flex")
|
| 341 |
+
time.sleep(1)
|
| 342 |
+
arm.write("Goal_Position", calib["gripper"]["end_pos"], "gripper")
|
| 343 |
+
time.sleep(1)
|
| 344 |
+
|
| 345 |
+
print("Calibrate wrist_roll")
|
| 346 |
+
calib["wrist_roll"] = move_to_calibrate(arm, "wrist_roll", invert_drive_mode=True)
|
| 347 |
+
calib["wrist_roll"] = apply_offset(calib["wrist_roll"], offset=790)
|
| 348 |
+
|
| 349 |
+
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"] - 1024, "wrist_roll")
|
| 350 |
+
arm.write("Goal_Position", calib["gripper"]["start_pos"], "gripper")
|
| 351 |
+
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
|
| 352 |
+
time.sleep(1)
|
| 353 |
+
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll")
|
| 354 |
+
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 2048, "wrist_flex")
|
| 355 |
+
|
| 356 |
+
def in_between_move_elbow_flex_hook():
|
| 357 |
+
nonlocal arm, calib
|
| 358 |
+
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex")
|
| 359 |
+
|
| 360 |
+
print("Calibrate elbow_flex")
|
| 361 |
+
calib["elbow_flex"] = move_to_calibrate(
|
| 362 |
+
arm,
|
| 363 |
+
"elbow_flex",
|
| 364 |
+
invert_drive_mode=True,
|
| 365 |
+
in_between_move_hook=in_between_move_elbow_flex_hook,
|
| 366 |
+
)
|
| 367 |
+
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
|
| 368 |
+
|
| 369 |
+
def in_between_move_shoulder_lift_hook():
|
| 370 |
+
nonlocal arm, calib
|
| 371 |
+
sl = arm.read("Present_Position", "shoulder_lift")
|
| 372 |
+
arm.write("Goal_Position", sl - 1500, "shoulder_lift")
|
| 373 |
+
time.sleep(1)
|
| 374 |
+
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1536, "elbow_flex")
|
| 375 |
+
time.sleep(1)
|
| 376 |
+
arm.write("Goal_Position", calib["wrist_flex"]["start_pos"], "wrist_flex")
|
| 377 |
+
time.sleep(1)
|
| 378 |
+
|
| 379 |
+
print("Calibrate shoulder_lift")
|
| 380 |
+
calib["shoulder_lift"] = move_to_calibrate(
|
| 381 |
+
arm, "shoulder_lift", in_between_move_hook=in_between_move_shoulder_lift_hook
|
| 382 |
+
)
|
| 383 |
+
calib["shoulder_lift"] = apply_offset(calib["shoulder_lift"], offset=-1024)
|
| 384 |
+
|
| 385 |
+
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
|
| 386 |
+
time.sleep(1)
|
| 387 |
+
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift")
|
| 388 |
+
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex")
|
| 389 |
+
time.sleep(2)
|
| 390 |
+
|
| 391 |
+
calib_modes = []
|
| 392 |
+
for name in arm.motor_names:
|
| 393 |
+
if name == "gripper":
|
| 394 |
+
calib_modes.append(CalibrationMode.LINEAR.name)
|
| 395 |
+
else:
|
| 396 |
+
calib_modes.append(CalibrationMode.DEGREE.name)
|
| 397 |
+
|
| 398 |
+
calib_dict = {
|
| 399 |
+
"homing_offset": [calib[name]["homing_offset"] for name in arm.motor_names],
|
| 400 |
+
"drive_mode": [calib[name]["drive_mode"] for name in arm.motor_names],
|
| 401 |
+
"start_pos": [calib[name]["start_pos"] for name in arm.motor_names],
|
| 402 |
+
"end_pos": [calib[name]["end_pos"] for name in arm.motor_names],
|
| 403 |
+
"calib_mode": calib_modes,
|
| 404 |
+
"motor_names": arm.motor_names,
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
# Re-enable original accerlation
|
| 408 |
+
arm.write("Lock", 0)
|
| 409 |
+
arm.write("Acceleration", initial_acceleration)
|
| 410 |
+
time.sleep(1)
|
| 411 |
+
|
| 412 |
+
return calib_dict
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
| 416 |
+
"""This function ensures that a neural network trained on data collected on a given robot
|
| 417 |
+
can work on another robot. For instance before calibration, setting a same goal position
|
| 418 |
+
for each motor of two different robots will get two very different positions. But after calibration,
|
| 419 |
+
the two robots will move to the same position.To this end, this function computes the homing offset
|
| 420 |
+
and the drive mode for each motor of a given robot.
|
| 421 |
+
|
| 422 |
+
Homing offset is used to shift the motor position to a ]-2048, +2048[ nominal range (when the motor uses 2048 steps
|
| 423 |
+
to complete a half a turn). This range is set around an arbitrary "zero position" corresponding to all motor positions
|
| 424 |
+
being 0. During the calibration process, you will need to manually move the robot to this "zero position".
|
| 425 |
+
|
| 426 |
+
Drive mode is used to invert the rotation direction of the motor. This is useful when some motors have been assembled
|
| 427 |
+
in the opposite orientation for some robots. During the calibration process, you will need to manually move the robot
|
| 428 |
+
to the "rotated position".
|
| 429 |
+
|
| 430 |
+
After calibration, the homing offsets and drive modes are stored in a cache.
|
| 431 |
+
|
| 432 |
+
Example of usage:
|
| 433 |
+
```python
|
| 434 |
+
run_arm_calibration(arm, "so100", "left", "follower")
|
| 435 |
+
```
|
| 436 |
+
"""
|
| 437 |
+
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
| 438 |
+
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
| 439 |
+
|
| 440 |
+
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
| 441 |
+
|
| 442 |
+
print("\nMove arm to zero position")
|
| 443 |
+
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
|
| 444 |
+
input("Press Enter to continue...")
|
| 445 |
+
|
| 446 |
+
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
|
| 447 |
+
# It is easy to identify and all motors are in a "quarter turn" position. Once calibration is done, this position will
|
| 448 |
+
# correspond to every motor angle being 0. If you set all 0 as Goal Position, the arm will move in this position.
|
| 449 |
+
zero_target_pos = convert_degrees_to_steps(ZERO_POSITION_DEGREE, arm.motor_models)
|
| 450 |
+
|
| 451 |
+
# Compute homing offset so that `present_position + homing_offset ~= target_position`.
|
| 452 |
+
zero_pos = arm.read("Present_Position")
|
| 453 |
+
homing_offset = zero_target_pos - zero_pos
|
| 454 |
+
|
| 455 |
+
# The rotated target position corresponds to a rotation of a quarter turn from the zero position.
|
| 456 |
+
# This allows to identify the rotation direction of each motor.
|
| 457 |
+
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
|
| 458 |
+
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
|
| 459 |
+
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
|
| 460 |
+
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
|
| 461 |
+
# of the previous motor in the kinetic chain.
|
| 462 |
+
print("\nMove arm to rotated target position")
|
| 463 |
+
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
|
| 464 |
+
input("Press Enter to continue...")
|
| 465 |
+
|
| 466 |
+
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
|
| 467 |
+
|
| 468 |
+
# Find drive mode by rotating each motor by a quarter of a turn.
|
| 469 |
+
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
|
| 470 |
+
rotated_pos = arm.read("Present_Position")
|
| 471 |
+
drive_mode = (rotated_pos < zero_pos).astype(np.int32)
|
| 472 |
+
|
| 473 |
+
# Re-compute homing offset to take into account drive mode
|
| 474 |
+
rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode)
|
| 475 |
+
homing_offset = rotated_target_pos - rotated_drived_pos
|
| 476 |
+
|
| 477 |
+
print("\nMove arm to rest position")
|
| 478 |
+
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
|
| 479 |
+
input("Press Enter to continue...")
|
| 480 |
+
print()
|
| 481 |
+
|
| 482 |
+
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
| 483 |
+
calib_modes = []
|
| 484 |
+
for name in arm.motor_names:
|
| 485 |
+
if name == "gripper":
|
| 486 |
+
calib_modes.append(CalibrationMode.LINEAR.name)
|
| 487 |
+
else:
|
| 488 |
+
calib_modes.append(CalibrationMode.DEGREE.name)
|
| 489 |
+
|
| 490 |
+
calib_dict = {
|
| 491 |
+
"homing_offset": homing_offset.tolist(),
|
| 492 |
+
"drive_mode": drive_mode.tolist(),
|
| 493 |
+
"start_pos": zero_pos.tolist(),
|
| 494 |
+
"end_pos": rotated_pos.tolist(),
|
| 495 |
+
"calib_mode": calib_modes,
|
| 496 |
+
"motor_names": arm.motor_names,
|
| 497 |
+
}
|
| 498 |
+
return calib_dict
|
lerobot/common/robot_devices/robots/lekiwi_remote.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import base64
|
| 16 |
+
import json
|
| 17 |
+
import threading
|
| 18 |
+
import time
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import cv2
|
| 22 |
+
import zmq
|
| 23 |
+
|
| 24 |
+
from lerobot.common.robot_devices.robots.mobile_manipulator import LeKiwi
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def setup_zmq_sockets(config):
|
| 28 |
+
context = zmq.Context()
|
| 29 |
+
cmd_socket = context.socket(zmq.PULL)
|
| 30 |
+
cmd_socket.setsockopt(zmq.CONFLATE, 1)
|
| 31 |
+
cmd_socket.bind(f"tcp://*:{config.port}")
|
| 32 |
+
|
| 33 |
+
video_socket = context.socket(zmq.PUSH)
|
| 34 |
+
video_socket.setsockopt(zmq.CONFLATE, 1)
|
| 35 |
+
video_socket.bind(f"tcp://*:{config.video_port}")
|
| 36 |
+
|
| 37 |
+
return context, cmd_socket, video_socket
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event):
|
| 41 |
+
while not stop_event.is_set():
|
| 42 |
+
local_dict = {}
|
| 43 |
+
for name, cam in cameras.items():
|
| 44 |
+
frame = cam.async_read()
|
| 45 |
+
ret, buffer = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
|
| 46 |
+
if ret:
|
| 47 |
+
local_dict[name] = base64.b64encode(buffer).decode("utf-8")
|
| 48 |
+
else:
|
| 49 |
+
local_dict[name] = ""
|
| 50 |
+
with images_lock:
|
| 51 |
+
latest_images_dict.update(local_dict)
|
| 52 |
+
time.sleep(0.01)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def calibrate_follower_arm(motors_bus, calib_dir_str):
|
| 56 |
+
"""
|
| 57 |
+
Calibrates the follower arm. Attempts to load an existing calibration file;
|
| 58 |
+
if not found, runs manual calibration and saves the result.
|
| 59 |
+
"""
|
| 60 |
+
calib_dir = Path(calib_dir_str)
|
| 61 |
+
calib_dir.mkdir(parents=True, exist_ok=True)
|
| 62 |
+
calib_file = calib_dir / "main_follower.json"
|
| 63 |
+
try:
|
| 64 |
+
from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration
|
| 65 |
+
except ImportError:
|
| 66 |
+
print("[WARNING] Calibration function not available. Skipping calibration.")
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
if calib_file.exists():
|
| 70 |
+
with open(calib_file) as f:
|
| 71 |
+
calibration = json.load(f)
|
| 72 |
+
print(f"[INFO] Loaded calibration from {calib_file}")
|
| 73 |
+
else:
|
| 74 |
+
print("[INFO] Calibration file not found. Running manual calibration...")
|
| 75 |
+
calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
|
| 76 |
+
print(f"[INFO] Calibration complete. Saving to {calib_file}")
|
| 77 |
+
with open(calib_file, "w") as f:
|
| 78 |
+
json.dump(calibration, f)
|
| 79 |
+
try:
|
| 80 |
+
motors_bus.set_calibration(calibration)
|
| 81 |
+
print("[INFO] Applied calibration for follower arm.")
|
| 82 |
+
except Exception as e:
|
| 83 |
+
print(f"[WARNING] Could not apply calibration: {e}")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def run_lekiwi(robot_config):
|
| 87 |
+
"""
|
| 88 |
+
Runs the LeKiwi robot:
|
| 89 |
+
- Sets up cameras and connects them.
|
| 90 |
+
- Initializes the follower arm motors.
|
| 91 |
+
- Calibrates the follower arm if necessary.
|
| 92 |
+
- Creates ZeroMQ sockets for receiving commands and streaming observations.
|
| 93 |
+
- Processes incoming commands (arm and wheel commands) and sends back sensor and camera data.
|
| 94 |
+
"""
|
| 95 |
+
# Import helper functions and classes
|
| 96 |
+
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
| 97 |
+
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus, TorqueMode
|
| 98 |
+
|
| 99 |
+
# Initialize cameras from the robot configuration.
|
| 100 |
+
cameras = make_cameras_from_configs(robot_config.cameras)
|
| 101 |
+
for cam in cameras.values():
|
| 102 |
+
cam.connect()
|
| 103 |
+
|
| 104 |
+
# Initialize the motors bus using the follower arm configuration.
|
| 105 |
+
motor_config = robot_config.follower_arms.get("main")
|
| 106 |
+
if motor_config is None:
|
| 107 |
+
print("[ERROR] Follower arm 'main' configuration not found.")
|
| 108 |
+
return
|
| 109 |
+
motors_bus = FeetechMotorsBus(motor_config)
|
| 110 |
+
motors_bus.connect()
|
| 111 |
+
|
| 112 |
+
# Calibrate the follower arm.
|
| 113 |
+
calibrate_follower_arm(motors_bus, robot_config.calibration_dir)
|
| 114 |
+
|
| 115 |
+
# Create the LeKiwi robot instance.
|
| 116 |
+
robot = LeKiwi(motors_bus)
|
| 117 |
+
|
| 118 |
+
# Define the expected arm motor IDs.
|
| 119 |
+
arm_motor_ids = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"]
|
| 120 |
+
|
| 121 |
+
# Disable torque for each arm motor.
|
| 122 |
+
for motor in arm_motor_ids:
|
| 123 |
+
motors_bus.write("Torque_Enable", TorqueMode.DISABLED.value, motor)
|
| 124 |
+
|
| 125 |
+
# Set up ZeroMQ sockets.
|
| 126 |
+
context, cmd_socket, video_socket = setup_zmq_sockets(robot_config)
|
| 127 |
+
|
| 128 |
+
# Start the camera capture thread.
|
| 129 |
+
latest_images_dict = {}
|
| 130 |
+
images_lock = threading.Lock()
|
| 131 |
+
stop_event = threading.Event()
|
| 132 |
+
cam_thread = threading.Thread(
|
| 133 |
+
target=run_camera_capture, args=(cameras, images_lock, latest_images_dict, stop_event), daemon=True
|
| 134 |
+
)
|
| 135 |
+
cam_thread.start()
|
| 136 |
+
|
| 137 |
+
last_cmd_time = time.time()
|
| 138 |
+
print("LeKiwi robot server started. Waiting for commands...")
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
while True:
|
| 142 |
+
loop_start_time = time.time()
|
| 143 |
+
|
| 144 |
+
# Process incoming commands (non-blocking).
|
| 145 |
+
while True:
|
| 146 |
+
try:
|
| 147 |
+
msg = cmd_socket.recv_string(zmq.NOBLOCK)
|
| 148 |
+
except zmq.Again:
|
| 149 |
+
break
|
| 150 |
+
try:
|
| 151 |
+
data = json.loads(msg)
|
| 152 |
+
# Process arm position commands.
|
| 153 |
+
if "arm_positions" in data:
|
| 154 |
+
arm_positions = data["arm_positions"]
|
| 155 |
+
if not isinstance(arm_positions, list):
|
| 156 |
+
print(f"[ERROR] Invalid arm_positions: {arm_positions}")
|
| 157 |
+
elif len(arm_positions) < len(arm_motor_ids):
|
| 158 |
+
print(
|
| 159 |
+
f"[WARNING] Received {len(arm_positions)} arm positions, expected {len(arm_motor_ids)}"
|
| 160 |
+
)
|
| 161 |
+
else:
|
| 162 |
+
for motor, pos in zip(arm_motor_ids, arm_positions, strict=False):
|
| 163 |
+
motors_bus.write("Goal_Position", pos, motor)
|
| 164 |
+
# Process wheel (base) commands.
|
| 165 |
+
if "raw_velocity" in data:
|
| 166 |
+
raw_command = data["raw_velocity"]
|
| 167 |
+
# Expect keys: "left_wheel", "back_wheel", "right_wheel".
|
| 168 |
+
command_speeds = [
|
| 169 |
+
int(raw_command.get("left_wheel", 0)),
|
| 170 |
+
int(raw_command.get("back_wheel", 0)),
|
| 171 |
+
int(raw_command.get("right_wheel", 0)),
|
| 172 |
+
]
|
| 173 |
+
robot.set_velocity(command_speeds)
|
| 174 |
+
last_cmd_time = time.time()
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"[ERROR] Parsing message failed: {e}")
|
| 177 |
+
|
| 178 |
+
# Watchdog: stop the robot if no command is received for over 0.5 seconds.
|
| 179 |
+
now = time.time()
|
| 180 |
+
if now - last_cmd_time > 0.5:
|
| 181 |
+
robot.stop()
|
| 182 |
+
last_cmd_time = now
|
| 183 |
+
|
| 184 |
+
# Read current wheel speeds from the robot.
|
| 185 |
+
current_velocity = robot.read_velocity()
|
| 186 |
+
|
| 187 |
+
# Read the follower arm state from the motors bus.
|
| 188 |
+
follower_arm_state = []
|
| 189 |
+
for motor in arm_motor_ids:
|
| 190 |
+
try:
|
| 191 |
+
pos = motors_bus.read("Present_Position", motor)
|
| 192 |
+
# Convert the position to a float (or use as is if already numeric).
|
| 193 |
+
follower_arm_state.append(float(pos) if not isinstance(pos, (int, float)) else pos)
|
| 194 |
+
except Exception as e:
|
| 195 |
+
print(f"[ERROR] Reading motor {motor} failed: {e}")
|
| 196 |
+
|
| 197 |
+
# Get the latest camera images.
|
| 198 |
+
with images_lock:
|
| 199 |
+
images_dict_copy = dict(latest_images_dict)
|
| 200 |
+
|
| 201 |
+
# Build the observation dictionary.
|
| 202 |
+
observation = {
|
| 203 |
+
"images": images_dict_copy,
|
| 204 |
+
"present_speed": current_velocity,
|
| 205 |
+
"follower_arm_state": follower_arm_state,
|
| 206 |
+
}
|
| 207 |
+
# Send the observation over the video socket.
|
| 208 |
+
video_socket.send_string(json.dumps(observation))
|
| 209 |
+
|
| 210 |
+
# Ensure a short sleep to avoid overloading the CPU.
|
| 211 |
+
elapsed = time.time() - loop_start_time
|
| 212 |
+
time.sleep(
|
| 213 |
+
max(0.033 - elapsed, 0)
|
| 214 |
+
) # If robot jitters increase the sleep and monitor cpu load with `top` in cmd
|
| 215 |
+
except KeyboardInterrupt:
|
| 216 |
+
print("Shutting down LeKiwi server.")
|
| 217 |
+
finally:
|
| 218 |
+
stop_event.set()
|
| 219 |
+
cam_thread.join()
|
| 220 |
+
robot.stop()
|
| 221 |
+
motors_bus.disconnect()
|
| 222 |
+
cmd_socket.close()
|
| 223 |
+
video_socket.close()
|
| 224 |
+
context.term()
|
lerobot/common/robot_devices/robots/manipulator.py
ADDED
|
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Contains logic to instantiate a robot, read information from its motors and cameras,
|
| 16 |
+
and send orders to its motors.
|
| 17 |
+
"""
|
| 18 |
+
# TODO(rcadene, aliberts): reorganize the codebase into one file per robot, with the associated
|
| 19 |
+
# calibration procedure, to make it easy for people to add their own robot.
|
| 20 |
+
|
| 21 |
+
import json
|
| 22 |
+
import logging
|
| 23 |
+
import time
|
| 24 |
+
import warnings
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
|
| 27 |
+
import numpy as np
|
| 28 |
+
import torch
|
| 29 |
+
|
| 30 |
+
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
| 31 |
+
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
|
| 32 |
+
from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
|
| 33 |
+
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
| 34 |
+
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def ensure_safe_goal_position(
|
| 38 |
+
goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float]
|
| 39 |
+
):
|
| 40 |
+
# Cap relative action target magnitude for safety.
|
| 41 |
+
diff = goal_pos - present_pos
|
| 42 |
+
max_relative_target = torch.tensor(max_relative_target)
|
| 43 |
+
safe_diff = torch.minimum(diff, max_relative_target)
|
| 44 |
+
safe_diff = torch.maximum(safe_diff, -max_relative_target)
|
| 45 |
+
safe_goal_pos = present_pos + safe_diff
|
| 46 |
+
|
| 47 |
+
if not torch.allclose(goal_pos, safe_goal_pos):
|
| 48 |
+
logging.warning(
|
| 49 |
+
"Relative goal position magnitude had to be clamped to be safe.\n"
|
| 50 |
+
f" requested relative goal position target: {diff}\n"
|
| 51 |
+
f" clamped relative goal position target: {safe_diff}"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
return safe_goal_pos
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class ManipulatorRobot:
|
| 58 |
+
# TODO(rcadene): Implement force feedback
|
| 59 |
+
"""This class allows to control any manipulator robot of various number of motors.
|
| 60 |
+
|
| 61 |
+
Non exhaustive list of robots:
|
| 62 |
+
- [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow expansion, developed
|
| 63 |
+
by Alexander Koch from [Tau Robotics](https://tau-robotics.com)
|
| 64 |
+
- [Koch v1.1](https://github.com/jess-moss/koch-v1-1) developed by Jess Moss
|
| 65 |
+
- [Aloha](https://www.trossenrobotics.com/aloha-kits) developed by Trossen Robotics
|
| 66 |
+
|
| 67 |
+
Example of instantiation, a pre-defined robot config is required:
|
| 68 |
+
```python
|
| 69 |
+
robot = ManipulatorRobot(KochRobotConfig())
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
Example of overwriting motors during instantiation:
|
| 73 |
+
```python
|
| 74 |
+
# Defines how to communicate with the motors of the leader and follower arms
|
| 75 |
+
leader_arms = {
|
| 76 |
+
"main": DynamixelMotorsBusConfig(
|
| 77 |
+
port="/dev/tty.usbmodem575E0031751",
|
| 78 |
+
motors={
|
| 79 |
+
# name: (index, model)
|
| 80 |
+
"shoulder_pan": (1, "xl330-m077"),
|
| 81 |
+
"shoulder_lift": (2, "xl330-m077"),
|
| 82 |
+
"elbow_flex": (3, "xl330-m077"),
|
| 83 |
+
"wrist_flex": (4, "xl330-m077"),
|
| 84 |
+
"wrist_roll": (5, "xl330-m077"),
|
| 85 |
+
"gripper": (6, "xl330-m077"),
|
| 86 |
+
},
|
| 87 |
+
),
|
| 88 |
+
}
|
| 89 |
+
follower_arms = {
|
| 90 |
+
"main": DynamixelMotorsBusConfig(
|
| 91 |
+
port="/dev/tty.usbmodem575E0032081",
|
| 92 |
+
motors={
|
| 93 |
+
# name: (index, model)
|
| 94 |
+
"shoulder_pan": (1, "xl430-w250"),
|
| 95 |
+
"shoulder_lift": (2, "xl430-w250"),
|
| 96 |
+
"elbow_flex": (3, "xl330-m288"),
|
| 97 |
+
"wrist_flex": (4, "xl330-m288"),
|
| 98 |
+
"wrist_roll": (5, "xl330-m288"),
|
| 99 |
+
"gripper": (6, "xl330-m288"),
|
| 100 |
+
},
|
| 101 |
+
),
|
| 102 |
+
}
|
| 103 |
+
robot_config = KochRobotConfig(leader_arms=leader_arms, follower_arms=follower_arms)
|
| 104 |
+
robot = ManipulatorRobot(robot_config)
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
Example of overwriting cameras during instantiation:
|
| 108 |
+
```python
|
| 109 |
+
# Defines how to communicate with 2 cameras connected to the computer.
|
| 110 |
+
# Here, the webcam of the laptop and the phone (connected in USB to the laptop)
|
| 111 |
+
# can be reached respectively using the camera indices 0 and 1. These indices can be
|
| 112 |
+
# arbitrary. See the documentation of `OpenCVCamera` to find your own camera indices.
|
| 113 |
+
cameras = {
|
| 114 |
+
"laptop": OpenCVCamera(camera_index=0, fps=30, width=640, height=480),
|
| 115 |
+
"phone": OpenCVCamera(camera_index=1, fps=30, width=640, height=480),
|
| 116 |
+
}
|
| 117 |
+
robot = ManipulatorRobot(KochRobotConfig(cameras=cameras))
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
Once the robot is instantiated, connect motors buses and cameras if any (Required):
|
| 121 |
+
```python
|
| 122 |
+
robot.connect()
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
Example of highest frequency teleoperation, which doesn't require cameras:
|
| 126 |
+
```python
|
| 127 |
+
while True:
|
| 128 |
+
robot.teleop_step()
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
Example of highest frequency data collection from motors and cameras (if any):
|
| 132 |
+
```python
|
| 133 |
+
while True:
|
| 134 |
+
observation, action = robot.teleop_step(record_data=True)
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
Example of controlling the robot with a policy:
|
| 138 |
+
```python
|
| 139 |
+
while True:
|
| 140 |
+
# Uses the follower arms and cameras to capture an observation
|
| 141 |
+
observation = robot.capture_observation()
|
| 142 |
+
|
| 143 |
+
# Assumes a policy has been instantiated
|
| 144 |
+
with torch.inference_mode():
|
| 145 |
+
action = policy.select_action(observation)
|
| 146 |
+
|
| 147 |
+
# Orders the robot to move
|
| 148 |
+
robot.send_action(action)
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
Example of disconnecting which is not mandatory since we disconnect when the object is deleted:
|
| 152 |
+
```python
|
| 153 |
+
robot.disconnect()
|
| 154 |
+
```
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def __init__(
|
| 158 |
+
self,
|
| 159 |
+
config: ManipulatorRobotConfig,
|
| 160 |
+
):
|
| 161 |
+
self.config = config
|
| 162 |
+
self.robot_type = self.config.type
|
| 163 |
+
self.calibration_dir = Path(self.config.calibration_dir)
|
| 164 |
+
self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
|
| 165 |
+
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
|
| 166 |
+
self.cameras = make_cameras_from_configs(self.config.cameras)
|
| 167 |
+
self.is_connected = False
|
| 168 |
+
self.logs = {}
|
| 169 |
+
|
| 170 |
+
def get_motor_names(self, arm: dict[str, MotorsBus]) -> list:
|
| 171 |
+
return [f"{arm}_{motor}" for arm, bus in arm.items() for motor in bus.motors]
|
| 172 |
+
|
| 173 |
+
@property
|
| 174 |
+
def camera_features(self) -> dict:
|
| 175 |
+
cam_ft = {}
|
| 176 |
+
for cam_key, cam in self.cameras.items():
|
| 177 |
+
key = f"observation.images.{cam_key}"
|
| 178 |
+
cam_ft[key] = {
|
| 179 |
+
"shape": (cam.height, cam.width, cam.channels),
|
| 180 |
+
"names": ["height", "width", "channels"],
|
| 181 |
+
"info": None,
|
| 182 |
+
}
|
| 183 |
+
return cam_ft
|
| 184 |
+
|
| 185 |
+
@property
|
| 186 |
+
def motor_features(self) -> dict:
|
| 187 |
+
action_names = self.get_motor_names(self.leader_arms)
|
| 188 |
+
state_names = self.get_motor_names(self.leader_arms)
|
| 189 |
+
return {
|
| 190 |
+
"action": {
|
| 191 |
+
"dtype": "float32",
|
| 192 |
+
"shape": (len(action_names),),
|
| 193 |
+
"names": action_names,
|
| 194 |
+
},
|
| 195 |
+
"observation.state": {
|
| 196 |
+
"dtype": "float32",
|
| 197 |
+
"shape": (len(state_names),),
|
| 198 |
+
"names": state_names,
|
| 199 |
+
},
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
@property
|
| 203 |
+
def features(self):
|
| 204 |
+
return {**self.motor_features, **self.camera_features}
|
| 205 |
+
|
| 206 |
+
@property
|
| 207 |
+
def has_camera(self):
|
| 208 |
+
return len(self.cameras) > 0
|
| 209 |
+
|
| 210 |
+
@property
|
| 211 |
+
def num_cameras(self):
|
| 212 |
+
return len(self.cameras)
|
| 213 |
+
|
| 214 |
+
@property
|
| 215 |
+
def available_arms(self):
|
| 216 |
+
available_arms = []
|
| 217 |
+
for name in self.follower_arms:
|
| 218 |
+
arm_id = get_arm_id(name, "follower")
|
| 219 |
+
available_arms.append(arm_id)
|
| 220 |
+
for name in self.leader_arms:
|
| 221 |
+
arm_id = get_arm_id(name, "leader")
|
| 222 |
+
available_arms.append(arm_id)
|
| 223 |
+
return available_arms
|
| 224 |
+
|
| 225 |
+
def connect(self):
|
| 226 |
+
if self.is_connected:
|
| 227 |
+
raise RobotDeviceAlreadyConnectedError(
|
| 228 |
+
"ManipulatorRobot is already connected. Do not run `robot.connect()` twice."
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
if not self.leader_arms and not self.follower_arms and not self.cameras:
|
| 232 |
+
raise ValueError(
|
| 233 |
+
"ManipulatorRobot doesn't have any device to connect. See example of usage in docstring of the class."
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# Connect the arms
|
| 237 |
+
for name in self.follower_arms:
|
| 238 |
+
print(f"Connecting {name} follower arm.")
|
| 239 |
+
self.follower_arms[name].connect()
|
| 240 |
+
for name in self.leader_arms:
|
| 241 |
+
print(f"Connecting {name} leader arm.")
|
| 242 |
+
self.leader_arms[name].connect()
|
| 243 |
+
|
| 244 |
+
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
|
| 245 |
+
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
| 246 |
+
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
| 247 |
+
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
| 248 |
+
|
| 249 |
+
# We assume that at connection time, arms are in a rest position, and torque can
|
| 250 |
+
# be safely disabled to run calibration and/or set robot preset configurations.
|
| 251 |
+
for name in self.follower_arms:
|
| 252 |
+
self.follower_arms[name].write("Torque_Enable", TorqueMode.DISABLED.value)
|
| 253 |
+
for name in self.leader_arms:
|
| 254 |
+
self.leader_arms[name].write("Torque_Enable", TorqueMode.DISABLED.value)
|
| 255 |
+
|
| 256 |
+
self.activate_calibration()
|
| 257 |
+
|
| 258 |
+
# Set robot preset (e.g. torque in leader gripper for Koch v1.1)
|
| 259 |
+
if self.robot_type in ["koch", "koch_bimanual"]:
|
| 260 |
+
self.set_koch_robot_preset()
|
| 261 |
+
elif self.robot_type == "aloha":
|
| 262 |
+
self.set_aloha_robot_preset()
|
| 263 |
+
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
| 264 |
+
self.set_so100_robot_preset()
|
| 265 |
+
|
| 266 |
+
# Enable torque on all motors of the follower arms
|
| 267 |
+
for name in self.follower_arms:
|
| 268 |
+
print(f"Activating torque on {name} follower arm.")
|
| 269 |
+
self.follower_arms[name].write("Torque_Enable", 1)
|
| 270 |
+
|
| 271 |
+
if self.config.gripper_open_degree is not None:
|
| 272 |
+
if self.robot_type not in ["koch", "koch_bimanual"]:
|
| 273 |
+
raise NotImplementedError(
|
| 274 |
+
f"{self.robot_type} does not support position AND current control in the handle, which is require to set the gripper open."
|
| 275 |
+
)
|
| 276 |
+
# Set the leader arm in torque mode with the gripper motor set to an angle. This makes it possible
|
| 277 |
+
# to squeeze the gripper and have it spring back to an open position on its own.
|
| 278 |
+
for name in self.leader_arms:
|
| 279 |
+
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
|
| 280 |
+
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
|
| 281 |
+
|
| 282 |
+
# Check both arms can be read
|
| 283 |
+
for name in self.follower_arms:
|
| 284 |
+
self.follower_arms[name].read("Present_Position")
|
| 285 |
+
for name in self.leader_arms:
|
| 286 |
+
self.leader_arms[name].read("Present_Position")
|
| 287 |
+
|
| 288 |
+
# Connect the cameras
|
| 289 |
+
for name in self.cameras:
|
| 290 |
+
self.cameras[name].connect()
|
| 291 |
+
|
| 292 |
+
self.is_connected = True
|
| 293 |
+
|
| 294 |
+
def activate_calibration(self):
|
| 295 |
+
"""After calibration all motors function in human interpretable ranges.
|
| 296 |
+
Rotations are expressed in degrees in nominal range of [-180, 180],
|
| 297 |
+
and linear motions (like gripper of Aloha) in nominal range of [0, 100].
|
| 298 |
+
"""
|
| 299 |
+
|
| 300 |
+
def load_or_run_calibration_(name, arm, arm_type):
|
| 301 |
+
arm_id = get_arm_id(name, arm_type)
|
| 302 |
+
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
|
| 303 |
+
|
| 304 |
+
if arm_calib_path.exists():
|
| 305 |
+
with open(arm_calib_path) as f:
|
| 306 |
+
calibration = json.load(f)
|
| 307 |
+
else:
|
| 308 |
+
# TODO(rcadene): display a warning in __init__ if calibration file not available
|
| 309 |
+
print(f"Missing calibration file '{arm_calib_path}'")
|
| 310 |
+
|
| 311 |
+
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
|
| 312 |
+
from lerobot.common.robot_devices.robots.dynamixel_calibration import run_arm_calibration
|
| 313 |
+
|
| 314 |
+
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
|
| 315 |
+
|
| 316 |
+
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
| 317 |
+
from lerobot.common.robot_devices.robots.feetech_calibration import (
|
| 318 |
+
run_arm_manual_calibration,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
| 322 |
+
|
| 323 |
+
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
| 324 |
+
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
| 325 |
+
with open(arm_calib_path, "w") as f:
|
| 326 |
+
json.dump(calibration, f)
|
| 327 |
+
|
| 328 |
+
return calibration
|
| 329 |
+
|
| 330 |
+
for name, arm in self.follower_arms.items():
|
| 331 |
+
calibration = load_or_run_calibration_(name, arm, "follower")
|
| 332 |
+
arm.set_calibration(calibration)
|
| 333 |
+
for name, arm in self.leader_arms.items():
|
| 334 |
+
calibration = load_or_run_calibration_(name, arm, "leader")
|
| 335 |
+
arm.set_calibration(calibration)
|
| 336 |
+
|
| 337 |
+
def set_koch_robot_preset(self):
|
| 338 |
+
def set_operating_mode_(arm):
|
| 339 |
+
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
| 340 |
+
|
| 341 |
+
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
| 342 |
+
raise ValueError("To run set robot preset, the torque must be disabled on all motors.")
|
| 343 |
+
|
| 344 |
+
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
|
| 345 |
+
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
|
| 346 |
+
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
|
| 347 |
+
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
|
| 348 |
+
all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"]
|
| 349 |
+
if len(all_motors_except_gripper) > 0:
|
| 350 |
+
# 4 corresponds to Extended Position on Koch motors
|
| 351 |
+
arm.write("Operating_Mode", 4, all_motors_except_gripper)
|
| 352 |
+
|
| 353 |
+
# Use 'position control current based' for gripper to be limited by the limit of the current.
|
| 354 |
+
# For the follower gripper, it means it can grasp an object without forcing too much even tho,
|
| 355 |
+
# it's goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
|
| 356 |
+
# For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger
|
| 357 |
+
# to make it move, and it will move back to its original target position when we release the force.
|
| 358 |
+
# 5 corresponds to Current Controlled Position on Koch gripper motors "xl330-m077, xl330-m288"
|
| 359 |
+
arm.write("Operating_Mode", 5, "gripper")
|
| 360 |
+
|
| 361 |
+
for name in self.follower_arms:
|
| 362 |
+
set_operating_mode_(self.follower_arms[name])
|
| 363 |
+
|
| 364 |
+
# Set better PID values to close the gap between recorded states and actions
|
| 365 |
+
# TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor
|
| 366 |
+
self.follower_arms[name].write("Position_P_Gain", 1500, "elbow_flex")
|
| 367 |
+
self.follower_arms[name].write("Position_I_Gain", 0, "elbow_flex")
|
| 368 |
+
self.follower_arms[name].write("Position_D_Gain", 600, "elbow_flex")
|
| 369 |
+
|
| 370 |
+
if self.config.gripper_open_degree is not None:
|
| 371 |
+
for name in self.leader_arms:
|
| 372 |
+
set_operating_mode_(self.leader_arms[name])
|
| 373 |
+
|
| 374 |
+
# Enable torque on the gripper of the leader arms, and move it to 45 degrees,
|
| 375 |
+
# so that we can use it as a trigger to close the gripper of the follower arms.
|
| 376 |
+
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
|
| 377 |
+
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
|
| 378 |
+
|
| 379 |
+
def set_aloha_robot_preset(self):
|
| 380 |
+
def set_shadow_(arm):
|
| 381 |
+
# Set secondary/shadow ID for shoulder and elbow. These joints have two motors.
|
| 382 |
+
# As a result, if only one of them is required to move to a certain position,
|
| 383 |
+
# the other will follow. This is to avoid breaking the motors.
|
| 384 |
+
if "shoulder_shadow" in arm.motor_names:
|
| 385 |
+
shoulder_idx = arm.read("ID", "shoulder")
|
| 386 |
+
arm.write("Secondary_ID", shoulder_idx, "shoulder_shadow")
|
| 387 |
+
|
| 388 |
+
if "elbow_shadow" in arm.motor_names:
|
| 389 |
+
elbow_idx = arm.read("ID", "elbow")
|
| 390 |
+
arm.write("Secondary_ID", elbow_idx, "elbow_shadow")
|
| 391 |
+
|
| 392 |
+
for name in self.follower_arms:
|
| 393 |
+
set_shadow_(self.follower_arms[name])
|
| 394 |
+
|
| 395 |
+
for name in self.leader_arms:
|
| 396 |
+
set_shadow_(self.leader_arms[name])
|
| 397 |
+
|
| 398 |
+
for name in self.follower_arms:
|
| 399 |
+
# Set a velocity limit of 131 as advised by Trossen Robotics
|
| 400 |
+
self.follower_arms[name].write("Velocity_Limit", 131)
|
| 401 |
+
|
| 402 |
+
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
|
| 403 |
+
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
|
| 404 |
+
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
|
| 405 |
+
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
|
| 406 |
+
all_motors_except_gripper = [
|
| 407 |
+
name for name in self.follower_arms[name].motor_names if name != "gripper"
|
| 408 |
+
]
|
| 409 |
+
if len(all_motors_except_gripper) > 0:
|
| 410 |
+
# 4 corresponds to Extended Position on Aloha motors
|
| 411 |
+
self.follower_arms[name].write("Operating_Mode", 4, all_motors_except_gripper)
|
| 412 |
+
|
| 413 |
+
# Use 'position control current based' for follower gripper to be limited by the limit of the current.
|
| 414 |
+
# It can grasp an object without forcing too much even tho,
|
| 415 |
+
# it's goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
|
| 416 |
+
# 5 corresponds to Current Controlled Position on Aloha gripper follower "xm430-w350"
|
| 417 |
+
self.follower_arms[name].write("Operating_Mode", 5, "gripper")
|
| 418 |
+
|
| 419 |
+
# Note: We can't enable torque on the leader gripper since "xc430-w150" doesn't have
|
| 420 |
+
# a Current Controlled Position mode.
|
| 421 |
+
|
| 422 |
+
if self.config.gripper_open_degree is not None:
|
| 423 |
+
warnings.warn(
|
| 424 |
+
f"`gripper_open_degree` is set to {self.config.gripper_open_degree}, but None is expected for Aloha instead",
|
| 425 |
+
stacklevel=1,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
def set_so100_robot_preset(self):
|
| 429 |
+
for name in self.follower_arms:
|
| 430 |
+
# Mode=0 for Position Control
|
| 431 |
+
self.follower_arms[name].write("Mode", 0)
|
| 432 |
+
# Set P_Coefficient to lower value to avoid shakiness (Default is 32)
|
| 433 |
+
self.follower_arms[name].write("P_Coefficient", 16)
|
| 434 |
+
# Set I_Coefficient and D_Coefficient to default value 0 and 32
|
| 435 |
+
self.follower_arms[name].write("I_Coefficient", 0)
|
| 436 |
+
self.follower_arms[name].write("D_Coefficient", 32)
|
| 437 |
+
# Close the write lock so that Maximum_Acceleration gets written to EPROM address,
|
| 438 |
+
# which is mandatory for Maximum_Acceleration to take effect after rebooting.
|
| 439 |
+
self.follower_arms[name].write("Lock", 0)
|
| 440 |
+
# Set Maximum_Acceleration to 254 to speedup acceleration and deceleration of
|
| 441 |
+
# the motors. Note: this configuration is not in the official STS3215 Memory Table
|
| 442 |
+
self.follower_arms[name].write("Maximum_Acceleration", 254)
|
| 443 |
+
self.follower_arms[name].write("Acceleration", 254)
|
| 444 |
+
|
| 445 |
+
def teleop_step(
|
| 446 |
+
self, record_data=False
|
| 447 |
+
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
| 448 |
+
if not self.is_connected:
|
| 449 |
+
raise RobotDeviceNotConnectedError(
|
| 450 |
+
"ManipulatorRobot is not connected. You need to run `robot.connect()`."
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# Prepare to assign the position of the leader to the follower
|
| 454 |
+
leader_pos = {}
|
| 455 |
+
for name in self.leader_arms:
|
| 456 |
+
before_lread_t = time.perf_counter()
|
| 457 |
+
leader_pos[name] = self.leader_arms[name].read("Present_Position")
|
| 458 |
+
leader_pos[name] = torch.from_numpy(leader_pos[name])
|
| 459 |
+
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t
|
| 460 |
+
|
| 461 |
+
# Send goal position to the follower
|
| 462 |
+
follower_goal_pos = {}
|
| 463 |
+
for name in self.follower_arms:
|
| 464 |
+
before_fwrite_t = time.perf_counter()
|
| 465 |
+
goal_pos = leader_pos[name]
|
| 466 |
+
|
| 467 |
+
# Cap goal position when too far away from present position.
|
| 468 |
+
# Slower fps expected due to reading from the follower.
|
| 469 |
+
if self.config.max_relative_target is not None:
|
| 470 |
+
present_pos = self.follower_arms[name].read("Present_Position")
|
| 471 |
+
present_pos = torch.from_numpy(present_pos)
|
| 472 |
+
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
|
| 473 |
+
|
| 474 |
+
# Used when record_data=True
|
| 475 |
+
follower_goal_pos[name] = goal_pos
|
| 476 |
+
|
| 477 |
+
goal_pos = goal_pos.numpy().astype(np.float32)
|
| 478 |
+
self.follower_arms[name].write("Goal_Position", goal_pos)
|
| 479 |
+
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
|
| 480 |
+
|
| 481 |
+
# Early exit when recording data is not requested
|
| 482 |
+
if not record_data:
|
| 483 |
+
return
|
| 484 |
+
|
| 485 |
+
# TODO(rcadene): Add velocity and other info
|
| 486 |
+
# Read follower position
|
| 487 |
+
follower_pos = {}
|
| 488 |
+
for name in self.follower_arms:
|
| 489 |
+
before_fread_t = time.perf_counter()
|
| 490 |
+
follower_pos[name] = self.follower_arms[name].read("Present_Position")
|
| 491 |
+
follower_pos[name] = torch.from_numpy(follower_pos[name])
|
| 492 |
+
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
|
| 493 |
+
|
| 494 |
+
# Create state by concatenating follower current position
|
| 495 |
+
state = []
|
| 496 |
+
for name in self.follower_arms:
|
| 497 |
+
if name in follower_pos:
|
| 498 |
+
state.append(follower_pos[name])
|
| 499 |
+
state = torch.cat(state)
|
| 500 |
+
|
| 501 |
+
# Create action by concatenating follower goal position
|
| 502 |
+
action = []
|
| 503 |
+
for name in self.follower_arms:
|
| 504 |
+
if name in follower_goal_pos:
|
| 505 |
+
action.append(follower_goal_pos[name])
|
| 506 |
+
action = torch.cat(action)
|
| 507 |
+
|
| 508 |
+
# Capture images from cameras
|
| 509 |
+
images = {}
|
| 510 |
+
for name in self.cameras:
|
| 511 |
+
before_camread_t = time.perf_counter()
|
| 512 |
+
images[name] = self.cameras[name].async_read()
|
| 513 |
+
images[name] = torch.from_numpy(images[name])
|
| 514 |
+
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
| 515 |
+
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
| 516 |
+
|
| 517 |
+
# Populate output dictionaries
|
| 518 |
+
obs_dict, action_dict = {}, {}
|
| 519 |
+
obs_dict["observation.state"] = state
|
| 520 |
+
action_dict["action"] = action
|
| 521 |
+
for name in self.cameras:
|
| 522 |
+
obs_dict[f"observation.images.{name}"] = images[name]
|
| 523 |
+
|
| 524 |
+
return obs_dict, action_dict
|
| 525 |
+
|
| 526 |
+
def capture_observation(self):
|
| 527 |
+
"""The returned observations do not have a batch dimension."""
|
| 528 |
+
if not self.is_connected:
|
| 529 |
+
raise RobotDeviceNotConnectedError(
|
| 530 |
+
"ManipulatorRobot is not connected. You need to run `robot.connect()`."
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
# Read follower position
|
| 534 |
+
follower_pos = {}
|
| 535 |
+
for name in self.follower_arms:
|
| 536 |
+
before_fread_t = time.perf_counter()
|
| 537 |
+
follower_pos[name] = self.follower_arms[name].read("Present_Position")
|
| 538 |
+
follower_pos[name] = torch.from_numpy(follower_pos[name])
|
| 539 |
+
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
|
| 540 |
+
|
| 541 |
+
# Create state by concatenating follower current position
|
| 542 |
+
state = []
|
| 543 |
+
for name in self.follower_arms:
|
| 544 |
+
if name in follower_pos:
|
| 545 |
+
state.append(follower_pos[name])
|
| 546 |
+
state = torch.cat(state)
|
| 547 |
+
|
| 548 |
+
# Capture images from cameras
|
| 549 |
+
images = {}
|
| 550 |
+
for name in self.cameras:
|
| 551 |
+
before_camread_t = time.perf_counter()
|
| 552 |
+
images[name] = self.cameras[name].async_read()
|
| 553 |
+
images[name] = torch.from_numpy(images[name])
|
| 554 |
+
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
| 555 |
+
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
| 556 |
+
|
| 557 |
+
# Populate output dictionaries and format to pytorch
|
| 558 |
+
obs_dict = {}
|
| 559 |
+
obs_dict["observation.state"] = state
|
| 560 |
+
for name in self.cameras:
|
| 561 |
+
obs_dict[f"observation.images.{name}"] = images[name]
|
| 562 |
+
return obs_dict
|
| 563 |
+
|
| 564 |
+
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
| 565 |
+
"""Command the follower arms to move to a target joint configuration.
|
| 566 |
+
|
| 567 |
+
The relative action magnitude may be clipped depending on the configuration parameter
|
| 568 |
+
`max_relative_target`. In this case, the action sent differs from original action.
|
| 569 |
+
Thus, this function always returns the action actually sent.
|
| 570 |
+
|
| 571 |
+
Args:
|
| 572 |
+
action: tensor containing the concatenated goal positions for the follower arms.
|
| 573 |
+
"""
|
| 574 |
+
if not self.is_connected:
|
| 575 |
+
raise RobotDeviceNotConnectedError(
|
| 576 |
+
"ManipulatorRobot is not connected. You need to run `robot.connect()`."
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
from_idx = 0
|
| 580 |
+
to_idx = 0
|
| 581 |
+
action_sent = []
|
| 582 |
+
for name in self.follower_arms:
|
| 583 |
+
# Get goal position of each follower arm by splitting the action vector
|
| 584 |
+
to_idx += len(self.follower_arms[name].motor_names)
|
| 585 |
+
goal_pos = action[from_idx:to_idx]
|
| 586 |
+
from_idx = to_idx
|
| 587 |
+
|
| 588 |
+
# Cap goal position when too far away from present position.
|
| 589 |
+
# Slower fps expected due to reading from the follower.
|
| 590 |
+
if self.config.max_relative_target is not None:
|
| 591 |
+
present_pos = self.follower_arms[name].read("Present_Position")
|
| 592 |
+
present_pos = torch.from_numpy(present_pos)
|
| 593 |
+
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
|
| 594 |
+
|
| 595 |
+
# Save tensor to concat and return
|
| 596 |
+
action_sent.append(goal_pos)
|
| 597 |
+
|
| 598 |
+
# Send goal position to each follower
|
| 599 |
+
goal_pos = goal_pos.numpy().astype(np.float32)
|
| 600 |
+
self.follower_arms[name].write("Goal_Position", goal_pos)
|
| 601 |
+
|
| 602 |
+
return torch.cat(action_sent)
|
| 603 |
+
|
| 604 |
+
def print_logs(self):
|
| 605 |
+
pass
|
| 606 |
+
# TODO(aliberts): move robot-specific logs logic here
|
| 607 |
+
|
| 608 |
+
def disconnect(self):
|
| 609 |
+
if not self.is_connected:
|
| 610 |
+
raise RobotDeviceNotConnectedError(
|
| 611 |
+
"ManipulatorRobot is not connected. You need to run `robot.connect()` before disconnecting."
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
for name in self.follower_arms:
|
| 615 |
+
self.follower_arms[name].disconnect()
|
| 616 |
+
|
| 617 |
+
for name in self.leader_arms:
|
| 618 |
+
self.leader_arms[name].disconnect()
|
| 619 |
+
|
| 620 |
+
for name in self.cameras:
|
| 621 |
+
self.cameras[name].disconnect()
|
| 622 |
+
|
| 623 |
+
self.is_connected = False
|
| 624 |
+
|
| 625 |
+
def __del__(self):
|
| 626 |
+
if getattr(self, "is_connected", False):
|
| 627 |
+
self.disconnect()
|
lerobot/common/robot_devices/robots/mobile_manipulator.py
ADDED
|
@@ -0,0 +1,703 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import base64
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import cv2
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
import zmq
|
| 25 |
+
|
| 26 |
+
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
| 27 |
+
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
| 28 |
+
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
|
| 29 |
+
from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig
|
| 30 |
+
from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration
|
| 31 |
+
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
| 32 |
+
from lerobot.common.robot_devices.utils import RobotDeviceNotConnectedError
|
| 33 |
+
|
| 34 |
+
PYNPUT_AVAILABLE = True
|
| 35 |
+
try:
|
| 36 |
+
# Only import if there's a valid X server or if we're not on a Pi
|
| 37 |
+
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
|
| 38 |
+
print("No DISPLAY set. Skipping pynput import.")
|
| 39 |
+
raise ImportError("pynput blocked intentionally due to no display.")
|
| 40 |
+
|
| 41 |
+
from pynput import keyboard
|
| 42 |
+
except ImportError:
|
| 43 |
+
keyboard = None
|
| 44 |
+
PYNPUT_AVAILABLE = False
|
| 45 |
+
except Exception as e:
|
| 46 |
+
keyboard = None
|
| 47 |
+
PYNPUT_AVAILABLE = False
|
| 48 |
+
print(f"Could not import pynput: {e}")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class MobileManipulator:
|
| 52 |
+
"""
|
| 53 |
+
MobileManipulator is a class for connecting to and controlling a remote mobile manipulator robot.
|
| 54 |
+
The robot includes a three omniwheel mobile base and a remote follower arm.
|
| 55 |
+
The leader arm is connected locally (on the laptop) and its joint positions are recorded and then
|
| 56 |
+
forwarded to the remote follower arm (after applying a safety clamp).
|
| 57 |
+
In parallel, keyboard teleoperation is used to generate raw velocity commands for the wheels.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, config: LeKiwiRobotConfig):
|
| 61 |
+
"""
|
| 62 |
+
Expected keys in config:
|
| 63 |
+
- ip, port, video_port for the remote connection.
|
| 64 |
+
- calibration_dir, leader_arms, follower_arms, max_relative_target, etc.
|
| 65 |
+
"""
|
| 66 |
+
self.robot_type = config.type
|
| 67 |
+
self.config = config
|
| 68 |
+
self.remote_ip = config.ip
|
| 69 |
+
self.remote_port = config.port
|
| 70 |
+
self.remote_port_video = config.video_port
|
| 71 |
+
self.calibration_dir = Path(self.config.calibration_dir)
|
| 72 |
+
self.logs = {}
|
| 73 |
+
|
| 74 |
+
self.teleop_keys = self.config.teleop_keys
|
| 75 |
+
|
| 76 |
+
# For teleoperation, the leader arm (local) is used to record the desired arm pose.
|
| 77 |
+
self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
|
| 78 |
+
|
| 79 |
+
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
|
| 80 |
+
|
| 81 |
+
self.cameras = make_cameras_from_configs(self.config.cameras)
|
| 82 |
+
|
| 83 |
+
self.is_connected = False
|
| 84 |
+
|
| 85 |
+
self.last_frames = {}
|
| 86 |
+
self.last_present_speed = {}
|
| 87 |
+
self.last_remote_arm_state = torch.zeros(6, dtype=torch.float32)
|
| 88 |
+
|
| 89 |
+
# Define three speed levels and a current index
|
| 90 |
+
self.speed_levels = [
|
| 91 |
+
{"xy": 0.1, "theta": 30}, # slow
|
| 92 |
+
{"xy": 0.2, "theta": 60}, # medium
|
| 93 |
+
{"xy": 0.3, "theta": 90}, # fast
|
| 94 |
+
]
|
| 95 |
+
self.speed_index = 0 # Start at slow
|
| 96 |
+
|
| 97 |
+
# ZeroMQ context and sockets.
|
| 98 |
+
self.context = None
|
| 99 |
+
self.cmd_socket = None
|
| 100 |
+
self.video_socket = None
|
| 101 |
+
|
| 102 |
+
# Keyboard state for base teleoperation.
|
| 103 |
+
self.running = True
|
| 104 |
+
self.pressed_keys = {
|
| 105 |
+
"forward": False,
|
| 106 |
+
"backward": False,
|
| 107 |
+
"left": False,
|
| 108 |
+
"right": False,
|
| 109 |
+
"rotate_left": False,
|
| 110 |
+
"rotate_right": False,
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
if PYNPUT_AVAILABLE:
|
| 114 |
+
print("pynput is available - enabling local keyboard listener.")
|
| 115 |
+
self.listener = keyboard.Listener(
|
| 116 |
+
on_press=self.on_press,
|
| 117 |
+
on_release=self.on_release,
|
| 118 |
+
)
|
| 119 |
+
self.listener.start()
|
| 120 |
+
else:
|
| 121 |
+
print("pynput not available - skipping local keyboard listener.")
|
| 122 |
+
self.listener = None
|
| 123 |
+
|
| 124 |
+
def get_motor_names(self, arms: dict[str, MotorsBus]) -> list:
|
| 125 |
+
return [f"{arm}_{motor}" for arm, bus in arms.items() for motor in bus.motors]
|
| 126 |
+
|
| 127 |
+
@property
|
| 128 |
+
def camera_features(self) -> dict:
|
| 129 |
+
cam_ft = {}
|
| 130 |
+
for cam_key, cam in self.cameras.items():
|
| 131 |
+
key = f"observation.images.{cam_key}"
|
| 132 |
+
cam_ft[key] = {
|
| 133 |
+
"shape": (cam.height, cam.width, cam.channels),
|
| 134 |
+
"names": ["height", "width", "channels"],
|
| 135 |
+
"info": None,
|
| 136 |
+
}
|
| 137 |
+
return cam_ft
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def motor_features(self) -> dict:
|
| 141 |
+
follower_arm_names = [
|
| 142 |
+
"shoulder_pan",
|
| 143 |
+
"shoulder_lift",
|
| 144 |
+
"elbow_flex",
|
| 145 |
+
"wrist_flex",
|
| 146 |
+
"wrist_roll",
|
| 147 |
+
"gripper",
|
| 148 |
+
]
|
| 149 |
+
observations = ["x_mm", "y_mm", "theta"]
|
| 150 |
+
combined_names = follower_arm_names + observations
|
| 151 |
+
return {
|
| 152 |
+
"action": {
|
| 153 |
+
"dtype": "float32",
|
| 154 |
+
"shape": (len(combined_names),),
|
| 155 |
+
"names": combined_names,
|
| 156 |
+
},
|
| 157 |
+
"observation.state": {
|
| 158 |
+
"dtype": "float32",
|
| 159 |
+
"shape": (len(combined_names),),
|
| 160 |
+
"names": combined_names,
|
| 161 |
+
},
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
@property
|
| 165 |
+
def features(self):
|
| 166 |
+
return {**self.motor_features, **self.camera_features}
|
| 167 |
+
|
| 168 |
+
@property
|
| 169 |
+
def has_camera(self):
|
| 170 |
+
return len(self.cameras) > 0
|
| 171 |
+
|
| 172 |
+
@property
|
| 173 |
+
def num_cameras(self):
|
| 174 |
+
return len(self.cameras)
|
| 175 |
+
|
| 176 |
+
@property
|
| 177 |
+
def available_arms(self):
|
| 178 |
+
available = []
|
| 179 |
+
for name in self.leader_arms:
|
| 180 |
+
available.append(get_arm_id(name, "leader"))
|
| 181 |
+
for name in self.follower_arms:
|
| 182 |
+
available.append(get_arm_id(name, "follower"))
|
| 183 |
+
return available
|
| 184 |
+
|
| 185 |
+
def on_press(self, key):
|
| 186 |
+
try:
|
| 187 |
+
# Movement
|
| 188 |
+
if key.char == self.teleop_keys["forward"]:
|
| 189 |
+
self.pressed_keys["forward"] = True
|
| 190 |
+
elif key.char == self.teleop_keys["backward"]:
|
| 191 |
+
self.pressed_keys["backward"] = True
|
| 192 |
+
elif key.char == self.teleop_keys["left"]:
|
| 193 |
+
self.pressed_keys["left"] = True
|
| 194 |
+
elif key.char == self.teleop_keys["right"]:
|
| 195 |
+
self.pressed_keys["right"] = True
|
| 196 |
+
elif key.char == self.teleop_keys["rotate_left"]:
|
| 197 |
+
self.pressed_keys["rotate_left"] = True
|
| 198 |
+
elif key.char == self.teleop_keys["rotate_right"]:
|
| 199 |
+
self.pressed_keys["rotate_right"] = True
|
| 200 |
+
|
| 201 |
+
# Quit teleoperation
|
| 202 |
+
elif key.char == self.teleop_keys["quit"]:
|
| 203 |
+
self.running = False
|
| 204 |
+
return False
|
| 205 |
+
|
| 206 |
+
# Speed control
|
| 207 |
+
elif key.char == self.teleop_keys["speed_up"]:
|
| 208 |
+
self.speed_index = min(self.speed_index + 1, 2)
|
| 209 |
+
print(f"Speed index increased to {self.speed_index}")
|
| 210 |
+
elif key.char == self.teleop_keys["speed_down"]:
|
| 211 |
+
self.speed_index = max(self.speed_index - 1, 0)
|
| 212 |
+
print(f"Speed index decreased to {self.speed_index}")
|
| 213 |
+
|
| 214 |
+
except AttributeError:
|
| 215 |
+
# e.g., if key is special like Key.esc
|
| 216 |
+
if key == keyboard.Key.esc:
|
| 217 |
+
self.running = False
|
| 218 |
+
return False
|
| 219 |
+
|
| 220 |
+
def on_release(self, key):
|
| 221 |
+
try:
|
| 222 |
+
if hasattr(key, "char"):
|
| 223 |
+
if key.char == self.teleop_keys["forward"]:
|
| 224 |
+
self.pressed_keys["forward"] = False
|
| 225 |
+
elif key.char == self.teleop_keys["backward"]:
|
| 226 |
+
self.pressed_keys["backward"] = False
|
| 227 |
+
elif key.char == self.teleop_keys["left"]:
|
| 228 |
+
self.pressed_keys["left"] = False
|
| 229 |
+
elif key.char == self.teleop_keys["right"]:
|
| 230 |
+
self.pressed_keys["right"] = False
|
| 231 |
+
elif key.char == self.teleop_keys["rotate_left"]:
|
| 232 |
+
self.pressed_keys["rotate_left"] = False
|
| 233 |
+
elif key.char == self.teleop_keys["rotate_right"]:
|
| 234 |
+
self.pressed_keys["rotate_right"] = False
|
| 235 |
+
except AttributeError:
|
| 236 |
+
pass
|
| 237 |
+
|
| 238 |
+
def connect(self):
|
| 239 |
+
if not self.leader_arms:
|
| 240 |
+
raise ValueError("MobileManipulator has no leader arm to connect.")
|
| 241 |
+
for name in self.leader_arms:
|
| 242 |
+
print(f"Connecting {name} leader arm.")
|
| 243 |
+
self.calibrate_leader()
|
| 244 |
+
|
| 245 |
+
# Set up ZeroMQ sockets to communicate with the remote mobile robot.
|
| 246 |
+
self.context = zmq.Context()
|
| 247 |
+
self.cmd_socket = self.context.socket(zmq.PUSH)
|
| 248 |
+
connection_string = f"tcp://{self.remote_ip}:{self.remote_port}"
|
| 249 |
+
self.cmd_socket.connect(connection_string)
|
| 250 |
+
self.cmd_socket.setsockopt(zmq.CONFLATE, 1)
|
| 251 |
+
self.video_socket = self.context.socket(zmq.PULL)
|
| 252 |
+
video_connection = f"tcp://{self.remote_ip}:{self.remote_port_video}"
|
| 253 |
+
self.video_socket.connect(video_connection)
|
| 254 |
+
self.video_socket.setsockopt(zmq.CONFLATE, 1)
|
| 255 |
+
print(
|
| 256 |
+
f"[INFO] Connected to remote robot at {connection_string} and video stream at {video_connection}."
|
| 257 |
+
)
|
| 258 |
+
self.is_connected = True
|
| 259 |
+
|
| 260 |
+
def load_or_run_calibration_(self, name, arm, arm_type):
|
| 261 |
+
arm_id = get_arm_id(name, arm_type)
|
| 262 |
+
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
|
| 263 |
+
|
| 264 |
+
if arm_calib_path.exists():
|
| 265 |
+
with open(arm_calib_path) as f:
|
| 266 |
+
calibration = json.load(f)
|
| 267 |
+
else:
|
| 268 |
+
print(f"Missing calibration file '{arm_calib_path}'")
|
| 269 |
+
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
| 270 |
+
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
| 271 |
+
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
| 272 |
+
with open(arm_calib_path, "w") as f:
|
| 273 |
+
json.dump(calibration, f)
|
| 274 |
+
|
| 275 |
+
return calibration
|
| 276 |
+
|
| 277 |
+
def calibrate_leader(self):
|
| 278 |
+
for name, arm in self.leader_arms.items():
|
| 279 |
+
# Connect the bus
|
| 280 |
+
arm.connect()
|
| 281 |
+
|
| 282 |
+
# Disable torque on all motors
|
| 283 |
+
for motor_id in arm.motors:
|
| 284 |
+
arm.write("Torque_Enable", TorqueMode.DISABLED.value, motor_id)
|
| 285 |
+
|
| 286 |
+
# Now run calibration
|
| 287 |
+
calibration = self.load_or_run_calibration_(name, arm, "leader")
|
| 288 |
+
arm.set_calibration(calibration)
|
| 289 |
+
|
| 290 |
+
def calibrate_follower(self):
|
| 291 |
+
for name, bus in self.follower_arms.items():
|
| 292 |
+
bus.connect()
|
| 293 |
+
|
| 294 |
+
# Disable torque on all motors
|
| 295 |
+
for motor_id in bus.motors:
|
| 296 |
+
bus.write("Torque_Enable", 0, motor_id)
|
| 297 |
+
|
| 298 |
+
# Then filter out wheels
|
| 299 |
+
arm_only_dict = {k: v for k, v in bus.motors.items() if not k.startswith("wheel_")}
|
| 300 |
+
if not arm_only_dict:
|
| 301 |
+
continue
|
| 302 |
+
|
| 303 |
+
original_motors = bus.motors
|
| 304 |
+
bus.motors = arm_only_dict
|
| 305 |
+
|
| 306 |
+
calibration = self.load_or_run_calibration_(name, bus, "follower")
|
| 307 |
+
bus.set_calibration(calibration)
|
| 308 |
+
|
| 309 |
+
bus.motors = original_motors
|
| 310 |
+
|
| 311 |
+
def _get_data(self):
|
| 312 |
+
"""
|
| 313 |
+
Polls the video socket for up to 15 ms. If data arrives, decode only
|
| 314 |
+
the *latest* message, returning frames, speed, and arm state. If
|
| 315 |
+
nothing arrives for any field, use the last known values.
|
| 316 |
+
"""
|
| 317 |
+
frames = {}
|
| 318 |
+
present_speed = {}
|
| 319 |
+
remote_arm_state_tensor = torch.zeros(6, dtype=torch.float32)
|
| 320 |
+
|
| 321 |
+
# Poll up to 15 ms
|
| 322 |
+
poller = zmq.Poller()
|
| 323 |
+
poller.register(self.video_socket, zmq.POLLIN)
|
| 324 |
+
socks = dict(poller.poll(15))
|
| 325 |
+
if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN:
|
| 326 |
+
# No new data arrived → reuse ALL old data
|
| 327 |
+
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
| 328 |
+
|
| 329 |
+
# Drain all messages, keep only the last
|
| 330 |
+
last_msg = None
|
| 331 |
+
while True:
|
| 332 |
+
try:
|
| 333 |
+
obs_string = self.video_socket.recv_string(zmq.NOBLOCK)
|
| 334 |
+
last_msg = obs_string
|
| 335 |
+
except zmq.Again:
|
| 336 |
+
break
|
| 337 |
+
|
| 338 |
+
if not last_msg:
|
| 339 |
+
# No new message → also reuse old
|
| 340 |
+
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
| 341 |
+
|
| 342 |
+
# Decode only the final message
|
| 343 |
+
try:
|
| 344 |
+
observation = json.loads(last_msg)
|
| 345 |
+
|
| 346 |
+
images_dict = observation.get("images", {})
|
| 347 |
+
new_speed = observation.get("present_speed", {})
|
| 348 |
+
new_arm_state = observation.get("follower_arm_state", None)
|
| 349 |
+
|
| 350 |
+
# Convert images
|
| 351 |
+
for cam_name, image_b64 in images_dict.items():
|
| 352 |
+
if image_b64:
|
| 353 |
+
jpg_data = base64.b64decode(image_b64)
|
| 354 |
+
np_arr = np.frombuffer(jpg_data, dtype=np.uint8)
|
| 355 |
+
frame_candidate = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
|
| 356 |
+
if frame_candidate is not None:
|
| 357 |
+
frames[cam_name] = frame_candidate
|
| 358 |
+
|
| 359 |
+
# If remote_arm_state is None and frames is None there is no message then use the previous message
|
| 360 |
+
if new_arm_state is not None and frames is not None:
|
| 361 |
+
self.last_frames = frames
|
| 362 |
+
|
| 363 |
+
remote_arm_state_tensor = torch.tensor(new_arm_state, dtype=torch.float32)
|
| 364 |
+
self.last_remote_arm_state = remote_arm_state_tensor
|
| 365 |
+
|
| 366 |
+
present_speed = new_speed
|
| 367 |
+
self.last_present_speed = new_speed
|
| 368 |
+
else:
|
| 369 |
+
frames = self.last_frames
|
| 370 |
+
|
| 371 |
+
remote_arm_state_tensor = self.last_remote_arm_state
|
| 372 |
+
|
| 373 |
+
present_speed = self.last_present_speed
|
| 374 |
+
|
| 375 |
+
except Exception as e:
|
| 376 |
+
print(f"[DEBUG] Error decoding video message: {e}")
|
| 377 |
+
# If decode fails, fall back to old data
|
| 378 |
+
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
| 379 |
+
|
| 380 |
+
return frames, present_speed, remote_arm_state_tensor
|
| 381 |
+
|
| 382 |
+
def _process_present_speed(self, present_speed: dict) -> torch.Tensor:
|
| 383 |
+
state_tensor = torch.zeros(3, dtype=torch.int32)
|
| 384 |
+
if present_speed:
|
| 385 |
+
decoded = {key: MobileManipulator.raw_to_degps(value) for key, value in present_speed.items()}
|
| 386 |
+
if "1" in decoded:
|
| 387 |
+
state_tensor[0] = decoded["1"]
|
| 388 |
+
if "2" in decoded:
|
| 389 |
+
state_tensor[1] = decoded["2"]
|
| 390 |
+
if "3" in decoded:
|
| 391 |
+
state_tensor[2] = decoded["3"]
|
| 392 |
+
return state_tensor
|
| 393 |
+
|
| 394 |
+
def teleop_step(
|
| 395 |
+
self, record_data: bool = False
|
| 396 |
+
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
| 397 |
+
if not self.is_connected:
|
| 398 |
+
raise RobotDeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.")
|
| 399 |
+
|
| 400 |
+
speed_setting = self.speed_levels[self.speed_index]
|
| 401 |
+
xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4
|
| 402 |
+
theta_speed = speed_setting["theta"] # e.g. 30, 60, or 90
|
| 403 |
+
|
| 404 |
+
# Prepare to assign the position of the leader to the follower
|
| 405 |
+
arm_positions = []
|
| 406 |
+
for name in self.leader_arms:
|
| 407 |
+
pos = self.leader_arms[name].read("Present_Position")
|
| 408 |
+
pos_tensor = torch.from_numpy(pos).float()
|
| 409 |
+
arm_positions.extend(pos_tensor.tolist())
|
| 410 |
+
|
| 411 |
+
y_cmd = 0.0 # m/s forward/backward
|
| 412 |
+
x_cmd = 0.0 # m/s lateral
|
| 413 |
+
theta_cmd = 0.0 # deg/s rotation
|
| 414 |
+
if self.pressed_keys["forward"]:
|
| 415 |
+
y_cmd += xy_speed
|
| 416 |
+
if self.pressed_keys["backward"]:
|
| 417 |
+
y_cmd -= xy_speed
|
| 418 |
+
if self.pressed_keys["left"]:
|
| 419 |
+
x_cmd += xy_speed
|
| 420 |
+
if self.pressed_keys["right"]:
|
| 421 |
+
x_cmd -= xy_speed
|
| 422 |
+
if self.pressed_keys["rotate_left"]:
|
| 423 |
+
theta_cmd += theta_speed
|
| 424 |
+
if self.pressed_keys["rotate_right"]:
|
| 425 |
+
theta_cmd -= theta_speed
|
| 426 |
+
|
| 427 |
+
wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
|
| 428 |
+
|
| 429 |
+
message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions}
|
| 430 |
+
self.cmd_socket.send_string(json.dumps(message))
|
| 431 |
+
|
| 432 |
+
if not record_data:
|
| 433 |
+
return
|
| 434 |
+
|
| 435 |
+
obs_dict = self.capture_observation()
|
| 436 |
+
|
| 437 |
+
arm_state_tensor = torch.tensor(arm_positions, dtype=torch.float32)
|
| 438 |
+
|
| 439 |
+
wheel_velocity_tuple = self.wheel_raw_to_body(wheel_commands)
|
| 440 |
+
wheel_velocity_mm = (
|
| 441 |
+
wheel_velocity_tuple[0] * 1000.0,
|
| 442 |
+
wheel_velocity_tuple[1] * 1000.0,
|
| 443 |
+
wheel_velocity_tuple[2],
|
| 444 |
+
)
|
| 445 |
+
wheel_tensor = torch.tensor(wheel_velocity_mm, dtype=torch.float32)
|
| 446 |
+
action_tensor = torch.cat([arm_state_tensor, wheel_tensor])
|
| 447 |
+
action_dict = {"action": action_tensor}
|
| 448 |
+
|
| 449 |
+
return obs_dict, action_dict
|
| 450 |
+
|
| 451 |
+
def capture_observation(self) -> dict:
|
| 452 |
+
"""
|
| 453 |
+
Capture observations from the remote robot: current follower arm positions,
|
| 454 |
+
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
| 455 |
+
and a camera frame.
|
| 456 |
+
"""
|
| 457 |
+
if not self.is_connected:
|
| 458 |
+
raise RobotDeviceNotConnectedError("Not connected. Run `connect()` first.")
|
| 459 |
+
|
| 460 |
+
frames, present_speed, remote_arm_state_tensor = self._get_data()
|
| 461 |
+
|
| 462 |
+
body_state = self.wheel_raw_to_body(present_speed)
|
| 463 |
+
|
| 464 |
+
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
|
| 465 |
+
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
|
| 466 |
+
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
|
| 467 |
+
|
| 468 |
+
obs_dict = {"observation.state": combined_state_tensor}
|
| 469 |
+
|
| 470 |
+
# Loop over each configured camera
|
| 471 |
+
for cam_name, cam in self.cameras.items():
|
| 472 |
+
frame = frames.get(cam_name, None)
|
| 473 |
+
if frame is None:
|
| 474 |
+
# Create a black image using the camera's configured width, height, and channels
|
| 475 |
+
frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8)
|
| 476 |
+
obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(frame)
|
| 477 |
+
|
| 478 |
+
return obs_dict
|
| 479 |
+
|
| 480 |
+
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
| 481 |
+
if not self.is_connected:
|
| 482 |
+
raise RobotDeviceNotConnectedError("Not connected. Run `connect()` first.")
|
| 483 |
+
|
| 484 |
+
# Ensure the action tensor has at least 9 elements:
|
| 485 |
+
# - First 6: arm positions.
|
| 486 |
+
# - Last 3: base commands.
|
| 487 |
+
if action.numel() < 9:
|
| 488 |
+
# Pad with zeros if there are not enough elements.
|
| 489 |
+
padded = torch.zeros(9, dtype=action.dtype)
|
| 490 |
+
padded[: action.numel()] = action
|
| 491 |
+
action = padded
|
| 492 |
+
|
| 493 |
+
# Extract arm and base actions.
|
| 494 |
+
arm_actions = action[:6].flatten()
|
| 495 |
+
base_actions = action[6:].flatten()
|
| 496 |
+
|
| 497 |
+
x_cmd_mm = base_actions[0].item() # mm/s
|
| 498 |
+
y_cmd_mm = base_actions[1].item() # mm/s
|
| 499 |
+
theta_cmd = base_actions[2].item() # deg/s
|
| 500 |
+
|
| 501 |
+
# Convert mm/s to m/s for the kinematics calculations.
|
| 502 |
+
x_cmd = x_cmd_mm / 1000.0 # m/s
|
| 503 |
+
y_cmd = y_cmd_mm / 1000.0 # m/s
|
| 504 |
+
|
| 505 |
+
# Compute wheel commands from body commands.
|
| 506 |
+
wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
|
| 507 |
+
|
| 508 |
+
arm_positions_list = arm_actions.tolist()
|
| 509 |
+
|
| 510 |
+
message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions_list}
|
| 511 |
+
self.cmd_socket.send_string(json.dumps(message))
|
| 512 |
+
|
| 513 |
+
return action
|
| 514 |
+
|
| 515 |
+
def print_logs(self):
|
| 516 |
+
pass
|
| 517 |
+
|
| 518 |
+
def disconnect(self):
|
| 519 |
+
if not self.is_connected:
|
| 520 |
+
raise RobotDeviceNotConnectedError("Not connected.")
|
| 521 |
+
if self.cmd_socket:
|
| 522 |
+
stop_cmd = {
|
| 523 |
+
"raw_velocity": {"left_wheel": 0, "back_wheel": 0, "right_wheel": 0},
|
| 524 |
+
"arm_positions": {},
|
| 525 |
+
}
|
| 526 |
+
self.cmd_socket.send_string(json.dumps(stop_cmd))
|
| 527 |
+
self.cmd_socket.close()
|
| 528 |
+
if self.video_socket:
|
| 529 |
+
self.video_socket.close()
|
| 530 |
+
if self.context:
|
| 531 |
+
self.context.term()
|
| 532 |
+
if PYNPUT_AVAILABLE:
|
| 533 |
+
self.listener.stop()
|
| 534 |
+
self.is_connected = False
|
| 535 |
+
print("[INFO] Disconnected from remote robot.")
|
| 536 |
+
|
| 537 |
+
def __del__(self):
|
| 538 |
+
if getattr(self, "is_connected", False):
|
| 539 |
+
self.disconnect()
|
| 540 |
+
if PYNPUT_AVAILABLE:
|
| 541 |
+
self.listener.stop()
|
| 542 |
+
|
| 543 |
+
@staticmethod
|
| 544 |
+
def degps_to_raw(degps: float) -> int:
|
| 545 |
+
steps_per_deg = 4096.0 / 360.0
|
| 546 |
+
speed_in_steps = abs(degps) * steps_per_deg
|
| 547 |
+
speed_int = int(round(speed_in_steps))
|
| 548 |
+
if speed_int > 0x7FFF:
|
| 549 |
+
speed_int = 0x7FFF
|
| 550 |
+
if degps < 0:
|
| 551 |
+
return speed_int | 0x8000
|
| 552 |
+
else:
|
| 553 |
+
return speed_int & 0x7FFF
|
| 554 |
+
|
| 555 |
+
@staticmethod
|
| 556 |
+
def raw_to_degps(raw_speed: int) -> float:
|
| 557 |
+
steps_per_deg = 4096.0 / 360.0
|
| 558 |
+
magnitude = raw_speed & 0x7FFF
|
| 559 |
+
degps = magnitude / steps_per_deg
|
| 560 |
+
if raw_speed & 0x8000:
|
| 561 |
+
degps = -degps
|
| 562 |
+
return degps
|
| 563 |
+
|
| 564 |
+
def body_to_wheel_raw(
|
| 565 |
+
self,
|
| 566 |
+
x_cmd: float,
|
| 567 |
+
y_cmd: float,
|
| 568 |
+
theta_cmd: float,
|
| 569 |
+
wheel_radius: float = 0.05,
|
| 570 |
+
base_radius: float = 0.125,
|
| 571 |
+
max_raw: int = 3000,
|
| 572 |
+
) -> dict:
|
| 573 |
+
"""
|
| 574 |
+
Convert desired body-frame velocities into wheel raw commands.
|
| 575 |
+
|
| 576 |
+
Parameters:
|
| 577 |
+
x_cmd : Linear velocity in x (m/s).
|
| 578 |
+
y_cmd : Linear velocity in y (m/s).
|
| 579 |
+
theta_cmd : Rotational velocity (deg/s).
|
| 580 |
+
wheel_radius: Radius of each wheel (meters).
|
| 581 |
+
base_radius : Distance from the center of rotation to each wheel (meters).
|
| 582 |
+
max_raw : Maximum allowed raw command (ticks) per wheel.
|
| 583 |
+
|
| 584 |
+
Returns:
|
| 585 |
+
A dictionary with wheel raw commands:
|
| 586 |
+
{"left_wheel": value, "back_wheel": value, "right_wheel": value}.
|
| 587 |
+
|
| 588 |
+
Notes:
|
| 589 |
+
- Internally, the method converts theta_cmd to rad/s for the kinematics.
|
| 590 |
+
- The raw command is computed from the wheels angular speed in deg/s
|
| 591 |
+
using degps_to_raw(). If any command exceeds max_raw, all commands
|
| 592 |
+
are scaled down proportionally.
|
| 593 |
+
"""
|
| 594 |
+
# Convert rotational velocity from deg/s to rad/s.
|
| 595 |
+
theta_rad = theta_cmd * (np.pi / 180.0)
|
| 596 |
+
# Create the body velocity vector [x, y, theta_rad].
|
| 597 |
+
velocity_vector = np.array([x_cmd, y_cmd, theta_rad])
|
| 598 |
+
|
| 599 |
+
# Define the wheel mounting angles (defined from y axis cw)
|
| 600 |
+
angles = np.radians(np.array([300, 180, 60]))
|
| 601 |
+
# Build the kinematic matrix: each row maps body velocities to a wheel’s linear speed.
|
| 602 |
+
# The third column (base_radius) accounts for the effect of rotation.
|
| 603 |
+
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
|
| 604 |
+
|
| 605 |
+
# Compute each wheel’s linear speed (m/s) and then its angular speed (rad/s).
|
| 606 |
+
wheel_linear_speeds = m.dot(velocity_vector)
|
| 607 |
+
wheel_angular_speeds = wheel_linear_speeds / wheel_radius
|
| 608 |
+
|
| 609 |
+
# Convert wheel angular speeds from rad/s to deg/s.
|
| 610 |
+
wheel_degps = wheel_angular_speeds * (180.0 / np.pi)
|
| 611 |
+
|
| 612 |
+
# Scaling
|
| 613 |
+
steps_per_deg = 4096.0 / 360.0
|
| 614 |
+
raw_floats = [abs(degps) * steps_per_deg for degps in wheel_degps]
|
| 615 |
+
max_raw_computed = max(raw_floats)
|
| 616 |
+
if max_raw_computed > max_raw:
|
| 617 |
+
scale = max_raw / max_raw_computed
|
| 618 |
+
wheel_degps = wheel_degps * scale
|
| 619 |
+
|
| 620 |
+
# Convert each wheel’s angular speed (deg/s) to a raw integer.
|
| 621 |
+
wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps]
|
| 622 |
+
|
| 623 |
+
return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]}
|
| 624 |
+
|
| 625 |
+
def wheel_raw_to_body(
|
| 626 |
+
self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125
|
| 627 |
+
) -> tuple:
|
| 628 |
+
"""
|
| 629 |
+
Convert wheel raw command feedback back into body-frame velocities.
|
| 630 |
+
|
| 631 |
+
Parameters:
|
| 632 |
+
wheel_raw : Dictionary with raw wheel commands (keys: "left_wheel", "back_wheel", "right_wheel").
|
| 633 |
+
wheel_radius: Radius of each wheel (meters).
|
| 634 |
+
base_radius : Distance from the robot center to each wheel (meters).
|
| 635 |
+
|
| 636 |
+
Returns:
|
| 637 |
+
A tuple (x_cmd, y_cmd, theta_cmd) where:
|
| 638 |
+
x_cmd : Linear velocity in x (m/s).
|
| 639 |
+
y_cmd : Linear velocity in y (m/s).
|
| 640 |
+
theta_cmd : Rotational velocity in deg/s.
|
| 641 |
+
"""
|
| 642 |
+
# Extract the raw values in order.
|
| 643 |
+
raw_list = [
|
| 644 |
+
int(wheel_raw.get("left_wheel", 0)),
|
| 645 |
+
int(wheel_raw.get("back_wheel", 0)),
|
| 646 |
+
int(wheel_raw.get("right_wheel", 0)),
|
| 647 |
+
]
|
| 648 |
+
|
| 649 |
+
# Convert each raw command back to an angular speed in deg/s.
|
| 650 |
+
wheel_degps = np.array([MobileManipulator.raw_to_degps(r) for r in raw_list])
|
| 651 |
+
# Convert from deg/s to rad/s.
|
| 652 |
+
wheel_radps = wheel_degps * (np.pi / 180.0)
|
| 653 |
+
# Compute each wheel’s linear speed (m/s) from its angular speed.
|
| 654 |
+
wheel_linear_speeds = wheel_radps * wheel_radius
|
| 655 |
+
|
| 656 |
+
# Define the wheel mounting angles (defined from y axis cw)
|
| 657 |
+
angles = np.radians(np.array([300, 180, 60]))
|
| 658 |
+
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
|
| 659 |
+
|
| 660 |
+
# Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds.
|
| 661 |
+
m_inv = np.linalg.inv(m)
|
| 662 |
+
velocity_vector = m_inv.dot(wheel_linear_speeds)
|
| 663 |
+
x_cmd, y_cmd, theta_rad = velocity_vector
|
| 664 |
+
theta_cmd = theta_rad * (180.0 / np.pi)
|
| 665 |
+
return (x_cmd, y_cmd, theta_cmd)
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
class LeKiwi:
|
| 669 |
+
def __init__(self, motor_bus):
|
| 670 |
+
"""
|
| 671 |
+
Initializes the LeKiwi with Feetech motors bus.
|
| 672 |
+
"""
|
| 673 |
+
self.motor_bus = motor_bus
|
| 674 |
+
self.motor_ids = ["left_wheel", "back_wheel", "right_wheel"]
|
| 675 |
+
|
| 676 |
+
# Initialize motors in velocity mode.
|
| 677 |
+
self.motor_bus.write("Lock", 0)
|
| 678 |
+
self.motor_bus.write("Mode", [1, 1, 1], self.motor_ids)
|
| 679 |
+
self.motor_bus.write("Lock", 1)
|
| 680 |
+
print("Motors set to velocity mode.")
|
| 681 |
+
|
| 682 |
+
def read_velocity(self):
|
| 683 |
+
"""
|
| 684 |
+
Reads the raw speeds for all wheels. Returns a dictionary with motor names:
|
| 685 |
+
"""
|
| 686 |
+
raw_speeds = self.motor_bus.read("Present_Speed", self.motor_ids)
|
| 687 |
+
return {
|
| 688 |
+
"left_wheel": int(raw_speeds[0]),
|
| 689 |
+
"back_wheel": int(raw_speeds[1]),
|
| 690 |
+
"right_wheel": int(raw_speeds[2]),
|
| 691 |
+
}
|
| 692 |
+
|
| 693 |
+
def set_velocity(self, command_speeds):
|
| 694 |
+
"""
|
| 695 |
+
Sends raw velocity commands (16-bit encoded values) directly to the motor bus.
|
| 696 |
+
The order of speeds must correspond to self.motor_ids.
|
| 697 |
+
"""
|
| 698 |
+
self.motor_bus.write("Goal_Speed", command_speeds, self.motor_ids)
|
| 699 |
+
|
| 700 |
+
def stop(self):
|
| 701 |
+
"""Stops the robot by setting all motor speeds to zero."""
|
| 702 |
+
self.motor_bus.write("Goal_Speed", [0, 0, 0], self.motor_ids)
|
| 703 |
+
print("Motors stopped.")
|
lerobot/common/robot_devices/robots/stretch.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import time
|
| 18 |
+
from dataclasses import replace
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from stretch_body.gamepad_teleop import GamePadTeleop
|
| 22 |
+
from stretch_body.robot import Robot as StretchAPI
|
| 23 |
+
from stretch_body.robot_params import RobotParams
|
| 24 |
+
|
| 25 |
+
from lerobot.common.robot_devices.robots.configs import StretchRobotConfig
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class StretchRobot(StretchAPI):
|
| 29 |
+
"""Wrapper of stretch_body.robot.Robot"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, config: StretchRobotConfig | None = None, **kwargs):
|
| 32 |
+
super().__init__()
|
| 33 |
+
if config is None:
|
| 34 |
+
self.config = StretchRobotConfig(**kwargs)
|
| 35 |
+
else:
|
| 36 |
+
# Overwrite config arguments using kwargs
|
| 37 |
+
self.config = replace(config, **kwargs)
|
| 38 |
+
|
| 39 |
+
self.robot_type = self.config.type
|
| 40 |
+
self.cameras = self.config.cameras
|
| 41 |
+
self.is_connected = False
|
| 42 |
+
self.teleop = None
|
| 43 |
+
self.logs = {}
|
| 44 |
+
|
| 45 |
+
# TODO(aliberts): test this
|
| 46 |
+
RobotParams.set_logging_level("WARNING")
|
| 47 |
+
RobotParams.set_logging_formatter("brief_console_formatter")
|
| 48 |
+
|
| 49 |
+
self.state_keys = None
|
| 50 |
+
self.action_keys = None
|
| 51 |
+
|
| 52 |
+
def connect(self) -> None:
|
| 53 |
+
self.is_connected = self.startup()
|
| 54 |
+
if not self.is_connected:
|
| 55 |
+
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
|
| 56 |
+
raise ConnectionError()
|
| 57 |
+
|
| 58 |
+
for name in self.cameras:
|
| 59 |
+
self.cameras[name].connect()
|
| 60 |
+
self.is_connected = self.is_connected and self.cameras[name].is_connected
|
| 61 |
+
|
| 62 |
+
if not self.is_connected:
|
| 63 |
+
print("Could not connect to the cameras, check that all cameras are plugged-in.")
|
| 64 |
+
raise ConnectionError()
|
| 65 |
+
|
| 66 |
+
self.run_calibration()
|
| 67 |
+
|
| 68 |
+
def run_calibration(self) -> None:
|
| 69 |
+
if not self.is_homed():
|
| 70 |
+
self.home()
|
| 71 |
+
|
| 72 |
+
def teleop_step(
|
| 73 |
+
self, record_data=False
|
| 74 |
+
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
| 75 |
+
# TODO(aliberts): return ndarrays instead of torch.Tensors
|
| 76 |
+
if not self.is_connected:
|
| 77 |
+
raise ConnectionError()
|
| 78 |
+
|
| 79 |
+
if self.teleop is None:
|
| 80 |
+
self.teleop = GamePadTeleop(robot_instance=False)
|
| 81 |
+
self.teleop.startup(robot=self)
|
| 82 |
+
|
| 83 |
+
before_read_t = time.perf_counter()
|
| 84 |
+
state = self.get_state()
|
| 85 |
+
action = self.teleop.gamepad_controller.get_state()
|
| 86 |
+
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
| 87 |
+
|
| 88 |
+
before_write_t = time.perf_counter()
|
| 89 |
+
self.teleop.do_motion(robot=self)
|
| 90 |
+
self.push_command()
|
| 91 |
+
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
| 92 |
+
|
| 93 |
+
if self.state_keys is None:
|
| 94 |
+
self.state_keys = list(state)
|
| 95 |
+
|
| 96 |
+
if not record_data:
|
| 97 |
+
return
|
| 98 |
+
|
| 99 |
+
state = torch.as_tensor(list(state.values()))
|
| 100 |
+
action = torch.as_tensor(list(action.values()))
|
| 101 |
+
|
| 102 |
+
# Capture images from cameras
|
| 103 |
+
images = {}
|
| 104 |
+
for name in self.cameras:
|
| 105 |
+
before_camread_t = time.perf_counter()
|
| 106 |
+
images[name] = self.cameras[name].async_read()
|
| 107 |
+
images[name] = torch.from_numpy(images[name])
|
| 108 |
+
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
| 109 |
+
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
| 110 |
+
|
| 111 |
+
# Populate output dictionaries
|
| 112 |
+
obs_dict, action_dict = {}, {}
|
| 113 |
+
obs_dict["observation.state"] = state
|
| 114 |
+
action_dict["action"] = action
|
| 115 |
+
for name in self.cameras:
|
| 116 |
+
obs_dict[f"observation.images.{name}"] = images[name]
|
| 117 |
+
|
| 118 |
+
return obs_dict, action_dict
|
| 119 |
+
|
| 120 |
+
def get_state(self) -> dict:
|
| 121 |
+
status = self.get_status()
|
| 122 |
+
return {
|
| 123 |
+
"head_pan.pos": status["head"]["head_pan"]["pos"],
|
| 124 |
+
"head_tilt.pos": status["head"]["head_tilt"]["pos"],
|
| 125 |
+
"lift.pos": status["lift"]["pos"],
|
| 126 |
+
"arm.pos": status["arm"]["pos"],
|
| 127 |
+
"wrist_pitch.pos": status["end_of_arm"]["wrist_pitch"]["pos"],
|
| 128 |
+
"wrist_roll.pos": status["end_of_arm"]["wrist_roll"]["pos"],
|
| 129 |
+
"wrist_yaw.pos": status["end_of_arm"]["wrist_yaw"]["pos"],
|
| 130 |
+
"gripper.pos": status["end_of_arm"]["stretch_gripper"]["pos"],
|
| 131 |
+
"base_x.vel": status["base"]["x_vel"],
|
| 132 |
+
"base_y.vel": status["base"]["y_vel"],
|
| 133 |
+
"base_theta.vel": status["base"]["theta_vel"],
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
def capture_observation(self) -> dict:
|
| 137 |
+
# TODO(aliberts): return ndarrays instead of torch.Tensors
|
| 138 |
+
before_read_t = time.perf_counter()
|
| 139 |
+
state = self.get_state()
|
| 140 |
+
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
| 141 |
+
|
| 142 |
+
if self.state_keys is None:
|
| 143 |
+
self.state_keys = list(state)
|
| 144 |
+
|
| 145 |
+
state = torch.as_tensor(list(state.values()))
|
| 146 |
+
|
| 147 |
+
# Capture images from cameras
|
| 148 |
+
images = {}
|
| 149 |
+
for name in self.cameras:
|
| 150 |
+
before_camread_t = time.perf_counter()
|
| 151 |
+
images[name] = self.cameras[name].async_read()
|
| 152 |
+
images[name] = torch.from_numpy(images[name])
|
| 153 |
+
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
| 154 |
+
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
| 155 |
+
|
| 156 |
+
# Populate output dictionaries
|
| 157 |
+
obs_dict = {}
|
| 158 |
+
obs_dict["observation.state"] = state
|
| 159 |
+
for name in self.cameras:
|
| 160 |
+
obs_dict[f"observation.images.{name}"] = images[name]
|
| 161 |
+
|
| 162 |
+
return obs_dict
|
| 163 |
+
|
| 164 |
+
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
| 165 |
+
# TODO(aliberts): return ndarrays instead of torch.Tensors
|
| 166 |
+
if not self.is_connected:
|
| 167 |
+
raise ConnectionError()
|
| 168 |
+
|
| 169 |
+
if self.teleop is None:
|
| 170 |
+
self.teleop = GamePadTeleop(robot_instance=False)
|
| 171 |
+
self.teleop.startup(robot=self)
|
| 172 |
+
|
| 173 |
+
if self.action_keys is None:
|
| 174 |
+
dummy_action = self.teleop.gamepad_controller.get_state()
|
| 175 |
+
self.action_keys = list(dummy_action.keys())
|
| 176 |
+
|
| 177 |
+
action_dict = dict(zip(self.action_keys, action.tolist(), strict=True))
|
| 178 |
+
|
| 179 |
+
before_write_t = time.perf_counter()
|
| 180 |
+
self.teleop.do_motion(state=action_dict, robot=self)
|
| 181 |
+
self.push_command()
|
| 182 |
+
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
| 183 |
+
|
| 184 |
+
# TODO(aliberts): return action_sent when motion is limited
|
| 185 |
+
return action
|
| 186 |
+
|
| 187 |
+
def print_logs(self) -> None:
|
| 188 |
+
pass
|
| 189 |
+
# TODO(aliberts): move robot-specific logs logic here
|
| 190 |
+
|
| 191 |
+
def teleop_safety_stop(self) -> None:
|
| 192 |
+
if self.teleop is not None:
|
| 193 |
+
self.teleop._safety_stop(robot=self)
|
| 194 |
+
|
| 195 |
+
def disconnect(self) -> None:
|
| 196 |
+
self.stop()
|
| 197 |
+
if self.teleop is not None:
|
| 198 |
+
self.teleop.gamepad_controller.stop()
|
| 199 |
+
self.teleop.stop()
|
| 200 |
+
|
| 201 |
+
if len(self.cameras) > 0:
|
| 202 |
+
for cam in self.cameras.values():
|
| 203 |
+
cam.disconnect()
|
| 204 |
+
|
| 205 |
+
self.is_connected = False
|
| 206 |
+
|
| 207 |
+
def __del__(self):
|
| 208 |
+
self.disconnect()
|
lerobot/common/robot_devices/robots/utils.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Protocol
|
| 16 |
+
|
| 17 |
+
from lerobot.common.robot_devices.robots.configs import (
|
| 18 |
+
AlohaRobotConfig,
|
| 19 |
+
KochBimanualRobotConfig,
|
| 20 |
+
KochRobotConfig,
|
| 21 |
+
LeKiwiRobotConfig,
|
| 22 |
+
ManipulatorRobotConfig,
|
| 23 |
+
MossRobotConfig,
|
| 24 |
+
RobotConfig,
|
| 25 |
+
So100RobotConfig,
|
| 26 |
+
StretchRobotConfig,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_arm_id(name, arm_type):
|
| 31 |
+
"""Returns the string identifier of a robot arm. For instance, for a bimanual manipulator
|
| 32 |
+
like Aloha, it could be left_follower, right_follower, left_leader, or right_leader.
|
| 33 |
+
"""
|
| 34 |
+
return f"{name}_{arm_type}"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Robot(Protocol):
|
| 38 |
+
# TODO(rcadene, aliberts): Add unit test checking the protocol is implemented in the corresponding classes
|
| 39 |
+
robot_type: str
|
| 40 |
+
features: dict
|
| 41 |
+
|
| 42 |
+
def connect(self): ...
|
| 43 |
+
def run_calibration(self): ...
|
| 44 |
+
def teleop_step(self, record_data=False): ...
|
| 45 |
+
def capture_observation(self): ...
|
| 46 |
+
def send_action(self, action): ...
|
| 47 |
+
def disconnect(self): ...
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
|
| 51 |
+
if robot_type == "aloha":
|
| 52 |
+
return AlohaRobotConfig(**kwargs)
|
| 53 |
+
elif robot_type == "koch":
|
| 54 |
+
return KochRobotConfig(**kwargs)
|
| 55 |
+
elif robot_type == "koch_bimanual":
|
| 56 |
+
return KochBimanualRobotConfig(**kwargs)
|
| 57 |
+
elif robot_type == "moss":
|
| 58 |
+
return MossRobotConfig(**kwargs)
|
| 59 |
+
elif robot_type == "so100":
|
| 60 |
+
return So100RobotConfig(**kwargs)
|
| 61 |
+
elif robot_type == "stretch":
|
| 62 |
+
return StretchRobotConfig(**kwargs)
|
| 63 |
+
elif robot_type == "lekiwi":
|
| 64 |
+
return LeKiwiRobotConfig(**kwargs)
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError(f"Robot type '{robot_type}' is not available.")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def make_robot_from_config(config: RobotConfig):
|
| 70 |
+
if isinstance(config, ManipulatorRobotConfig):
|
| 71 |
+
from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
|
| 72 |
+
|
| 73 |
+
return ManipulatorRobot(config)
|
| 74 |
+
elif isinstance(config, LeKiwiRobotConfig):
|
| 75 |
+
from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator
|
| 76 |
+
|
| 77 |
+
return MobileManipulator(config)
|
| 78 |
+
else:
|
| 79 |
+
from lerobot.common.robot_devices.robots.stretch import StretchRobot
|
| 80 |
+
|
| 81 |
+
return StretchRobot(config)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def make_robot(robot_type: str, **kwargs) -> Robot:
|
| 85 |
+
config = make_robot_config(robot_type, **kwargs)
|
| 86 |
+
return make_robot_from_config(config)
|
lerobot/common/robot_devices/utils.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import platform
|
| 16 |
+
import time
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def busy_wait(seconds):
|
| 20 |
+
if platform.system() == "Darwin":
|
| 21 |
+
# On Mac, `time.sleep` is not accurate and we need to use this while loop trick,
|
| 22 |
+
# but it consumes CPU cycles.
|
| 23 |
+
# TODO(rcadene): find an alternative: from python 11, time.sleep is precise
|
| 24 |
+
end_time = time.perf_counter() + seconds
|
| 25 |
+
while time.perf_counter() < end_time:
|
| 26 |
+
pass
|
| 27 |
+
else:
|
| 28 |
+
# On Linux time.sleep is accurate
|
| 29 |
+
if seconds > 0:
|
| 30 |
+
time.sleep(seconds)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def safe_disconnect(func):
|
| 34 |
+
# TODO(aliberts): Allow to pass custom exceptions
|
| 35 |
+
# (e.g. ThreadServiceExit, KeyboardInterrupt, SystemExit, UnpluggedError, DynamixelCommError)
|
| 36 |
+
def wrapper(robot, *args, **kwargs):
|
| 37 |
+
try:
|
| 38 |
+
return func(robot, *args, **kwargs)
|
| 39 |
+
except Exception as e:
|
| 40 |
+
if robot.is_connected:
|
| 41 |
+
robot.disconnect()
|
| 42 |
+
raise e
|
| 43 |
+
|
| 44 |
+
return wrapper
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class RobotDeviceNotConnectedError(Exception):
|
| 48 |
+
"""Exception raised when the robot device is not connected."""
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self, message="This robot device is not connected. Try calling `robot_device.connect()` first."
|
| 52 |
+
):
|
| 53 |
+
self.message = message
|
| 54 |
+
super().__init__(self.message)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class RobotDeviceAlreadyConnectedError(Exception):
|
| 58 |
+
"""Exception raised when the robot device is already connected."""
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
message="This robot device is already connected. Try not calling `robot_device.connect()` twice.",
|
| 63 |
+
):
|
| 64 |
+
self.message = message
|
| 65 |
+
super().__init__(self.message)
|
lerobot/common/utils/benchmark.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
import threading
|
| 17 |
+
import time
|
| 18 |
+
from contextlib import ContextDecorator
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class TimeBenchmark(ContextDecorator):
|
| 22 |
+
"""
|
| 23 |
+
Measures execution time using a context manager or decorator.
|
| 24 |
+
|
| 25 |
+
This class supports both context manager and decorator usage, and is thread-safe for multithreaded
|
| 26 |
+
environments.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
print: If True, prints the elapsed time upon exiting the context or completing the function. Defaults
|
| 30 |
+
to False.
|
| 31 |
+
|
| 32 |
+
Examples:
|
| 33 |
+
|
| 34 |
+
Using as a context manager:
|
| 35 |
+
|
| 36 |
+
>>> benchmark = TimeBenchmark()
|
| 37 |
+
>>> with benchmark:
|
| 38 |
+
... time.sleep(1)
|
| 39 |
+
>>> print(f"Block took {benchmark.result:.4f} seconds")
|
| 40 |
+
Block took approximately 1.0000 seconds
|
| 41 |
+
|
| 42 |
+
Using with multithreading:
|
| 43 |
+
|
| 44 |
+
```python
|
| 45 |
+
import threading
|
| 46 |
+
|
| 47 |
+
benchmark = TimeBenchmark()
|
| 48 |
+
|
| 49 |
+
def context_manager_example():
|
| 50 |
+
with benchmark:
|
| 51 |
+
time.sleep(0.01)
|
| 52 |
+
print(f"Block took {benchmark.result_ms:.2f} milliseconds")
|
| 53 |
+
|
| 54 |
+
threads = []
|
| 55 |
+
for _ in range(3):
|
| 56 |
+
t1 = threading.Thread(target=context_manager_example)
|
| 57 |
+
threads.append(t1)
|
| 58 |
+
|
| 59 |
+
for t in threads:
|
| 60 |
+
t.start()
|
| 61 |
+
|
| 62 |
+
for t in threads:
|
| 63 |
+
t.join()
|
| 64 |
+
```
|
| 65 |
+
Expected output:
|
| 66 |
+
Block took approximately 10.00 milliseconds
|
| 67 |
+
Block took approximately 10.00 milliseconds
|
| 68 |
+
Block took approximately 10.00 milliseconds
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(self, print=False):
|
| 72 |
+
self.local = threading.local()
|
| 73 |
+
self.print_time = print
|
| 74 |
+
|
| 75 |
+
def __enter__(self):
|
| 76 |
+
self.local.start_time = time.perf_counter()
|
| 77 |
+
return self
|
| 78 |
+
|
| 79 |
+
def __exit__(self, *exc):
|
| 80 |
+
self.local.end_time = time.perf_counter()
|
| 81 |
+
self.local.elapsed_time = self.local.end_time - self.local.start_time
|
| 82 |
+
if self.print_time:
|
| 83 |
+
print(f"Elapsed time: {self.local.elapsed_time:.4f} seconds")
|
| 84 |
+
return False
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def result(self):
|
| 88 |
+
return getattr(self.local, "elapsed_time", None)
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def result_ms(self):
|
| 92 |
+
return self.result * 1e3
|
lerobot/common/utils/hub.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from tempfile import TemporaryDirectory
|
| 17 |
+
from typing import Any, Type, TypeVar
|
| 18 |
+
|
| 19 |
+
from huggingface_hub import HfApi
|
| 20 |
+
from huggingface_hub.utils import validate_hf_hub_args
|
| 21 |
+
|
| 22 |
+
T = TypeVar("T", bound="HubMixin")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class HubMixin:
|
| 26 |
+
"""
|
| 27 |
+
A Mixin containing the functionality to push an object to the hub.
|
| 28 |
+
|
| 29 |
+
This is similar to huggingface_hub.ModelHubMixin but is lighter and makes less assumptions about its
|
| 30 |
+
subclasses (in particular, the fact that it's not necessarily a model).
|
| 31 |
+
|
| 32 |
+
The inheriting classes must implement '_save_pretrained' and 'from_pretrained'.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def save_pretrained(
|
| 36 |
+
self,
|
| 37 |
+
save_directory: str | Path,
|
| 38 |
+
*,
|
| 39 |
+
repo_id: str | None = None,
|
| 40 |
+
push_to_hub: bool = False,
|
| 41 |
+
card_kwargs: dict[str, Any] | None = None,
|
| 42 |
+
**push_to_hub_kwargs,
|
| 43 |
+
) -> str | None:
|
| 44 |
+
"""
|
| 45 |
+
Save object in local directory.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
save_directory (`str` or `Path`):
|
| 49 |
+
Path to directory in which the object will be saved.
|
| 50 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 51 |
+
Whether or not to push your object to the Huggingface Hub after saving it.
|
| 52 |
+
repo_id (`str`, *optional*):
|
| 53 |
+
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
|
| 54 |
+
not provided.
|
| 55 |
+
card_kwargs (`Dict[str, Any]`, *optional*):
|
| 56 |
+
Additional arguments passed to the card template to customize the card.
|
| 57 |
+
push_to_hub_kwargs:
|
| 58 |
+
Additional key word arguments passed along to the [`~HubMixin.push_to_hub`] method.
|
| 59 |
+
Returns:
|
| 60 |
+
`str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
|
| 61 |
+
"""
|
| 62 |
+
save_directory = Path(save_directory)
|
| 63 |
+
save_directory.mkdir(parents=True, exist_ok=True)
|
| 64 |
+
|
| 65 |
+
# save object (weights, files, etc.)
|
| 66 |
+
self._save_pretrained(save_directory)
|
| 67 |
+
|
| 68 |
+
# push to the Hub if required
|
| 69 |
+
if push_to_hub:
|
| 70 |
+
if repo_id is None:
|
| 71 |
+
repo_id = save_directory.name # Defaults to `save_directory` name
|
| 72 |
+
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
def _save_pretrained(self, save_directory: Path) -> None:
|
| 76 |
+
"""
|
| 77 |
+
Overwrite this method in subclass to define how to save your object.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
save_directory (`str` or `Path`):
|
| 81 |
+
Path to directory in which the object files will be saved.
|
| 82 |
+
"""
|
| 83 |
+
raise NotImplementedError
|
| 84 |
+
|
| 85 |
+
@classmethod
|
| 86 |
+
@validate_hf_hub_args
|
| 87 |
+
def from_pretrained(
|
| 88 |
+
cls: Type[T],
|
| 89 |
+
pretrained_name_or_path: str | Path,
|
| 90 |
+
*,
|
| 91 |
+
force_download: bool = False,
|
| 92 |
+
resume_download: bool | None = None,
|
| 93 |
+
proxies: dict | None = None,
|
| 94 |
+
token: str | bool | None = None,
|
| 95 |
+
cache_dir: str | Path | None = None,
|
| 96 |
+
local_files_only: bool = False,
|
| 97 |
+
revision: str | None = None,
|
| 98 |
+
**kwargs,
|
| 99 |
+
) -> T:
|
| 100 |
+
"""
|
| 101 |
+
Download the object from the Huggingface Hub and instantiate it.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
pretrained_name_or_path (`str`, `Path`):
|
| 105 |
+
- Either the `repo_id` (string) of the object hosted on the Hub, e.g. `lerobot/diffusion_pusht`.
|
| 106 |
+
- Or a path to a `directory` containing the object files saved using `.save_pretrained`,
|
| 107 |
+
e.g., `../path/to/my_model_directory/`.
|
| 108 |
+
revision (`str`, *optional*):
|
| 109 |
+
Revision on the Hub. Can be a branch name, a git tag or any commit id.
|
| 110 |
+
Defaults to the latest commit on `main` branch.
|
| 111 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 112 |
+
Whether to force (re-)downloading the files from the Hub, overriding the existing cache.
|
| 113 |
+
proxies (`Dict[str, str]`, *optional*):
|
| 114 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 115 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
|
| 116 |
+
token (`str` or `bool`, *optional*):
|
| 117 |
+
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
| 118 |
+
cached when running `huggingface-cli login`.
|
| 119 |
+
cache_dir (`str`, `Path`, *optional*):
|
| 120 |
+
Path to the folder where cached files are stored.
|
| 121 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
| 122 |
+
If `True`, avoid downloading the file and return the path to the local cached file if it exists.
|
| 123 |
+
kwargs (`Dict`, *optional*):
|
| 124 |
+
Additional kwargs to pass to the object during initialization.
|
| 125 |
+
"""
|
| 126 |
+
raise NotImplementedError
|
| 127 |
+
|
| 128 |
+
@validate_hf_hub_args
|
| 129 |
+
def push_to_hub(
|
| 130 |
+
self,
|
| 131 |
+
repo_id: str,
|
| 132 |
+
*,
|
| 133 |
+
commit_message: str | None = None,
|
| 134 |
+
private: bool | None = None,
|
| 135 |
+
token: str | None = None,
|
| 136 |
+
branch: str | None = None,
|
| 137 |
+
create_pr: bool | None = None,
|
| 138 |
+
allow_patterns: list[str] | str | None = None,
|
| 139 |
+
ignore_patterns: list[str] | str | None = None,
|
| 140 |
+
delete_patterns: list[str] | str | None = None,
|
| 141 |
+
card_kwargs: dict[str, Any] | None = None,
|
| 142 |
+
) -> str:
|
| 143 |
+
"""
|
| 144 |
+
Upload model checkpoint to the Hub.
|
| 145 |
+
|
| 146 |
+
Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
|
| 147 |
+
`delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
|
| 148 |
+
details.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
repo_id (`str`):
|
| 152 |
+
ID of the repository to push to (example: `"username/my-model"`).
|
| 153 |
+
commit_message (`str`, *optional*):
|
| 154 |
+
Message to commit while pushing.
|
| 155 |
+
private (`bool`, *optional*):
|
| 156 |
+
Whether the repository created should be private.
|
| 157 |
+
If `None` (default), the repo will be public unless the organization's default is private.
|
| 158 |
+
token (`str`, *optional*):
|
| 159 |
+
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
| 160 |
+
cached when running `huggingface-cli login`.
|
| 161 |
+
branch (`str`, *optional*):
|
| 162 |
+
The git branch on which to push the model. This defaults to `"main"`.
|
| 163 |
+
create_pr (`boolean`, *optional*):
|
| 164 |
+
Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
|
| 165 |
+
allow_patterns (`List[str]` or `str`, *optional*):
|
| 166 |
+
If provided, only files matching at least one pattern are pushed.
|
| 167 |
+
ignore_patterns (`List[str]` or `str`, *optional*):
|
| 168 |
+
If provided, files matching any of the patterns are not pushed.
|
| 169 |
+
delete_patterns (`List[str]` or `str`, *optional*):
|
| 170 |
+
If provided, remote files matching any of the patterns will be deleted from the repo.
|
| 171 |
+
card_kwargs (`Dict[str, Any]`, *optional*):
|
| 172 |
+
Additional arguments passed to the card template to customize the card.
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
The url of the commit of your object in the given repository.
|
| 176 |
+
"""
|
| 177 |
+
api = HfApi(token=token)
|
| 178 |
+
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
|
| 179 |
+
|
| 180 |
+
if commit_message is None:
|
| 181 |
+
if "Policy" in self.__class__.__name__:
|
| 182 |
+
commit_message = "Upload policy"
|
| 183 |
+
elif "Config" in self.__class__.__name__:
|
| 184 |
+
commit_message = "Upload config"
|
| 185 |
+
else:
|
| 186 |
+
commit_message = f"Upload {self.__class__.__name__}"
|
| 187 |
+
|
| 188 |
+
# Push the files to the repo in a single commit
|
| 189 |
+
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
|
| 190 |
+
saved_path = Path(tmp) / repo_id
|
| 191 |
+
self.save_pretrained(saved_path, card_kwargs=card_kwargs)
|
| 192 |
+
return api.upload_folder(
|
| 193 |
+
repo_id=repo_id,
|
| 194 |
+
repo_type="model",
|
| 195 |
+
folder_path=saved_path,
|
| 196 |
+
commit_message=commit_message,
|
| 197 |
+
revision=branch,
|
| 198 |
+
create_pr=create_pr,
|
| 199 |
+
allow_patterns=allow_patterns,
|
| 200 |
+
ignore_patterns=ignore_patterns,
|
| 201 |
+
delete_patterns=delete_patterns,
|
| 202 |
+
)
|
lerobot/common/utils/import_utils.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
import importlib
|
| 17 |
+
import logging
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
|
| 21 |
+
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
|
| 22 |
+
Check if the package spec exists and grab its version to avoid importing a local directory.
|
| 23 |
+
**Note:** this doesn't work for all packages.
|
| 24 |
+
"""
|
| 25 |
+
package_exists = importlib.util.find_spec(pkg_name) is not None
|
| 26 |
+
package_version = "N/A"
|
| 27 |
+
if package_exists:
|
| 28 |
+
try:
|
| 29 |
+
# Primary method to get the package version
|
| 30 |
+
package_version = importlib.metadata.version(pkg_name)
|
| 31 |
+
except importlib.metadata.PackageNotFoundError:
|
| 32 |
+
# Fallback method: Only for "torch" and versions containing "dev"
|
| 33 |
+
if pkg_name == "torch":
|
| 34 |
+
try:
|
| 35 |
+
package = importlib.import_module(pkg_name)
|
| 36 |
+
temp_version = getattr(package, "__version__", "N/A")
|
| 37 |
+
# Check if the version contains "dev"
|
| 38 |
+
if "dev" in temp_version:
|
| 39 |
+
package_version = temp_version
|
| 40 |
+
package_exists = True
|
| 41 |
+
else:
|
| 42 |
+
package_exists = False
|
| 43 |
+
except ImportError:
|
| 44 |
+
# If the package can't be imported, it's not available
|
| 45 |
+
package_exists = False
|
| 46 |
+
else:
|
| 47 |
+
# For packages other than "torch", don't attempt the fallback and set as not available
|
| 48 |
+
package_exists = False
|
| 49 |
+
logging.debug(f"Detected {pkg_name} version: {package_version}")
|
| 50 |
+
if return_version:
|
| 51 |
+
return package_exists, package_version
|
| 52 |
+
else:
|
| 53 |
+
return package_exists
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
_torch_available, _torch_version = is_package_available("torch", return_version=True)
|
| 57 |
+
_gym_xarm_available = is_package_available("gym_xarm")
|
| 58 |
+
_gym_aloha_available = is_package_available("gym_aloha")
|
| 59 |
+
_gym_pusht_available = is_package_available("gym_pusht")
|
lerobot/common/utils/io_utils.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
import json
|
| 17 |
+
import warnings
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import TypeVar
|
| 20 |
+
|
| 21 |
+
import imageio
|
| 22 |
+
|
| 23 |
+
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
|
| 24 |
+
T = TypeVar("T", bound=JsonLike)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def write_video(video_path, stacked_frames, fps):
|
| 28 |
+
# Filter out DeprecationWarnings raised from pkg_resources
|
| 29 |
+
with warnings.catch_warnings():
|
| 30 |
+
warnings.filterwarnings(
|
| 31 |
+
"ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
|
| 32 |
+
)
|
| 33 |
+
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def deserialize_json_into_object(fpath: Path, obj: T) -> T:
|
| 37 |
+
"""
|
| 38 |
+
Loads the JSON data from `fpath` and recursively fills `obj` with the
|
| 39 |
+
corresponding values (strictly matching structure and types).
|
| 40 |
+
Tuples in `obj` are expected to be lists in the JSON data, which will be
|
| 41 |
+
converted back into tuples.
|
| 42 |
+
"""
|
| 43 |
+
with open(fpath, encoding="utf-8") as f:
|
| 44 |
+
data = json.load(f)
|
| 45 |
+
|
| 46 |
+
def _deserialize(target, source):
|
| 47 |
+
"""
|
| 48 |
+
Recursively overwrite the structure in `target` with data from `source`,
|
| 49 |
+
performing strict checks on structure and type.
|
| 50 |
+
Returns the updated version of `target` (especially important for tuples).
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
# If the target is a dictionary, source must be a dictionary as well.
|
| 54 |
+
if isinstance(target, dict):
|
| 55 |
+
if not isinstance(source, dict):
|
| 56 |
+
raise TypeError(f"Type mismatch: expected dict, got {type(source)}")
|
| 57 |
+
|
| 58 |
+
# Check that they have exactly the same set of keys.
|
| 59 |
+
if target.keys() != source.keys():
|
| 60 |
+
raise ValueError(
|
| 61 |
+
f"Dictionary keys do not match.\nExpected: {target.keys()}, got: {source.keys()}"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Recursively update each key.
|
| 65 |
+
for k in target:
|
| 66 |
+
target[k] = _deserialize(target[k], source[k])
|
| 67 |
+
|
| 68 |
+
return target
|
| 69 |
+
|
| 70 |
+
# If the target is a list, source must be a list as well.
|
| 71 |
+
elif isinstance(target, list):
|
| 72 |
+
if not isinstance(source, list):
|
| 73 |
+
raise TypeError(f"Type mismatch: expected list, got {type(source)}")
|
| 74 |
+
|
| 75 |
+
# Check length
|
| 76 |
+
if len(target) != len(source):
|
| 77 |
+
raise ValueError(f"List length mismatch: expected {len(target)}, got {len(source)}")
|
| 78 |
+
|
| 79 |
+
# Recursively update each element.
|
| 80 |
+
for i in range(len(target)):
|
| 81 |
+
target[i] = _deserialize(target[i], source[i])
|
| 82 |
+
|
| 83 |
+
return target
|
| 84 |
+
|
| 85 |
+
# If the target is a tuple, the source must be a list in JSON,
|
| 86 |
+
# which we'll convert back to a tuple.
|
| 87 |
+
elif isinstance(target, tuple):
|
| 88 |
+
if not isinstance(source, list):
|
| 89 |
+
raise TypeError(f"Type mismatch: expected list (for tuple), got {type(source)}")
|
| 90 |
+
|
| 91 |
+
if len(target) != len(source):
|
| 92 |
+
raise ValueError(f"Tuple length mismatch: expected {len(target)}, got {len(source)}")
|
| 93 |
+
|
| 94 |
+
# Convert each element, forming a new tuple.
|
| 95 |
+
converted_items = []
|
| 96 |
+
for t_item, s_item in zip(target, source, strict=False):
|
| 97 |
+
converted_items.append(_deserialize(t_item, s_item))
|
| 98 |
+
|
| 99 |
+
# Return a brand new tuple (tuples are immutable in Python).
|
| 100 |
+
return tuple(converted_items)
|
| 101 |
+
|
| 102 |
+
# Otherwise, we're dealing with a "primitive" (int, float, str, bool, None).
|
| 103 |
+
else:
|
| 104 |
+
# Check the exact type. If these must match 1:1, do:
|
| 105 |
+
if type(target) is not type(source):
|
| 106 |
+
raise TypeError(f"Type mismatch: expected {type(target)}, got {type(source)}")
|
| 107 |
+
return source
|
| 108 |
+
|
| 109 |
+
# Perform the in-place/recursive deserialization
|
| 110 |
+
updated_obj = _deserialize(obj, data)
|
| 111 |
+
return updated_obj
|
lerobot/common/utils/logging_utils.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
from lerobot.common.utils.utils import format_big_number
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class AverageMeter:
|
| 22 |
+
"""
|
| 23 |
+
Computes and stores the average and current value
|
| 24 |
+
Adapted from https://github.com/pytorch/examples/blob/main/imagenet/main.py
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, name: str, fmt: str = ":f"):
|
| 28 |
+
self.name = name
|
| 29 |
+
self.fmt = fmt
|
| 30 |
+
self.reset()
|
| 31 |
+
|
| 32 |
+
def reset(self) -> None:
|
| 33 |
+
self.val = 0.0
|
| 34 |
+
self.avg = 0.0
|
| 35 |
+
self.sum = 0.0
|
| 36 |
+
self.count = 0.0
|
| 37 |
+
|
| 38 |
+
def update(self, val: float, n: int = 1) -> None:
|
| 39 |
+
self.val = val
|
| 40 |
+
self.sum += val * n
|
| 41 |
+
self.count += n
|
| 42 |
+
self.avg = self.sum / self.count
|
| 43 |
+
|
| 44 |
+
def __str__(self):
|
| 45 |
+
fmtstr = "{name}:{avg" + self.fmt + "}"
|
| 46 |
+
return fmtstr.format(**self.__dict__)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class MetricsTracker:
|
| 50 |
+
"""
|
| 51 |
+
A helper class to track and log metrics over time.
|
| 52 |
+
|
| 53 |
+
Usage pattern:
|
| 54 |
+
|
| 55 |
+
```python
|
| 56 |
+
# initialize, potentially with non-zero initial step (e.g. if resuming run)
|
| 57 |
+
metrics = {"loss": AverageMeter("loss", ":.3f")}
|
| 58 |
+
train_metrics = MetricsTracker(cfg, dataset, metrics, initial_step=step)
|
| 59 |
+
|
| 60 |
+
# update metrics derived from step (samples, episodes, epochs) at each training step
|
| 61 |
+
train_metrics.step()
|
| 62 |
+
|
| 63 |
+
# update various metrics
|
| 64 |
+
loss = policy.forward(batch)
|
| 65 |
+
train_metrics.loss = loss
|
| 66 |
+
|
| 67 |
+
# display current metrics
|
| 68 |
+
logging.info(train_metrics)
|
| 69 |
+
|
| 70 |
+
# export for wandb
|
| 71 |
+
wandb.log(train_metrics.to_dict())
|
| 72 |
+
|
| 73 |
+
# reset averages after logging
|
| 74 |
+
train_metrics.reset_averages()
|
| 75 |
+
```
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
__keys__ = [
|
| 79 |
+
"_batch_size",
|
| 80 |
+
"_num_frames",
|
| 81 |
+
"_avg_samples_per_ep",
|
| 82 |
+
"metrics",
|
| 83 |
+
"steps",
|
| 84 |
+
"samples",
|
| 85 |
+
"episodes",
|
| 86 |
+
"epochs",
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
batch_size: int,
|
| 92 |
+
num_frames: int,
|
| 93 |
+
num_episodes: int,
|
| 94 |
+
metrics: dict[str, AverageMeter],
|
| 95 |
+
initial_step: int = 0,
|
| 96 |
+
):
|
| 97 |
+
self.__dict__.update({k: None for k in self.__keys__})
|
| 98 |
+
self._batch_size = batch_size
|
| 99 |
+
self._num_frames = num_frames
|
| 100 |
+
self._avg_samples_per_ep = num_frames / num_episodes
|
| 101 |
+
self.metrics = metrics
|
| 102 |
+
|
| 103 |
+
self.steps = initial_step
|
| 104 |
+
# A sample is an (observation,action) pair, where observation and action
|
| 105 |
+
# can be on multiple timestamps. In a batch, we have `batch_size` number of samples.
|
| 106 |
+
self.samples = self.steps * self._batch_size
|
| 107 |
+
self.episodes = self.samples / self._avg_samples_per_ep
|
| 108 |
+
self.epochs = self.samples / self._num_frames
|
| 109 |
+
|
| 110 |
+
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
|
| 111 |
+
if name in self.__dict__:
|
| 112 |
+
return self.__dict__[name]
|
| 113 |
+
elif name in self.metrics:
|
| 114 |
+
return self.metrics[name]
|
| 115 |
+
else:
|
| 116 |
+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
| 117 |
+
|
| 118 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
| 119 |
+
if name in self.__dict__:
|
| 120 |
+
super().__setattr__(name, value)
|
| 121 |
+
elif name in self.metrics:
|
| 122 |
+
self.metrics[name].update(value)
|
| 123 |
+
else:
|
| 124 |
+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
| 125 |
+
|
| 126 |
+
def step(self) -> None:
|
| 127 |
+
"""
|
| 128 |
+
Updates metrics that depend on 'step' for one step.
|
| 129 |
+
"""
|
| 130 |
+
self.steps += 1
|
| 131 |
+
self.samples += self._batch_size
|
| 132 |
+
self.episodes = self.samples / self._avg_samples_per_ep
|
| 133 |
+
self.epochs = self.samples / self._num_frames
|
| 134 |
+
|
| 135 |
+
def __str__(self) -> str:
|
| 136 |
+
display_list = [
|
| 137 |
+
f"step:{format_big_number(self.steps)}",
|
| 138 |
+
# number of samples seen during training
|
| 139 |
+
f"smpl:{format_big_number(self.samples)}",
|
| 140 |
+
# number of episodes seen during training
|
| 141 |
+
f"ep:{format_big_number(self.episodes)}",
|
| 142 |
+
# number of time all unique samples are seen
|
| 143 |
+
f"epch:{self.epochs:.2f}",
|
| 144 |
+
*[str(m) for m in self.metrics.values()],
|
| 145 |
+
]
|
| 146 |
+
return " ".join(display_list)
|
| 147 |
+
|
| 148 |
+
def to_dict(self, use_avg: bool = True) -> dict[str, int | float]:
|
| 149 |
+
"""
|
| 150 |
+
Returns the current metric values (or averages if `use_avg=True`) as a dict.
|
| 151 |
+
"""
|
| 152 |
+
return {
|
| 153 |
+
"steps": self.steps,
|
| 154 |
+
"samples": self.samples,
|
| 155 |
+
"episodes": self.episodes,
|
| 156 |
+
"epochs": self.epochs,
|
| 157 |
+
**{k: m.avg if use_avg else m.val for k, m in self.metrics.items()},
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
def reset_averages(self) -> None:
|
| 161 |
+
"""Resets average meters."""
|
| 162 |
+
for m in self.metrics.values():
|
| 163 |
+
m.reset()
|
lerobot/common/utils/random_utils.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
import random
|
| 17 |
+
from contextlib import contextmanager
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Any, Generator
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
from safetensors.torch import load_file, save_file
|
| 24 |
+
|
| 25 |
+
from lerobot.common.constants import RNG_STATE
|
| 26 |
+
from lerobot.common.datasets.utils import flatten_dict, unflatten_dict
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def serialize_python_rng_state() -> dict[str, torch.Tensor]:
|
| 30 |
+
"""
|
| 31 |
+
Returns the rng state for `random` in the form of a flat dict[str, torch.Tensor] to be saved using
|
| 32 |
+
`safetensors.save_file()` or `torch.save()`.
|
| 33 |
+
"""
|
| 34 |
+
py_state = random.getstate()
|
| 35 |
+
return {
|
| 36 |
+
"py_rng_version": torch.tensor([py_state[0]], dtype=torch.int64),
|
| 37 |
+
"py_rng_state": torch.tensor(py_state[1], dtype=torch.int64),
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
| 42 |
+
"""
|
| 43 |
+
Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`.
|
| 44 |
+
"""
|
| 45 |
+
py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None)
|
| 46 |
+
random.setstate(py_state)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def serialize_numpy_rng_state() -> dict[str, torch.Tensor]:
|
| 50 |
+
"""
|
| 51 |
+
Returns the rng state for `numpy` in the form of a flat dict[str, torch.Tensor] to be saved using
|
| 52 |
+
`safetensors.save_file()` or `torch.save()`.
|
| 53 |
+
"""
|
| 54 |
+
np_state = np.random.get_state()
|
| 55 |
+
# Ensure no breaking changes from numpy
|
| 56 |
+
assert np_state[0] == "MT19937"
|
| 57 |
+
return {
|
| 58 |
+
"np_rng_state_values": torch.tensor(np_state[1], dtype=torch.int64),
|
| 59 |
+
"np_rng_state_index": torch.tensor([np_state[2]], dtype=torch.int64),
|
| 60 |
+
"np_rng_has_gauss": torch.tensor([np_state[3]], dtype=torch.int64),
|
| 61 |
+
"np_rng_cached_gaussian": torch.tensor([np_state[4]], dtype=torch.float32),
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def deserialize_numpy_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
| 66 |
+
"""
|
| 67 |
+
Restores the rng state for `numpy` from a dictionary produced by `serialize_numpy_rng_state()`.
|
| 68 |
+
"""
|
| 69 |
+
np_state = (
|
| 70 |
+
"MT19937",
|
| 71 |
+
rng_state_dict["np_rng_state_values"].numpy(),
|
| 72 |
+
rng_state_dict["np_rng_state_index"].item(),
|
| 73 |
+
rng_state_dict["np_rng_has_gauss"].item(),
|
| 74 |
+
rng_state_dict["np_rng_cached_gaussian"].item(),
|
| 75 |
+
)
|
| 76 |
+
np.random.set_state(np_state)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def serialize_torch_rng_state() -> dict[str, torch.Tensor]:
|
| 80 |
+
"""
|
| 81 |
+
Returns the rng state for `torch` in the form of a flat dict[str, torch.Tensor] to be saved using
|
| 82 |
+
`safetensors.save_file()` or `torch.save()`.
|
| 83 |
+
"""
|
| 84 |
+
torch_rng_state_dict = {"torch_rng_state": torch.get_rng_state()}
|
| 85 |
+
if torch.cuda.is_available():
|
| 86 |
+
torch_rng_state_dict["torch_cuda_rng_state"] = torch.cuda.get_rng_state()
|
| 87 |
+
return torch_rng_state_dict
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def deserialize_torch_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
| 91 |
+
"""
|
| 92 |
+
Restores the rng state for `torch` from a dictionary produced by `serialize_torch_rng_state()`.
|
| 93 |
+
"""
|
| 94 |
+
torch.set_rng_state(rng_state_dict["torch_rng_state"])
|
| 95 |
+
if torch.cuda.is_available() and "torch_cuda_rng_state" in rng_state_dict:
|
| 96 |
+
torch.cuda.set_rng_state(rng_state_dict["torch_cuda_rng_state"])
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def serialize_rng_state() -> dict[str, torch.Tensor]:
|
| 100 |
+
"""
|
| 101 |
+
Returns the rng state for `random`, `numpy`, and `torch`, in the form of a flat
|
| 102 |
+
dict[str, torch.Tensor] to be saved using `safetensors.save_file()` `torch.save()`.
|
| 103 |
+
"""
|
| 104 |
+
py_rng_state_dict = serialize_python_rng_state()
|
| 105 |
+
np_rng_state_dict = serialize_numpy_rng_state()
|
| 106 |
+
torch_rng_state_dict = serialize_torch_rng_state()
|
| 107 |
+
|
| 108 |
+
return {
|
| 109 |
+
**py_rng_state_dict,
|
| 110 |
+
**np_rng_state_dict,
|
| 111 |
+
**torch_rng_state_dict,
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
| 116 |
+
"""
|
| 117 |
+
Restores the rng state for `random`, `numpy`, and `torch` from a dictionary produced by
|
| 118 |
+
`serialize_rng_state()`.
|
| 119 |
+
"""
|
| 120 |
+
py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")}
|
| 121 |
+
np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")}
|
| 122 |
+
torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")}
|
| 123 |
+
|
| 124 |
+
deserialize_python_rng_state(py_rng_state_dict)
|
| 125 |
+
deserialize_numpy_rng_state(np_rng_state_dict)
|
| 126 |
+
deserialize_torch_rng_state(torch_rng_state_dict)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def save_rng_state(save_dir: Path) -> None:
|
| 130 |
+
rng_state_dict = serialize_rng_state()
|
| 131 |
+
flat_rng_state_dict = flatten_dict(rng_state_dict)
|
| 132 |
+
save_file(flat_rng_state_dict, save_dir / RNG_STATE)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def load_rng_state(save_dir: Path) -> None:
|
| 136 |
+
flat_rng_state_dict = load_file(save_dir / RNG_STATE)
|
| 137 |
+
rng_state_dict = unflatten_dict(flat_rng_state_dict)
|
| 138 |
+
deserialize_rng_state(rng_state_dict)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_rng_state() -> dict[str, Any]:
|
| 142 |
+
"""Get the random state for `random`, `numpy`, and `torch`."""
|
| 143 |
+
random_state_dict = {
|
| 144 |
+
"random_state": random.getstate(),
|
| 145 |
+
"numpy_random_state": np.random.get_state(),
|
| 146 |
+
"torch_random_state": torch.random.get_rng_state(),
|
| 147 |
+
}
|
| 148 |
+
if torch.cuda.is_available():
|
| 149 |
+
random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
|
| 150 |
+
return random_state_dict
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def set_rng_state(random_state_dict: dict[str, Any]):
|
| 154 |
+
"""Set the random state for `random`, `numpy`, and `torch`.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
random_state_dict: A dictionary of the form returned by `get_rng_state`.
|
| 158 |
+
"""
|
| 159 |
+
random.setstate(random_state_dict["random_state"])
|
| 160 |
+
np.random.set_state(random_state_dict["numpy_random_state"])
|
| 161 |
+
torch.random.set_rng_state(random_state_dict["torch_random_state"])
|
| 162 |
+
if torch.cuda.is_available():
|
| 163 |
+
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def set_seed(seed) -> None:
|
| 167 |
+
"""Set seed for reproducibility."""
|
| 168 |
+
random.seed(seed)
|
| 169 |
+
np.random.seed(seed)
|
| 170 |
+
torch.manual_seed(seed)
|
| 171 |
+
if torch.cuda.is_available():
|
| 172 |
+
torch.cuda.manual_seed_all(seed)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
@contextmanager
|
| 176 |
+
def seeded_context(seed: int) -> Generator[None, None, None]:
|
| 177 |
+
"""Set the seed when entering a context, and restore the prior random state at exit.
|
| 178 |
+
|
| 179 |
+
Example usage:
|
| 180 |
+
|
| 181 |
+
```
|
| 182 |
+
a = random.random() # produces some random number
|
| 183 |
+
with seeded_context(1337):
|
| 184 |
+
b = random.random() # produces some other random number
|
| 185 |
+
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
|
| 186 |
+
```
|
| 187 |
+
"""
|
| 188 |
+
random_state_dict = get_rng_state()
|
| 189 |
+
set_seed(seed)
|
| 190 |
+
yield None
|
| 191 |
+
set_rng_state(random_state_dict)
|
lerobot/common/utils/train_utils.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
import logging
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
from termcolor import colored
|
| 20 |
+
from torch.optim import Optimizer
|
| 21 |
+
from torch.optim.lr_scheduler import LRScheduler
|
| 22 |
+
|
| 23 |
+
from lerobot.common.constants import (
|
| 24 |
+
CHECKPOINTS_DIR,
|
| 25 |
+
LAST_CHECKPOINT_LINK,
|
| 26 |
+
PRETRAINED_MODEL_DIR,
|
| 27 |
+
TRAINING_STATE_DIR,
|
| 28 |
+
TRAINING_STEP,
|
| 29 |
+
)
|
| 30 |
+
from lerobot.common.datasets.utils import load_json, write_json
|
| 31 |
+
from lerobot.common.optim.optimizers import load_optimizer_state, save_optimizer_state
|
| 32 |
+
from lerobot.common.optim.schedulers import load_scheduler_state, save_scheduler_state
|
| 33 |
+
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
| 34 |
+
from lerobot.common.utils.random_utils import load_rng_state, save_rng_state
|
| 35 |
+
from lerobot.configs.train import TrainPipelineConfig
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def log_output_dir(out_dir):
|
| 39 |
+
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_step_identifier(step: int, total_steps: int) -> str:
|
| 43 |
+
num_digits = max(6, len(str(total_steps)))
|
| 44 |
+
return f"{step:0{num_digits}d}"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Path:
|
| 48 |
+
"""Returns the checkpoint sub-directory corresponding to the step number."""
|
| 49 |
+
step_identifier = get_step_identifier(step, total_steps)
|
| 50 |
+
return output_dir / CHECKPOINTS_DIR / step_identifier
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def save_training_step(step: int, save_dir: Path) -> None:
|
| 54 |
+
write_json({"step": step}, save_dir / TRAINING_STEP)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def load_training_step(save_dir: Path) -> int:
|
| 58 |
+
training_step = load_json(save_dir / TRAINING_STEP)
|
| 59 |
+
return training_step["step"]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
|
| 63 |
+
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
|
| 64 |
+
if last_checkpoint_dir.is_symlink():
|
| 65 |
+
last_checkpoint_dir.unlink()
|
| 66 |
+
relative_target = checkpoint_dir.relative_to(checkpoint_dir.parent)
|
| 67 |
+
last_checkpoint_dir.symlink_to(relative_target)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def save_checkpoint(
|
| 71 |
+
checkpoint_dir: Path,
|
| 72 |
+
step: int,
|
| 73 |
+
cfg: TrainPipelineConfig,
|
| 74 |
+
policy: PreTrainedPolicy,
|
| 75 |
+
optimizer: Optimizer,
|
| 76 |
+
scheduler: LRScheduler | None = None,
|
| 77 |
+
) -> None:
|
| 78 |
+
"""This function creates the following directory structure:
|
| 79 |
+
|
| 80 |
+
005000/ # training step at checkpoint
|
| 81 |
+
├── pretrained_model/
|
| 82 |
+
│ ├── config.json # policy config
|
| 83 |
+
│ ├── model.safetensors # policy weights
|
| 84 |
+
│ └── train_config.json # train config
|
| 85 |
+
└── training_state/
|
| 86 |
+
├── optimizer_param_groups.json # optimizer param groups
|
| 87 |
+
├── optimizer_state.safetensors # optimizer state
|
| 88 |
+
├── rng_state.safetensors # rng states
|
| 89 |
+
├── scheduler_state.json # scheduler state
|
| 90 |
+
└── training_step.json # training step
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
cfg (TrainPipelineConfig): The training config used for this run.
|
| 94 |
+
step (int): The training step at that checkpoint.
|
| 95 |
+
policy (PreTrainedPolicy): The policy to save.
|
| 96 |
+
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
|
| 97 |
+
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
| 98 |
+
"""
|
| 99 |
+
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
| 100 |
+
policy.save_pretrained(pretrained_dir)
|
| 101 |
+
cfg.save_pretrained(pretrained_dir)
|
| 102 |
+
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def save_training_state(
|
| 106 |
+
checkpoint_dir: Path,
|
| 107 |
+
train_step: int,
|
| 108 |
+
optimizer: Optimizer | None = None,
|
| 109 |
+
scheduler: LRScheduler | None = None,
|
| 110 |
+
) -> None:
|
| 111 |
+
"""
|
| 112 |
+
Saves the training step, optimizer state, scheduler state, and rng state.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
save_dir (Path): The directory to save artifacts to.
|
| 116 |
+
train_step (int): Current training step.
|
| 117 |
+
optimizer (Optimizer | None, optional): The optimizer from which to save the state_dict.
|
| 118 |
+
Defaults to None.
|
| 119 |
+
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
|
| 120 |
+
Defaults to None.
|
| 121 |
+
"""
|
| 122 |
+
save_dir = checkpoint_dir / TRAINING_STATE_DIR
|
| 123 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 124 |
+
save_training_step(train_step, save_dir)
|
| 125 |
+
save_rng_state(save_dir)
|
| 126 |
+
if optimizer is not None:
|
| 127 |
+
save_optimizer_state(optimizer, save_dir)
|
| 128 |
+
if scheduler is not None:
|
| 129 |
+
save_scheduler_state(scheduler, save_dir)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def load_training_state(
|
| 133 |
+
checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None
|
| 134 |
+
) -> tuple[int, Optimizer, LRScheduler | None]:
|
| 135 |
+
"""
|
| 136 |
+
Loads the training step, optimizer state, scheduler state, and rng state.
|
| 137 |
+
This is used to resume a training run.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir.
|
| 141 |
+
optimizer (Optimizer): The optimizer to load the state_dict to.
|
| 142 |
+
scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None).
|
| 143 |
+
|
| 144 |
+
Raises:
|
| 145 |
+
NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
tuple[int, Optimizer, LRScheduler | None]: training step, optimizer and scheduler with their
|
| 149 |
+
state_dict loaded.
|
| 150 |
+
"""
|
| 151 |
+
training_state_dir = checkpoint_dir / TRAINING_STATE_DIR
|
| 152 |
+
if not training_state_dir.is_dir():
|
| 153 |
+
raise NotADirectoryError(training_state_dir)
|
| 154 |
+
|
| 155 |
+
load_rng_state(training_state_dir)
|
| 156 |
+
step = load_training_step(training_state_dir)
|
| 157 |
+
optimizer = load_optimizer_state(optimizer, training_state_dir)
|
| 158 |
+
if scheduler is not None:
|
| 159 |
+
scheduler = load_scheduler_state(scheduler, training_state_dir)
|
| 160 |
+
|
| 161 |
+
return step, optimizer, scheduler
|
lerobot/common/utils/utils.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
import logging
|
| 17 |
+
import os
|
| 18 |
+
import os.path as osp
|
| 19 |
+
import platform
|
| 20 |
+
import subprocess
|
| 21 |
+
from copy import copy
|
| 22 |
+
from datetime import datetime, timezone
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def none_or_int(value):
|
| 30 |
+
if value == "None":
|
| 31 |
+
return None
|
| 32 |
+
return int(value)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def inside_slurm():
|
| 36 |
+
"""Check whether the python process was launched through slurm"""
|
| 37 |
+
# TODO(rcadene): return False for interactive mode `--pty bash`
|
| 38 |
+
return "SLURM_JOB_ID" in os.environ
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def auto_select_torch_device() -> torch.device:
|
| 42 |
+
"""Tries to select automatically a torch device."""
|
| 43 |
+
if torch.cuda.is_available():
|
| 44 |
+
logging.info("Cuda backend detected, using cuda.")
|
| 45 |
+
return torch.device("cuda")
|
| 46 |
+
elif torch.backends.mps.is_available():
|
| 47 |
+
logging.info("Metal backend detected, using cuda.")
|
| 48 |
+
return torch.device("mps")
|
| 49 |
+
else:
|
| 50 |
+
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
|
| 51 |
+
return torch.device("cpu")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
|
| 55 |
+
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
| 56 |
+
"""Given a string, return a torch.device with checks on whether the device is available."""
|
| 57 |
+
try_device = str(try_device)
|
| 58 |
+
match try_device:
|
| 59 |
+
case "cuda":
|
| 60 |
+
assert torch.cuda.is_available()
|
| 61 |
+
device = torch.device("cuda")
|
| 62 |
+
case "mps":
|
| 63 |
+
assert torch.backends.mps.is_available()
|
| 64 |
+
device = torch.device("mps")
|
| 65 |
+
case "cpu":
|
| 66 |
+
device = torch.device("cpu")
|
| 67 |
+
if log:
|
| 68 |
+
logging.warning("Using CPU, this will be slow.")
|
| 69 |
+
case _:
|
| 70 |
+
device = torch.device(try_device)
|
| 71 |
+
if log:
|
| 72 |
+
logging.warning(f"Using custom {try_device} device.")
|
| 73 |
+
|
| 74 |
+
return device
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
|
| 78 |
+
"""
|
| 79 |
+
mps is currently not compatible with float64
|
| 80 |
+
"""
|
| 81 |
+
if isinstance(device, torch.device):
|
| 82 |
+
device = device.type
|
| 83 |
+
if device == "mps" and dtype == torch.float64:
|
| 84 |
+
return torch.float32
|
| 85 |
+
else:
|
| 86 |
+
return dtype
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def is_torch_device_available(try_device: str) -> bool:
|
| 90 |
+
try_device = str(try_device) # Ensure try_device is a string
|
| 91 |
+
if try_device == "cuda":
|
| 92 |
+
return torch.cuda.is_available()
|
| 93 |
+
elif try_device == "mps":
|
| 94 |
+
return torch.backends.mps.is_available()
|
| 95 |
+
elif try_device == "cpu":
|
| 96 |
+
return True
|
| 97 |
+
else:
|
| 98 |
+
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def is_amp_available(device: str):
|
| 102 |
+
if device in ["cuda", "cpu"]:
|
| 103 |
+
return True
|
| 104 |
+
elif device == "mps":
|
| 105 |
+
return False
|
| 106 |
+
else:
|
| 107 |
+
raise ValueError(f"Unknown device '{device}.")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def init_logging():
|
| 111 |
+
def custom_format(record):
|
| 112 |
+
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 113 |
+
fnameline = f"{record.pathname}:{record.lineno}"
|
| 114 |
+
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
| 115 |
+
return message
|
| 116 |
+
|
| 117 |
+
logging.basicConfig(level=logging.INFO)
|
| 118 |
+
|
| 119 |
+
for handler in logging.root.handlers[:]:
|
| 120 |
+
logging.root.removeHandler(handler)
|
| 121 |
+
|
| 122 |
+
formatter = logging.Formatter()
|
| 123 |
+
formatter.format = custom_format
|
| 124 |
+
console_handler = logging.StreamHandler()
|
| 125 |
+
console_handler.setFormatter(formatter)
|
| 126 |
+
logging.getLogger().addHandler(console_handler)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def format_big_number(num, precision=0):
|
| 130 |
+
suffixes = ["", "K", "M", "B", "T", "Q"]
|
| 131 |
+
divisor = 1000.0
|
| 132 |
+
|
| 133 |
+
for suffix in suffixes:
|
| 134 |
+
if abs(num) < divisor:
|
| 135 |
+
return f"{num:.{precision}f}{suffix}"
|
| 136 |
+
num /= divisor
|
| 137 |
+
|
| 138 |
+
return num
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _relative_path_between(path1: Path, path2: Path) -> Path:
|
| 142 |
+
"""Returns path1 relative to path2."""
|
| 143 |
+
path1 = path1.absolute()
|
| 144 |
+
path2 = path2.absolute()
|
| 145 |
+
try:
|
| 146 |
+
return path1.relative_to(path2)
|
| 147 |
+
except ValueError: # most likely because path1 is not a subpath of path2
|
| 148 |
+
common_parts = Path(osp.commonpath([path1, path2])).parts
|
| 149 |
+
return Path(
|
| 150 |
+
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def print_cuda_memory_usage():
|
| 155 |
+
"""Use this function to locate and debug memory leak."""
|
| 156 |
+
import gc
|
| 157 |
+
|
| 158 |
+
gc.collect()
|
| 159 |
+
# Also clear the cache if you want to fully release the memory
|
| 160 |
+
torch.cuda.empty_cache()
|
| 161 |
+
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
|
| 162 |
+
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
|
| 163 |
+
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
|
| 164 |
+
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def capture_timestamp_utc():
|
| 168 |
+
return datetime.now(timezone.utc)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def say(text, blocking=False):
|
| 172 |
+
system = platform.system()
|
| 173 |
+
|
| 174 |
+
if system == "Darwin":
|
| 175 |
+
cmd = ["say", text]
|
| 176 |
+
|
| 177 |
+
elif system == "Linux":
|
| 178 |
+
cmd = ["spd-say", text]
|
| 179 |
+
if blocking:
|
| 180 |
+
cmd.append("--wait")
|
| 181 |
+
|
| 182 |
+
elif system == "Windows":
|
| 183 |
+
cmd = [
|
| 184 |
+
"PowerShell",
|
| 185 |
+
"-Command",
|
| 186 |
+
"Add-Type -AssemblyName System.Speech; "
|
| 187 |
+
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')",
|
| 188 |
+
]
|
| 189 |
+
|
| 190 |
+
else:
|
| 191 |
+
raise RuntimeError("Unsupported operating system for text-to-speech.")
|
| 192 |
+
|
| 193 |
+
if blocking:
|
| 194 |
+
subprocess.run(cmd, check=True)
|
| 195 |
+
else:
|
| 196 |
+
subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def log_say(text, play_sounds, blocking=False):
|
| 200 |
+
logging.info(text)
|
| 201 |
+
|
| 202 |
+
if play_sounds:
|
| 203 |
+
say(text, blocking)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def get_channel_first_image_shape(image_shape: tuple) -> tuple:
|
| 207 |
+
shape = copy(image_shape)
|
| 208 |
+
if shape[2] < shape[0] and shape[2] < shape[1]: # (h, w, c) -> (c, h, w)
|
| 209 |
+
shape = (shape[2], shape[0], shape[1])
|
| 210 |
+
elif not (shape[0] < shape[1] and shape[0] < shape[2]):
|
| 211 |
+
raise ValueError(image_shape)
|
| 212 |
+
|
| 213 |
+
return shape
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def has_method(cls: object, method_name: str) -> bool:
|
| 217 |
+
return hasattr(cls, method_name) and callable(getattr(cls, method_name))
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
|
| 221 |
+
"""
|
| 222 |
+
Return True if a given string can be converted to a numpy dtype.
|
| 223 |
+
"""
|
| 224 |
+
try:
|
| 225 |
+
# Attempt to convert the string to a numpy dtype
|
| 226 |
+
np.dtype(dtype_str)
|
| 227 |
+
return True
|
| 228 |
+
except TypeError:
|
| 229 |
+
# If a TypeError is raised, the string is not a valid dtype
|
| 230 |
+
return False
|
lerobot/common/utils/wandb_utils.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
import logging
|
| 17 |
+
import os
|
| 18 |
+
import re
|
| 19 |
+
from glob import glob
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
| 23 |
+
from termcolor import colored
|
| 24 |
+
|
| 25 |
+
from lerobot.common.constants import PRETRAINED_MODEL_DIR
|
| 26 |
+
from lerobot.configs.train import TrainPipelineConfig
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
|
| 30 |
+
"""Return a group name for logging. Optionally returns group name as list."""
|
| 31 |
+
lst = [
|
| 32 |
+
f"policy:{cfg.policy.type}",
|
| 33 |
+
f"dataset:{cfg.dataset.repo_id}",
|
| 34 |
+
f"seed:{cfg.seed}",
|
| 35 |
+
]
|
| 36 |
+
if cfg.env is not None:
|
| 37 |
+
lst.append(f"env:{cfg.env.type}")
|
| 38 |
+
return lst if return_list else "-".join(lst)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_wandb_run_id_from_filesystem(log_dir: Path) -> str:
|
| 42 |
+
# Get the WandB run ID.
|
| 43 |
+
paths = glob(str(log_dir / "wandb/latest-run/run-*"))
|
| 44 |
+
if len(paths) != 1:
|
| 45 |
+
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
| 46 |
+
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
|
| 47 |
+
if match is None:
|
| 48 |
+
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
| 49 |
+
wandb_run_id = match.groups(0)[0]
|
| 50 |
+
return wandb_run_id
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_safe_wandb_artifact_name(name: str):
|
| 54 |
+
"""WandB artifacts don't accept ":" or "/" in their name."""
|
| 55 |
+
return name.replace(":", "_").replace("/", "_")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class WandBLogger:
|
| 59 |
+
"""A helper class to log object using wandb."""
|
| 60 |
+
|
| 61 |
+
def __init__(self, cfg: TrainPipelineConfig):
|
| 62 |
+
self.cfg = cfg.wandb
|
| 63 |
+
self.log_dir = cfg.output_dir
|
| 64 |
+
self.job_name = cfg.job_name
|
| 65 |
+
self.env_fps = cfg.env.fps if cfg.env else None
|
| 66 |
+
self._group = cfg_to_group(cfg)
|
| 67 |
+
|
| 68 |
+
# Set up WandB.
|
| 69 |
+
os.environ["WANDB_SILENT"] = "True"
|
| 70 |
+
import wandb
|
| 71 |
+
|
| 72 |
+
wandb_run_id = (
|
| 73 |
+
cfg.wandb.run_id
|
| 74 |
+
if cfg.wandb.run_id
|
| 75 |
+
else get_wandb_run_id_from_filesystem(self.log_dir)
|
| 76 |
+
if cfg.resume
|
| 77 |
+
else None
|
| 78 |
+
)
|
| 79 |
+
wandb.init(
|
| 80 |
+
id=wandb_run_id,
|
| 81 |
+
project=self.cfg.project,
|
| 82 |
+
entity=self.cfg.entity,
|
| 83 |
+
name=self.job_name,
|
| 84 |
+
notes=self.cfg.notes,
|
| 85 |
+
tags=cfg_to_group(cfg, return_list=True),
|
| 86 |
+
dir=self.log_dir,
|
| 87 |
+
config=cfg.to_dict(),
|
| 88 |
+
# TODO(rcadene): try set to True
|
| 89 |
+
save_code=False,
|
| 90 |
+
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
|
| 91 |
+
job_type="train_eval",
|
| 92 |
+
resume="must" if cfg.resume else None,
|
| 93 |
+
)
|
| 94 |
+
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
| 95 |
+
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
| 96 |
+
self._wandb = wandb
|
| 97 |
+
|
| 98 |
+
def log_policy(self, checkpoint_dir: Path):
|
| 99 |
+
"""Checkpoints the policy to wandb."""
|
| 100 |
+
if self.cfg.disable_artifact:
|
| 101 |
+
return
|
| 102 |
+
|
| 103 |
+
step_id = checkpoint_dir.name
|
| 104 |
+
artifact_name = f"{self._group}-{step_id}"
|
| 105 |
+
artifact_name = get_safe_wandb_artifact_name(artifact_name)
|
| 106 |
+
artifact = self._wandb.Artifact(artifact_name, type="model")
|
| 107 |
+
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
| 108 |
+
self._wandb.log_artifact(artifact)
|
| 109 |
+
|
| 110 |
+
def log_dict(self, d: dict, step: int, mode: str = "train"):
|
| 111 |
+
if mode not in {"train", "eval"}:
|
| 112 |
+
raise ValueError(mode)
|
| 113 |
+
|
| 114 |
+
for k, v in d.items():
|
| 115 |
+
if not isinstance(v, (int, float, str)):
|
| 116 |
+
logging.warning(
|
| 117 |
+
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
| 118 |
+
)
|
| 119 |
+
continue
|
| 120 |
+
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
| 121 |
+
|
| 122 |
+
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
| 123 |
+
if mode not in {"train", "eval"}:
|
| 124 |
+
raise ValueError(mode)
|
| 125 |
+
|
| 126 |
+
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")
|
| 127 |
+
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
lerobot/configs/default.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
|
| 19 |
+
from lerobot.common import (
|
| 20 |
+
policies, # noqa: F401
|
| 21 |
+
)
|
| 22 |
+
from lerobot.common.datasets.transforms import ImageTransformsConfig
|
| 23 |
+
from lerobot.common.datasets.video_utils import get_safe_default_codec
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class DatasetConfig:
|
| 28 |
+
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
|
| 29 |
+
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
|
| 30 |
+
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
| 31 |
+
# datasets are provided.
|
| 32 |
+
repo_id: str
|
| 33 |
+
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
| 34 |
+
root: str | None = None
|
| 35 |
+
episodes: list[int] | None = None
|
| 36 |
+
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
| 37 |
+
revision: str | None = None
|
| 38 |
+
use_imagenet_stats: bool = True
|
| 39 |
+
video_backend: str = field(default_factory=get_safe_default_codec)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class WandBConfig:
|
| 44 |
+
enable: bool = False
|
| 45 |
+
# Set to true to disable saving an artifact despite training.save_checkpoint=True
|
| 46 |
+
disable_artifact: bool = False
|
| 47 |
+
project: str = "lerobot"
|
| 48 |
+
entity: str | None = None
|
| 49 |
+
notes: str | None = None
|
| 50 |
+
run_id: str | None = None
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclass
|
| 54 |
+
class EvalConfig:
|
| 55 |
+
n_episodes: int = 50
|
| 56 |
+
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
|
| 57 |
+
batch_size: int = 50
|
| 58 |
+
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
| 59 |
+
use_async_envs: bool = False
|
| 60 |
+
|
| 61 |
+
def __post_init__(self):
|
| 62 |
+
if self.batch_size > self.n_episodes:
|
| 63 |
+
raise ValueError(
|
| 64 |
+
"The eval batch size is greater than the number of eval episodes "
|
| 65 |
+
f"({self.batch_size} > {self.n_episodes}). As a result, {self.batch_size} "
|
| 66 |
+
f"eval environments will be instantiated, but only {self.n_episodes} will be used. "
|
| 67 |
+
"This might significantly slow down evaluation. To fix this, you should update your command "
|
| 68 |
+
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={self.batch_size}`), "
|
| 69 |
+
f"or lower the batch size (e.g. `eval.batch_size={self.n_episodes}`)."
|
| 70 |
+
)
|
lerobot/configs/eval.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import datetime as dt
|
| 16 |
+
import logging
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
from lerobot.common import envs, policies # noqa: F401
|
| 21 |
+
from lerobot.configs import parser
|
| 22 |
+
from lerobot.configs.default import EvalConfig
|
| 23 |
+
from lerobot.configs.policies import PreTrainedConfig
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class EvalPipelineConfig:
|
| 28 |
+
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
|
| 29 |
+
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch
|
| 30 |
+
# (useful for debugging). This argument is mutually exclusive with `--config`.
|
| 31 |
+
env: envs.EnvConfig
|
| 32 |
+
eval: EvalConfig = field(default_factory=EvalConfig)
|
| 33 |
+
policy: PreTrainedConfig | None = None
|
| 34 |
+
output_dir: Path | None = None
|
| 35 |
+
job_name: str | None = None
|
| 36 |
+
seed: int | None = 1000
|
| 37 |
+
|
| 38 |
+
def __post_init__(self):
|
| 39 |
+
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
| 40 |
+
policy_path = parser.get_path_arg("policy")
|
| 41 |
+
if policy_path:
|
| 42 |
+
cli_overrides = parser.get_cli_overrides("policy")
|
| 43 |
+
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
| 44 |
+
self.policy.pretrained_path = policy_path
|
| 45 |
+
|
| 46 |
+
else:
|
| 47 |
+
logging.warning(
|
| 48 |
+
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
if not self.job_name:
|
| 52 |
+
if self.env is None:
|
| 53 |
+
self.job_name = f"{self.policy.type}"
|
| 54 |
+
else:
|
| 55 |
+
self.job_name = f"{self.env.type}_{self.policy.type}"
|
| 56 |
+
|
| 57 |
+
if not self.output_dir:
|
| 58 |
+
now = dt.datetime.now()
|
| 59 |
+
eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
| 60 |
+
self.output_dir = Path("outputs/eval") / eval_dir
|
| 61 |
+
|
| 62 |
+
@classmethod
|
| 63 |
+
def __get_path_fields__(cls) -> list[str]:
|
| 64 |
+
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
| 65 |
+
return ["policy"]
|
lerobot/configs/parser.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import importlib
|
| 15 |
+
import inspect
|
| 16 |
+
import pkgutil
|
| 17 |
+
import sys
|
| 18 |
+
from argparse import ArgumentError
|
| 19 |
+
from functools import wraps
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Sequence
|
| 22 |
+
|
| 23 |
+
import draccus
|
| 24 |
+
|
| 25 |
+
from lerobot.common.utils.utils import has_method
|
| 26 |
+
|
| 27 |
+
PATH_KEY = "path"
|
| 28 |
+
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
| 29 |
+
draccus.set_config_type("json")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
|
| 33 |
+
"""Parses arguments from cli at a given nested attribute level.
|
| 34 |
+
|
| 35 |
+
For example, supposing the main script was called with:
|
| 36 |
+
python myscript.py --arg1=1 --arg2.subarg1=abc --arg2.subarg2=some/path
|
| 37 |
+
|
| 38 |
+
If called during execution of myscript.py, get_cli_overrides("arg2") will return:
|
| 39 |
+
["--subarg1=abc" "--subarg2=some/path"]
|
| 40 |
+
"""
|
| 41 |
+
if args is None:
|
| 42 |
+
args = sys.argv[1:]
|
| 43 |
+
attr_level_args = []
|
| 44 |
+
detect_string = f"--{field_name}."
|
| 45 |
+
exclude_strings = (f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", f"--{field_name}.{PATH_KEY}=")
|
| 46 |
+
for arg in args:
|
| 47 |
+
if arg.startswith(detect_string) and not arg.startswith(exclude_strings):
|
| 48 |
+
denested_arg = f"--{arg.removeprefix(detect_string)}"
|
| 49 |
+
attr_level_args.append(denested_arg)
|
| 50 |
+
|
| 51 |
+
return attr_level_args
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
|
| 55 |
+
if args is None:
|
| 56 |
+
args = sys.argv[1:]
|
| 57 |
+
prefix = f"--{arg_name}="
|
| 58 |
+
for arg in args:
|
| 59 |
+
if arg.startswith(prefix):
|
| 60 |
+
return arg[len(prefix) :]
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
|
| 65 |
+
"""Parse plugin-related arguments from command-line arguments.
|
| 66 |
+
|
| 67 |
+
This function extracts arguments from command-line arguments that match a specified suffix pattern.
|
| 68 |
+
It processes arguments in the format '--key=value' and returns them as a dictionary.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
plugin_arg_suffix (str): The suffix to identify plugin-related arguments.
|
| 72 |
+
cli_args (Sequence[str]): A sequence of command-line arguments to parse.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
dict: A dictionary containing the parsed plugin arguments where:
|
| 76 |
+
- Keys are the argument names (with '--' prefix removed if present)
|
| 77 |
+
- Values are the corresponding argument values
|
| 78 |
+
|
| 79 |
+
Example:
|
| 80 |
+
>>> args = ['--env.discover_packages_path=my_package',
|
| 81 |
+
... '--other_arg=value']
|
| 82 |
+
>>> parse_plugin_args('discover_packages_path', args)
|
| 83 |
+
{'env.discover_packages_path': 'my_package'}
|
| 84 |
+
"""
|
| 85 |
+
plugin_args = {}
|
| 86 |
+
for arg in args:
|
| 87 |
+
if "=" in arg and plugin_arg_suffix in arg:
|
| 88 |
+
key, value = arg.split("=", 1)
|
| 89 |
+
# Remove leading '--' if present
|
| 90 |
+
if key.startswith("--"):
|
| 91 |
+
key = key[2:]
|
| 92 |
+
plugin_args[key] = value
|
| 93 |
+
return plugin_args
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class PluginLoadError(Exception):
|
| 97 |
+
"""Raised when a plugin fails to load."""
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def load_plugin(plugin_path: str) -> None:
|
| 101 |
+
"""Load and initialize a plugin from a given Python package path.
|
| 102 |
+
|
| 103 |
+
This function attempts to load a plugin by importing its package and any submodules.
|
| 104 |
+
Plugin registration is expected to happen during package initialization, i.e. when
|
| 105 |
+
the package is imported the gym environment should be registered and the config classes
|
| 106 |
+
registered with their parents using the `register_subclass` decorator.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
plugin_path (str): The Python package path to the plugin (e.g. "mypackage.plugins.myplugin")
|
| 110 |
+
|
| 111 |
+
Raises:
|
| 112 |
+
PluginLoadError: If the plugin cannot be loaded due to import errors or if the package path is invalid.
|
| 113 |
+
|
| 114 |
+
Examples:
|
| 115 |
+
>>> load_plugin("external_plugin.core") # Loads plugin from external package
|
| 116 |
+
|
| 117 |
+
Notes:
|
| 118 |
+
- The plugin package should handle its own registration during import
|
| 119 |
+
- All submodules in the plugin package will be imported
|
| 120 |
+
- Implementation follows the plugin discovery pattern from Python packaging guidelines
|
| 121 |
+
|
| 122 |
+
See Also:
|
| 123 |
+
https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/
|
| 124 |
+
"""
|
| 125 |
+
try:
|
| 126 |
+
package_module = importlib.import_module(plugin_path, __package__)
|
| 127 |
+
except (ImportError, ModuleNotFoundError) as e:
|
| 128 |
+
raise PluginLoadError(
|
| 129 |
+
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
|
| 130 |
+
) from e
|
| 131 |
+
|
| 132 |
+
def iter_namespace(ns_pkg):
|
| 133 |
+
return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
|
| 134 |
+
|
| 135 |
+
try:
|
| 136 |
+
for _finder, pkg_name, _ispkg in iter_namespace(package_module):
|
| 137 |
+
importlib.import_module(pkg_name)
|
| 138 |
+
except ImportError as e:
|
| 139 |
+
raise PluginLoadError(
|
| 140 |
+
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
|
| 141 |
+
) from e
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
| 145 |
+
return parse_arg(f"{field_name}.{PATH_KEY}", args)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
| 149 |
+
return parse_arg(f"{field_name}.{draccus.CHOICE_TYPE_KEY}", args)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
|
| 153 |
+
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]:
|
| 157 |
+
"""
|
| 158 |
+
Filters command-line arguments related to fields with specific path arguments.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
fields_to_filter (str | list[str]): A single str or a list of str whose arguments need to be filtered.
|
| 162 |
+
args (Sequence[str] | None): The sequence of command-line arguments to be filtered.
|
| 163 |
+
Defaults to None.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
list[str]: A filtered list of arguments, with arguments related to the specified
|
| 167 |
+
fields removed.
|
| 168 |
+
|
| 169 |
+
Raises:
|
| 170 |
+
ArgumentError: If both a path argument (e.g., `--field_name.path`) and a type
|
| 171 |
+
argument (e.g., `--field_name.type`) are specified for the same field.
|
| 172 |
+
"""
|
| 173 |
+
if isinstance(fields_to_filter, str):
|
| 174 |
+
fields_to_filter = [fields_to_filter]
|
| 175 |
+
|
| 176 |
+
filtered_args = args
|
| 177 |
+
for field in fields_to_filter:
|
| 178 |
+
if get_path_arg(field, args):
|
| 179 |
+
if get_type_arg(field, args):
|
| 180 |
+
raise ArgumentError(
|
| 181 |
+
argument=None,
|
| 182 |
+
message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}",
|
| 183 |
+
)
|
| 184 |
+
filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")]
|
| 185 |
+
|
| 186 |
+
return filtered_args
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def wrap(config_path: Path | None = None):
|
| 190 |
+
"""
|
| 191 |
+
HACK: Similar to draccus.wrap but does three additional things:
|
| 192 |
+
- Will remove '.path' arguments from CLI in order to process them later on.
|
| 193 |
+
- If a 'config_path' is passed and the main config class has a 'from_pretrained' method, will
|
| 194 |
+
initialize it from there to allow to fetch configs from the hub directly
|
| 195 |
+
- Will load plugins specified in the CLI arguments. These plugins will typically register
|
| 196 |
+
their own subclasses of config classes, so that draccus can find the right class to instantiate
|
| 197 |
+
from the CLI '.type' arguments
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
def wrapper_outer(fn):
|
| 201 |
+
@wraps(fn)
|
| 202 |
+
def wrapper_inner(*args, **kwargs):
|
| 203 |
+
argspec = inspect.getfullargspec(fn)
|
| 204 |
+
argtype = argspec.annotations[argspec.args[0]]
|
| 205 |
+
if len(args) > 0 and type(args[0]) is argtype:
|
| 206 |
+
cfg = args[0]
|
| 207 |
+
args = args[1:]
|
| 208 |
+
else:
|
| 209 |
+
cli_args = sys.argv[1:]
|
| 210 |
+
plugin_args = parse_plugin_args(PLUGIN_DISCOVERY_SUFFIX, cli_args)
|
| 211 |
+
for plugin_cli_arg, plugin_path in plugin_args.items():
|
| 212 |
+
try:
|
| 213 |
+
load_plugin(plugin_path)
|
| 214 |
+
except PluginLoadError as e:
|
| 215 |
+
# add the relevant CLI arg to the error message
|
| 216 |
+
raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e
|
| 217 |
+
cli_args = filter_arg(plugin_cli_arg, cli_args)
|
| 218 |
+
config_path_cli = parse_arg("config_path", cli_args)
|
| 219 |
+
if has_method(argtype, "__get_path_fields__"):
|
| 220 |
+
path_fields = argtype.__get_path_fields__()
|
| 221 |
+
cli_args = filter_path_args(path_fields, cli_args)
|
| 222 |
+
if has_method(argtype, "from_pretrained") and config_path_cli:
|
| 223 |
+
cli_args = filter_arg("config_path", cli_args)
|
| 224 |
+
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
| 225 |
+
else:
|
| 226 |
+
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
|
| 227 |
+
response = fn(cfg, *args, **kwargs)
|
| 228 |
+
return response
|
| 229 |
+
|
| 230 |
+
return wrapper_inner
|
| 231 |
+
|
| 232 |
+
return wrapper_outer
|
lerobot/configs/policies.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import abc
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Type, TypeVar
|
| 20 |
+
|
| 21 |
+
import draccus
|
| 22 |
+
from huggingface_hub import hf_hub_download
|
| 23 |
+
from huggingface_hub.constants import CONFIG_NAME
|
| 24 |
+
from huggingface_hub.errors import HfHubHTTPError
|
| 25 |
+
|
| 26 |
+
from lerobot.common.optim.optimizers import OptimizerConfig
|
| 27 |
+
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
| 28 |
+
from lerobot.common.utils.hub import HubMixin
|
| 29 |
+
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
| 30 |
+
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
| 31 |
+
|
| 32 |
+
# Generic variable that is either PreTrainedConfig or a subclass thereof
|
| 33 |
+
T = TypeVar("T", bound="PreTrainedConfig")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
| 38 |
+
"""
|
| 39 |
+
Base configuration class for policy models.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
| 43 |
+
current step and additional steps going back).
|
| 44 |
+
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
| 45 |
+
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
| 46 |
+
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
|
| 47 |
+
normalization mode to apply.
|
| 48 |
+
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
|
| 49 |
+
the original scale.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
n_obs_steps: int = 1
|
| 53 |
+
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
|
| 54 |
+
|
| 55 |
+
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
| 56 |
+
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
| 57 |
+
|
| 58 |
+
device: str | None = None # cuda | cpu | mp
|
| 59 |
+
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
| 60 |
+
# automatic gradient scaling is used.
|
| 61 |
+
use_amp: bool = False
|
| 62 |
+
|
| 63 |
+
def __post_init__(self):
|
| 64 |
+
self.pretrained_path = None
|
| 65 |
+
if not self.device or not is_torch_device_available(self.device):
|
| 66 |
+
auto_device = auto_select_torch_device()
|
| 67 |
+
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
| 68 |
+
self.device = auto_device.type
|
| 69 |
+
|
| 70 |
+
# Automatically deactivate AMP if necessary
|
| 71 |
+
if self.use_amp and not is_amp_available(self.device):
|
| 72 |
+
logging.warning(
|
| 73 |
+
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
| 74 |
+
)
|
| 75 |
+
self.use_amp = False
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def type(self) -> str:
|
| 79 |
+
return self.get_choice_name(self.__class__)
|
| 80 |
+
|
| 81 |
+
@abc.abstractproperty
|
| 82 |
+
def observation_delta_indices(self) -> list | None:
|
| 83 |
+
raise NotImplementedError
|
| 84 |
+
|
| 85 |
+
@abc.abstractproperty
|
| 86 |
+
def action_delta_indices(self) -> list | None:
|
| 87 |
+
raise NotImplementedError
|
| 88 |
+
|
| 89 |
+
@abc.abstractproperty
|
| 90 |
+
def reward_delta_indices(self) -> list | None:
|
| 91 |
+
raise NotImplementedError
|
| 92 |
+
|
| 93 |
+
@abc.abstractmethod
|
| 94 |
+
def get_optimizer_preset(self) -> OptimizerConfig:
|
| 95 |
+
raise NotImplementedError
|
| 96 |
+
|
| 97 |
+
@abc.abstractmethod
|
| 98 |
+
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
| 99 |
+
raise NotImplementedError
|
| 100 |
+
|
| 101 |
+
@abc.abstractmethod
|
| 102 |
+
def validate_features(self) -> None:
|
| 103 |
+
raise NotImplementedError
|
| 104 |
+
|
| 105 |
+
@property
|
| 106 |
+
def robot_state_feature(self) -> PolicyFeature | None:
|
| 107 |
+
for _, ft in self.input_features.items():
|
| 108 |
+
if ft.type is FeatureType.STATE:
|
| 109 |
+
return ft
|
| 110 |
+
return None
|
| 111 |
+
|
| 112 |
+
@property
|
| 113 |
+
def env_state_feature(self) -> PolicyFeature | None:
|
| 114 |
+
for _, ft in self.input_features.items():
|
| 115 |
+
if ft.type is FeatureType.ENV:
|
| 116 |
+
return ft
|
| 117 |
+
return None
|
| 118 |
+
|
| 119 |
+
@property
|
| 120 |
+
def image_features(self) -> dict[str, PolicyFeature]:
|
| 121 |
+
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
|
| 122 |
+
|
| 123 |
+
@property
|
| 124 |
+
def action_feature(self) -> PolicyFeature | None:
|
| 125 |
+
for _, ft in self.output_features.items():
|
| 126 |
+
if ft.type is FeatureType.ACTION:
|
| 127 |
+
return ft
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
def _save_pretrained(self, save_directory: Path) -> None:
|
| 131 |
+
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
| 132 |
+
draccus.dump(self, f, indent=4)
|
| 133 |
+
|
| 134 |
+
@classmethod
|
| 135 |
+
def from_pretrained(
|
| 136 |
+
cls: Type[T],
|
| 137 |
+
pretrained_name_or_path: str | Path,
|
| 138 |
+
*,
|
| 139 |
+
force_download: bool = False,
|
| 140 |
+
resume_download: bool = None,
|
| 141 |
+
proxies: dict | None = None,
|
| 142 |
+
token: str | bool | None = None,
|
| 143 |
+
cache_dir: str | Path | None = None,
|
| 144 |
+
local_files_only: bool = False,
|
| 145 |
+
revision: str | None = None,
|
| 146 |
+
**policy_kwargs,
|
| 147 |
+
) -> T:
|
| 148 |
+
model_id = str(pretrained_name_or_path)
|
| 149 |
+
config_file: str | None = None
|
| 150 |
+
if Path(model_id).is_dir():
|
| 151 |
+
if CONFIG_NAME in os.listdir(model_id):
|
| 152 |
+
config_file = os.path.join(model_id, CONFIG_NAME)
|
| 153 |
+
else:
|
| 154 |
+
print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
| 155 |
+
else:
|
| 156 |
+
try:
|
| 157 |
+
config_file = hf_hub_download(
|
| 158 |
+
repo_id=model_id,
|
| 159 |
+
filename=CONFIG_NAME,
|
| 160 |
+
revision=revision,
|
| 161 |
+
cache_dir=cache_dir,
|
| 162 |
+
force_download=force_download,
|
| 163 |
+
proxies=proxies,
|
| 164 |
+
resume_download=resume_download,
|
| 165 |
+
token=token,
|
| 166 |
+
local_files_only=local_files_only,
|
| 167 |
+
)
|
| 168 |
+
except HfHubHTTPError as e:
|
| 169 |
+
raise FileNotFoundError(
|
| 170 |
+
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
| 171 |
+
) from e
|
| 172 |
+
|
| 173 |
+
# HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus
|
| 174 |
+
# something like --policy.path (in addition to --policy.type)
|
| 175 |
+
cli_overrides = policy_kwargs.pop("cli_overrides", [])
|
| 176 |
+
return draccus.parse(cls, config_file, args=cli_overrides)
|
lerobot/configs/train.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import datetime as dt
|
| 15 |
+
import os
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Type
|
| 19 |
+
|
| 20 |
+
import draccus
|
| 21 |
+
from huggingface_hub import hf_hub_download
|
| 22 |
+
from huggingface_hub.errors import HfHubHTTPError
|
| 23 |
+
|
| 24 |
+
from lerobot.common import envs
|
| 25 |
+
from lerobot.common.optim import OptimizerConfig
|
| 26 |
+
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
| 27 |
+
from lerobot.common.utils.hub import HubMixin
|
| 28 |
+
from lerobot.configs import parser
|
| 29 |
+
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
| 30 |
+
from lerobot.configs.policies import PreTrainedConfig
|
| 31 |
+
|
| 32 |
+
TRAIN_CONFIG_NAME = "train_config.json"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class TrainPipelineConfig(HubMixin):
|
| 37 |
+
dataset: DatasetConfig
|
| 38 |
+
env: envs.EnvConfig | None = None
|
| 39 |
+
policy: PreTrainedConfig | None = None
|
| 40 |
+
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
| 41 |
+
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
| 42 |
+
output_dir: Path | None = None
|
| 43 |
+
job_name: str | None = None
|
| 44 |
+
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
| 45 |
+
# `dir` is the directory of an existing run with at least one checkpoint in it.
|
| 46 |
+
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
|
| 47 |
+
# regardless of what's provided with the training command at the time of resumption.
|
| 48 |
+
resume: bool = False
|
| 49 |
+
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
| 50 |
+
# AND for the evaluation environments.
|
| 51 |
+
seed: int | None = 1000
|
| 52 |
+
# Number of workers for the dataloader.
|
| 53 |
+
num_workers: int = 4
|
| 54 |
+
batch_size: int = 8
|
| 55 |
+
steps: int = 100_000
|
| 56 |
+
eval_freq: int = 20_000
|
| 57 |
+
log_freq: int = 200
|
| 58 |
+
save_checkpoint: bool = True
|
| 59 |
+
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
| 60 |
+
save_freq: int = 20_000
|
| 61 |
+
use_policy_training_preset: bool = True
|
| 62 |
+
optimizer: OptimizerConfig | None = None
|
| 63 |
+
scheduler: LRSchedulerConfig | None = None
|
| 64 |
+
eval: EvalConfig = field(default_factory=EvalConfig)
|
| 65 |
+
wandb: WandBConfig = field(default_factory=WandBConfig)
|
| 66 |
+
|
| 67 |
+
def __post_init__(self):
|
| 68 |
+
self.checkpoint_path = None
|
| 69 |
+
|
| 70 |
+
def validate(self):
|
| 71 |
+
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
| 72 |
+
policy_path = parser.get_path_arg("policy")
|
| 73 |
+
if policy_path:
|
| 74 |
+
# Only load the policy config
|
| 75 |
+
cli_overrides = parser.get_cli_overrides("policy")
|
| 76 |
+
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
| 77 |
+
self.policy.pretrained_path = policy_path
|
| 78 |
+
elif self.resume:
|
| 79 |
+
# The entire train config is already loaded, we just need to get the checkpoint dir
|
| 80 |
+
config_path = parser.parse_arg("config_path")
|
| 81 |
+
if not config_path:
|
| 82 |
+
raise ValueError(
|
| 83 |
+
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
|
| 84 |
+
)
|
| 85 |
+
if not Path(config_path).resolve().exists():
|
| 86 |
+
raise NotADirectoryError(
|
| 87 |
+
f"{config_path=} is expected to be a local path. "
|
| 88 |
+
"Resuming from the hub is not supported for now."
|
| 89 |
+
)
|
| 90 |
+
policy_path = Path(config_path).parent
|
| 91 |
+
self.policy.pretrained_path = policy_path
|
| 92 |
+
self.checkpoint_path = policy_path.parent
|
| 93 |
+
|
| 94 |
+
if not self.job_name:
|
| 95 |
+
if self.env is None:
|
| 96 |
+
self.job_name = f"{self.policy.type}"
|
| 97 |
+
else:
|
| 98 |
+
self.job_name = f"{self.env.type}_{self.policy.type}"
|
| 99 |
+
|
| 100 |
+
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
| 101 |
+
raise FileExistsError(
|
| 102 |
+
f"Output directory {self.output_dir} already exists and resume is {self.resume}. "
|
| 103 |
+
f"Please change your output directory so that {self.output_dir} is not overwritten."
|
| 104 |
+
)
|
| 105 |
+
elif not self.output_dir:
|
| 106 |
+
now = dt.datetime.now()
|
| 107 |
+
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
| 108 |
+
self.output_dir = Path("outputs/train") / train_dir
|
| 109 |
+
|
| 110 |
+
if isinstance(self.dataset.repo_id, list):
|
| 111 |
+
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
|
| 112 |
+
|
| 113 |
+
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
| 114 |
+
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
| 115 |
+
elif self.use_policy_training_preset and not self.resume:
|
| 116 |
+
self.optimizer = self.policy.get_optimizer_preset()
|
| 117 |
+
self.scheduler = self.policy.get_scheduler_preset()
|
| 118 |
+
|
| 119 |
+
@classmethod
|
| 120 |
+
def __get_path_fields__(cls) -> list[str]:
|
| 121 |
+
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
| 122 |
+
return ["policy"]
|
| 123 |
+
|
| 124 |
+
def to_dict(self) -> dict:
|
| 125 |
+
return draccus.encode(self)
|
| 126 |
+
|
| 127 |
+
def _save_pretrained(self, save_directory: Path) -> None:
|
| 128 |
+
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
| 129 |
+
draccus.dump(self, f, indent=4)
|
| 130 |
+
|
| 131 |
+
@classmethod
|
| 132 |
+
def from_pretrained(
|
| 133 |
+
cls: Type["TrainPipelineConfig"],
|
| 134 |
+
pretrained_name_or_path: str | Path,
|
| 135 |
+
*,
|
| 136 |
+
force_download: bool = False,
|
| 137 |
+
resume_download: bool = None,
|
| 138 |
+
proxies: dict | None = None,
|
| 139 |
+
token: str | bool | None = None,
|
| 140 |
+
cache_dir: str | Path | None = None,
|
| 141 |
+
local_files_only: bool = False,
|
| 142 |
+
revision: str | None = None,
|
| 143 |
+
**kwargs,
|
| 144 |
+
) -> "TrainPipelineConfig":
|
| 145 |
+
model_id = str(pretrained_name_or_path)
|
| 146 |
+
config_file: str | None = None
|
| 147 |
+
if Path(model_id).is_dir():
|
| 148 |
+
if TRAIN_CONFIG_NAME in os.listdir(model_id):
|
| 149 |
+
config_file = os.path.join(model_id, TRAIN_CONFIG_NAME)
|
| 150 |
+
else:
|
| 151 |
+
print(f"{TRAIN_CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
| 152 |
+
elif Path(model_id).is_file():
|
| 153 |
+
config_file = model_id
|
| 154 |
+
else:
|
| 155 |
+
try:
|
| 156 |
+
config_file = hf_hub_download(
|
| 157 |
+
repo_id=model_id,
|
| 158 |
+
filename=TRAIN_CONFIG_NAME,
|
| 159 |
+
revision=revision,
|
| 160 |
+
cache_dir=cache_dir,
|
| 161 |
+
force_download=force_download,
|
| 162 |
+
proxies=proxies,
|
| 163 |
+
resume_download=resume_download,
|
| 164 |
+
token=token,
|
| 165 |
+
local_files_only=local_files_only,
|
| 166 |
+
)
|
| 167 |
+
except HfHubHTTPError as e:
|
| 168 |
+
raise FileNotFoundError(
|
| 169 |
+
f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
| 170 |
+
) from e
|
| 171 |
+
|
| 172 |
+
cli_args = kwargs.pop("cli_args", [])
|
| 173 |
+
cfg = draccus.parse(cls, config_file, args=cli_args)
|
| 174 |
+
|
| 175 |
+
return cfg
|
lerobot/configs/types.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# Note: We subclass str so that serialization is straightforward
|
| 15 |
+
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from enum import Enum
|
| 18 |
+
from typing import Any, Protocol
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class FeatureType(str, Enum):
|
| 22 |
+
STATE = "STATE"
|
| 23 |
+
VISUAL = "VISUAL"
|
| 24 |
+
ENV = "ENV"
|
| 25 |
+
ACTION = "ACTION"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class NormalizationMode(str, Enum):
|
| 29 |
+
MIN_MAX = "MIN_MAX"
|
| 30 |
+
MEAN_STD = "MEAN_STD"
|
| 31 |
+
IDENTITY = "IDENTITY"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class DictLike(Protocol):
|
| 35 |
+
def __getitem__(self, key: Any) -> Any: ...
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class PolicyFeature:
|
| 40 |
+
type: FeatureType
|
| 41 |
+
shape: tuple
|
lerobot/scripts/configure_motor.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
This script configure a single motor at a time to a given ID and baudrate.
|
| 16 |
+
|
| 17 |
+
Example of usage:
|
| 18 |
+
```bash
|
| 19 |
+
python lerobot/scripts/configure_motor.py \
|
| 20 |
+
--port /dev/tty.usbmodem585A0080521 \
|
| 21 |
+
--brand feetech \
|
| 22 |
+
--model sts3215 \
|
| 23 |
+
--baudrate 1000000 \
|
| 24 |
+
--ID 1
|
| 25 |
+
```
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
import argparse
|
| 29 |
+
import time
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_motor_bus_cls(brand: str) -> tuple:
|
| 33 |
+
if brand == "feetech":
|
| 34 |
+
from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig
|
| 35 |
+
from lerobot.common.robot_devices.motors.feetech import (
|
| 36 |
+
MODEL_BAUDRATE_TABLE,
|
| 37 |
+
SCS_SERIES_BAUDRATE_TABLE,
|
| 38 |
+
FeetechMotorsBus,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
return FeetechMotorsBusConfig, FeetechMotorsBus, MODEL_BAUDRATE_TABLE, SCS_SERIES_BAUDRATE_TABLE
|
| 42 |
+
|
| 43 |
+
elif brand == "dynamixel":
|
| 44 |
+
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
|
| 45 |
+
from lerobot.common.robot_devices.motors.dynamixel import (
|
| 46 |
+
MODEL_BAUDRATE_TABLE,
|
| 47 |
+
X_SERIES_BAUDRATE_TABLE,
|
| 48 |
+
DynamixelMotorsBus,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
return DynamixelMotorsBusConfig, DynamixelMotorsBus, MODEL_BAUDRATE_TABLE, X_SERIES_BAUDRATE_TABLE
|
| 52 |
+
|
| 53 |
+
else:
|
| 54 |
+
raise ValueError(
|
| 55 |
+
f"Currently we do not support this motor brand: {brand}. We currently support feetech and dynamixel motors."
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
| 60 |
+
motor_bus_config_cls, motor_bus_cls, model_baudrate_table, series_baudrate_table = get_motor_bus_cls(
|
| 61 |
+
brand
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Check if the provided model exists in the model_baud_rate_table
|
| 65 |
+
if model not in model_baudrate_table:
|
| 66 |
+
raise ValueError(
|
| 67 |
+
f"Invalid model '{model}' for brand '{brand}'. Supported models: {list(model_baudrate_table.keys())}"
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Setup motor names, indices, and models
|
| 71 |
+
motor_name = "motor"
|
| 72 |
+
motor_index_arbitrary = motor_idx_des # Use the motor ID passed via argument
|
| 73 |
+
motor_model = model # Use the motor model passed via argument
|
| 74 |
+
|
| 75 |
+
config = motor_bus_config_cls(port=port, motors={motor_name: (motor_index_arbitrary, motor_model)})
|
| 76 |
+
|
| 77 |
+
# Initialize the MotorBus with the correct port and motor configurations
|
| 78 |
+
motor_bus = motor_bus_cls(config=config)
|
| 79 |
+
|
| 80 |
+
# Try to connect to the motor bus and handle any connection-specific errors
|
| 81 |
+
try:
|
| 82 |
+
motor_bus.connect()
|
| 83 |
+
print(f"Connected on port {motor_bus.port}")
|
| 84 |
+
except OSError as e:
|
| 85 |
+
print(f"Error occurred when connecting to the motor bus: {e}")
|
| 86 |
+
return
|
| 87 |
+
|
| 88 |
+
# Motor bus is connected, proceed with the rest of the operations
|
| 89 |
+
try:
|
| 90 |
+
print("Scanning all baudrates and motor indices")
|
| 91 |
+
all_baudrates = set(series_baudrate_table.values())
|
| 92 |
+
motor_index = -1 # Set the motor index to an out-of-range value.
|
| 93 |
+
|
| 94 |
+
for baudrate in all_baudrates:
|
| 95 |
+
motor_bus.set_bus_baudrate(baudrate)
|
| 96 |
+
present_ids = motor_bus.find_motor_indices(list(range(1, 10)))
|
| 97 |
+
if len(present_ids) > 1:
|
| 98 |
+
raise ValueError(
|
| 99 |
+
"Error: More than one motor ID detected. This script is designed to only handle one motor at a time. Please disconnect all but one motor."
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
if len(present_ids) == 1:
|
| 103 |
+
if motor_index != -1:
|
| 104 |
+
raise ValueError(
|
| 105 |
+
"Error: More than one motor ID detected. This script is designed to only handle one motor at a time. Please disconnect all but one motor."
|
| 106 |
+
)
|
| 107 |
+
motor_index = present_ids[0]
|
| 108 |
+
break
|
| 109 |
+
|
| 110 |
+
if motor_index == -1:
|
| 111 |
+
raise ValueError("No motors detected. Please ensure you have one motor connected.")
|
| 112 |
+
|
| 113 |
+
print(f"Motor index found at: {motor_index}")
|
| 114 |
+
|
| 115 |
+
if brand == "feetech":
|
| 116 |
+
# Allows ID and BAUDRATE to be written in memory
|
| 117 |
+
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
|
| 118 |
+
|
| 119 |
+
if baudrate != baudrate_des:
|
| 120 |
+
print(f"Setting its baudrate to {baudrate_des}")
|
| 121 |
+
baudrate_idx = list(series_baudrate_table.values()).index(baudrate_des)
|
| 122 |
+
|
| 123 |
+
# The write can fail, so we allow retries
|
| 124 |
+
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx)
|
| 125 |
+
time.sleep(0.5)
|
| 126 |
+
motor_bus.set_bus_baudrate(baudrate_des)
|
| 127 |
+
present_baudrate_idx = motor_bus.read_with_motor_ids(
|
| 128 |
+
motor_bus.motor_models, motor_index, "Baud_Rate", num_retry=2
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
if present_baudrate_idx != baudrate_idx:
|
| 132 |
+
raise OSError("Failed to write baudrate.")
|
| 133 |
+
|
| 134 |
+
print(f"Setting its index to desired index {motor_idx_des}")
|
| 135 |
+
if brand == "feetech":
|
| 136 |
+
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
|
| 137 |
+
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des)
|
| 138 |
+
|
| 139 |
+
present_idx = motor_bus.read_with_motor_ids(motor_bus.motor_models, motor_idx_des, "ID", num_retry=2)
|
| 140 |
+
if present_idx != motor_idx_des:
|
| 141 |
+
raise OSError("Failed to write index.")
|
| 142 |
+
|
| 143 |
+
if brand == "feetech":
|
| 144 |
+
# Set Maximum_Acceleration to 254 to speedup acceleration and deceleration of
|
| 145 |
+
# the motors. Note: this configuration is not in the official STS3215 Memory Table
|
| 146 |
+
motor_bus.write("Lock", 0)
|
| 147 |
+
motor_bus.write("Maximum_Acceleration", 254)
|
| 148 |
+
|
| 149 |
+
motor_bus.write("Goal_Position", 2048)
|
| 150 |
+
time.sleep(4)
|
| 151 |
+
print("Present Position", motor_bus.read("Present_Position"))
|
| 152 |
+
|
| 153 |
+
motor_bus.write("Offset", 0)
|
| 154 |
+
time.sleep(4)
|
| 155 |
+
print("Offset", motor_bus.read("Offset"))
|
| 156 |
+
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"Error occurred during motor configuration: {e}")
|
| 159 |
+
|
| 160 |
+
finally:
|
| 161 |
+
motor_bus.disconnect()
|
| 162 |
+
print("Disconnected from motor bus.")
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
if __name__ == "__main__":
|
| 166 |
+
parser = argparse.ArgumentParser()
|
| 167 |
+
parser.add_argument("--port", type=str, required=True, help="Motors bus port (e.g. dynamixel,feetech)")
|
| 168 |
+
parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)")
|
| 169 |
+
parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)")
|
| 170 |
+
parser.add_argument("--ID", type=int, required=True, help="Desired ID of the current motor (e.g. 1,2,3)")
|
| 171 |
+
parser.add_argument(
|
| 172 |
+
"--baudrate", type=int, default=1000000, help="Desired baudrate for the motor (default: 1000000)"
|
| 173 |
+
)
|
| 174 |
+
args = parser.parse_args()
|
| 175 |
+
|
| 176 |
+
configure_motor(args.port, args.brand, args.model, args.ID, args.baudrate)
|
lerobot/scripts/control_robot.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Utilities to control a robot.
|
| 16 |
+
|
| 17 |
+
Useful to record a dataset, replay a recorded episode, run the policy on your robot
|
| 18 |
+
and record an evaluation dataset, and to recalibrate your robot if needed.
|
| 19 |
+
|
| 20 |
+
Examples of usage:
|
| 21 |
+
|
| 22 |
+
- Recalibrate your robot:
|
| 23 |
+
```bash
|
| 24 |
+
python lerobot/scripts/control_robot.py \
|
| 25 |
+
--robot.type=so100 \
|
| 26 |
+
--control.type=calibrate
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
- Unlimited teleoperation at highest frequency (~200 Hz is expected), to exit with CTRL+C:
|
| 30 |
+
```bash
|
| 31 |
+
python lerobot/scripts/control_robot.py \
|
| 32 |
+
--robot.type=so100 \
|
| 33 |
+
--robot.cameras='{}' \
|
| 34 |
+
--control.type=teleoperate
|
| 35 |
+
|
| 36 |
+
# Add the cameras from the robot definition to visualize them:
|
| 37 |
+
python lerobot/scripts/control_robot.py \
|
| 38 |
+
--robot.type=so100 \
|
| 39 |
+
--control.type=teleoperate
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
- Unlimited teleoperation at a limited frequency of 30 Hz, to simulate data recording frequency:
|
| 43 |
+
```bash
|
| 44 |
+
python lerobot/scripts/control_robot.py \
|
| 45 |
+
--robot.type=so100 \
|
| 46 |
+
--control.type=teleoperate \
|
| 47 |
+
--control.fps=30
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
- Record one episode in order to test replay:
|
| 51 |
+
```bash
|
| 52 |
+
python lerobot/scripts/control_robot.py \
|
| 53 |
+
--robot.type=so100 \
|
| 54 |
+
--control.type=record \
|
| 55 |
+
--control.fps=30 \
|
| 56 |
+
--control.single_task="Grasp a lego block and put it in the bin." \
|
| 57 |
+
--control.repo_id=$USER/koch_test \
|
| 58 |
+
--control.num_episodes=1 \
|
| 59 |
+
--control.push_to_hub=True
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
- Visualize dataset:
|
| 63 |
+
```bash
|
| 64 |
+
python lerobot/scripts/visualize_dataset.py \
|
| 65 |
+
--repo-id $USER/koch_test \
|
| 66 |
+
--episode-index 0
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
- Replay this test episode:
|
| 70 |
+
```bash
|
| 71 |
+
python lerobot/scripts/control_robot.py replay \
|
| 72 |
+
--robot.type=so100 \
|
| 73 |
+
--control.type=replay \
|
| 74 |
+
--control.fps=30 \
|
| 75 |
+
--control.repo_id=$USER/koch_test \
|
| 76 |
+
--control.episode=0
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
- Record a full dataset in order to train a policy, with 2 seconds of warmup,
|
| 80 |
+
30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes:
|
| 81 |
+
```bash
|
| 82 |
+
python lerobot/scripts/control_robot.py record \
|
| 83 |
+
--robot.type=so100 \
|
| 84 |
+
--control.type=record \
|
| 85 |
+
--control.fps 30 \
|
| 86 |
+
--control.repo_id=$USER/koch_pick_place_lego \
|
| 87 |
+
--control.num_episodes=50 \
|
| 88 |
+
--control.warmup_time_s=2 \
|
| 89 |
+
--control.episode_time_s=30 \
|
| 90 |
+
--control.reset_time_s=10
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
- For remote controlled robots like LeKiwi, run this script on the robot edge device (e.g. RaspBerryPi):
|
| 94 |
+
```bash
|
| 95 |
+
python lerobot/scripts/control_robot.py \
|
| 96 |
+
--robot.type=lekiwi \
|
| 97 |
+
--control.type=remote_robot
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
**NOTE**: You can use your keyboard to control data recording flow.
|
| 101 |
+
- Tap right arrow key '->' to early exit while recording an episode and go to resseting the environment.
|
| 102 |
+
- Tap right arrow key '->' to early exit while resetting the environment and got to recording the next episode.
|
| 103 |
+
- Tap left arrow key '<-' to early exit and re-record the current episode.
|
| 104 |
+
- Tap escape key 'esc' to stop the data recording.
|
| 105 |
+
This might require a sudo permission to allow your terminal to monitor keyboard events.
|
| 106 |
+
|
| 107 |
+
**NOTE**: You can resume/continue data recording by running the same data recording command and adding `--control.resume=true`.
|
| 108 |
+
|
| 109 |
+
- Train on this dataset with the ACT policy:
|
| 110 |
+
```bash
|
| 111 |
+
python lerobot/scripts/train.py \
|
| 112 |
+
--dataset.repo_id=${HF_USER}/koch_pick_place_lego \
|
| 113 |
+
--policy.type=act \
|
| 114 |
+
--output_dir=outputs/train/act_koch_pick_place_lego \
|
| 115 |
+
--job_name=act_koch_pick_place_lego \
|
| 116 |
+
--device=cuda \
|
| 117 |
+
--wandb.enable=true
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
- Run the pretrained policy on the robot:
|
| 121 |
+
```bash
|
| 122 |
+
python lerobot/scripts/control_robot.py \
|
| 123 |
+
--robot.type=so100 \
|
| 124 |
+
--control.type=record \
|
| 125 |
+
--control.fps=30 \
|
| 126 |
+
--control.single_task="Grasp a lego block and put it in the bin." \
|
| 127 |
+
--control.repo_id=$USER/eval_act_koch_pick_place_lego \
|
| 128 |
+
--control.num_episodes=10 \
|
| 129 |
+
--control.warmup_time_s=2 \
|
| 130 |
+
--control.episode_time_s=30 \
|
| 131 |
+
--control.reset_time_s=10 \
|
| 132 |
+
--control.push_to_hub=true \
|
| 133 |
+
--control.policy.path=outputs/train/act_koch_pick_place_lego/checkpoints/080000/pretrained_model
|
| 134 |
+
```
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
import logging
|
| 138 |
+
import time
|
| 139 |
+
from dataclasses import asdict
|
| 140 |
+
from pprint import pformat
|
| 141 |
+
|
| 142 |
+
# from safetensors.torch import load_file, save_file
|
| 143 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
| 144 |
+
from lerobot.common.policies.factory import make_policy
|
| 145 |
+
from lerobot.common.robot_devices.control_configs import (
|
| 146 |
+
CalibrateControlConfig,
|
| 147 |
+
ControlPipelineConfig,
|
| 148 |
+
RecordControlConfig,
|
| 149 |
+
RemoteRobotConfig,
|
| 150 |
+
ReplayControlConfig,
|
| 151 |
+
TeleoperateControlConfig,
|
| 152 |
+
)
|
| 153 |
+
from lerobot.common.robot_devices.control_utils import (
|
| 154 |
+
control_loop,
|
| 155 |
+
init_keyboard_listener,
|
| 156 |
+
log_control_info,
|
| 157 |
+
record_episode,
|
| 158 |
+
reset_environment,
|
| 159 |
+
sanity_check_dataset_name,
|
| 160 |
+
sanity_check_dataset_robot_compatibility,
|
| 161 |
+
stop_recording,
|
| 162 |
+
warmup_record,
|
| 163 |
+
)
|
| 164 |
+
from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config
|
| 165 |
+
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
|
| 166 |
+
from lerobot.common.utils.utils import has_method, init_logging, log_say
|
| 167 |
+
from lerobot.configs import parser
|
| 168 |
+
|
| 169 |
+
########################################################################################
|
| 170 |
+
# Control modes
|
| 171 |
+
########################################################################################
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
@safe_disconnect
|
| 175 |
+
def calibrate(robot: Robot, cfg: CalibrateControlConfig):
|
| 176 |
+
# TODO(aliberts): move this code in robots' classes
|
| 177 |
+
if robot.robot_type.startswith("stretch"):
|
| 178 |
+
if not robot.is_connected:
|
| 179 |
+
robot.connect()
|
| 180 |
+
if not robot.is_homed():
|
| 181 |
+
robot.home()
|
| 182 |
+
return
|
| 183 |
+
|
| 184 |
+
arms = robot.available_arms if cfg.arms is None else cfg.arms
|
| 185 |
+
unknown_arms = [arm_id for arm_id in arms if arm_id not in robot.available_arms]
|
| 186 |
+
available_arms_str = " ".join(robot.available_arms)
|
| 187 |
+
unknown_arms_str = " ".join(unknown_arms)
|
| 188 |
+
|
| 189 |
+
if arms is None or len(arms) == 0:
|
| 190 |
+
raise ValueError(
|
| 191 |
+
"No arm provided. Use `--arms` as argument with one or more available arms.\n"
|
| 192 |
+
f"For instance, to recalibrate all arms add: `--arms {available_arms_str}`"
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
if len(unknown_arms) > 0:
|
| 196 |
+
raise ValueError(
|
| 197 |
+
f"Unknown arms provided ('{unknown_arms_str}'). Available arms are `{available_arms_str}`."
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
for arm_id in arms:
|
| 201 |
+
arm_calib_path = robot.calibration_dir / f"{arm_id}.json"
|
| 202 |
+
if arm_calib_path.exists():
|
| 203 |
+
print(f"Removing '{arm_calib_path}'")
|
| 204 |
+
arm_calib_path.unlink()
|
| 205 |
+
else:
|
| 206 |
+
print(f"Calibration file not found '{arm_calib_path}'")
|
| 207 |
+
|
| 208 |
+
if robot.is_connected:
|
| 209 |
+
robot.disconnect()
|
| 210 |
+
|
| 211 |
+
if robot.robot_type.startswith("lekiwi") and "main_follower" in arms:
|
| 212 |
+
print("Calibrating only the lekiwi follower arm 'main_follower'...")
|
| 213 |
+
robot.calibrate_follower()
|
| 214 |
+
return
|
| 215 |
+
|
| 216 |
+
if robot.robot_type.startswith("lekiwi") and "main_leader" in arms:
|
| 217 |
+
print("Calibrating only the lekiwi leader arm 'main_leader'...")
|
| 218 |
+
robot.calibrate_leader()
|
| 219 |
+
return
|
| 220 |
+
|
| 221 |
+
# Calling `connect` automatically runs calibration
|
| 222 |
+
# when the calibration file is missing
|
| 223 |
+
robot.connect()
|
| 224 |
+
robot.disconnect()
|
| 225 |
+
print("Calibration is done! You can now teleoperate and record datasets!")
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
@safe_disconnect
|
| 229 |
+
def teleoperate(robot: Robot, cfg: TeleoperateControlConfig):
|
| 230 |
+
control_loop(
|
| 231 |
+
robot,
|
| 232 |
+
control_time_s=cfg.teleop_time_s,
|
| 233 |
+
fps=cfg.fps,
|
| 234 |
+
teleoperate=True,
|
| 235 |
+
display_cameras=cfg.display_cameras,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
@safe_disconnect
|
| 240 |
+
def record(
|
| 241 |
+
robot: Robot,
|
| 242 |
+
cfg: RecordControlConfig,
|
| 243 |
+
) -> LeRobotDataset:
|
| 244 |
+
# TODO(rcadene): Add option to record logs
|
| 245 |
+
if cfg.resume:
|
| 246 |
+
dataset = LeRobotDataset(
|
| 247 |
+
cfg.repo_id,
|
| 248 |
+
root=cfg.root,
|
| 249 |
+
)
|
| 250 |
+
if len(robot.cameras) > 0:
|
| 251 |
+
dataset.start_image_writer(
|
| 252 |
+
num_processes=cfg.num_image_writer_processes,
|
| 253 |
+
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
| 254 |
+
)
|
| 255 |
+
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video)
|
| 256 |
+
else:
|
| 257 |
+
# Create empty dataset or load existing saved episodes
|
| 258 |
+
sanity_check_dataset_name(cfg.repo_id, cfg.policy)
|
| 259 |
+
dataset = LeRobotDataset.create(
|
| 260 |
+
cfg.repo_id,
|
| 261 |
+
cfg.fps,
|
| 262 |
+
root=cfg.root,
|
| 263 |
+
robot=robot,
|
| 264 |
+
use_videos=cfg.video,
|
| 265 |
+
image_writer_processes=cfg.num_image_writer_processes,
|
| 266 |
+
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# Load pretrained policy
|
| 270 |
+
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
| 271 |
+
|
| 272 |
+
if not robot.is_connected:
|
| 273 |
+
robot.connect()
|
| 274 |
+
|
| 275 |
+
listener, events = init_keyboard_listener()
|
| 276 |
+
|
| 277 |
+
# Execute a few seconds without recording to:
|
| 278 |
+
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
| 279 |
+
# 2. give times to the robot devices to connect and start synchronizing,
|
| 280 |
+
# 3. place the cameras windows on screen
|
| 281 |
+
enable_teleoperation = policy is None
|
| 282 |
+
log_say("Warmup record", cfg.play_sounds)
|
| 283 |
+
warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
|
| 284 |
+
|
| 285 |
+
if has_method(robot, "teleop_safety_stop"):
|
| 286 |
+
robot.teleop_safety_stop()
|
| 287 |
+
|
| 288 |
+
recorded_episodes = 0
|
| 289 |
+
while True:
|
| 290 |
+
if recorded_episodes >= cfg.num_episodes:
|
| 291 |
+
break
|
| 292 |
+
|
| 293 |
+
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
|
| 294 |
+
record_episode(
|
| 295 |
+
robot=robot,
|
| 296 |
+
dataset=dataset,
|
| 297 |
+
events=events,
|
| 298 |
+
episode_time_s=cfg.episode_time_s,
|
| 299 |
+
display_cameras=cfg.display_cameras,
|
| 300 |
+
policy=policy,
|
| 301 |
+
fps=cfg.fps,
|
| 302 |
+
single_task=cfg.single_task,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
# Execute a few seconds without recording to give time to manually reset the environment
|
| 306 |
+
# Current code logic doesn't allow to teleoperate during this time.
|
| 307 |
+
# TODO(rcadene): add an option to enable teleoperation during reset
|
| 308 |
+
# Skip reset for the last episode to be recorded
|
| 309 |
+
if not events["stop_recording"] and (
|
| 310 |
+
(recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"]
|
| 311 |
+
):
|
| 312 |
+
log_say("Reset the environment", cfg.play_sounds)
|
| 313 |
+
reset_environment(robot, events, cfg.reset_time_s, cfg.fps)
|
| 314 |
+
|
| 315 |
+
if events["rerecord_episode"]:
|
| 316 |
+
log_say("Re-record episode", cfg.play_sounds)
|
| 317 |
+
events["rerecord_episode"] = False
|
| 318 |
+
events["exit_early"] = False
|
| 319 |
+
dataset.clear_episode_buffer()
|
| 320 |
+
continue
|
| 321 |
+
|
| 322 |
+
dataset.save_episode()
|
| 323 |
+
recorded_episodes += 1
|
| 324 |
+
|
| 325 |
+
if events["stop_recording"]:
|
| 326 |
+
break
|
| 327 |
+
|
| 328 |
+
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
| 329 |
+
stop_recording(robot, listener, cfg.display_cameras)
|
| 330 |
+
|
| 331 |
+
if cfg.push_to_hub:
|
| 332 |
+
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
|
| 333 |
+
|
| 334 |
+
log_say("Exiting", cfg.play_sounds)
|
| 335 |
+
return dataset
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
@safe_disconnect
|
| 339 |
+
def replay(
|
| 340 |
+
robot: Robot,
|
| 341 |
+
cfg: ReplayControlConfig,
|
| 342 |
+
):
|
| 343 |
+
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
| 344 |
+
# TODO(rcadene): Add option to record logs
|
| 345 |
+
|
| 346 |
+
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
|
| 347 |
+
actions = dataset.hf_dataset.select_columns("action")
|
| 348 |
+
|
| 349 |
+
if not robot.is_connected:
|
| 350 |
+
robot.connect()
|
| 351 |
+
|
| 352 |
+
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
| 353 |
+
for idx in range(dataset.num_frames):
|
| 354 |
+
start_episode_t = time.perf_counter()
|
| 355 |
+
|
| 356 |
+
action = actions[idx]["action"]
|
| 357 |
+
robot.send_action(action)
|
| 358 |
+
|
| 359 |
+
dt_s = time.perf_counter() - start_episode_t
|
| 360 |
+
busy_wait(1 / cfg.fps - dt_s)
|
| 361 |
+
|
| 362 |
+
dt_s = time.perf_counter() - start_episode_t
|
| 363 |
+
log_control_info(robot, dt_s, fps=cfg.fps)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
@parser.wrap()
|
| 367 |
+
def control_robot(cfg: ControlPipelineConfig):
|
| 368 |
+
init_logging()
|
| 369 |
+
logging.info(pformat(asdict(cfg)))
|
| 370 |
+
|
| 371 |
+
robot = make_robot_from_config(cfg.robot)
|
| 372 |
+
|
| 373 |
+
if isinstance(cfg.control, CalibrateControlConfig):
|
| 374 |
+
calibrate(robot, cfg.control)
|
| 375 |
+
elif isinstance(cfg.control, TeleoperateControlConfig):
|
| 376 |
+
teleoperate(robot, cfg.control)
|
| 377 |
+
elif isinstance(cfg.control, RecordControlConfig):
|
| 378 |
+
record(robot, cfg.control)
|
| 379 |
+
elif isinstance(cfg.control, ReplayControlConfig):
|
| 380 |
+
replay(robot, cfg.control)
|
| 381 |
+
elif isinstance(cfg.control, RemoteRobotConfig):
|
| 382 |
+
from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi
|
| 383 |
+
|
| 384 |
+
run_lekiwi(cfg.robot)
|
| 385 |
+
|
| 386 |
+
if robot.is_connected:
|
| 387 |
+
# Disconnect manually to avoid a "Core dump" during process
|
| 388 |
+
# termination due to camera threads not properly exiting.
|
| 389 |
+
robot.disconnect()
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
if __name__ == "__main__":
|
| 393 |
+
control_robot()
|
lerobot/scripts/control_sim_robot.py
ADDED
|
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Utilities to control a robot in simulation.
|
| 16 |
+
|
| 17 |
+
Useful to record a dataset, replay a recorded episode and record an evaluation dataset.
|
| 18 |
+
|
| 19 |
+
Examples of usage:
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
- Unlimited teleoperation at a limited frequency of 30 Hz, to simulate data recording frequency.
|
| 23 |
+
You can modify this value depending on how fast your simulation can run:
|
| 24 |
+
```bash
|
| 25 |
+
python lerobot/scripts/control_robot.py teleoperate \
|
| 26 |
+
--fps 30 \
|
| 27 |
+
--robot-path lerobot/configs/robot/your_robot_config.yaml \
|
| 28 |
+
--sim-config lerobot/configs/env/your_sim_config.yaml
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
- Record one episode in order to test replay:
|
| 32 |
+
```bash
|
| 33 |
+
python lerobot/scripts/control_sim_robot.py record \
|
| 34 |
+
--robot-path lerobot/configs/robot/your_robot_config.yaml \
|
| 35 |
+
--sim-config lerobot/configs/env/your_sim_config.yaml \
|
| 36 |
+
--fps 30 \
|
| 37 |
+
--repo-id $USER/robot_sim_test \
|
| 38 |
+
--num-episodes 1 \
|
| 39 |
+
--run-compute-stats 0
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
Enable the --push-to-hub 1 to push the recorded dataset to the huggingface hub.
|
| 43 |
+
|
| 44 |
+
- Visualize dataset:
|
| 45 |
+
```bash
|
| 46 |
+
python lerobot/scripts/visualize_dataset.py \
|
| 47 |
+
--repo-id $USER/robot_sim_test \
|
| 48 |
+
--episode-index 0
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
- Replay a sequence of test episodes:
|
| 52 |
+
```bash
|
| 53 |
+
python lerobot/scripts/control_sim_robot.py replay \
|
| 54 |
+
--robot-path lerobot/configs/robot/your_robot_config.yaml \
|
| 55 |
+
--sim-config lerobot/configs/env/your_sim_config.yaml \
|
| 56 |
+
--fps 30 \
|
| 57 |
+
--repo-id $USER/robot_sim_test \
|
| 58 |
+
--episode 0
|
| 59 |
+
```
|
| 60 |
+
Note: The seed is saved, therefore, during replay we can load the same environment state as the one during collection.
|
| 61 |
+
|
| 62 |
+
- Record a full dataset in order to train a policy,
|
| 63 |
+
30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes:
|
| 64 |
+
```bash
|
| 65 |
+
python lerobot/scripts/control_sim_robot.py record \
|
| 66 |
+
--robot-path lerobot/configs/robot/your_robot_config.yaml \
|
| 67 |
+
--sim-config lerobot/configs/env/your_sim_config.yaml \
|
| 68 |
+
--fps 30 \
|
| 69 |
+
--repo-id $USER/robot_sim_test \
|
| 70 |
+
--num-episodes 50 \
|
| 71 |
+
--episode-time-s 30 \
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
**NOTE**: You can use your keyboard to control data recording flow.
|
| 75 |
+
- Tap right arrow key '->' to early exit while recording an episode and go to resetting the environment.
|
| 76 |
+
- Tap right arrow key '->' to early exit while resetting the environment and got to recording the next episode.
|
| 77 |
+
- Tap left arrow key '<-' to early exit and re-record the current episode.
|
| 78 |
+
- Tap escape key 'esc' to stop the data recording.
|
| 79 |
+
This might require a sudo permission to allow your terminal to monitor keyboard events.
|
| 80 |
+
|
| 81 |
+
**NOTE**: You can resume/continue data recording by running the same data recording command twice.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
import argparse
|
| 85 |
+
import importlib
|
| 86 |
+
import logging
|
| 87 |
+
import time
|
| 88 |
+
from pathlib import Path
|
| 89 |
+
|
| 90 |
+
import cv2
|
| 91 |
+
import gymnasium as gym
|
| 92 |
+
import numpy as np
|
| 93 |
+
import torch
|
| 94 |
+
|
| 95 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
| 96 |
+
from lerobot.common.robot_devices.control_utils import (
|
| 97 |
+
init_keyboard_listener,
|
| 98 |
+
init_policy,
|
| 99 |
+
is_headless,
|
| 100 |
+
log_control_info,
|
| 101 |
+
predict_action,
|
| 102 |
+
sanity_check_dataset_name,
|
| 103 |
+
sanity_check_dataset_robot_compatibility,
|
| 104 |
+
stop_recording,
|
| 105 |
+
)
|
| 106 |
+
from lerobot.common.robot_devices.robots.utils import Robot, make_robot
|
| 107 |
+
from lerobot.common.robot_devices.utils import busy_wait
|
| 108 |
+
from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say
|
| 109 |
+
|
| 110 |
+
raise NotImplementedError("This script is currently deactivated")
|
| 111 |
+
|
| 112 |
+
DEFAULT_FEATURES = {
|
| 113 |
+
"next.reward": {
|
| 114 |
+
"dtype": "float32",
|
| 115 |
+
"shape": (1,),
|
| 116 |
+
"names": None,
|
| 117 |
+
},
|
| 118 |
+
"next.success": {
|
| 119 |
+
"dtype": "bool",
|
| 120 |
+
"shape": (1,),
|
| 121 |
+
"names": None,
|
| 122 |
+
},
|
| 123 |
+
"seed": {
|
| 124 |
+
"dtype": "int64",
|
| 125 |
+
"shape": (1,),
|
| 126 |
+
"names": None,
|
| 127 |
+
},
|
| 128 |
+
"timestamp": {
|
| 129 |
+
"dtype": "float32",
|
| 130 |
+
"shape": (1,),
|
| 131 |
+
"names": None,
|
| 132 |
+
},
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
########################################################################################
|
| 137 |
+
# Utilities
|
| 138 |
+
########################################################################################
|
| 139 |
+
def none_or_int(value):
|
| 140 |
+
if value == "None":
|
| 141 |
+
return None
|
| 142 |
+
return int(value)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def init_sim_calibration(robot, cfg):
|
| 146 |
+
# Constants necessary for transforming the joint pos of the real robot to the sim
|
| 147 |
+
# depending on the robot description used in that sim.
|
| 148 |
+
start_pos = np.array(robot.leader_arms.main.calibration["start_pos"])
|
| 149 |
+
axis_directions = np.array(cfg.get("axis_directions", [1]))
|
| 150 |
+
offsets = np.array(cfg.get("offsets", [0])) * np.pi
|
| 151 |
+
|
| 152 |
+
return {"start_pos": start_pos, "axis_directions": axis_directions, "offsets": offsets}
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets):
|
| 156 |
+
"""Counts - starting position -> radians -> align axes -> offset"""
|
| 157 |
+
return axis_directions * (real_positions - start_pos) * 2.0 * np.pi / 4096 + offsets
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
########################################################################################
|
| 161 |
+
# Control modes
|
| 162 |
+
########################################################################################
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None):
|
| 166 |
+
env = env()
|
| 167 |
+
env.reset()
|
| 168 |
+
start_teleop_t = time.perf_counter()
|
| 169 |
+
while True:
|
| 170 |
+
leader_pos = robot.leader_arms.main.read("Present_Position")
|
| 171 |
+
action = process_action_fn(leader_pos)
|
| 172 |
+
env.step(np.expand_dims(action, 0))
|
| 173 |
+
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
|
| 174 |
+
print("Teleoperation processes finished.")
|
| 175 |
+
break
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def record(
|
| 179 |
+
env,
|
| 180 |
+
robot: Robot,
|
| 181 |
+
process_action_from_leader,
|
| 182 |
+
root: Path,
|
| 183 |
+
repo_id: str,
|
| 184 |
+
task: str,
|
| 185 |
+
fps: int | None = None,
|
| 186 |
+
tags: list[str] | None = None,
|
| 187 |
+
pretrained_policy_name_or_path: str = None,
|
| 188 |
+
policy_overrides: bool | None = None,
|
| 189 |
+
episode_time_s: int = 30,
|
| 190 |
+
num_episodes: int = 50,
|
| 191 |
+
video: bool = True,
|
| 192 |
+
push_to_hub: bool = True,
|
| 193 |
+
num_image_writer_processes: int = 0,
|
| 194 |
+
num_image_writer_threads_per_camera: int = 4,
|
| 195 |
+
display_cameras: bool = False,
|
| 196 |
+
play_sounds: bool = True,
|
| 197 |
+
resume: bool = False,
|
| 198 |
+
local_files_only: bool = False,
|
| 199 |
+
run_compute_stats: bool = True,
|
| 200 |
+
) -> LeRobotDataset:
|
| 201 |
+
# Load pretrained policy
|
| 202 |
+
policy = None
|
| 203 |
+
if pretrained_policy_name_or_path is not None:
|
| 204 |
+
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
| 205 |
+
|
| 206 |
+
if fps is None:
|
| 207 |
+
fps = policy_fps
|
| 208 |
+
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
|
| 209 |
+
|
| 210 |
+
if policy is None and process_action_from_leader is None:
|
| 211 |
+
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
|
| 212 |
+
|
| 213 |
+
# initialize listener before sim env
|
| 214 |
+
listener, events = init_keyboard_listener()
|
| 215 |
+
|
| 216 |
+
# create sim env
|
| 217 |
+
env = env()
|
| 218 |
+
|
| 219 |
+
# Create empty dataset or load existing saved episodes
|
| 220 |
+
num_cameras = sum([1 if "image" in key else 0 for key in env.observation_space])
|
| 221 |
+
|
| 222 |
+
# get image keys
|
| 223 |
+
image_keys = [key for key in env.observation_space if "image" in key]
|
| 224 |
+
state_keys_dict = env_cfg.state_keys
|
| 225 |
+
|
| 226 |
+
if resume:
|
| 227 |
+
dataset = LeRobotDataset(
|
| 228 |
+
repo_id,
|
| 229 |
+
root=root,
|
| 230 |
+
local_files_only=local_files_only,
|
| 231 |
+
)
|
| 232 |
+
dataset.start_image_writer(
|
| 233 |
+
num_processes=num_image_writer_processes,
|
| 234 |
+
num_threads=num_image_writer_threads_per_camera * num_cameras,
|
| 235 |
+
)
|
| 236 |
+
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
|
| 237 |
+
else:
|
| 238 |
+
features = DEFAULT_FEATURES
|
| 239 |
+
# add image keys to features
|
| 240 |
+
for key in image_keys:
|
| 241 |
+
shape = env.observation_space[key].shape
|
| 242 |
+
if not key.startswith("observation.image."):
|
| 243 |
+
key = "observation.image." + key
|
| 244 |
+
features[key] = {"dtype": "video", "names": ["channels", "height", "width"], "shape": shape}
|
| 245 |
+
|
| 246 |
+
for key, obs_key in state_keys_dict.items():
|
| 247 |
+
features[key] = {
|
| 248 |
+
"dtype": "float32",
|
| 249 |
+
"names": None,
|
| 250 |
+
"shape": env.observation_space[obs_key].shape,
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
|
| 254 |
+
|
| 255 |
+
# Create empty dataset or load existing saved episodes
|
| 256 |
+
sanity_check_dataset_name(repo_id, policy)
|
| 257 |
+
dataset = LeRobotDataset.create(
|
| 258 |
+
repo_id,
|
| 259 |
+
fps,
|
| 260 |
+
root=root,
|
| 261 |
+
features=features,
|
| 262 |
+
use_videos=video,
|
| 263 |
+
image_writer_processes=num_image_writer_processes,
|
| 264 |
+
image_writer_threads=num_image_writer_threads_per_camera * num_cameras,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
recorded_episodes = 0
|
| 268 |
+
while True:
|
| 269 |
+
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
|
| 270 |
+
|
| 271 |
+
if events is None:
|
| 272 |
+
events = {"exit_early": False}
|
| 273 |
+
|
| 274 |
+
if episode_time_s is None:
|
| 275 |
+
episode_time_s = float("inf")
|
| 276 |
+
|
| 277 |
+
timestamp = 0
|
| 278 |
+
start_episode_t = time.perf_counter()
|
| 279 |
+
|
| 280 |
+
seed = np.random.randint(0, 1e5)
|
| 281 |
+
observation, info = env.reset(seed=seed)
|
| 282 |
+
|
| 283 |
+
while timestamp < episode_time_s:
|
| 284 |
+
start_loop_t = time.perf_counter()
|
| 285 |
+
|
| 286 |
+
if policy is not None:
|
| 287 |
+
action = predict_action(observation, policy, device, use_amp)
|
| 288 |
+
else:
|
| 289 |
+
leader_pos = robot.leader_arms.main.read("Present_Position")
|
| 290 |
+
action = process_action_from_leader(leader_pos)
|
| 291 |
+
|
| 292 |
+
observation, reward, terminated, _, info = env.step(action)
|
| 293 |
+
|
| 294 |
+
success = info.get("is_success", False)
|
| 295 |
+
env_timestamp = info.get("timestamp", dataset.episode_buffer["size"] / fps)
|
| 296 |
+
|
| 297 |
+
frame = {
|
| 298 |
+
"action": torch.from_numpy(action),
|
| 299 |
+
"next.reward": reward,
|
| 300 |
+
"next.success": success,
|
| 301 |
+
"seed": seed,
|
| 302 |
+
"timestamp": env_timestamp,
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
for key in image_keys:
|
| 306 |
+
if not key.startswith("observation.image"):
|
| 307 |
+
frame["observation.image." + key] = observation[key]
|
| 308 |
+
else:
|
| 309 |
+
frame[key] = observation[key]
|
| 310 |
+
|
| 311 |
+
for key, obs_key in state_keys_dict.items():
|
| 312 |
+
frame[key] = torch.from_numpy(observation[obs_key])
|
| 313 |
+
|
| 314 |
+
dataset.add_frame(frame)
|
| 315 |
+
|
| 316 |
+
if display_cameras and not is_headless():
|
| 317 |
+
for key in image_keys:
|
| 318 |
+
cv2.imshow(key, cv2.cvtColor(observation[key], cv2.COLOR_RGB2BGR))
|
| 319 |
+
cv2.waitKey(1)
|
| 320 |
+
|
| 321 |
+
if fps is not None:
|
| 322 |
+
dt_s = time.perf_counter() - start_loop_t
|
| 323 |
+
busy_wait(1 / fps - dt_s)
|
| 324 |
+
|
| 325 |
+
dt_s = time.perf_counter() - start_loop_t
|
| 326 |
+
log_control_info(robot, dt_s, fps=fps)
|
| 327 |
+
|
| 328 |
+
timestamp = time.perf_counter() - start_episode_t
|
| 329 |
+
if events["exit_early"] or terminated:
|
| 330 |
+
events["exit_early"] = False
|
| 331 |
+
break
|
| 332 |
+
|
| 333 |
+
if events["rerecord_episode"]:
|
| 334 |
+
log_say("Re-record episode", play_sounds)
|
| 335 |
+
events["rerecord_episode"] = False
|
| 336 |
+
events["exit_early"] = False
|
| 337 |
+
dataset.clear_episode_buffer()
|
| 338 |
+
continue
|
| 339 |
+
|
| 340 |
+
dataset.save_episode(task=task)
|
| 341 |
+
recorded_episodes += 1
|
| 342 |
+
|
| 343 |
+
if events["stop_recording"] or recorded_episodes >= num_episodes:
|
| 344 |
+
break
|
| 345 |
+
else:
|
| 346 |
+
logging.info("Waiting for a few seconds before starting next episode recording...")
|
| 347 |
+
busy_wait(3)
|
| 348 |
+
|
| 349 |
+
log_say("Stop recording", play_sounds, blocking=True)
|
| 350 |
+
stop_recording(robot, listener, display_cameras)
|
| 351 |
+
|
| 352 |
+
if run_compute_stats:
|
| 353 |
+
logging.info("Computing dataset statistics")
|
| 354 |
+
dataset.consolidate(run_compute_stats)
|
| 355 |
+
|
| 356 |
+
if push_to_hub:
|
| 357 |
+
dataset.push_to_hub(tags=tags)
|
| 358 |
+
|
| 359 |
+
log_say("Exiting", play_sounds)
|
| 360 |
+
return dataset
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def replay(
|
| 364 |
+
env, root: Path, repo_id: str, episode: int, fps: int | None = None, local_files_only: bool = True
|
| 365 |
+
):
|
| 366 |
+
env = env()
|
| 367 |
+
|
| 368 |
+
local_dir = Path(root) / repo_id
|
| 369 |
+
if not local_dir.exists():
|
| 370 |
+
raise ValueError(local_dir)
|
| 371 |
+
|
| 372 |
+
dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
|
| 373 |
+
items = dataset.hf_dataset.select_columns("action")
|
| 374 |
+
seeds = dataset.hf_dataset.select_columns("seed")["seed"]
|
| 375 |
+
|
| 376 |
+
from_idx = dataset.episode_data_index["from"][episode].item()
|
| 377 |
+
to_idx = dataset.episode_data_index["to"][episode].item()
|
| 378 |
+
env.reset(seed=seeds[from_idx].item())
|
| 379 |
+
logging.info("Replaying episode")
|
| 380 |
+
log_say("Replaying episode", play_sounds=True)
|
| 381 |
+
for idx in range(from_idx, to_idx):
|
| 382 |
+
start_episode_t = time.perf_counter()
|
| 383 |
+
action = items[idx]["action"]
|
| 384 |
+
env.step(action.unsqueeze(0).numpy())
|
| 385 |
+
dt_s = time.perf_counter() - start_episode_t
|
| 386 |
+
busy_wait(1 / fps - dt_s)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
if __name__ == "__main__":
|
| 390 |
+
parser = argparse.ArgumentParser()
|
| 391 |
+
subparsers = parser.add_subparsers(dest="mode", required=True)
|
| 392 |
+
|
| 393 |
+
# Set common options for all the subparsers
|
| 394 |
+
base_parser = argparse.ArgumentParser(add_help=False)
|
| 395 |
+
base_parser.add_argument(
|
| 396 |
+
"--robot-path",
|
| 397 |
+
type=str,
|
| 398 |
+
default="lerobot/configs/robot/koch.yaml",
|
| 399 |
+
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
base_parser.add_argument(
|
| 403 |
+
"--sim-config",
|
| 404 |
+
help="Path to a yaml config you want to use for initializing a sim environment based on gym ",
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
parser_record = subparsers.add_parser("teleoperate", parents=[base_parser])
|
| 408 |
+
|
| 409 |
+
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
| 410 |
+
parser_record.add_argument(
|
| 411 |
+
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
| 412 |
+
)
|
| 413 |
+
parser_record.add_argument(
|
| 414 |
+
"--root",
|
| 415 |
+
type=Path,
|
| 416 |
+
default=None,
|
| 417 |
+
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
|
| 418 |
+
)
|
| 419 |
+
parser_record.add_argument(
|
| 420 |
+
"--repo-id",
|
| 421 |
+
type=str,
|
| 422 |
+
default="lerobot/test",
|
| 423 |
+
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
| 424 |
+
)
|
| 425 |
+
parser_record.add_argument(
|
| 426 |
+
"--episode-time-s",
|
| 427 |
+
type=int,
|
| 428 |
+
default=60,
|
| 429 |
+
help="Number of seconds for data recording for each episode.",
|
| 430 |
+
)
|
| 431 |
+
parser_record.add_argument(
|
| 432 |
+
"--task",
|
| 433 |
+
type=str,
|
| 434 |
+
required=True,
|
| 435 |
+
help="A description of the task preformed during recording that can be used as a language instruction.",
|
| 436 |
+
)
|
| 437 |
+
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
|
| 438 |
+
parser_record.add_argument(
|
| 439 |
+
"--run-compute-stats",
|
| 440 |
+
type=int,
|
| 441 |
+
default=1,
|
| 442 |
+
help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
|
| 443 |
+
)
|
| 444 |
+
parser_record.add_argument(
|
| 445 |
+
"--push-to-hub",
|
| 446 |
+
type=int,
|
| 447 |
+
default=1,
|
| 448 |
+
help="Upload dataset to Hugging Face hub.",
|
| 449 |
+
)
|
| 450 |
+
parser_record.add_argument(
|
| 451 |
+
"--tags",
|
| 452 |
+
type=str,
|
| 453 |
+
nargs="*",
|
| 454 |
+
help="Add tags to your dataset on the hub.",
|
| 455 |
+
)
|
| 456 |
+
parser_record.add_argument(
|
| 457 |
+
"--num-image-writer-processes",
|
| 458 |
+
type=int,
|
| 459 |
+
default=0,
|
| 460 |
+
help=(
|
| 461 |
+
"Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only; "
|
| 462 |
+
"set to ≥1 to use subprocesses, each using threads to write images. The best number of processes "
|
| 463 |
+
"and threads depends on your system. We recommend 4 threads per camera with 0 processes. "
|
| 464 |
+
"If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses."
|
| 465 |
+
),
|
| 466 |
+
)
|
| 467 |
+
parser_record.add_argument(
|
| 468 |
+
"--num-image-writer-threads-per-camera",
|
| 469 |
+
type=int,
|
| 470 |
+
default=4,
|
| 471 |
+
help=(
|
| 472 |
+
"Number of threads writing the frames as png images on disk, per camera. "
|
| 473 |
+
"Too much threads might cause unstable teleoperation fps due to main thread being blocked. "
|
| 474 |
+
"Not enough threads might cause low camera fps."
|
| 475 |
+
),
|
| 476 |
+
)
|
| 477 |
+
parser_record.add_argument(
|
| 478 |
+
"--display-cameras",
|
| 479 |
+
type=int,
|
| 480 |
+
default=0,
|
| 481 |
+
help="Visualize image observations with opencv.",
|
| 482 |
+
)
|
| 483 |
+
parser_record.add_argument(
|
| 484 |
+
"--resume",
|
| 485 |
+
type=int,
|
| 486 |
+
default=0,
|
| 487 |
+
help="Resume recording on an existing dataset.",
|
| 488 |
+
)
|
| 489 |
+
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
| 490 |
+
parser_replay.add_argument(
|
| 491 |
+
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
| 492 |
+
)
|
| 493 |
+
parser_replay.add_argument(
|
| 494 |
+
"--root",
|
| 495 |
+
type=Path,
|
| 496 |
+
default=None,
|
| 497 |
+
help="Root directory where the dataset will be stored locally (e.g. 'data/hf_username/dataset_name'). By default, stored in cache folder.",
|
| 498 |
+
)
|
| 499 |
+
parser_replay.add_argument(
|
| 500 |
+
"--repo-id",
|
| 501 |
+
type=str,
|
| 502 |
+
default="lerobot/test",
|
| 503 |
+
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
| 504 |
+
)
|
| 505 |
+
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episodes to replay.")
|
| 506 |
+
|
| 507 |
+
args = parser.parse_args()
|
| 508 |
+
|
| 509 |
+
init_logging()
|
| 510 |
+
|
| 511 |
+
control_mode = args.mode
|
| 512 |
+
robot_path = args.robot_path
|
| 513 |
+
env_config_path = args.sim_config
|
| 514 |
+
kwargs = vars(args)
|
| 515 |
+
del kwargs["mode"]
|
| 516 |
+
del kwargs["robot_path"]
|
| 517 |
+
del kwargs["sim_config"]
|
| 518 |
+
|
| 519 |
+
# make gym env
|
| 520 |
+
env_cfg = init_hydra_config(env_config_path)
|
| 521 |
+
importlib.import_module(f"gym_{env_cfg.env.type}")
|
| 522 |
+
|
| 523 |
+
def env_constructor():
|
| 524 |
+
return gym.make(env_cfg.env.handle, disable_env_checker=True, **env_cfg.env.gym)
|
| 525 |
+
|
| 526 |
+
robot = None
|
| 527 |
+
process_leader_actions_fn = None
|
| 528 |
+
|
| 529 |
+
if control_mode in ["teleoperate", "record"]:
|
| 530 |
+
# make robot
|
| 531 |
+
robot_overrides = ["~cameras", "~follower_arms"]
|
| 532 |
+
# TODO(rcadene): remove
|
| 533 |
+
robot_cfg = init_hydra_config(robot_path, robot_overrides)
|
| 534 |
+
robot = make_robot(robot_cfg)
|
| 535 |
+
robot.connect()
|
| 536 |
+
|
| 537 |
+
calib_kwgs = init_sim_calibration(robot, env_cfg.calibration)
|
| 538 |
+
|
| 539 |
+
def process_leader_actions_fn(action):
|
| 540 |
+
return real_positions_to_sim(action, **calib_kwgs)
|
| 541 |
+
|
| 542 |
+
robot.leader_arms.main.calibration = None
|
| 543 |
+
|
| 544 |
+
if control_mode == "teleoperate":
|
| 545 |
+
teleoperate(env_constructor, robot, process_leader_actions_fn)
|
| 546 |
+
|
| 547 |
+
elif control_mode == "record":
|
| 548 |
+
record(env_constructor, robot, process_leader_actions_fn, **kwargs)
|
| 549 |
+
|
| 550 |
+
elif control_mode == "replay":
|
| 551 |
+
replay(env_constructor, **kwargs)
|
| 552 |
+
|
| 553 |
+
else:
|
| 554 |
+
raise ValueError(
|
| 555 |
+
f"Invalid control mode: '{control_mode}', only valid modes are teleoperate, record and replay."
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
if robot and robot.is_connected:
|
| 559 |
+
# Disconnect manually to avoid a "Core dump" during process
|
| 560 |
+
# termination due to camera threads not properly exiting.
|
| 561 |
+
robot.disconnect()
|
lerobot/scripts/display_sys_info.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
"""Use this script to get a quick summary of your system config.
|
| 18 |
+
It should be able to run without any of LeRobot's dependencies or LeRobot itself installed.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import platform
|
| 22 |
+
|
| 23 |
+
HAS_HF_HUB = True
|
| 24 |
+
HAS_HF_DATASETS = True
|
| 25 |
+
HAS_NP = True
|
| 26 |
+
HAS_TORCH = True
|
| 27 |
+
HAS_LEROBOT = True
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
import huggingface_hub
|
| 31 |
+
except ImportError:
|
| 32 |
+
HAS_HF_HUB = False
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
import datasets
|
| 36 |
+
except ImportError:
|
| 37 |
+
HAS_HF_DATASETS = False
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
import numpy as np
|
| 41 |
+
except ImportError:
|
| 42 |
+
HAS_NP = False
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
import torch
|
| 46 |
+
except ImportError:
|
| 47 |
+
HAS_TORCH = False
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
import lerobot
|
| 51 |
+
except ImportError:
|
| 52 |
+
HAS_LEROBOT = False
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
lerobot_version = lerobot.__version__ if HAS_LEROBOT else "N/A"
|
| 56 |
+
hf_hub_version = huggingface_hub.__version__ if HAS_HF_HUB else "N/A"
|
| 57 |
+
hf_datasets_version = datasets.__version__ if HAS_HF_DATASETS else "N/A"
|
| 58 |
+
np_version = np.__version__ if HAS_NP else "N/A"
|
| 59 |
+
|
| 60 |
+
torch_version = torch.__version__ if HAS_TORCH else "N/A"
|
| 61 |
+
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
|
| 62 |
+
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# TODO(aliberts): refactor into an actual command `lerobot env`
|
| 66 |
+
def display_sys_info() -> dict:
|
| 67 |
+
"""Run this to get basic system info to help for tracking issues & bugs."""
|
| 68 |
+
info = {
|
| 69 |
+
"`lerobot` version": lerobot_version,
|
| 70 |
+
"Platform": platform.platform(),
|
| 71 |
+
"Python version": platform.python_version(),
|
| 72 |
+
"Huggingface_hub version": hf_hub_version,
|
| 73 |
+
"Dataset version": hf_datasets_version,
|
| 74 |
+
"Numpy version": np_version,
|
| 75 |
+
"PyTorch version (GPU?)": f"{torch_version} ({torch_cuda_available})",
|
| 76 |
+
"Cuda version": cuda_version,
|
| 77 |
+
"Using GPU in script?": "<fill in>",
|
| 78 |
+
# "Using distributed or parallel set-up in script?": "<fill in>",
|
| 79 |
+
}
|
| 80 |
+
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
|
| 81 |
+
print(format_dict(info))
|
| 82 |
+
return info
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def format_dict(d: dict) -> str:
|
| 86 |
+
return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
if __name__ == "__main__":
|
| 90 |
+
display_sys_info()
|
lerobot/scripts/eval.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""Evaluate a policy on an environment by running rollouts and computing metrics.
|
| 17 |
+
|
| 18 |
+
Usage examples:
|
| 19 |
+
|
| 20 |
+
You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/diffusion_pusht)
|
| 21 |
+
for 10 episodes.
|
| 22 |
+
|
| 23 |
+
```
|
| 24 |
+
python lerobot/scripts/eval.py \
|
| 25 |
+
--policy.path=lerobot/diffusion_pusht \
|
| 26 |
+
--env.type=pusht \
|
| 27 |
+
--eval.batch_size=10 \
|
| 28 |
+
--eval.n_episodes=10 \
|
| 29 |
+
--use_amp=false \
|
| 30 |
+
--device=cuda
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes.
|
| 34 |
+
```
|
| 35 |
+
python lerobot/scripts/eval.py \
|
| 36 |
+
--policy.path=outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \
|
| 37 |
+
--env.type=pusht \
|
| 38 |
+
--eval.batch_size=10 \
|
| 39 |
+
--eval.n_episodes=10 \
|
| 40 |
+
--use_amp=false \
|
| 41 |
+
--device=cuda
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
Note that in both examples, the repo/folder should contain at least `config.json` and `model.safetensors` files.
|
| 45 |
+
|
| 46 |
+
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
import json
|
| 50 |
+
import logging
|
| 51 |
+
import threading
|
| 52 |
+
import time
|
| 53 |
+
from contextlib import nullcontext
|
| 54 |
+
from copy import deepcopy
|
| 55 |
+
from dataclasses import asdict
|
| 56 |
+
from pathlib import Path
|
| 57 |
+
from pprint import pformat
|
| 58 |
+
from typing import Callable
|
| 59 |
+
|
| 60 |
+
import einops
|
| 61 |
+
import gymnasium as gym
|
| 62 |
+
import numpy as np
|
| 63 |
+
import torch
|
| 64 |
+
from termcolor import colored
|
| 65 |
+
from torch import Tensor, nn
|
| 66 |
+
from tqdm import trange
|
| 67 |
+
|
| 68 |
+
from lerobot.common.envs.factory import make_env
|
| 69 |
+
from lerobot.common.envs.utils import preprocess_observation
|
| 70 |
+
from lerobot.common.policies.factory import make_policy
|
| 71 |
+
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
| 72 |
+
from lerobot.common.policies.utils import get_device_from_parameters
|
| 73 |
+
from lerobot.common.utils.io_utils import write_video
|
| 74 |
+
from lerobot.common.utils.random_utils import set_seed
|
| 75 |
+
from lerobot.common.utils.utils import (
|
| 76 |
+
get_safe_torch_device,
|
| 77 |
+
init_logging,
|
| 78 |
+
inside_slurm,
|
| 79 |
+
)
|
| 80 |
+
from lerobot.configs import parser
|
| 81 |
+
from lerobot.configs.eval import EvalPipelineConfig
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def rollout(
|
| 85 |
+
env: gym.vector.VectorEnv,
|
| 86 |
+
policy: PreTrainedPolicy,
|
| 87 |
+
seeds: list[int] | None = None,
|
| 88 |
+
return_observations: bool = False,
|
| 89 |
+
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
| 90 |
+
) -> dict:
|
| 91 |
+
"""Run a batched policy rollout once through a batch of environments.
|
| 92 |
+
|
| 93 |
+
Note that all environments in the batch are run until the last environment is done. This means some
|
| 94 |
+
data will probably need to be discarded (for environments that aren't the first one to be done).
|
| 95 |
+
|
| 96 |
+
The return dictionary contains:
|
| 97 |
+
(optional) "observation": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
|
| 98 |
+
keys. NOTE the that this has an extra sequence element relative to the other keys in the
|
| 99 |
+
dictionary. This is because an extra observation is included for after the environment is
|
| 100 |
+
terminated or truncated.
|
| 101 |
+
"action": A (batch, sequence, action_dim) tensor of actions applied based on the observations (not
|
| 102 |
+
including the last observations).
|
| 103 |
+
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
|
| 104 |
+
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
|
| 105 |
+
environment termination/truncation).
|
| 106 |
+
"done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
| 107 |
+
the first True is followed by True's all the way till the end. This can be used for masking
|
| 108 |
+
extraneous elements from the sequences above.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
env: The batch of environments.
|
| 112 |
+
policy: The policy. Must be a PyTorch nn module.
|
| 113 |
+
seeds: The environments are seeded once at the start of the rollout. If provided, this argument
|
| 114 |
+
specifies the seeds for each of the environments.
|
| 115 |
+
return_observations: Whether to include all observations in the returned rollout data. Observations
|
| 116 |
+
are returned optionally because they typically take more memory to cache. Defaults to False.
|
| 117 |
+
render_callback: Optional rendering callback to be used after the environments are reset, and after
|
| 118 |
+
every step.
|
| 119 |
+
Returns:
|
| 120 |
+
The dictionary described above.
|
| 121 |
+
"""
|
| 122 |
+
assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
|
| 123 |
+
device = get_device_from_parameters(policy)
|
| 124 |
+
|
| 125 |
+
# Reset the policy and environments.
|
| 126 |
+
policy.reset()
|
| 127 |
+
|
| 128 |
+
observation, info = env.reset(seed=seeds)
|
| 129 |
+
if render_callback is not None:
|
| 130 |
+
render_callback(env)
|
| 131 |
+
|
| 132 |
+
all_observations = []
|
| 133 |
+
all_actions = []
|
| 134 |
+
all_rewards = []
|
| 135 |
+
all_successes = []
|
| 136 |
+
all_dones = []
|
| 137 |
+
|
| 138 |
+
step = 0
|
| 139 |
+
# Keep track of which environments are done.
|
| 140 |
+
done = np.array([False] * env.num_envs)
|
| 141 |
+
max_steps = env.call("_max_episode_steps")[0]
|
| 142 |
+
progbar = trange(
|
| 143 |
+
max_steps,
|
| 144 |
+
desc=f"Running rollout with at most {max_steps} steps",
|
| 145 |
+
disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs
|
| 146 |
+
leave=False,
|
| 147 |
+
)
|
| 148 |
+
while not np.all(done):
|
| 149 |
+
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
| 150 |
+
observation = preprocess_observation(observation)
|
| 151 |
+
if return_observations:
|
| 152 |
+
all_observations.append(deepcopy(observation))
|
| 153 |
+
|
| 154 |
+
observation = {
|
| 155 |
+
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
with torch.inference_mode():
|
| 159 |
+
action = policy.select_action(observation)
|
| 160 |
+
|
| 161 |
+
# Convert to CPU / numpy.
|
| 162 |
+
action = action.to("cpu").numpy()
|
| 163 |
+
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
| 164 |
+
|
| 165 |
+
# Apply the next action.
|
| 166 |
+
observation, reward, terminated, truncated, info = env.step(action)
|
| 167 |
+
if render_callback is not None:
|
| 168 |
+
render_callback(env)
|
| 169 |
+
|
| 170 |
+
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
| 171 |
+
# available of none of the envs finished.
|
| 172 |
+
if "final_info" in info:
|
| 173 |
+
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
|
| 174 |
+
else:
|
| 175 |
+
successes = [False] * env.num_envs
|
| 176 |
+
|
| 177 |
+
# Keep track of which environments are done so far.
|
| 178 |
+
done = terminated | truncated | done
|
| 179 |
+
|
| 180 |
+
all_actions.append(torch.from_numpy(action))
|
| 181 |
+
all_rewards.append(torch.from_numpy(reward))
|
| 182 |
+
all_dones.append(torch.from_numpy(done))
|
| 183 |
+
all_successes.append(torch.tensor(successes))
|
| 184 |
+
|
| 185 |
+
step += 1
|
| 186 |
+
running_success_rate = (
|
| 187 |
+
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
|
| 188 |
+
)
|
| 189 |
+
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
|
| 190 |
+
progbar.update()
|
| 191 |
+
|
| 192 |
+
# Track the final observation.
|
| 193 |
+
if return_observations:
|
| 194 |
+
observation = preprocess_observation(observation)
|
| 195 |
+
all_observations.append(deepcopy(observation))
|
| 196 |
+
|
| 197 |
+
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
|
| 198 |
+
ret = {
|
| 199 |
+
"action": torch.stack(all_actions, dim=1),
|
| 200 |
+
"reward": torch.stack(all_rewards, dim=1),
|
| 201 |
+
"success": torch.stack(all_successes, dim=1),
|
| 202 |
+
"done": torch.stack(all_dones, dim=1),
|
| 203 |
+
}
|
| 204 |
+
if return_observations:
|
| 205 |
+
stacked_observations = {}
|
| 206 |
+
for key in all_observations[0]:
|
| 207 |
+
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
|
| 208 |
+
ret["observation"] = stacked_observations
|
| 209 |
+
|
| 210 |
+
if hasattr(policy, "use_original_modules"):
|
| 211 |
+
policy.use_original_modules()
|
| 212 |
+
|
| 213 |
+
return ret
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def eval_policy(
|
| 217 |
+
env: gym.vector.VectorEnv,
|
| 218 |
+
policy: PreTrainedPolicy,
|
| 219 |
+
n_episodes: int,
|
| 220 |
+
max_episodes_rendered: int = 0,
|
| 221 |
+
videos_dir: Path | None = None,
|
| 222 |
+
return_episode_data: bool = False,
|
| 223 |
+
start_seed: int | None = None,
|
| 224 |
+
) -> dict:
|
| 225 |
+
"""
|
| 226 |
+
Args:
|
| 227 |
+
env: The batch of environments.
|
| 228 |
+
policy: The policy.
|
| 229 |
+
n_episodes: The number of episodes to evaluate.
|
| 230 |
+
max_episodes_rendered: Maximum number of episodes to render into videos.
|
| 231 |
+
videos_dir: Where to save rendered videos.
|
| 232 |
+
return_episode_data: Whether to return episode data for online training. Incorporates the data into
|
| 233 |
+
the "episodes" key of the returned dictionary.
|
| 234 |
+
start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the
|
| 235 |
+
seed is incremented by 1. If not provided, the environments are not manually seeded.
|
| 236 |
+
Returns:
|
| 237 |
+
Dictionary with metrics and data regarding the rollouts.
|
| 238 |
+
"""
|
| 239 |
+
if max_episodes_rendered > 0 and not videos_dir:
|
| 240 |
+
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
|
| 241 |
+
|
| 242 |
+
if not isinstance(policy, PreTrainedPolicy):
|
| 243 |
+
raise ValueError(
|
| 244 |
+
f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided."
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
start = time.time()
|
| 248 |
+
policy.eval()
|
| 249 |
+
|
| 250 |
+
# Determine how many batched rollouts we need to get n_episodes. Note that if n_episodes is not evenly
|
| 251 |
+
# divisible by env.num_envs we end up discarding some data in the last batch.
|
| 252 |
+
n_batches = n_episodes // env.num_envs + int((n_episodes % env.num_envs) != 0)
|
| 253 |
+
|
| 254 |
+
# Keep track of some metrics.
|
| 255 |
+
sum_rewards = []
|
| 256 |
+
max_rewards = []
|
| 257 |
+
all_successes = []
|
| 258 |
+
all_seeds = []
|
| 259 |
+
threads = [] # for video saving threads
|
| 260 |
+
n_episodes_rendered = 0 # for saving the correct number of videos
|
| 261 |
+
|
| 262 |
+
# Callback for visualization.
|
| 263 |
+
def render_frame(env: gym.vector.VectorEnv):
|
| 264 |
+
# noqa: B023
|
| 265 |
+
if n_episodes_rendered >= max_episodes_rendered:
|
| 266 |
+
return
|
| 267 |
+
n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
|
| 268 |
+
if isinstance(env, gym.vector.SyncVectorEnv):
|
| 269 |
+
ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023
|
| 270 |
+
elif isinstance(env, gym.vector.AsyncVectorEnv):
|
| 271 |
+
# Here we must render all frames and discard any we don't need.
|
| 272 |
+
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
|
| 273 |
+
|
| 274 |
+
if max_episodes_rendered > 0:
|
| 275 |
+
video_paths: list[str] = []
|
| 276 |
+
|
| 277 |
+
if return_episode_data:
|
| 278 |
+
episode_data: dict | None = None
|
| 279 |
+
|
| 280 |
+
# we dont want progress bar when we use slurm, since it clutters the logs
|
| 281 |
+
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
|
| 282 |
+
for batch_ix in progbar:
|
| 283 |
+
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
|
| 284 |
+
# step.
|
| 285 |
+
if max_episodes_rendered > 0:
|
| 286 |
+
ep_frames: list[np.ndarray] = []
|
| 287 |
+
|
| 288 |
+
if start_seed is None:
|
| 289 |
+
seeds = None
|
| 290 |
+
else:
|
| 291 |
+
seeds = range(
|
| 292 |
+
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
|
| 293 |
+
)
|
| 294 |
+
rollout_data = rollout(
|
| 295 |
+
env,
|
| 296 |
+
policy,
|
| 297 |
+
seeds=list(seeds) if seeds else None,
|
| 298 |
+
return_observations=return_episode_data,
|
| 299 |
+
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
| 303 |
+
# this won't be included).
|
| 304 |
+
n_steps = rollout_data["done"].shape[1]
|
| 305 |
+
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
|
| 306 |
+
done_indices = torch.argmax(rollout_data["done"].to(int), dim=1)
|
| 307 |
+
|
| 308 |
+
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
|
| 309 |
+
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
|
| 310 |
+
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
|
| 311 |
+
# Extend metrics.
|
| 312 |
+
batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum")
|
| 313 |
+
sum_rewards.extend(batch_sum_rewards.tolist())
|
| 314 |
+
batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max")
|
| 315 |
+
max_rewards.extend(batch_max_rewards.tolist())
|
| 316 |
+
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
| 317 |
+
all_successes.extend(batch_successes.tolist())
|
| 318 |
+
if seeds:
|
| 319 |
+
all_seeds.extend(seeds)
|
| 320 |
+
else:
|
| 321 |
+
all_seeds.append(None)
|
| 322 |
+
|
| 323 |
+
# FIXME: episode_data is either None or it doesn't exist
|
| 324 |
+
if return_episode_data:
|
| 325 |
+
this_episode_data = _compile_episode_data(
|
| 326 |
+
rollout_data,
|
| 327 |
+
done_indices,
|
| 328 |
+
start_episode_index=batch_ix * env.num_envs,
|
| 329 |
+
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
|
| 330 |
+
fps=env.unwrapped.metadata["render_fps"],
|
| 331 |
+
)
|
| 332 |
+
if episode_data is None:
|
| 333 |
+
episode_data = this_episode_data
|
| 334 |
+
else:
|
| 335 |
+
# Some sanity checks to make sure we are correctly compiling the data.
|
| 336 |
+
assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0]
|
| 337 |
+
assert episode_data["index"][-1] + 1 == this_episode_data["index"][0]
|
| 338 |
+
# Concatenate the episode data.
|
| 339 |
+
episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data}
|
| 340 |
+
|
| 341 |
+
# Maybe render video for visualization.
|
| 342 |
+
if max_episodes_rendered > 0 and len(ep_frames) > 0:
|
| 343 |
+
batch_stacked_frames = np.stack(ep_frames, axis=1) # (b, t, *)
|
| 344 |
+
for stacked_frames, done_index in zip(
|
| 345 |
+
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
| 346 |
+
):
|
| 347 |
+
if n_episodes_rendered >= max_episodes_rendered:
|
| 348 |
+
break
|
| 349 |
+
|
| 350 |
+
videos_dir.mkdir(parents=True, exist_ok=True)
|
| 351 |
+
video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4"
|
| 352 |
+
video_paths.append(str(video_path))
|
| 353 |
+
thread = threading.Thread(
|
| 354 |
+
target=write_video,
|
| 355 |
+
args=(
|
| 356 |
+
str(video_path),
|
| 357 |
+
stacked_frames[: done_index + 1], # + 1 to capture the last observation
|
| 358 |
+
env.unwrapped.metadata["render_fps"],
|
| 359 |
+
),
|
| 360 |
+
)
|
| 361 |
+
thread.start()
|
| 362 |
+
threads.append(thread)
|
| 363 |
+
n_episodes_rendered += 1
|
| 364 |
+
|
| 365 |
+
progbar.set_postfix(
|
| 366 |
+
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Wait till all video rendering threads are done.
|
| 370 |
+
for thread in threads:
|
| 371 |
+
thread.join()
|
| 372 |
+
|
| 373 |
+
# Compile eval info.
|
| 374 |
+
info = {
|
| 375 |
+
"per_episode": [
|
| 376 |
+
{
|
| 377 |
+
"episode_ix": i,
|
| 378 |
+
"sum_reward": sum_reward,
|
| 379 |
+
"max_reward": max_reward,
|
| 380 |
+
"success": success,
|
| 381 |
+
"seed": seed,
|
| 382 |
+
}
|
| 383 |
+
for i, (sum_reward, max_reward, success, seed) in enumerate(
|
| 384 |
+
zip(
|
| 385 |
+
sum_rewards[:n_episodes],
|
| 386 |
+
max_rewards[:n_episodes],
|
| 387 |
+
all_successes[:n_episodes],
|
| 388 |
+
all_seeds[:n_episodes],
|
| 389 |
+
strict=True,
|
| 390 |
+
)
|
| 391 |
+
)
|
| 392 |
+
],
|
| 393 |
+
"aggregated": {
|
| 394 |
+
"avg_sum_reward": float(np.nanmean(sum_rewards[:n_episodes])),
|
| 395 |
+
"avg_max_reward": float(np.nanmean(max_rewards[:n_episodes])),
|
| 396 |
+
"pc_success": float(np.nanmean(all_successes[:n_episodes]) * 100),
|
| 397 |
+
"eval_s": time.time() - start,
|
| 398 |
+
"eval_ep_s": (time.time() - start) / n_episodes,
|
| 399 |
+
},
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
if return_episode_data:
|
| 403 |
+
info["episodes"] = episode_data
|
| 404 |
+
|
| 405 |
+
if max_episodes_rendered > 0:
|
| 406 |
+
info["video_paths"] = video_paths
|
| 407 |
+
|
| 408 |
+
return info
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def _compile_episode_data(
|
| 412 |
+
rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float
|
| 413 |
+
) -> dict:
|
| 414 |
+
"""Convenience function for `eval_policy(return_episode_data=True)`
|
| 415 |
+
|
| 416 |
+
Compiles all the rollout data into a Hugging Face dataset.
|
| 417 |
+
|
| 418 |
+
Similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`).
|
| 419 |
+
"""
|
| 420 |
+
ep_dicts = []
|
| 421 |
+
total_frames = 0
|
| 422 |
+
for ep_ix in range(rollout_data["action"].shape[0]):
|
| 423 |
+
# + 2 to include the first done frame and the last observation frame.
|
| 424 |
+
num_frames = done_indices[ep_ix].item() + 2
|
| 425 |
+
total_frames += num_frames
|
| 426 |
+
|
| 427 |
+
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
|
| 428 |
+
ep_dict = {
|
| 429 |
+
"action": rollout_data["action"][ep_ix, : num_frames - 1],
|
| 430 |
+
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
|
| 431 |
+
"frame_index": torch.arange(0, num_frames - 1, 1),
|
| 432 |
+
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
|
| 433 |
+
"next.done": rollout_data["done"][ep_ix, : num_frames - 1],
|
| 434 |
+
"next.success": rollout_data["success"][ep_ix, : num_frames - 1],
|
| 435 |
+
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
# For the last observation frame, all other keys will just be copy padded.
|
| 439 |
+
for k in ep_dict:
|
| 440 |
+
ep_dict[k] = torch.cat([ep_dict[k], ep_dict[k][-1:]])
|
| 441 |
+
|
| 442 |
+
for key in rollout_data["observation"]:
|
| 443 |
+
ep_dict[key] = rollout_data["observation"][key][ep_ix, :num_frames]
|
| 444 |
+
|
| 445 |
+
ep_dicts.append(ep_dict)
|
| 446 |
+
|
| 447 |
+
data_dict = {}
|
| 448 |
+
for key in ep_dicts[0]:
|
| 449 |
+
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
| 450 |
+
|
| 451 |
+
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
|
| 452 |
+
|
| 453 |
+
return data_dict
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
@parser.wrap()
|
| 457 |
+
def eval_main(cfg: EvalPipelineConfig):
|
| 458 |
+
logging.info(pformat(asdict(cfg)))
|
| 459 |
+
|
| 460 |
+
# Check device is available
|
| 461 |
+
device = get_safe_torch_device(cfg.policy.device, log=True)
|
| 462 |
+
|
| 463 |
+
torch.backends.cudnn.benchmark = True
|
| 464 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 465 |
+
set_seed(cfg.seed)
|
| 466 |
+
|
| 467 |
+
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
| 468 |
+
|
| 469 |
+
logging.info("Making environment.")
|
| 470 |
+
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
| 471 |
+
|
| 472 |
+
logging.info("Making policy.")
|
| 473 |
+
|
| 474 |
+
policy = make_policy(
|
| 475 |
+
cfg=cfg.policy,
|
| 476 |
+
env_cfg=cfg.env,
|
| 477 |
+
)
|
| 478 |
+
policy.eval()
|
| 479 |
+
|
| 480 |
+
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
| 481 |
+
info = eval_policy(
|
| 482 |
+
env,
|
| 483 |
+
policy,
|
| 484 |
+
cfg.eval.n_episodes,
|
| 485 |
+
max_episodes_rendered=10,
|
| 486 |
+
videos_dir=Path(cfg.output_dir) / "videos",
|
| 487 |
+
start_seed=cfg.seed,
|
| 488 |
+
)
|
| 489 |
+
print(info["aggregated"])
|
| 490 |
+
|
| 491 |
+
# Save info
|
| 492 |
+
with open(Path(cfg.output_dir) / "eval_info.json", "w") as f:
|
| 493 |
+
json.dump(info, f, indent=2)
|
| 494 |
+
|
| 495 |
+
env.close()
|
| 496 |
+
|
| 497 |
+
logging.info("End of eval")
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
if __name__ == "__main__":
|
| 501 |
+
init_logging()
|
| 502 |
+
eval_main()
|
lerobot/scripts/find_motors_bus_port.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import os
|
| 15 |
+
import time
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
from serial.tools import list_ports # Part of pyserial library
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def find_available_ports():
|
| 22 |
+
if os.name == "nt": # Windows
|
| 23 |
+
# List COM ports using pyserial
|
| 24 |
+
ports = [port.device for port in list_ports.comports()]
|
| 25 |
+
else: # Linux/macOS
|
| 26 |
+
# List /dev/tty* ports for Unix-based systems
|
| 27 |
+
ports = [str(path) for path in Path("/dev").glob("tty*")]
|
| 28 |
+
return ports
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def find_port():
|
| 32 |
+
print("Finding all available ports for the MotorsBus.")
|
| 33 |
+
ports_before = find_available_ports()
|
| 34 |
+
print("Ports before disconnecting:", ports_before)
|
| 35 |
+
|
| 36 |
+
print("Remove the USB cable from your MotorsBus and press Enter when done.")
|
| 37 |
+
input() # Wait for user to disconnect the device
|
| 38 |
+
|
| 39 |
+
time.sleep(0.5) # Allow some time for port to be released
|
| 40 |
+
ports_after = find_available_ports()
|
| 41 |
+
ports_diff = list(set(ports_before) - set(ports_after))
|
| 42 |
+
|
| 43 |
+
if len(ports_diff) == 1:
|
| 44 |
+
port = ports_diff[0]
|
| 45 |
+
print(f"The port of this MotorsBus is '{port}'")
|
| 46 |
+
print("Reconnect the USB cable.")
|
| 47 |
+
elif len(ports_diff) == 0:
|
| 48 |
+
raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).")
|
| 49 |
+
else:
|
| 50 |
+
raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
# Helper to find the USB port associated with your MotorsBus.
|
| 55 |
+
find_port()
|
lerobot/scripts/push_dataset_to_hub.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""
|
| 17 |
+
Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
|
| 18 |
+
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
|
| 19 |
+
installation of neural net specific packages like pytorch, tensorflow, jax.
|
| 20 |
+
|
| 21 |
+
Example of how to download raw datasets, convert them into LeRobotDataset format, and push them to the hub:
|
| 22 |
+
```
|
| 23 |
+
python lerobot/scripts/push_dataset_to_hub.py \
|
| 24 |
+
--raw-dir data/pusht_raw \
|
| 25 |
+
--raw-format pusht_zarr \
|
| 26 |
+
--repo-id lerobot/pusht
|
| 27 |
+
|
| 28 |
+
python lerobot/scripts/push_dataset_to_hub.py \
|
| 29 |
+
--raw-dir data/xarm_lift_medium_raw \
|
| 30 |
+
--raw-format xarm_pkl \
|
| 31 |
+
--repo-id lerobot/xarm_lift_medium
|
| 32 |
+
|
| 33 |
+
python lerobot/scripts/push_dataset_to_hub.py \
|
| 34 |
+
--raw-dir data/aloha_sim_insertion_scripted_raw \
|
| 35 |
+
--raw-format aloha_hdf5 \
|
| 36 |
+
--repo-id lerobot/aloha_sim_insertion_scripted
|
| 37 |
+
|
| 38 |
+
python lerobot/scripts/push_dataset_to_hub.py \
|
| 39 |
+
--raw-dir data/umi_cup_in_the_wild_raw \
|
| 40 |
+
--raw-format umi_zarr \
|
| 41 |
+
--repo-id lerobot/umi_cup_in_the_wild
|
| 42 |
+
```
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
import argparse
|
| 46 |
+
import json
|
| 47 |
+
import shutil
|
| 48 |
+
import warnings
|
| 49 |
+
from pathlib import Path
|
| 50 |
+
from typing import Any
|
| 51 |
+
|
| 52 |
+
import torch
|
| 53 |
+
from huggingface_hub import HfApi
|
| 54 |
+
from safetensors.torch import save_file
|
| 55 |
+
|
| 56 |
+
from lerobot.common.datasets.compute_stats import compute_stats
|
| 57 |
+
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
| 58 |
+
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
| 59 |
+
from lerobot.common.datasets.utils import create_branch, create_lerobot_dataset_card, flatten_dict
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
| 63 |
+
if raw_format == "pusht_zarr":
|
| 64 |
+
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
|
| 65 |
+
elif raw_format == "umi_zarr":
|
| 66 |
+
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
| 67 |
+
elif raw_format == "aloha_hdf5":
|
| 68 |
+
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
| 69 |
+
elif raw_format in ["rlds", "openx"]:
|
| 70 |
+
from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import from_raw_to_lerobot_format
|
| 71 |
+
elif raw_format == "dora_parquet":
|
| 72 |
+
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
|
| 73 |
+
elif raw_format == "xarm_pkl":
|
| 74 |
+
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
|
| 75 |
+
elif raw_format == "cam_png":
|
| 76 |
+
from lerobot.common.datasets.push_dataset_to_hub.cam_png_format import from_raw_to_lerobot_format
|
| 77 |
+
else:
|
| 78 |
+
raise ValueError(
|
| 79 |
+
f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
return from_raw_to_lerobot_format
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def save_meta_data(
|
| 86 |
+
info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path
|
| 87 |
+
):
|
| 88 |
+
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
| 89 |
+
|
| 90 |
+
# save info
|
| 91 |
+
info_path = meta_data_dir / "info.json"
|
| 92 |
+
with open(str(info_path), "w") as f:
|
| 93 |
+
json.dump(info, f, indent=4)
|
| 94 |
+
|
| 95 |
+
# save stats
|
| 96 |
+
stats_path = meta_data_dir / "stats.safetensors"
|
| 97 |
+
save_file(flatten_dict(stats), stats_path)
|
| 98 |
+
|
| 99 |
+
# save episode_data_index
|
| 100 |
+
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
|
| 101 |
+
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
|
| 102 |
+
save_file(episode_data_index, ep_data_idx_path)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None):
|
| 106 |
+
"""Expect all meta data files to be all stored in a single "meta_data" directory.
|
| 107 |
+
On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
|
| 108 |
+
"""
|
| 109 |
+
api = HfApi()
|
| 110 |
+
api.upload_folder(
|
| 111 |
+
folder_path=meta_data_dir,
|
| 112 |
+
path_in_repo="meta_data",
|
| 113 |
+
repo_id=repo_id,
|
| 114 |
+
revision=revision,
|
| 115 |
+
repo_type="dataset",
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def push_dataset_card_to_hub(
|
| 120 |
+
repo_id: str,
|
| 121 |
+
revision: str | None,
|
| 122 |
+
tags: list | None = None,
|
| 123 |
+
license: str = "apache-2.0",
|
| 124 |
+
**card_kwargs,
|
| 125 |
+
):
|
| 126 |
+
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
|
| 127 |
+
card = create_lerobot_dataset_card(tags=tags, license=license, **card_kwargs)
|
| 128 |
+
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
|
| 132 |
+
"""Expect mp4 files to be all stored in a single "videos" directory.
|
| 133 |
+
On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
|
| 134 |
+
"""
|
| 135 |
+
api = HfApi()
|
| 136 |
+
api.upload_folder(
|
| 137 |
+
folder_path=videos_dir,
|
| 138 |
+
path_in_repo="videos",
|
| 139 |
+
repo_id=repo_id,
|
| 140 |
+
revision=revision,
|
| 141 |
+
repo_type="dataset",
|
| 142 |
+
allow_patterns="*.mp4",
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def push_dataset_to_hub(
|
| 147 |
+
raw_dir: Path,
|
| 148 |
+
raw_format: str,
|
| 149 |
+
repo_id: str,
|
| 150 |
+
push_to_hub: bool = True,
|
| 151 |
+
local_dir: Path | None = None,
|
| 152 |
+
fps: int | None = None,
|
| 153 |
+
video: bool = True,
|
| 154 |
+
batch_size: int = 32,
|
| 155 |
+
num_workers: int = 8,
|
| 156 |
+
episodes: list[int] | None = None,
|
| 157 |
+
force_override: bool = False,
|
| 158 |
+
resume: bool = False,
|
| 159 |
+
cache_dir: Path = Path("/tmp"),
|
| 160 |
+
tests_data_dir: Path | None = None,
|
| 161 |
+
encoding: dict | None = None,
|
| 162 |
+
):
|
| 163 |
+
check_repo_id(repo_id)
|
| 164 |
+
user_id, dataset_id = repo_id.split("/")
|
| 165 |
+
|
| 166 |
+
# Robustify when `raw_dir` is str instead of Path
|
| 167 |
+
raw_dir = Path(raw_dir)
|
| 168 |
+
if not raw_dir.exists():
|
| 169 |
+
raise NotADirectoryError(
|
| 170 |
+
f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub: "
|
| 171 |
+
f"`python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw`"
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if local_dir:
|
| 175 |
+
# Robustify when `local_dir` is str instead of Path
|
| 176 |
+
local_dir = Path(local_dir)
|
| 177 |
+
|
| 178 |
+
# Send warning if local_dir isn't well formatted
|
| 179 |
+
if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id:
|
| 180 |
+
warnings.warn(
|
| 181 |
+
f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.",
|
| 182 |
+
stacklevel=1,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Check we don't override an existing `local_dir` by mistake
|
| 186 |
+
if local_dir.exists():
|
| 187 |
+
if force_override:
|
| 188 |
+
shutil.rmtree(local_dir)
|
| 189 |
+
elif not resume:
|
| 190 |
+
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
|
| 191 |
+
|
| 192 |
+
meta_data_dir = local_dir / "meta_data"
|
| 193 |
+
videos_dir = local_dir / "videos"
|
| 194 |
+
else:
|
| 195 |
+
# Temporary directory used to store images, videos, meta_data
|
| 196 |
+
meta_data_dir = Path(cache_dir) / "meta_data"
|
| 197 |
+
videos_dir = Path(cache_dir) / "videos"
|
| 198 |
+
|
| 199 |
+
if raw_format is None:
|
| 200 |
+
# TODO(rcadene, adilzouitine): implement auto_find_raw_format
|
| 201 |
+
raise NotImplementedError()
|
| 202 |
+
# raw_format = auto_find_raw_format(raw_dir)
|
| 203 |
+
|
| 204 |
+
# convert dataset from original raw format to LeRobot format
|
| 205 |
+
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
| 206 |
+
|
| 207 |
+
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
| 208 |
+
raw_dir,
|
| 209 |
+
videos_dir,
|
| 210 |
+
fps,
|
| 211 |
+
video,
|
| 212 |
+
episodes,
|
| 213 |
+
encoding,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
lerobot_dataset = LeRobotDataset.from_preloaded(
|
| 217 |
+
repo_id=repo_id,
|
| 218 |
+
hf_dataset=hf_dataset,
|
| 219 |
+
episode_data_index=episode_data_index,
|
| 220 |
+
info=info,
|
| 221 |
+
videos_dir=videos_dir,
|
| 222 |
+
)
|
| 223 |
+
stats = compute_stats(lerobot_dataset, batch_size, num_workers)
|
| 224 |
+
|
| 225 |
+
if local_dir:
|
| 226 |
+
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
| 227 |
+
hf_dataset.save_to_disk(str(local_dir / "train"))
|
| 228 |
+
|
| 229 |
+
if push_to_hub or local_dir:
|
| 230 |
+
# mandatory for upload
|
| 231 |
+
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
| 232 |
+
|
| 233 |
+
if push_to_hub:
|
| 234 |
+
hf_dataset.push_to_hub(repo_id, revision="main")
|
| 235 |
+
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
| 236 |
+
push_dataset_card_to_hub(repo_id, revision="main")
|
| 237 |
+
if video:
|
| 238 |
+
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
| 239 |
+
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
| 240 |
+
|
| 241 |
+
if tests_data_dir:
|
| 242 |
+
# get the first episode
|
| 243 |
+
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
| 244 |
+
test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
|
| 245 |
+
episode_data_index = {k: v[:1] for k, v in episode_data_index.items()}
|
| 246 |
+
|
| 247 |
+
test_hf_dataset = test_hf_dataset.with_format(None)
|
| 248 |
+
test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train"))
|
| 249 |
+
|
| 250 |
+
tests_meta_data = tests_data_dir / repo_id / "meta_data"
|
| 251 |
+
save_meta_data(info, stats, episode_data_index, tests_meta_data)
|
| 252 |
+
|
| 253 |
+
# copy videos of first episode to tests directory
|
| 254 |
+
episode_index = 0
|
| 255 |
+
tests_videos_dir = tests_data_dir / repo_id / "videos"
|
| 256 |
+
tests_videos_dir.mkdir(parents=True, exist_ok=True)
|
| 257 |
+
for key in lerobot_dataset.camera_keys:
|
| 258 |
+
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
| 259 |
+
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
|
| 260 |
+
|
| 261 |
+
if local_dir is None:
|
| 262 |
+
# clear cache
|
| 263 |
+
shutil.rmtree(meta_data_dir)
|
| 264 |
+
shutil.rmtree(videos_dir)
|
| 265 |
+
|
| 266 |
+
return lerobot_dataset
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def main():
|
| 270 |
+
parser = argparse.ArgumentParser()
|
| 271 |
+
|
| 272 |
+
parser.add_argument(
|
| 273 |
+
"--raw-dir",
|
| 274 |
+
type=Path,
|
| 275 |
+
required=True,
|
| 276 |
+
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
|
| 277 |
+
)
|
| 278 |
+
# TODO(rcadene): add automatic detection of the format
|
| 279 |
+
parser.add_argument(
|
| 280 |
+
"--raw-format",
|
| 281 |
+
type=str,
|
| 282 |
+
required=True,
|
| 283 |
+
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`, `rlds`, `openx`).",
|
| 284 |
+
)
|
| 285 |
+
parser.add_argument(
|
| 286 |
+
"--repo-id",
|
| 287 |
+
type=str,
|
| 288 |
+
required=True,
|
| 289 |
+
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
| 290 |
+
)
|
| 291 |
+
parser.add_argument(
|
| 292 |
+
"--local-dir",
|
| 293 |
+
type=Path,
|
| 294 |
+
help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).",
|
| 295 |
+
)
|
| 296 |
+
parser.add_argument(
|
| 297 |
+
"--push-to-hub",
|
| 298 |
+
type=int,
|
| 299 |
+
default=1,
|
| 300 |
+
help="Upload to hub.",
|
| 301 |
+
)
|
| 302 |
+
parser.add_argument(
|
| 303 |
+
"--fps",
|
| 304 |
+
type=int,
|
| 305 |
+
help="Frame rate used to collect videos. If not provided, use the default one specified in the code.",
|
| 306 |
+
)
|
| 307 |
+
parser.add_argument(
|
| 308 |
+
"--video",
|
| 309 |
+
type=int,
|
| 310 |
+
default=1,
|
| 311 |
+
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
|
| 312 |
+
)
|
| 313 |
+
parser.add_argument(
|
| 314 |
+
"--batch-size",
|
| 315 |
+
type=int,
|
| 316 |
+
default=32,
|
| 317 |
+
help="Batch size loaded by DataLoader for computing the dataset statistics.",
|
| 318 |
+
)
|
| 319 |
+
parser.add_argument(
|
| 320 |
+
"--num-workers",
|
| 321 |
+
type=int,
|
| 322 |
+
default=8,
|
| 323 |
+
help="Number of processes of Dataloader for computing the dataset statistics.",
|
| 324 |
+
)
|
| 325 |
+
parser.add_argument(
|
| 326 |
+
"--episodes",
|
| 327 |
+
type=int,
|
| 328 |
+
nargs="*",
|
| 329 |
+
help="When provided, only converts the provided episodes (e.g `--episodes 2 3 4`). Useful to test the code on 1 episode.",
|
| 330 |
+
)
|
| 331 |
+
parser.add_argument(
|
| 332 |
+
"--force-override",
|
| 333 |
+
type=int,
|
| 334 |
+
default=0,
|
| 335 |
+
help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.",
|
| 336 |
+
)
|
| 337 |
+
parser.add_argument(
|
| 338 |
+
"--resume",
|
| 339 |
+
type=int,
|
| 340 |
+
default=0,
|
| 341 |
+
help="When set to 1, resumes a previous run.",
|
| 342 |
+
)
|
| 343 |
+
parser.add_argument(
|
| 344 |
+
"--cache-dir",
|
| 345 |
+
type=Path,
|
| 346 |
+
required=False,
|
| 347 |
+
default="/tmp",
|
| 348 |
+
help="Directory to store the temporary videos and images generated while creating the dataset.",
|
| 349 |
+
)
|
| 350 |
+
parser.add_argument(
|
| 351 |
+
"--tests-data-dir",
|
| 352 |
+
type=Path,
|
| 353 |
+
help=(
|
| 354 |
+
"When provided, save tests artifacts into the given directory "
|
| 355 |
+
"(e.g. `--tests-data-dir tests/data` will save to tests/data/{--repo-id})."
|
| 356 |
+
),
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
args = parser.parse_args()
|
| 360 |
+
push_dataset_to_hub(**vars(args))
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
if __name__ == "__main__":
|
| 364 |
+
main()
|
lerobot/scripts/push_pretrained.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""
|
| 17 |
+
Once you have trained a policy with our training script (lerobot/scripts/train.py), use this script to push it
|
| 18 |
+
to the hub.
|
| 19 |
+
|
| 20 |
+
Example:
|
| 21 |
+
|
| 22 |
+
```bash
|
| 23 |
+
python lerobot/scripts/push_pretrained.py \
|
| 24 |
+
--pretrained_path=outputs/train/act_aloha_sim_transfer_cube_human/checkpoints/last/pretrained_model \
|
| 25 |
+
--repo_id=lerobot/act_aloha_sim_transfer_cube_human
|
| 26 |
+
```
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from dataclasses import dataclass
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
|
| 32 |
+
import draccus
|
| 33 |
+
from huggingface_hub import HfApi
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class PushPreTrainedConfig:
|
| 38 |
+
pretrained_path: Path
|
| 39 |
+
repo_id: str
|
| 40 |
+
branch: str | None = None
|
| 41 |
+
private: bool = False
|
| 42 |
+
exist_ok: bool = False
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@draccus.wrap()
|
| 46 |
+
def main(cfg: PushPreTrainedConfig):
|
| 47 |
+
hub_api = HfApi()
|
| 48 |
+
hub_api.create_repo(
|
| 49 |
+
repo_id=cfg.repo_id,
|
| 50 |
+
private=cfg.private,
|
| 51 |
+
repo_type="model",
|
| 52 |
+
exist_ok=cfg.exist_ok,
|
| 53 |
+
)
|
| 54 |
+
if cfg.branch:
|
| 55 |
+
hub_api.create_branch(
|
| 56 |
+
repo_id=cfg.repo_id,
|
| 57 |
+
branch=cfg.branch,
|
| 58 |
+
repo_type="model",
|
| 59 |
+
exist_ok=cfg.exist_ok,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
hub_api.upload_folder(
|
| 63 |
+
repo_id=cfg.repo_id,
|
| 64 |
+
folder_path=cfg.pretrained_path,
|
| 65 |
+
repo_type="model",
|
| 66 |
+
revision=cfg.branch,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
if __name__ == "__main__":
|
| 71 |
+
main()
|
lerobot/scripts/train.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
import logging
|
| 17 |
+
import time
|
| 18 |
+
from contextlib import nullcontext
|
| 19 |
+
from pprint import pformat
|
| 20 |
+
from typing import Any
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from termcolor import colored
|
| 24 |
+
from torch.amp import GradScaler
|
| 25 |
+
from torch.optim import Optimizer
|
| 26 |
+
|
| 27 |
+
from lerobot.common.datasets.factory import make_dataset
|
| 28 |
+
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
| 29 |
+
from lerobot.common.datasets.utils import cycle
|
| 30 |
+
from lerobot.common.envs.factory import make_env
|
| 31 |
+
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
| 32 |
+
from lerobot.common.policies.factory import make_policy
|
| 33 |
+
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
| 34 |
+
from lerobot.common.policies.utils import get_device_from_parameters
|
| 35 |
+
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
| 36 |
+
from lerobot.common.utils.random_utils import set_seed
|
| 37 |
+
from lerobot.common.utils.train_utils import (
|
| 38 |
+
get_step_checkpoint_dir,
|
| 39 |
+
get_step_identifier,
|
| 40 |
+
load_training_state,
|
| 41 |
+
save_checkpoint,
|
| 42 |
+
update_last_checkpoint,
|
| 43 |
+
)
|
| 44 |
+
from lerobot.common.utils.utils import (
|
| 45 |
+
format_big_number,
|
| 46 |
+
get_safe_torch_device,
|
| 47 |
+
has_method,
|
| 48 |
+
init_logging,
|
| 49 |
+
)
|
| 50 |
+
from lerobot.common.utils.wandb_utils import WandBLogger
|
| 51 |
+
from lerobot.configs import parser
|
| 52 |
+
from lerobot.configs.train import TrainPipelineConfig
|
| 53 |
+
from lerobot.scripts.eval import eval_policy
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def update_policy(
|
| 57 |
+
train_metrics: MetricsTracker,
|
| 58 |
+
policy: PreTrainedPolicy,
|
| 59 |
+
batch: Any,
|
| 60 |
+
optimizer: Optimizer,
|
| 61 |
+
grad_clip_norm: float,
|
| 62 |
+
grad_scaler: GradScaler,
|
| 63 |
+
lr_scheduler=None,
|
| 64 |
+
use_amp: bool = False,
|
| 65 |
+
lock=None,
|
| 66 |
+
) -> tuple[MetricsTracker, dict]:
|
| 67 |
+
start_time = time.perf_counter()
|
| 68 |
+
device = get_device_from_parameters(policy)
|
| 69 |
+
policy.train()
|
| 70 |
+
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
| 71 |
+
loss, output_dict = policy.forward(batch)
|
| 72 |
+
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
| 73 |
+
grad_scaler.scale(loss).backward()
|
| 74 |
+
|
| 75 |
+
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
|
| 76 |
+
grad_scaler.unscale_(optimizer)
|
| 77 |
+
|
| 78 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
| 79 |
+
policy.parameters(),
|
| 80 |
+
grad_clip_norm,
|
| 81 |
+
error_if_nonfinite=False,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
| 85 |
+
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
| 86 |
+
with lock if lock is not None else nullcontext():
|
| 87 |
+
grad_scaler.step(optimizer)
|
| 88 |
+
# Updates the scale for next iteration.
|
| 89 |
+
grad_scaler.update()
|
| 90 |
+
|
| 91 |
+
optimizer.zero_grad()
|
| 92 |
+
|
| 93 |
+
# Step through pytorch scheduler at every batch instead of epoch
|
| 94 |
+
if lr_scheduler is not None:
|
| 95 |
+
lr_scheduler.step()
|
| 96 |
+
|
| 97 |
+
if has_method(policy, "update"):
|
| 98 |
+
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
|
| 99 |
+
policy.update()
|
| 100 |
+
|
| 101 |
+
train_metrics.loss = loss.item()
|
| 102 |
+
train_metrics.grad_norm = grad_norm.item()
|
| 103 |
+
train_metrics.lr = optimizer.param_groups[0]["lr"]
|
| 104 |
+
train_metrics.update_s = time.perf_counter() - start_time
|
| 105 |
+
return train_metrics, output_dict
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@parser.wrap()
|
| 109 |
+
def train(cfg: TrainPipelineConfig):
|
| 110 |
+
cfg.validate()
|
| 111 |
+
logging.info(pformat(cfg.to_dict()))
|
| 112 |
+
|
| 113 |
+
if cfg.wandb.enable and cfg.wandb.project:
|
| 114 |
+
wandb_logger = WandBLogger(cfg)
|
| 115 |
+
else:
|
| 116 |
+
wandb_logger = None
|
| 117 |
+
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
| 118 |
+
|
| 119 |
+
if cfg.seed is not None:
|
| 120 |
+
set_seed(cfg.seed)
|
| 121 |
+
|
| 122 |
+
# Check device is available
|
| 123 |
+
device = get_safe_torch_device(cfg.policy.device, log=True)
|
| 124 |
+
torch.backends.cudnn.benchmark = True
|
| 125 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 126 |
+
|
| 127 |
+
logging.info("Creating dataset")
|
| 128 |
+
dataset = make_dataset(cfg)
|
| 129 |
+
|
| 130 |
+
# Create environment used for evaluating checkpoints during training on simulation data.
|
| 131 |
+
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
| 132 |
+
# using the eval.py instead, with gym_dora environment and dora-rs.
|
| 133 |
+
eval_env = None
|
| 134 |
+
if cfg.eval_freq > 0 and cfg.env is not None:
|
| 135 |
+
logging.info("Creating env")
|
| 136 |
+
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size)
|
| 137 |
+
|
| 138 |
+
logging.info("Creating policy")
|
| 139 |
+
policy = make_policy(
|
| 140 |
+
cfg=cfg.policy,
|
| 141 |
+
ds_meta=dataset.meta,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
logging.info("Creating optimizer and scheduler")
|
| 145 |
+
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
| 146 |
+
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
|
| 147 |
+
|
| 148 |
+
step = 0 # number of policy updates (forward + backward + optim)
|
| 149 |
+
|
| 150 |
+
if cfg.resume:
|
| 151 |
+
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
| 152 |
+
|
| 153 |
+
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
| 154 |
+
num_total_params = sum(p.numel() for p in policy.parameters())
|
| 155 |
+
|
| 156 |
+
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
| 157 |
+
if cfg.env is not None:
|
| 158 |
+
logging.info(f"{cfg.env.task=}")
|
| 159 |
+
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
| 160 |
+
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
| 161 |
+
logging.info(f"{dataset.num_episodes=}")
|
| 162 |
+
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
| 163 |
+
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
| 164 |
+
|
| 165 |
+
# create dataloader for offline training
|
| 166 |
+
if hasattr(cfg.policy, "drop_n_last_frames"):
|
| 167 |
+
shuffle = False
|
| 168 |
+
sampler = EpisodeAwareSampler(
|
| 169 |
+
dataset.episode_data_index,
|
| 170 |
+
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
| 171 |
+
shuffle=True,
|
| 172 |
+
)
|
| 173 |
+
else:
|
| 174 |
+
shuffle = True
|
| 175 |
+
sampler = None
|
| 176 |
+
|
| 177 |
+
dataloader = torch.utils.data.DataLoader(
|
| 178 |
+
dataset,
|
| 179 |
+
num_workers=cfg.num_workers,
|
| 180 |
+
batch_size=cfg.batch_size,
|
| 181 |
+
shuffle=shuffle,
|
| 182 |
+
sampler=sampler,
|
| 183 |
+
pin_memory=device.type != "cpu",
|
| 184 |
+
drop_last=False,
|
| 185 |
+
)
|
| 186 |
+
dl_iter = cycle(dataloader)
|
| 187 |
+
|
| 188 |
+
policy.train()
|
| 189 |
+
|
| 190 |
+
train_metrics = {
|
| 191 |
+
"loss": AverageMeter("loss", ":.3f"),
|
| 192 |
+
"grad_norm": AverageMeter("grdn", ":.3f"),
|
| 193 |
+
"lr": AverageMeter("lr", ":0.1e"),
|
| 194 |
+
"update_s": AverageMeter("updt_s", ":.3f"),
|
| 195 |
+
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
train_tracker = MetricsTracker(
|
| 199 |
+
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
logging.info("Start offline training on a fixed dataset")
|
| 203 |
+
for _ in range(step, cfg.steps):
|
| 204 |
+
start_time = time.perf_counter()
|
| 205 |
+
batch = next(dl_iter)
|
| 206 |
+
train_tracker.dataloading_s = time.perf_counter() - start_time
|
| 207 |
+
|
| 208 |
+
for key in batch:
|
| 209 |
+
if isinstance(batch[key], torch.Tensor):
|
| 210 |
+
batch[key] = batch[key].to(device, non_blocking=True)
|
| 211 |
+
|
| 212 |
+
train_tracker, output_dict = update_policy(
|
| 213 |
+
train_tracker,
|
| 214 |
+
policy,
|
| 215 |
+
batch,
|
| 216 |
+
optimizer,
|
| 217 |
+
cfg.optimizer.grad_clip_norm,
|
| 218 |
+
grad_scaler=grad_scaler,
|
| 219 |
+
lr_scheduler=lr_scheduler,
|
| 220 |
+
use_amp=cfg.policy.use_amp,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
| 224 |
+
# increment `step` here.
|
| 225 |
+
step += 1
|
| 226 |
+
train_tracker.step()
|
| 227 |
+
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
| 228 |
+
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
| 229 |
+
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
| 230 |
+
|
| 231 |
+
if is_log_step:
|
| 232 |
+
logging.info(train_tracker)
|
| 233 |
+
if wandb_logger:
|
| 234 |
+
wandb_log_dict = train_tracker.to_dict()
|
| 235 |
+
if output_dict:
|
| 236 |
+
wandb_log_dict.update(output_dict)
|
| 237 |
+
wandb_logger.log_dict(wandb_log_dict, step)
|
| 238 |
+
train_tracker.reset_averages()
|
| 239 |
+
|
| 240 |
+
if cfg.save_checkpoint and is_saving_step:
|
| 241 |
+
logging.info(f"Checkpoint policy after step {step}")
|
| 242 |
+
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
| 243 |
+
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
|
| 244 |
+
update_last_checkpoint(checkpoint_dir)
|
| 245 |
+
if wandb_logger:
|
| 246 |
+
wandb_logger.log_policy(checkpoint_dir)
|
| 247 |
+
|
| 248 |
+
if cfg.env and is_eval_step:
|
| 249 |
+
step_id = get_step_identifier(step, cfg.steps)
|
| 250 |
+
logging.info(f"Eval policy at step {step}")
|
| 251 |
+
with (
|
| 252 |
+
torch.no_grad(),
|
| 253 |
+
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
| 254 |
+
):
|
| 255 |
+
eval_info = eval_policy(
|
| 256 |
+
eval_env,
|
| 257 |
+
policy,
|
| 258 |
+
cfg.eval.n_episodes,
|
| 259 |
+
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
| 260 |
+
max_episodes_rendered=4,
|
| 261 |
+
start_seed=cfg.seed,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
eval_metrics = {
|
| 265 |
+
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
| 266 |
+
"pc_success": AverageMeter("success", ":.1f"),
|
| 267 |
+
"eval_s": AverageMeter("eval_s", ":.3f"),
|
| 268 |
+
}
|
| 269 |
+
eval_tracker = MetricsTracker(
|
| 270 |
+
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
|
| 271 |
+
)
|
| 272 |
+
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
|
| 273 |
+
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
|
| 274 |
+
eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success")
|
| 275 |
+
logging.info(eval_tracker)
|
| 276 |
+
if wandb_logger:
|
| 277 |
+
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
| 278 |
+
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
| 279 |
+
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
| 280 |
+
|
| 281 |
+
if eval_env:
|
| 282 |
+
eval_env.close()
|
| 283 |
+
logging.info("End of training")
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
if __name__ == "__main__":
|
| 287 |
+
init_logging()
|
| 288 |
+
train()
|
lerobot/scripts/visualize_dataset.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
|
| 17 |
+
|
| 18 |
+
Note: The last frame of the episode doesn't always correspond to a final state.
|
| 19 |
+
That's because our datasets are composed of transition from state to state up to
|
| 20 |
+
the antepenultimate state associated to the ultimate action to arrive in the final state.
|
| 21 |
+
However, there might not be a transition from a final state to another state.
|
| 22 |
+
|
| 23 |
+
Note: This script aims to visualize the data used to train the neural networks.
|
| 24 |
+
~What you see is what you get~. When visualizing image modality, it is often expected to observe
|
| 25 |
+
lossy compression artifacts since these images have been decoded from compressed mp4 videos to
|
| 26 |
+
save disk space. The compression factor applied has been tuned to not affect success rate.
|
| 27 |
+
|
| 28 |
+
Examples:
|
| 29 |
+
|
| 30 |
+
- Visualize data stored on a local machine:
|
| 31 |
+
```
|
| 32 |
+
local$ python lerobot/scripts/visualize_dataset.py \
|
| 33 |
+
--repo-id lerobot/pusht \
|
| 34 |
+
--episode-index 0
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
- Visualize data stored on a distant machine with a local viewer:
|
| 38 |
+
```
|
| 39 |
+
distant$ python lerobot/scripts/visualize_dataset.py \
|
| 40 |
+
--repo-id lerobot/pusht \
|
| 41 |
+
--episode-index 0 \
|
| 42 |
+
--save 1 \
|
| 43 |
+
--output-dir path/to/directory
|
| 44 |
+
|
| 45 |
+
local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd .
|
| 46 |
+
local$ rerun lerobot_pusht_episode_0.rrd
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
- Visualize data stored on a distant machine through streaming:
|
| 50 |
+
(You need to forward the websocket port to the distant machine, with
|
| 51 |
+
`ssh -L 9087:localhost:9087 username@remote-host`)
|
| 52 |
+
```
|
| 53 |
+
distant$ python lerobot/scripts/visualize_dataset.py \
|
| 54 |
+
--repo-id lerobot/pusht \
|
| 55 |
+
--episode-index 0 \
|
| 56 |
+
--mode distant \
|
| 57 |
+
--ws-port 9087
|
| 58 |
+
|
| 59 |
+
local$ rerun ws://localhost:9087
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
import argparse
|
| 65 |
+
import gc
|
| 66 |
+
import logging
|
| 67 |
+
import time
|
| 68 |
+
from pathlib import Path
|
| 69 |
+
from typing import Iterator
|
| 70 |
+
|
| 71 |
+
import numpy as np
|
| 72 |
+
import rerun as rr
|
| 73 |
+
import torch
|
| 74 |
+
import torch.utils.data
|
| 75 |
+
import tqdm
|
| 76 |
+
|
| 77 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class EpisodeSampler(torch.utils.data.Sampler):
|
| 81 |
+
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
| 82 |
+
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
| 83 |
+
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
| 84 |
+
self.frame_ids = range(from_idx, to_idx)
|
| 85 |
+
|
| 86 |
+
def __iter__(self) -> Iterator:
|
| 87 |
+
return iter(self.frame_ids)
|
| 88 |
+
|
| 89 |
+
def __len__(self) -> int:
|
| 90 |
+
return len(self.frame_ids)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
| 94 |
+
assert chw_float32_torch.dtype == torch.float32
|
| 95 |
+
assert chw_float32_torch.ndim == 3
|
| 96 |
+
c, h, w = chw_float32_torch.shape
|
| 97 |
+
assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
|
| 98 |
+
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
|
| 99 |
+
return hwc_uint8_numpy
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def visualize_dataset(
|
| 103 |
+
dataset: LeRobotDataset,
|
| 104 |
+
episode_index: int,
|
| 105 |
+
batch_size: int = 32,
|
| 106 |
+
num_workers: int = 0,
|
| 107 |
+
mode: str = "local",
|
| 108 |
+
web_port: int = 9090,
|
| 109 |
+
ws_port: int = 9087,
|
| 110 |
+
save: bool = False,
|
| 111 |
+
output_dir: Path | None = None,
|
| 112 |
+
) -> Path | None:
|
| 113 |
+
if save:
|
| 114 |
+
assert output_dir is not None, (
|
| 115 |
+
"Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
repo_id = dataset.repo_id
|
| 119 |
+
|
| 120 |
+
logging.info("Loading dataloader")
|
| 121 |
+
episode_sampler = EpisodeSampler(dataset, episode_index)
|
| 122 |
+
dataloader = torch.utils.data.DataLoader(
|
| 123 |
+
dataset,
|
| 124 |
+
num_workers=num_workers,
|
| 125 |
+
batch_size=batch_size,
|
| 126 |
+
sampler=episode_sampler,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
logging.info("Starting Rerun")
|
| 130 |
+
|
| 131 |
+
if mode not in ["local", "distant"]:
|
| 132 |
+
raise ValueError(mode)
|
| 133 |
+
|
| 134 |
+
spawn_local_viewer = mode == "local" and not save
|
| 135 |
+
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
|
| 136 |
+
|
| 137 |
+
# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
|
| 138 |
+
# when iterating on a dataloader with `num_workers` > 0
|
| 139 |
+
# TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
|
| 140 |
+
gc.collect()
|
| 141 |
+
|
| 142 |
+
if mode == "distant":
|
| 143 |
+
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
|
| 144 |
+
|
| 145 |
+
logging.info("Logging to Rerun")
|
| 146 |
+
|
| 147 |
+
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
| 148 |
+
# iterate over the batch
|
| 149 |
+
for i in range(len(batch["index"])):
|
| 150 |
+
rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
|
| 151 |
+
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
|
| 152 |
+
|
| 153 |
+
# display each camera image
|
| 154 |
+
for key in dataset.meta.camera_keys:
|
| 155 |
+
# TODO(rcadene): add `.compress()`? is it lossless?
|
| 156 |
+
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
|
| 157 |
+
|
| 158 |
+
# display each dimension of action space (e.g. actuators command)
|
| 159 |
+
if "action" in batch:
|
| 160 |
+
for dim_idx, val in enumerate(batch["action"][i]):
|
| 161 |
+
rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
|
| 162 |
+
|
| 163 |
+
# display each dimension of observed state space (e.g. agent position in joint space)
|
| 164 |
+
if "observation.state" in batch:
|
| 165 |
+
for dim_idx, val in enumerate(batch["observation.state"][i]):
|
| 166 |
+
rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
|
| 167 |
+
|
| 168 |
+
if "next.done" in batch:
|
| 169 |
+
rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
|
| 170 |
+
|
| 171 |
+
if "next.reward" in batch:
|
| 172 |
+
rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
|
| 173 |
+
|
| 174 |
+
if "next.success" in batch:
|
| 175 |
+
rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
|
| 176 |
+
|
| 177 |
+
if mode == "local" and save:
|
| 178 |
+
# save .rrd locally
|
| 179 |
+
output_dir = Path(output_dir)
|
| 180 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 181 |
+
repo_id_str = repo_id.replace("/", "_")
|
| 182 |
+
rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd"
|
| 183 |
+
rr.save(rrd_path)
|
| 184 |
+
return rrd_path
|
| 185 |
+
|
| 186 |
+
elif mode == "distant":
|
| 187 |
+
# stop the process from exiting since it is serving the websocket connection
|
| 188 |
+
try:
|
| 189 |
+
while True:
|
| 190 |
+
time.sleep(1)
|
| 191 |
+
except KeyboardInterrupt:
|
| 192 |
+
print("Ctrl-C received. Exiting.")
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def main():
|
| 196 |
+
parser = argparse.ArgumentParser()
|
| 197 |
+
|
| 198 |
+
parser.add_argument(
|
| 199 |
+
"--repo-id",
|
| 200 |
+
type=str,
|
| 201 |
+
required=True,
|
| 202 |
+
help="Name of hugging face repository containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
|
| 203 |
+
)
|
| 204 |
+
parser.add_argument(
|
| 205 |
+
"--episode-index",
|
| 206 |
+
type=int,
|
| 207 |
+
required=True,
|
| 208 |
+
help="Episode to visualize.",
|
| 209 |
+
)
|
| 210 |
+
parser.add_argument(
|
| 211 |
+
"--root",
|
| 212 |
+
type=Path,
|
| 213 |
+
default=None,
|
| 214 |
+
help="Root directory for the dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
| 215 |
+
)
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--output-dir",
|
| 218 |
+
type=Path,
|
| 219 |
+
default=None,
|
| 220 |
+
help="Directory path to write a .rrd file when `--save 1` is set.",
|
| 221 |
+
)
|
| 222 |
+
parser.add_argument(
|
| 223 |
+
"--batch-size",
|
| 224 |
+
type=int,
|
| 225 |
+
default=32,
|
| 226 |
+
help="Batch size loaded by DataLoader.",
|
| 227 |
+
)
|
| 228 |
+
parser.add_argument(
|
| 229 |
+
"--num-workers",
|
| 230 |
+
type=int,
|
| 231 |
+
default=4,
|
| 232 |
+
help="Number of processes of Dataloader for loading the data.",
|
| 233 |
+
)
|
| 234 |
+
parser.add_argument(
|
| 235 |
+
"--mode",
|
| 236 |
+
type=str,
|
| 237 |
+
default="local",
|
| 238 |
+
help=(
|
| 239 |
+
"Mode of viewing between 'local' or 'distant'. "
|
| 240 |
+
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
|
| 241 |
+
"'distant' creates a server on the distant machine where the data is stored. "
|
| 242 |
+
"Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
| 243 |
+
),
|
| 244 |
+
)
|
| 245 |
+
parser.add_argument(
|
| 246 |
+
"--web-port",
|
| 247 |
+
type=int,
|
| 248 |
+
default=9090,
|
| 249 |
+
help="Web port for rerun.io when `--mode distant` is set.",
|
| 250 |
+
)
|
| 251 |
+
parser.add_argument(
|
| 252 |
+
"--ws-port",
|
| 253 |
+
type=int,
|
| 254 |
+
default=9087,
|
| 255 |
+
help="Web socket port for rerun.io when `--mode distant` is set.",
|
| 256 |
+
)
|
| 257 |
+
parser.add_argument(
|
| 258 |
+
"--save",
|
| 259 |
+
type=int,
|
| 260 |
+
default=0,
|
| 261 |
+
help=(
|
| 262 |
+
"Save a .rrd file in the directory provided by `--output-dir`. "
|
| 263 |
+
"It also deactivates the spawning of a viewer. "
|
| 264 |
+
"Visualize the data by running `rerun path/to/file.rrd` on your local machine."
|
| 265 |
+
),
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
parser.add_argument(
|
| 269 |
+
"--tolerance-s",
|
| 270 |
+
type=float,
|
| 271 |
+
default=1e-4,
|
| 272 |
+
help=(
|
| 273 |
+
"Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
|
| 274 |
+
"This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
|
| 275 |
+
"If not given, defaults to 1e-4."
|
| 276 |
+
),
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
args = parser.parse_args()
|
| 280 |
+
kwargs = vars(args)
|
| 281 |
+
repo_id = kwargs.pop("repo_id")
|
| 282 |
+
root = kwargs.pop("root")
|
| 283 |
+
tolerance_s = kwargs.pop("tolerance_s")
|
| 284 |
+
|
| 285 |
+
logging.info("Loading dataset")
|
| 286 |
+
dataset = LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
|
| 287 |
+
|
| 288 |
+
visualize_dataset(dataset, **vars(args))
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
if __name__ == "__main__":
|
| 292 |
+
main()
|
lerobot/scripts/visualize_dataset_html.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
|
| 17 |
+
|
| 18 |
+
Note: The last frame of the episode doesnt always correspond to a final state.
|
| 19 |
+
That's because our datasets are composed of transition from state to state up to
|
| 20 |
+
the antepenultimate state associated to the ultimate action to arrive in the final state.
|
| 21 |
+
However, there might not be a transition from a final state to another state.
|
| 22 |
+
|
| 23 |
+
Note: This script aims to visualize the data used to train the neural networks.
|
| 24 |
+
~What you see is what you get~. When visualizing image modality, it is often expected to observe
|
| 25 |
+
lossly compression artifacts since these images have been decoded from compressed mp4 videos to
|
| 26 |
+
save disk space. The compression factor applied has been tuned to not affect success rate.
|
| 27 |
+
|
| 28 |
+
Example of usage:
|
| 29 |
+
|
| 30 |
+
- Visualize data stored on a local machine:
|
| 31 |
+
```bash
|
| 32 |
+
local$ python lerobot/scripts/visualize_dataset_html.py \
|
| 33 |
+
--repo-id lerobot/pusht
|
| 34 |
+
|
| 35 |
+
local$ open http://localhost:9090
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
- Visualize data stored on a distant machine with a local viewer:
|
| 39 |
+
```bash
|
| 40 |
+
distant$ python lerobot/scripts/visualize_dataset_html.py \
|
| 41 |
+
--repo-id lerobot/pusht
|
| 42 |
+
|
| 43 |
+
local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel
|
| 44 |
+
local$ open http://localhost:9090
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
- Select episodes to visualize:
|
| 48 |
+
```bash
|
| 49 |
+
python lerobot/scripts/visualize_dataset_html.py \
|
| 50 |
+
--repo-id lerobot/pusht \
|
| 51 |
+
--episodes 7 3 5 1 4
|
| 52 |
+
```
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
import argparse
|
| 56 |
+
import csv
|
| 57 |
+
import json
|
| 58 |
+
import logging
|
| 59 |
+
import re
|
| 60 |
+
import shutil
|
| 61 |
+
import tempfile
|
| 62 |
+
from io import StringIO
|
| 63 |
+
from pathlib import Path
|
| 64 |
+
|
| 65 |
+
import numpy as np
|
| 66 |
+
import pandas as pd
|
| 67 |
+
import requests
|
| 68 |
+
from flask import Flask, redirect, render_template, request, url_for
|
| 69 |
+
|
| 70 |
+
from lerobot import available_datasets
|
| 71 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
| 72 |
+
from lerobot.common.datasets.utils import IterableNamespace
|
| 73 |
+
from lerobot.common.utils.utils import init_logging
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def run_server(
|
| 77 |
+
dataset: LeRobotDataset | IterableNamespace | None,
|
| 78 |
+
episodes: list[int] | None,
|
| 79 |
+
host: str,
|
| 80 |
+
port: str,
|
| 81 |
+
static_folder: Path,
|
| 82 |
+
template_folder: Path,
|
| 83 |
+
):
|
| 84 |
+
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
| 85 |
+
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
| 86 |
+
|
| 87 |
+
@app.route("/")
|
| 88 |
+
def hommepage(dataset=dataset):
|
| 89 |
+
if dataset:
|
| 90 |
+
dataset_namespace, dataset_name = dataset.repo_id.split("/")
|
| 91 |
+
return redirect(
|
| 92 |
+
url_for(
|
| 93 |
+
"show_episode",
|
| 94 |
+
dataset_namespace=dataset_namespace,
|
| 95 |
+
dataset_name=dataset_name,
|
| 96 |
+
episode_id=0,
|
| 97 |
+
)
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
dataset_param, episode_param = None, None
|
| 101 |
+
all_params = request.args
|
| 102 |
+
if "dataset" in all_params:
|
| 103 |
+
dataset_param = all_params["dataset"]
|
| 104 |
+
if "episode" in all_params:
|
| 105 |
+
episode_param = int(all_params["episode"])
|
| 106 |
+
|
| 107 |
+
if dataset_param:
|
| 108 |
+
dataset_namespace, dataset_name = dataset_param.split("/")
|
| 109 |
+
return redirect(
|
| 110 |
+
url_for(
|
| 111 |
+
"show_episode",
|
| 112 |
+
dataset_namespace=dataset_namespace,
|
| 113 |
+
dataset_name=dataset_name,
|
| 114 |
+
episode_id=episode_param if episode_param is not None else 0,
|
| 115 |
+
)
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
featured_datasets = [
|
| 119 |
+
"lerobot/aloha_static_cups_open",
|
| 120 |
+
"lerobot/columbia_cairlab_pusht_real",
|
| 121 |
+
"lerobot/taco_play",
|
| 122 |
+
]
|
| 123 |
+
return render_template(
|
| 124 |
+
"visualize_dataset_homepage.html",
|
| 125 |
+
featured_datasets=featured_datasets,
|
| 126 |
+
lerobot_datasets=available_datasets,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
@app.route("/<string:dataset_namespace>/<string:dataset_name>")
|
| 130 |
+
def show_first_episode(dataset_namespace, dataset_name):
|
| 131 |
+
first_episode_id = 0
|
| 132 |
+
return redirect(
|
| 133 |
+
url_for(
|
| 134 |
+
"show_episode",
|
| 135 |
+
dataset_namespace=dataset_namespace,
|
| 136 |
+
dataset_name=dataset_name,
|
| 137 |
+
episode_id=first_episode_id,
|
| 138 |
+
)
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
|
| 142 |
+
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
|
| 143 |
+
repo_id = f"{dataset_namespace}/{dataset_name}"
|
| 144 |
+
try:
|
| 145 |
+
if dataset is None:
|
| 146 |
+
dataset = get_dataset_info(repo_id)
|
| 147 |
+
except FileNotFoundError:
|
| 148 |
+
return (
|
| 149 |
+
"Make sure to convert your LeRobotDataset to v2 & above. See how to convert your dataset at https://github.com/huggingface/lerobot/pull/461",
|
| 150 |
+
400,
|
| 151 |
+
)
|
| 152 |
+
dataset_version = (
|
| 153 |
+
str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
|
| 154 |
+
)
|
| 155 |
+
match = re.search(r"v(\d+)\.", dataset_version)
|
| 156 |
+
if match:
|
| 157 |
+
major_version = int(match.group(1))
|
| 158 |
+
if major_version < 2:
|
| 159 |
+
return "Make sure to convert your LeRobotDataset to v2 & above."
|
| 160 |
+
|
| 161 |
+
episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
|
| 162 |
+
dataset_info = {
|
| 163 |
+
"repo_id": f"{dataset_namespace}/{dataset_name}",
|
| 164 |
+
"num_samples": dataset.num_frames
|
| 165 |
+
if isinstance(dataset, LeRobotDataset)
|
| 166 |
+
else dataset.total_frames,
|
| 167 |
+
"num_episodes": dataset.num_episodes
|
| 168 |
+
if isinstance(dataset, LeRobotDataset)
|
| 169 |
+
else dataset.total_episodes,
|
| 170 |
+
"fps": dataset.fps,
|
| 171 |
+
}
|
| 172 |
+
if isinstance(dataset, LeRobotDataset):
|
| 173 |
+
video_paths = [
|
| 174 |
+
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
|
| 175 |
+
]
|
| 176 |
+
videos_info = [
|
| 177 |
+
{"url": url_for("static", filename=video_path), "filename": video_path.parent.name}
|
| 178 |
+
for video_path in video_paths
|
| 179 |
+
]
|
| 180 |
+
tasks = dataset.meta.episodes[episode_id]["tasks"]
|
| 181 |
+
else:
|
| 182 |
+
video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
|
| 183 |
+
videos_info = [
|
| 184 |
+
{
|
| 185 |
+
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
|
| 186 |
+
+ dataset.video_path.format(
|
| 187 |
+
episode_chunk=int(episode_id) // dataset.chunks_size,
|
| 188 |
+
video_key=video_key,
|
| 189 |
+
episode_index=episode_id,
|
| 190 |
+
),
|
| 191 |
+
"filename": video_key,
|
| 192 |
+
}
|
| 193 |
+
for video_key in video_keys
|
| 194 |
+
]
|
| 195 |
+
|
| 196 |
+
response = requests.get(
|
| 197 |
+
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
|
| 198 |
+
)
|
| 199 |
+
response.raise_for_status()
|
| 200 |
+
# Split into lines and parse each line as JSON
|
| 201 |
+
tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
|
| 202 |
+
|
| 203 |
+
filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
|
| 204 |
+
tasks = filtered_tasks_jsonl[0]["tasks"]
|
| 205 |
+
|
| 206 |
+
videos_info[0]["language_instruction"] = tasks
|
| 207 |
+
|
| 208 |
+
if episodes is None:
|
| 209 |
+
episodes = list(
|
| 210 |
+
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
return render_template(
|
| 214 |
+
"visualize_dataset_template.html",
|
| 215 |
+
episode_id=episode_id,
|
| 216 |
+
episodes=episodes,
|
| 217 |
+
dataset_info=dataset_info,
|
| 218 |
+
videos_info=videos_info,
|
| 219 |
+
episode_data_csv_str=episode_data_csv_str,
|
| 220 |
+
columns=columns,
|
| 221 |
+
ignored_columns=ignored_columns,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
app.run(host=host, port=port)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def get_ep_csv_fname(episode_id: int):
|
| 228 |
+
ep_csv_fname = f"episode_{episode_id}.csv"
|
| 229 |
+
return ep_csv_fname
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index):
|
| 233 |
+
"""Get a csv str containing timeseries data of an episode (e.g. state and action).
|
| 234 |
+
This file will be loaded by Dygraph javascript to plot data in real time."""
|
| 235 |
+
columns = []
|
| 236 |
+
|
| 237 |
+
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
|
| 238 |
+
selected_columns.remove("timestamp")
|
| 239 |
+
|
| 240 |
+
ignored_columns = []
|
| 241 |
+
for column_name in selected_columns:
|
| 242 |
+
shape = dataset.features[column_name]["shape"]
|
| 243 |
+
shape_dim = len(shape)
|
| 244 |
+
if shape_dim > 1:
|
| 245 |
+
selected_columns.remove(column_name)
|
| 246 |
+
ignored_columns.append(column_name)
|
| 247 |
+
|
| 248 |
+
# init header of csv with state and action names
|
| 249 |
+
header = ["timestamp"]
|
| 250 |
+
|
| 251 |
+
for column_name in selected_columns:
|
| 252 |
+
dim_state = (
|
| 253 |
+
dataset.meta.shapes[column_name][0]
|
| 254 |
+
if isinstance(dataset, LeRobotDataset)
|
| 255 |
+
else dataset.features[column_name].shape[0]
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
|
| 259 |
+
column_names = dataset.features[column_name]["names"]
|
| 260 |
+
while not isinstance(column_names, list):
|
| 261 |
+
column_names = list(column_names.values())[0]
|
| 262 |
+
else:
|
| 263 |
+
column_names = [f"{column_name}_{i}" for i in range(dim_state)]
|
| 264 |
+
columns.append({"key": column_name, "value": column_names})
|
| 265 |
+
|
| 266 |
+
header += column_names
|
| 267 |
+
|
| 268 |
+
selected_columns.insert(0, "timestamp")
|
| 269 |
+
|
| 270 |
+
if isinstance(dataset, LeRobotDataset):
|
| 271 |
+
from_idx = dataset.episode_data_index["from"][episode_index]
|
| 272 |
+
to_idx = dataset.episode_data_index["to"][episode_index]
|
| 273 |
+
data = (
|
| 274 |
+
dataset.hf_dataset.select(range(from_idx, to_idx))
|
| 275 |
+
.select_columns(selected_columns)
|
| 276 |
+
.with_format("pandas")
|
| 277 |
+
)
|
| 278 |
+
else:
|
| 279 |
+
repo_id = dataset.repo_id
|
| 280 |
+
|
| 281 |
+
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
|
| 282 |
+
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
|
| 283 |
+
)
|
| 284 |
+
df = pd.read_parquet(url)
|
| 285 |
+
data = df[selected_columns] # Select specific columns
|
| 286 |
+
|
| 287 |
+
rows = np.hstack(
|
| 288 |
+
(
|
| 289 |
+
np.expand_dims(data["timestamp"], axis=1),
|
| 290 |
+
*[np.vstack(data[col]) for col in selected_columns[1:]],
|
| 291 |
+
)
|
| 292 |
+
).tolist()
|
| 293 |
+
|
| 294 |
+
# Convert data to CSV string
|
| 295 |
+
csv_buffer = StringIO()
|
| 296 |
+
csv_writer = csv.writer(csv_buffer)
|
| 297 |
+
# Write header
|
| 298 |
+
csv_writer.writerow(header)
|
| 299 |
+
# Write data rows
|
| 300 |
+
csv_writer.writerows(rows)
|
| 301 |
+
csv_string = csv_buffer.getvalue()
|
| 302 |
+
|
| 303 |
+
return csv_string, columns, ignored_columns
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
| 307 |
+
# get first frame of episode (hack to get video_path of the episode)
|
| 308 |
+
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
| 309 |
+
return [
|
| 310 |
+
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
|
| 311 |
+
for key in dataset.meta.video_keys
|
| 312 |
+
]
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
| 316 |
+
# check if the dataset has language instructions
|
| 317 |
+
if "language_instruction" not in dataset.features:
|
| 318 |
+
return None
|
| 319 |
+
|
| 320 |
+
# get first frame index
|
| 321 |
+
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
| 322 |
+
|
| 323 |
+
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
| 324 |
+
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
| 325 |
+
# with the tf.tensor appearing in the string
|
| 326 |
+
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def get_dataset_info(repo_id: str) -> IterableNamespace:
|
| 330 |
+
response = requests.get(
|
| 331 |
+
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
|
| 332 |
+
)
|
| 333 |
+
response.raise_for_status() # Raises an HTTPError for bad responses
|
| 334 |
+
dataset_info = response.json()
|
| 335 |
+
dataset_info["repo_id"] = repo_id
|
| 336 |
+
return IterableNamespace(dataset_info)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def visualize_dataset_html(
|
| 340 |
+
dataset: LeRobotDataset | None,
|
| 341 |
+
episodes: list[int] | None = None,
|
| 342 |
+
output_dir: Path | None = None,
|
| 343 |
+
serve: bool = True,
|
| 344 |
+
host: str = "127.0.0.1",
|
| 345 |
+
port: int = 9090,
|
| 346 |
+
force_override: bool = False,
|
| 347 |
+
) -> Path | None:
|
| 348 |
+
init_logging()
|
| 349 |
+
|
| 350 |
+
template_dir = Path(__file__).resolve().parent.parent / "templates"
|
| 351 |
+
|
| 352 |
+
if output_dir is None:
|
| 353 |
+
# Create a temporary directory that will be automatically cleaned up
|
| 354 |
+
output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_")
|
| 355 |
+
|
| 356 |
+
output_dir = Path(output_dir)
|
| 357 |
+
if output_dir.exists():
|
| 358 |
+
if force_override:
|
| 359 |
+
shutil.rmtree(output_dir)
|
| 360 |
+
else:
|
| 361 |
+
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
|
| 362 |
+
|
| 363 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 364 |
+
|
| 365 |
+
static_dir = output_dir / "static"
|
| 366 |
+
static_dir.mkdir(parents=True, exist_ok=True)
|
| 367 |
+
|
| 368 |
+
if dataset is None:
|
| 369 |
+
if serve:
|
| 370 |
+
run_server(
|
| 371 |
+
dataset=None,
|
| 372 |
+
episodes=None,
|
| 373 |
+
host=host,
|
| 374 |
+
port=port,
|
| 375 |
+
static_folder=static_dir,
|
| 376 |
+
template_folder=template_dir,
|
| 377 |
+
)
|
| 378 |
+
else:
|
| 379 |
+
# Create a simlink from the dataset video folder containing mp4 files to the output directory
|
| 380 |
+
# so that the http server can get access to the mp4 files.
|
| 381 |
+
if isinstance(dataset, LeRobotDataset):
|
| 382 |
+
ln_videos_dir = static_dir / "videos"
|
| 383 |
+
if not ln_videos_dir.exists():
|
| 384 |
+
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
|
| 385 |
+
|
| 386 |
+
if serve:
|
| 387 |
+
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def main():
|
| 391 |
+
parser = argparse.ArgumentParser()
|
| 392 |
+
|
| 393 |
+
parser.add_argument(
|
| 394 |
+
"--repo-id",
|
| 395 |
+
type=str,
|
| 396 |
+
default=None,
|
| 397 |
+
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
|
| 398 |
+
)
|
| 399 |
+
parser.add_argument(
|
| 400 |
+
"--root",
|
| 401 |
+
type=Path,
|
| 402 |
+
default=None,
|
| 403 |
+
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
| 404 |
+
)
|
| 405 |
+
parser.add_argument(
|
| 406 |
+
"--load-from-hf-hub",
|
| 407 |
+
type=int,
|
| 408 |
+
default=0,
|
| 409 |
+
help="Load videos and parquet files from HF Hub rather than local system.",
|
| 410 |
+
)
|
| 411 |
+
parser.add_argument(
|
| 412 |
+
"--episodes",
|
| 413 |
+
type=int,
|
| 414 |
+
nargs="*",
|
| 415 |
+
default=None,
|
| 416 |
+
help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
|
| 417 |
+
)
|
| 418 |
+
parser.add_argument(
|
| 419 |
+
"--output-dir",
|
| 420 |
+
type=Path,
|
| 421 |
+
default=None,
|
| 422 |
+
help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
|
| 423 |
+
)
|
| 424 |
+
parser.add_argument(
|
| 425 |
+
"--serve",
|
| 426 |
+
type=int,
|
| 427 |
+
default=1,
|
| 428 |
+
help="Launch web server.",
|
| 429 |
+
)
|
| 430 |
+
parser.add_argument(
|
| 431 |
+
"--host",
|
| 432 |
+
type=str,
|
| 433 |
+
default="127.0.0.1",
|
| 434 |
+
help="Web host used by the http server.",
|
| 435 |
+
)
|
| 436 |
+
parser.add_argument(
|
| 437 |
+
"--port",
|
| 438 |
+
type=int,
|
| 439 |
+
default=9090,
|
| 440 |
+
help="Web port used by the http server.",
|
| 441 |
+
)
|
| 442 |
+
parser.add_argument(
|
| 443 |
+
"--force-override",
|
| 444 |
+
type=int,
|
| 445 |
+
default=0,
|
| 446 |
+
help="Delete the output directory if it exists already.",
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
parser.add_argument(
|
| 450 |
+
"--tolerance-s",
|
| 451 |
+
type=float,
|
| 452 |
+
default=1e-4,
|
| 453 |
+
help=(
|
| 454 |
+
"Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
|
| 455 |
+
"This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
|
| 456 |
+
"If not given, defaults to 1e-4."
|
| 457 |
+
),
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
args = parser.parse_args()
|
| 461 |
+
kwargs = vars(args)
|
| 462 |
+
repo_id = kwargs.pop("repo_id")
|
| 463 |
+
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
|
| 464 |
+
root = kwargs.pop("root")
|
| 465 |
+
tolerance_s = kwargs.pop("tolerance_s")
|
| 466 |
+
|
| 467 |
+
dataset = None
|
| 468 |
+
if repo_id:
|
| 469 |
+
dataset = (
|
| 470 |
+
LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
|
| 471 |
+
if not load_from_hf_hub
|
| 472 |
+
else get_dataset_info(repo_id)
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
visualize_dataset_html(dataset, **vars(args))
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
if __name__ == "__main__":
|
| 479 |
+
main()
|
lerobot/scripts/visualize_image_transforms.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
""" Visualize effects of image transforms for a given configuration.
|
| 17 |
+
|
| 18 |
+
This script will generate examples of transformed images as they are output by LeRobot dataset.
|
| 19 |
+
Additionally, each individual transform can be visualized separately as well as examples of combined transforms
|
| 20 |
+
|
| 21 |
+
Example:
|
| 22 |
+
```bash
|
| 23 |
+
python lerobot/scripts/visualize_image_transforms.py \
|
| 24 |
+
--repo_id=lerobot/pusht \
|
| 25 |
+
--episodes='[0]' \
|
| 26 |
+
--image_transforms.enable=True
|
| 27 |
+
```
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
import logging
|
| 31 |
+
from copy import deepcopy
|
| 32 |
+
from dataclasses import replace
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
|
| 35 |
+
import draccus
|
| 36 |
+
from torchvision.transforms import ToPILImage
|
| 37 |
+
|
| 38 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
| 39 |
+
from lerobot.common.datasets.transforms import (
|
| 40 |
+
ImageTransforms,
|
| 41 |
+
ImageTransformsConfig,
|
| 42 |
+
make_transform_from_config,
|
| 43 |
+
)
|
| 44 |
+
from lerobot.configs.default import DatasetConfig
|
| 45 |
+
|
| 46 |
+
OUTPUT_DIR = Path("outputs/image_transforms")
|
| 47 |
+
to_pil = ToPILImage()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def save_all_transforms(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
|
| 51 |
+
output_dir_all = output_dir / "all"
|
| 52 |
+
output_dir_all.mkdir(parents=True, exist_ok=True)
|
| 53 |
+
|
| 54 |
+
tfs = ImageTransforms(cfg)
|
| 55 |
+
for i in range(1, n_examples + 1):
|
| 56 |
+
transformed_frame = tfs(original_frame)
|
| 57 |
+
to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100)
|
| 58 |
+
|
| 59 |
+
print("Combined transforms examples saved to:")
|
| 60 |
+
print(f" {output_dir_all}")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
|
| 64 |
+
if not cfg.enable:
|
| 65 |
+
logging.warning(
|
| 66 |
+
"No single transforms will be saved, because `image_transforms.enable=False`. To enable, set `enable` to True in `ImageTransformsConfig` or in the command line with `--image_transforms.enable=True`."
|
| 67 |
+
)
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
print("Individual transforms examples saved to:")
|
| 71 |
+
for tf_name, tf_cfg in cfg.tfs.items():
|
| 72 |
+
# Apply a few transformation with random value in min_max range
|
| 73 |
+
output_dir_single = output_dir / tf_name
|
| 74 |
+
output_dir_single.mkdir(parents=True, exist_ok=True)
|
| 75 |
+
|
| 76 |
+
tf = make_transform_from_config(tf_cfg)
|
| 77 |
+
for i in range(1, n_examples + 1):
|
| 78 |
+
transformed_frame = tf(original_frame)
|
| 79 |
+
to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100)
|
| 80 |
+
|
| 81 |
+
# Apply min, max, average transformations
|
| 82 |
+
tf_cfg_kwgs_min = deepcopy(tf_cfg.kwargs)
|
| 83 |
+
tf_cfg_kwgs_max = deepcopy(tf_cfg.kwargs)
|
| 84 |
+
tf_cfg_kwgs_avg = deepcopy(tf_cfg.kwargs)
|
| 85 |
+
|
| 86 |
+
for key, (min_, max_) in tf_cfg.kwargs.items():
|
| 87 |
+
avg = (min_ + max_) / 2
|
| 88 |
+
tf_cfg_kwgs_min[key] = [min_, min_]
|
| 89 |
+
tf_cfg_kwgs_max[key] = [max_, max_]
|
| 90 |
+
tf_cfg_kwgs_avg[key] = [avg, avg]
|
| 91 |
+
|
| 92 |
+
tf_min = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min}))
|
| 93 |
+
tf_max = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max}))
|
| 94 |
+
tf_avg = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg}))
|
| 95 |
+
|
| 96 |
+
tf_frame_min = tf_min(original_frame)
|
| 97 |
+
tf_frame_max = tf_max(original_frame)
|
| 98 |
+
tf_frame_avg = tf_avg(original_frame)
|
| 99 |
+
|
| 100 |
+
to_pil(tf_frame_min).save(output_dir_single / "min.png", quality=100)
|
| 101 |
+
to_pil(tf_frame_max).save(output_dir_single / "max.png", quality=100)
|
| 102 |
+
to_pil(tf_frame_avg).save(output_dir_single / "mean.png", quality=100)
|
| 103 |
+
|
| 104 |
+
print(f" {output_dir_single}")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@draccus.wrap()
|
| 108 |
+
def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5):
|
| 109 |
+
dataset = LeRobotDataset(
|
| 110 |
+
repo_id=cfg.repo_id,
|
| 111 |
+
episodes=cfg.episodes,
|
| 112 |
+
revision=cfg.revision,
|
| 113 |
+
video_backend=cfg.video_backend,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
output_dir = output_dir / cfg.repo_id.split("/")[-1]
|
| 117 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 118 |
+
|
| 119 |
+
# Get 1st frame from 1st camera of 1st episode
|
| 120 |
+
original_frame = dataset[0][dataset.meta.camera_keys[0]]
|
| 121 |
+
to_pil(original_frame).save(output_dir / "original_frame.png", quality=100)
|
| 122 |
+
print("\nOriginal frame saved to:")
|
| 123 |
+
print(f" {output_dir / 'original_frame.png'}.")
|
| 124 |
+
|
| 125 |
+
save_all_transforms(cfg.image_transforms, original_frame, output_dir, n_examples)
|
| 126 |
+
save_each_transform(cfg.image_transforms, original_frame, output_dir, n_examples)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
if __name__ == "__main__":
|
| 130 |
+
visualize_image_transforms()
|
lerobot/templates/visualize_dataset_homepage.html
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>Interactive Video Background Page</title>
|
| 7 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
| 8 |
+
<script defer src="https://cdn.jsdelivr.net/npm/alpinejs@3.x.x/dist/cdn.min.js"></script>
|
| 9 |
+
</head>
|
| 10 |
+
<body class="h-screen overflow-hidden font-mono text-white" x-data="{
|
| 11 |
+
inputValue: '',
|
| 12 |
+
navigateToDataset() {
|
| 13 |
+
const trimmedValue = this.inputValue.trim();
|
| 14 |
+
if (trimmedValue) {
|
| 15 |
+
window.location.href = `/${trimmedValue}`;
|
| 16 |
+
}
|
| 17 |
+
}
|
| 18 |
+
}">
|
| 19 |
+
<div class="fixed inset-0 w-full h-full overflow-hidden">
|
| 20 |
+
<video class="absolute min-w-full min-h-full w-auto h-auto top-1/2 left-1/2 transform -translate-x-1/2 -translate-y-1/2" autoplay muted loop>
|
| 21 |
+
<source src="https://huggingface.co/datasets/cadene/koch_bimanual_folding/resolve/v1.6/videos/observation.images.phone_episode_000037.mp4" type="video/mp4">
|
| 22 |
+
Your browser does not support HTML5 video.
|
| 23 |
+
</video>
|
| 24 |
+
</div>
|
| 25 |
+
<div class="fixed inset-0 bg-black bg-opacity-80"></div>
|
| 26 |
+
<div class="relative z-10 flex flex-col items-center justify-center h-screen">
|
| 27 |
+
<div class="text-center mb-8">
|
| 28 |
+
<h1 class="text-4xl font-bold mb-4">LeRobot Dataset Visualizer</h1>
|
| 29 |
+
|
| 30 |
+
<a href="https://x.com/RemiCadene/status/1825455895561859185" target="_blank" rel="noopener noreferrer" class="underline">create & train your own robots</a>
|
| 31 |
+
|
| 32 |
+
<p class="text-xl mb-4"></p>
|
| 33 |
+
<div class="text-left inline-block">
|
| 34 |
+
<h3 class="font-semibold mb-2 mt-4">Example Datasets:</h3>
|
| 35 |
+
<ul class="list-disc list-inside">
|
| 36 |
+
{% for dataset in featured_datasets %}
|
| 37 |
+
<li><a href="/{{ dataset }}" class="text-blue-300 hover:text-blue-100 hover:underline">{{ dataset }}</a></li>
|
| 38 |
+
{% endfor %}
|
| 39 |
+
</ul>
|
| 40 |
+
</div>
|
| 41 |
+
</div>
|
| 42 |
+
<div class="flex w-full max-w-lg px-4 mb-4">
|
| 43 |
+
<input
|
| 44 |
+
type="text"
|
| 45 |
+
x-model="inputValue"
|
| 46 |
+
@keyup.enter="navigateToDataset"
|
| 47 |
+
placeholder="enter dataset id (ex: lerobot/droid_100)"
|
| 48 |
+
class="flex-grow px-4 py-2 rounded-l bg-white bg-opacity-20 text-white placeholder-gray-300 focus:outline-none focus:ring-2 focus:ring-blue-300"
|
| 49 |
+
>
|
| 50 |
+
<button
|
| 51 |
+
@click="navigateToDataset"
|
| 52 |
+
class="px-4 py-2 bg-blue-500 text-white rounded-r hover:bg-blue-600 focus:outline-none focus:ring-2 focus:ring-blue-300"
|
| 53 |
+
>
|
| 54 |
+
Go
|
| 55 |
+
</button>
|
| 56 |
+
</div>
|
| 57 |
+
|
| 58 |
+
<details class="mt-4 max-w-full px-4">
|
| 59 |
+
<summary>More example datasets</summary>
|
| 60 |
+
<ul class="list-disc list-inside max-h-28 overflow-y-auto break-all">
|
| 61 |
+
{% for dataset in lerobot_datasets %}
|
| 62 |
+
<li><a href="/{{ dataset }}" class="text-blue-300 hover:text-blue-100 hover:underline">{{ dataset }}</a></li>
|
| 63 |
+
{% endfor %}
|
| 64 |
+
</ul>
|
| 65 |
+
</details>
|
| 66 |
+
</div>
|
| 67 |
+
</body>
|
| 68 |
+
</html>
|