joaoocruz00 commited on
Commit
b97495f
·
verified ·
1 Parent(s): 069c4d7

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +32 -0
  2. lerobot/common/policies/vqbet/configuration_vqbet.py +200 -0
  3. lerobot/common/policies/vqbet/modeling_vqbet.py +911 -0
  4. lerobot/common/robot_devices/cameras/configs.py +114 -0
  5. lerobot/common/robot_devices/cameras/intelrealsense.py +538 -0
  6. lerobot/common/robot_devices/cameras/opencv.py +518 -0
  7. lerobot/common/robot_devices/cameras/utils.py +67 -0
  8. lerobot/common/robot_devices/control_configs.py +129 -0
  9. lerobot/common/robot_devices/control_utils.py +347 -0
  10. lerobot/common/robot_devices/motors/configs.py +41 -0
  11. lerobot/common/robot_devices/motors/dynamixel.py +873 -0
  12. lerobot/common/robot_devices/motors/feetech.py +898 -0
  13. lerobot/common/robot_devices/motors/utils.py +67 -0
  14. lerobot/common/robot_devices/robots/configs.py +613 -0
  15. lerobot/common/robot_devices/robots/dynamixel_calibration.py +144 -0
  16. lerobot/common/robot_devices/robots/feetech_calibration.py +498 -0
  17. lerobot/common/robot_devices/robots/lekiwi_remote.py +224 -0
  18. lerobot/common/robot_devices/robots/manipulator.py +627 -0
  19. lerobot/common/robot_devices/robots/mobile_manipulator.py +703 -0
  20. lerobot/common/robot_devices/robots/stretch.py +208 -0
  21. lerobot/common/robot_devices/robots/utils.py +86 -0
  22. lerobot/common/robot_devices/utils.py +65 -0
  23. lerobot/common/utils/benchmark.py +92 -0
  24. lerobot/common/utils/hub.py +202 -0
  25. lerobot/common/utils/import_utils.py +59 -0
  26. lerobot/common/utils/io_utils.py +111 -0
  27. lerobot/common/utils/logging_utils.py +163 -0
  28. lerobot/common/utils/random_utils.py +191 -0
  29. lerobot/common/utils/train_utils.py +161 -0
  30. lerobot/common/utils/utils.py +230 -0
  31. lerobot/common/utils/wandb_utils.py +127 -0
  32. lerobot/configs/default.py +70 -0
  33. lerobot/configs/eval.py +65 -0
  34. lerobot/configs/parser.py +232 -0
  35. lerobot/configs/policies.py +176 -0
  36. lerobot/configs/train.py +175 -0
  37. lerobot/configs/types.py +41 -0
  38. lerobot/scripts/configure_motor.py +176 -0
  39. lerobot/scripts/control_robot.py +393 -0
  40. lerobot/scripts/control_sim_robot.py +561 -0
  41. lerobot/scripts/display_sys_info.py +90 -0
  42. lerobot/scripts/eval.py +502 -0
  43. lerobot/scripts/find_motors_bus_port.py +55 -0
  44. lerobot/scripts/push_dataset_to_hub.py +364 -0
  45. lerobot/scripts/push_pretrained.py +71 -0
  46. lerobot/scripts/train.py +288 -0
  47. lerobot/scripts/visualize_dataset.py +292 -0
  48. lerobot/scripts/visualize_dataset_html.py +479 -0
  49. lerobot/scripts/visualize_image_transforms.py +130 -0
  50. lerobot/templates/visualize_dataset_homepage.html +68 -0
.gitattributes CHANGED
@@ -18,3 +18,35 @@
18
  *.mp4 filter=lfs diff=lfs merge=lfs -text
19
  *.arrow filter=lfs diff=lfs merge=lfs -text
20
  *.json !text !filter !merge !diff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  *.mp4 filter=lfs diff=lfs merge=lfs -text
19
  *.arrow filter=lfs diff=lfs merge=lfs -text
20
  *.json !text !filter !merge !diff
21
+ media/lerobot-logo-light.png filter=lfs diff=lfs merge=lfs -text
22
+ media/aloha/follower_rotated.webp filter=lfs diff=lfs merge=lfs -text
23
+ media/aloha/follower_zero.webp filter=lfs diff=lfs merge=lfs -text
24
+ media/lerobot-logo-thumbnail.png filter=lfs diff=lfs merge=lfs -text
25
+ media/wandb.png filter=lfs diff=lfs merge=lfs -text
26
+ media/gym/pusht_diffusion.gif filter=lfs diff=lfs merge=lfs -text
27
+ media/aloha/leader_rest.webp filter=lfs diff=lfs merge=lfs -text
28
+ media/aloha/leader_zero.webp filter=lfs diff=lfs merge=lfs -text
29
+ media/koch/follower_rest.webp filter=lfs diff=lfs merge=lfs -text
30
+ media/koch/follower_zero.webp filter=lfs diff=lfs merge=lfs -text
31
+ media/gym/aloha_act.gif filter=lfs diff=lfs merge=lfs -text
32
+ media/aloha/leader_rotated.webp filter=lfs diff=lfs merge=lfs -text
33
+ media/koch/follower_rotated.webp filter=lfs diff=lfs merge=lfs -text
34
+ media/koch/leader_rest.webp filter=lfs diff=lfs merge=lfs -text
35
+ media/gym/simxarm_tdmpc.gif filter=lfs diff=lfs merge=lfs -text
36
+ media/koch/leader_rotated.webp filter=lfs diff=lfs merge=lfs -text
37
+ media/aloha/follower_rest.webp filter=lfs diff=lfs merge=lfs -text
38
+ media/koch/leader_zero.webp filter=lfs diff=lfs merge=lfs -text
39
+ media/lekiwi/kiwi.webp filter=lfs diff=lfs merge=lfs -text
40
+ media/lekiwi/mobile_calib_rest.webp filter=lfs diff=lfs merge=lfs -text
41
+ media/lekiwi/mobile_calib_rotated.webp filter=lfs diff=lfs merge=lfs -text
42
+ media/moss/follower_initial.webp filter=lfs diff=lfs merge=lfs -text
43
+ media/lekiwi/motor_ids.webp filter=lfs diff=lfs merge=lfs -text
44
+ media/moss/follower_rotated.webp filter=lfs diff=lfs merge=lfs -text
45
+ media/lekiwi/mobile_calib_zero.webp filter=lfs diff=lfs merge=lfs -text
46
+ media/moss/follower_rest.webp filter=lfs diff=lfs merge=lfs -text
47
+ media/moss/leader_rotated.webp filter=lfs diff=lfs merge=lfs -text
48
+ media/moss/follower_zero.webp filter=lfs diff=lfs merge=lfs -text
49
+ media/moss/leader_zero.webp filter=lfs diff=lfs merge=lfs -text
50
+ media/so100/follower_initial.webp filter=lfs diff=lfs merge=lfs -text
51
+ media/so100/follower_rest.webp filter=lfs diff=lfs merge=lfs -text
52
+ media/so100/leader_follower.webp filter=lfs diff=lfs merge=lfs -text
lerobot/common/policies/vqbet/configuration_vqbet.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
4
+ # and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
5
+ # and The HuggingFace Inc. team. All rights reserved.
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ from dataclasses import dataclass, field
20
+
21
+ from lerobot.common.optim.optimizers import AdamConfig
22
+ from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
23
+ from lerobot.configs.policies import PreTrainedConfig
24
+ from lerobot.configs.types import NormalizationMode
25
+
26
+
27
+ @PreTrainedConfig.register_subclass("vqbet")
28
+ @dataclass
29
+ class VQBeTConfig(PreTrainedConfig):
30
+ """Configuration class for VQ-BeT.
31
+
32
+ Defaults are configured for training with PushT providing proprioceptive and single camera observations.
33
+
34
+ The parameters you will most likely need to change are the ones which depend on the environment / sensors.
35
+ Those are: `input_shapes` and `output_shapes`.
36
+
37
+ Notes on the inputs and outputs:
38
+ - "observation.state" is required as an input key.
39
+ - At least one key starting with "observation.image is required as an input.
40
+ - If there are multiple keys beginning with "observation.image" they are treated as multiple camera
41
+ views. Right now we only support all images having the same shape.
42
+ - "action" is required as an output key.
43
+
44
+ Args:
45
+ n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
46
+ current step and additional steps going back).
47
+ n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts.
48
+ action_chunk_size: Action chunk size of each action prediction token.
49
+ input_shapes: A dictionary defining the shapes of the input data for the policy.
50
+ The key represents the input data name, and the value is a list indicating the dimensions
51
+ of the corresponding data. For example, "observation.image" refers to an input from
52
+ a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
53
+ Importantly, shapes doesnt include batch dimension or temporal dimension.
54
+ output_shapes: A dictionary defining the shapes of the output data for the policy.
55
+ The key represents the output data name, and the value is a list indicating the dimensions
56
+ of the corresponding data. For example, "action" refers to an output shape of [14], indicating
57
+ 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
58
+ input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
59
+ and the value specifies the normalization mode to apply. The two available modes are "mean_std"
60
+ which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
61
+ [-1, 1] range.
62
+ output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
63
+ original scale. Note that this is also used for normalizing the training targets.
64
+ vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
65
+ crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
66
+ within the image size. If None, no cropping is done.
67
+ crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
68
+ mode).
69
+ pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
70
+ `None` means no pretrained weights.
71
+ use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
72
+ The group sizes are set to be about 16 (to be precise, feature_dim // 16).
73
+ spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
74
+ n_vqvae_training_steps: Number of optimization steps for training Residual VQ.
75
+ vqvae_n_embed: Number of embedding vectors in the RVQ dictionary (each layer).
76
+ vqvae_embedding_dim: Dimension of each embedding vector in the RVQ dictionary.
77
+ vqvae_enc_hidden_dim: Size of hidden dimensions of Encoder / Decoder part of Residaul VQ-VAE
78
+ gpt_block_size: Max block size of minGPT (should be larger than the number of input tokens)
79
+ gpt_input_dim: Size of output input of GPT. This is also used as the dimension of observation features.
80
+ gpt_output_dim: Size of output dimension of GPT. This is also used as a input dimension of offset / bin prediction headers.
81
+ gpt_n_layer: Number of layers of GPT
82
+ gpt_n_head: Number of headers of GPT
83
+ gpt_hidden_dim: Size of hidden dimensions of GPT
84
+ dropout: Dropout rate for GPT
85
+ mlp_hidden_dim: Size of hidden dimensions of offset header / bin prediction headers parts of VQ-BeT
86
+ offset_loss_weight: A constant that is multiplied to the offset loss
87
+ primary_code_loss_weight: A constant that is multiplied to the primary code prediction loss
88
+ secondary_code_loss_weight: A constant that is multiplied to the secondary code prediction loss
89
+ bet_softmax_temperature: Sampling temperature of code for rollout with VQ-BeT
90
+ sequentially_select: Whether select code of primary / secondary as sequentially (pick primary code,
91
+ and then select secodnary code), or at the same time.
92
+ """
93
+
94
+ # Inputs / output structure.
95
+ n_obs_steps: int = 5
96
+ n_action_pred_token: int = 3
97
+ action_chunk_size: int = 5
98
+
99
+ normalization_mapping: dict[str, NormalizationMode] = field(
100
+ default_factory=lambda: {
101
+ "VISUAL": NormalizationMode.IDENTITY,
102
+ "STATE": NormalizationMode.MIN_MAX,
103
+ "ACTION": NormalizationMode.MIN_MAX,
104
+ }
105
+ )
106
+
107
+ # Architecture / modeling.
108
+ # Vision backbone.
109
+ vision_backbone: str = "resnet18"
110
+ crop_shape: tuple[int, int] | None = (84, 84)
111
+ crop_is_random: bool = True
112
+ pretrained_backbone_weights: str | None = None
113
+ use_group_norm: bool = True
114
+ spatial_softmax_num_keypoints: int = 32
115
+ # VQ-VAE
116
+ n_vqvae_training_steps: int = 20000
117
+ vqvae_n_embed: int = 16
118
+ vqvae_embedding_dim: int = 256
119
+ vqvae_enc_hidden_dim: int = 128
120
+ # VQ-BeT
121
+ gpt_block_size: int = 500
122
+ gpt_input_dim: int = 512
123
+ gpt_output_dim: int = 512
124
+ gpt_n_layer: int = 8
125
+ gpt_n_head: int = 8
126
+ gpt_hidden_dim: int = 512
127
+ dropout: float = 0.1
128
+ mlp_hidden_dim: int = 1024
129
+ offset_loss_weight: float = 10000.0
130
+ primary_code_loss_weight: float = 5.0
131
+ secondary_code_loss_weight: float = 0.5
132
+ bet_softmax_temperature: float = 0.1
133
+ sequentially_select: bool = False
134
+
135
+ # Training presets
136
+ optimizer_lr: float = 1e-4
137
+ optimizer_betas: tuple = (0.95, 0.999)
138
+ optimizer_eps: float = 1e-8
139
+ optimizer_weight_decay: float = 1e-6
140
+ optimizer_vqvae_lr: float = 1e-3
141
+ optimizer_vqvae_weight_decay: float = 1e-4
142
+ scheduler_warmup_steps: int = 500
143
+
144
+ def __post_init__(self):
145
+ super().__post_init__()
146
+
147
+ """Input validation (not exhaustive)."""
148
+ if not self.vision_backbone.startswith("resnet"):
149
+ raise ValueError(
150
+ f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
151
+ )
152
+
153
+ def get_optimizer_preset(self) -> AdamConfig:
154
+ return AdamConfig(
155
+ lr=self.optimizer_lr,
156
+ betas=self.optimizer_betas,
157
+ eps=self.optimizer_eps,
158
+ weight_decay=self.optimizer_weight_decay,
159
+ )
160
+
161
+ def get_scheduler_preset(self) -> VQBeTSchedulerConfig:
162
+ return VQBeTSchedulerConfig(
163
+ num_warmup_steps=self.scheduler_warmup_steps,
164
+ num_vqvae_training_steps=self.n_vqvae_training_steps,
165
+ )
166
+
167
+ def validate_features(self) -> None:
168
+ # Note: this check was previously performed inside VQBeTRgbEncoder in the form of
169
+ # assert len(image_keys) == 1
170
+ if not len(self.image_features) == 1:
171
+ raise ValueError("You must provide only one image among the inputs.")
172
+
173
+ if self.crop_shape is not None:
174
+ for key, image_ft in self.image_features.items():
175
+ if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
176
+ raise ValueError(
177
+ f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
178
+ f"for `crop_shape` and {image_ft.shape} for "
179
+ f"`{key}`."
180
+ )
181
+
182
+ # Check that all input images have the same shape.
183
+ first_image_key, first_image_ft = next(iter(self.image_features.items()))
184
+ for key, image_ft in self.image_features.items():
185
+ if image_ft.shape != first_image_ft.shape:
186
+ raise ValueError(
187
+ f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
188
+ )
189
+
190
+ @property
191
+ def observation_delta_indices(self) -> list:
192
+ return list(range(1 - self.n_obs_steps, 1))
193
+
194
+ @property
195
+ def action_delta_indices(self) -> list:
196
+ return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1))
197
+
198
+ @property
199
+ def reward_delta_indices(self) -> None:
200
+ return None
lerobot/common/policies/vqbet/modeling_vqbet.py ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
4
+ # and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
5
+ # and The HuggingFace Inc. team. All rights reserved.
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ import warnings
20
+ from collections import deque
21
+ from typing import Callable, List
22
+
23
+ import einops
24
+ import numpy as np
25
+ import torch
26
+ import torch.nn.functional as F # noqa: N812
27
+ import torchvision
28
+ from torch import Tensor, nn
29
+
30
+ from lerobot.common.policies.normalize import Normalize, Unnormalize
31
+ from lerobot.common.policies.pretrained import PreTrainedPolicy
32
+ from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
33
+ from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
34
+ from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
35
+
36
+ # ruff: noqa: N806
37
+
38
+
39
+ class VQBeTPolicy(PreTrainedPolicy):
40
+ """
41
+ VQ-BeT Policy as per "Behavior Generation with Latent Actions"
42
+ """
43
+
44
+ config_class = VQBeTConfig
45
+ name = "vqbet"
46
+
47
+ def __init__(
48
+ self,
49
+ config: VQBeTConfig | None = None,
50
+ dataset_stats: dict[str, dict[str, Tensor]] | None = None,
51
+ ):
52
+ """
53
+ Args:
54
+ config: Policy configuration class instance or None, in which case the default instantiation of
55
+ the configuration class is used.
56
+ dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
57
+ that they will be passed with a call to `load_state_dict` before the policy is used.
58
+ """
59
+ super().__init__(config)
60
+ config.validate_features()
61
+ self.config = config
62
+
63
+ self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
64
+ self.normalize_targets = Normalize(
65
+ config.output_features, config.normalization_mapping, dataset_stats
66
+ )
67
+ self.unnormalize_outputs = Unnormalize(
68
+ config.output_features, config.normalization_mapping, dataset_stats
69
+ )
70
+
71
+ self.vqbet = VQBeTModel(config)
72
+
73
+ self.reset()
74
+
75
+ def get_optim_params(self) -> dict:
76
+ vqvae_params = (
77
+ list(self.vqbet.action_head.vqvae_model.encoder.parameters())
78
+ + list(self.vqbet.action_head.vqvae_model.decoder.parameters())
79
+ + list(self.vqbet.action_head.vqvae_model.vq_layer.parameters())
80
+ )
81
+ decay_params, no_decay_params = self.vqbet.policy.configure_parameters()
82
+ decay_params = (
83
+ decay_params
84
+ + list(self.vqbet.rgb_encoder.parameters())
85
+ + list(self.vqbet.state_projector.parameters())
86
+ + list(self.vqbet.rgb_feature_projector.parameters())
87
+ + [self.vqbet.action_token]
88
+ + list(self.vqbet.action_head.map_to_cbet_preds_offset.parameters())
89
+ )
90
+
91
+ if self.config.sequentially_select:
92
+ decay_params = (
93
+ decay_params
94
+ + list(self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
95
+ + list(self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
96
+ )
97
+ else:
98
+ decay_params = decay_params + list(self.vqbet.action_head.map_to_cbet_preds_bin.parameters())
99
+
100
+ return [
101
+ {
102
+ "params": decay_params,
103
+ },
104
+ {
105
+ "params": vqvae_params,
106
+ "weight_decay": self.config.optimizer_vqvae_weight_decay,
107
+ "lr": self.config.optimizer_vqvae_lr,
108
+ },
109
+ {
110
+ "params": no_decay_params,
111
+ "weight_decay": 0.0,
112
+ },
113
+ ]
114
+
115
+ def reset(self):
116
+ """
117
+ Clear observation and action queues. Should be called on `env.reset()`
118
+ queues are populated during rollout of the policy, they contain the n latest observations and actions
119
+ """
120
+ self._queues = {
121
+ "observation.images": deque(maxlen=self.config.n_obs_steps),
122
+ "observation.state": deque(maxlen=self.config.n_obs_steps),
123
+ "action": deque(maxlen=self.config.action_chunk_size),
124
+ }
125
+
126
+ @torch.no_grad
127
+ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
128
+ """Select a single action given environment observations.
129
+
130
+ This method wraps `select_actions` in order to return one action at a time for execution in the
131
+ environment. It works by managing the actions in a queue and only calling `select_actions` when the
132
+ queue is empty.
133
+ """
134
+
135
+ batch = self.normalize_inputs(batch)
136
+ batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
137
+ batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
138
+ # Note: It's important that this happens after stacking the images into a single key.
139
+ self._queues = populate_queues(self._queues, batch)
140
+
141
+ if not self.vqbet.action_head.vqvae_model.discretized.item():
142
+ warnings.warn(
143
+ "To evaluate in the environment, your VQ-BeT model should contain a pretrained Residual VQ.",
144
+ stacklevel=1,
145
+ )
146
+
147
+ if len(self._queues["action"]) == 0:
148
+ batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
149
+ actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
150
+
151
+ # the dimension of returned action is (batch_size, action_chunk_size, action_dim)
152
+ actions = self.unnormalize_outputs({"action": actions})["action"]
153
+ # since the data in the action queue's dimension is (action_chunk_size, batch_size, action_dim), we transpose the action and fill the queue
154
+ self._queues["action"].extend(actions.transpose(0, 1))
155
+
156
+ action = self._queues["action"].popleft()
157
+ return action
158
+
159
+ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
160
+ """Run the batch through the model and compute the loss for training or validation."""
161
+ batch = self.normalize_inputs(batch)
162
+ batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
163
+ batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
164
+ batch = self.normalize_targets(batch)
165
+ # VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
166
+ if not self.vqbet.action_head.vqvae_model.discretized.item():
167
+ # loss: total loss of training RVQ
168
+ # n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
169
+ # n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree).
170
+ loss, n_different_codes, n_different_combinations, recon_l1_error = (
171
+ self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
172
+ )
173
+ return loss, {
174
+ "n_different_codes": n_different_codes,
175
+ "n_different_combinations": n_different_combinations,
176
+ "recon_l1_error": recon_l1_error,
177
+ }
178
+ # if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts.
179
+ _, loss_dict = self.vqbet(batch, rollout=False)
180
+ loss = loss_dict.pop("loss")
181
+
182
+ return loss, loss_dict
183
+
184
+
185
+ class SpatialSoftmax(nn.Module):
186
+ """
187
+ Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
188
+ (https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
189
+
190
+ At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
191
+ of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
192
+
193
+ Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
194
+ -----------------------------------------------------
195
+ | (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
196
+ | (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
197
+ | ... | ... | ... | ... |
198
+ | (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
199
+ -----------------------------------------------------
200
+ This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
201
+ product with the coordinates (120x2) to get expected points of maximal activation (512x2).
202
+
203
+ The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
204
+ provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
205
+ linear mapping (in_channels, H, W) -> (num_kp, H, W).
206
+ """
207
+
208
+ def __init__(self, input_shape, num_kp=None):
209
+ """
210
+ Args:
211
+ input_shape (list): (C, H, W) input feature map shape.
212
+ num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
213
+ """
214
+ super().__init__()
215
+
216
+ assert len(input_shape) == 3
217
+ self._in_c, self._in_h, self._in_w = input_shape
218
+
219
+ if num_kp is not None:
220
+ self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
221
+ self._out_c = num_kp
222
+ else:
223
+ self.nets = None
224
+ self._out_c = self._in_c
225
+
226
+ # we could use torch.linspace directly but that seems to behave slightly differently than numpy
227
+ # and causes a small degradation in pc_success of pre-trained models.
228
+ pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
229
+ pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
230
+ pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
231
+ # register as buffer so it's moved to the correct device.
232
+ self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
233
+
234
+ def forward(self, features: Tensor) -> Tensor:
235
+ """
236
+ Args:
237
+ features: (B, C, H, W) input feature maps.
238
+ Returns:
239
+ (B, K, 2) image-space coordinates of keypoints.
240
+ """
241
+ if self.nets is not None:
242
+ features = self.nets(features)
243
+
244
+ # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
245
+ features = features.reshape(-1, self._in_h * self._in_w)
246
+ # 2d softmax normalization
247
+ attention = F.softmax(features, dim=-1)
248
+ # [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
249
+ expected_xy = attention @ self.pos_grid
250
+ # reshape to [B, K, 2]
251
+ feature_keypoints = expected_xy.view(-1, self._out_c, 2)
252
+
253
+ return feature_keypoints
254
+
255
+
256
+ class VQBeTModel(nn.Module):
257
+ """VQ-BeT: The underlying neural network for VQ-BeT
258
+
259
+ Note: In this code we use the terms `rgb_encoder`, 'policy', `action_head`. The meanings are as follows.
260
+ - The `rgb_encoder` process rgb-style image observations to one-dimensional embedding vectors
261
+ - A `policy` is a minGPT architecture, that takes observation sequences and action query tokens to generate `features`.
262
+ - These `features` pass through the action head, which passes through the code prediction, offset prediction head,
263
+ and finally generates a prediction for the action chunks.
264
+
265
+ -------------------------------** legend **-------------------------------
266
+ │ n = n_obs_steps, p = n_action_pred_token, c = action_chunk_size) │
267
+ │ o_{t} : visual observation at timestep {t} │
268
+ │ s_{t} : state observation at timestep {t} │
269
+ │ a_{t} : action at timestep {t} │
270
+ │ A_Q : action_query_token │
271
+ --------------------------------------------------------------------------
272
+
273
+
274
+ Training Phase 1. Discretize action using Residual VQ (for config.n_vqvae_training_steps steps)
275
+
276
+
277
+ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
278
+ │ │ │ │ │ │
279
+ │ RVQ encoder │ ─► │ Residual │ ─► │ RVQ Decoder │
280
+ │ (a_{t}~a_{t+p}) │ │ Code Quantizer │ │ │
281
+ │ │ │ │ │ │
282
+ └─────────────────┘ └─────────────────┘ └─────────────────┘
283
+
284
+ Training Phase 2.
285
+
286
+ timestep {t-n+1} timestep {t-n+2} timestep {t}
287
+ ┌─────┴─────┐ ┌─────┴─────┐ ┌─────┴─────┐
288
+
289
+ o_{t-n+1} o_{t-n+2} ... o_{t}
290
+ │ │ │
291
+ │ s_{t-n+1} │ s_{t-n+2} ... │ s_{t} p
292
+ │ │ │ │ │ │ ┌───────┴───────┐
293
+ │ │ A_Q │ │ A_Q ... │ │ A_Q ... A_Q
294
+ │ │ │ │ │ │ │ │ │ │
295
+ ┌───▼─────▼─────▼─────▼─────▼─────▼─────────────────▼─────▼─────▼───────────────▼───┐
296
+ │ │
297
+ │ GPT │ => policy
298
+ │ │
299
+ └───────────────▼─────────────────▼─────────────────────────────▼───────────────▼───┘
300
+ │ │ │ │
301
+ ┌───┴───┐ ┌───┴───┐ ┌───┴───┐ ┌───┴───┐
302
+ code offset code offset code offset code offset
303
+ ▼ │ ▼ │ ▼ │ ▼ │ => action_head
304
+ RVQ Decoder │ RVQ Decoder │ RVQ Decoder │ RVQ Decoder │
305
+ └── + ──┘ └── + ──┘ └── + ──┘ └── + ──┘
306
+ ▼ ▼ ▼ ▼
307
+ action chunk action chunk action chunk action chunk
308
+ a_{t-n+1} ~ a_{t-n+2} ~ a_{t} ~ ... a_{t+p-1} ~
309
+ a_{t-n+c} a_{t-n+c+1} a_{t+c-1} a_{t+p+c-1}
310
+
311
+
312
+ ONLY this chunk is used in rollout!
313
+ """
314
+
315
+ def __init__(self, config: VQBeTConfig):
316
+ super().__init__()
317
+ self.config = config
318
+
319
+ self.rgb_encoder = VQBeTRgbEncoder(config)
320
+ self.num_images = len(self.config.image_features)
321
+ # This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above.
322
+ # Note: During the forward pass, this token is repeated as many times as needed. The authors also experimented with initializing the necessary number of tokens independently and observed inferior results.
323
+ self.action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
324
+
325
+ # To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
326
+ self.state_projector = MLP(
327
+ config.robot_state_feature.shape[0], hidden_channels=[self.config.gpt_input_dim]
328
+ )
329
+ self.rgb_feature_projector = MLP(
330
+ self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
331
+ )
332
+
333
+ # GPT part of VQ-BeT
334
+ self.policy = GPT(config)
335
+ # bin prediction head / offset prediction head part of VQ-BeT
336
+ self.action_head = VQBeTHead(config)
337
+
338
+ # Action tokens for: each observation step, the current action token, and all future action tokens.
339
+ num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1
340
+ self.register_buffer(
341
+ "select_target_actions_indices",
342
+ torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
343
+ )
344
+
345
+ def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
346
+ # Input validation.
347
+ assert set(batch).issuperset({"observation.state", "observation.images"})
348
+ batch_size, n_obs_steps = batch["observation.state"].shape[:2]
349
+ assert n_obs_steps == self.config.n_obs_steps
350
+
351
+ # Extract image feature (first combine batch and sequence dims).
352
+ img_features = self.rgb_encoder(
353
+ einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
354
+ )
355
+ # Separate batch and sequence dims.
356
+ img_features = einops.rearrange(
357
+ img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
358
+ )
359
+
360
+ # Arrange prior and current observation step tokens as shown in the class docstring.
361
+ # First project features to token dimension.
362
+ rgb_tokens = self.rgb_feature_projector(
363
+ img_features
364
+ ) # (batch, obs_step, number of different cameras, projection dims)
365
+ input_tokens = [rgb_tokens[:, :, i] for i in range(rgb_tokens.size(2))]
366
+ input_tokens.append(
367
+ self.state_projector(batch["observation.state"])
368
+ ) # (batch, obs_step, projection dims)
369
+ input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
370
+ # Interleave tokens by stacking and rearranging.
371
+ input_tokens = torch.stack(input_tokens, dim=2)
372
+ input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d")
373
+
374
+ len_additional_action_token = self.config.n_action_pred_token - 1
375
+ future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1)
376
+
377
+ # add additional action query tokens for predicting future action chunks
378
+ input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1)
379
+
380
+ # get action features (pass through GPT)
381
+ features = self.policy(input_tokens)
382
+ # len(self.config.input_features) is the number of different observation modes.
383
+ # this line gets the index of action prompt tokens.
384
+ historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len(
385
+ self.config.input_features
386
+ )
387
+
388
+ # only extract the output tokens at the position of action query:
389
+ # Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
390
+ # mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://arxiv.org/pdf/2206.11251).
391
+ # Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
392
+ if len_additional_action_token > 0:
393
+ features = torch.cat(
394
+ [features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
395
+ )
396
+ else:
397
+ features = features[:, historical_act_pred_index]
398
+ # pass through action head
399
+ action_head_output = self.action_head(features)
400
+ # if rollout, VQ-BeT don't calculate loss
401
+ if rollout:
402
+ return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape(
403
+ batch_size, self.config.action_chunk_size, -1
404
+ )
405
+ # else, it calculate overall loss (bin prediction loss, and offset loss)
406
+ else:
407
+ output = batch["action"][:, self.select_target_actions_indices]
408
+ loss = self.action_head.loss_fn(action_head_output, output, reduction="mean")
409
+ return action_head_output, loss
410
+
411
+
412
+ class VQBeTHead(nn.Module):
413
+ def __init__(self, config: VQBeTConfig):
414
+ """
415
+ VQBeTHead takes output of GPT layers, and pass the feature through bin prediction head (`self.map_to_cbet_preds_bin`), and offset prediction head (`self.map_to_cbet_preds_offset`)
416
+
417
+ self.map_to_cbet_preds_bin: outputs probability of each code (for each layer).
418
+ The input dimension of `self.map_to_cbet_preds_bin` is same with the output of GPT,
419
+ and the output dimension of `self.map_to_cbet_preds_bin` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed`.
420
+ if the agent select the code sequentially, we use self.map_to_cbet_preds_primary_bin and self.map_to_cbet_preds_secondary_bin instead of self._map_to_cbet_preds_bin.
421
+
422
+ self.map_to_cbet_preds_offset: output the predicted offsets for all the codes in all the layers.
423
+ The input dimension of ` self.map_to_cbet_preds_offset` is same with the output of GPT,
424
+ and the output dimension of ` self.map_to_cbet_preds_offset` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed * config.action_chunk_size * config.action_feature.shape[0]`.
425
+ """
426
+
427
+ super().__init__()
428
+ self.config = config
429
+ # init vqvae
430
+ self.vqvae_model = VqVae(config)
431
+ if config.sequentially_select:
432
+ self.map_to_cbet_preds_primary_bin = MLP(
433
+ in_channels=config.gpt_output_dim,
434
+ hidden_channels=[self.config.vqvae_n_embed],
435
+ )
436
+ self.map_to_cbet_preds_secondary_bin = MLP(
437
+ in_channels=config.gpt_output_dim + self.config.vqvae_n_embed,
438
+ hidden_channels=[self.config.vqvae_n_embed],
439
+ )
440
+ else:
441
+ self.map_to_cbet_preds_bin = MLP(
442
+ in_channels=config.gpt_output_dim,
443
+ hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed],
444
+ )
445
+ self.map_to_cbet_preds_offset = MLP(
446
+ in_channels=config.gpt_output_dim,
447
+ hidden_channels=[
448
+ self.vqvae_model.vqvae_num_layers
449
+ * self.config.vqvae_n_embed
450
+ * config.action_chunk_size
451
+ * config.action_feature.shape[0],
452
+ ],
453
+ )
454
+ # loss
455
+ self._focal_loss_fn = FocalLoss(gamma=2.0)
456
+
457
+ def discretize(self, n_vqvae_training_steps, actions):
458
+ # Resize the action sequence data to fit the action chunk size using a sliding window approach.
459
+ actions = torch.cat(
460
+ [
461
+ actions[:, j : j + self.config.action_chunk_size, :]
462
+ for j in range(actions.shape[1] + 1 - self.config.action_chunk_size)
463
+ ],
464
+ dim=0,
465
+ )
466
+ # `actions` is a tensor of shape (new_batch, action_chunk_size, action_dim) where new_batch is the number of possible chunks created from the original sequences using the sliding window.
467
+
468
+ loss, metric = self.vqvae_model.vqvae_forward(actions)
469
+ n_different_codes = sum(
470
+ [len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)]
471
+ )
472
+ n_different_combinations = len(torch.unique(metric[2], dim=0))
473
+ recon_l1_error = metric[0].detach().cpu().item()
474
+ self.vqvae_model.optimized_steps += 1
475
+ # if we updated RVQ more than `n_vqvae_training_steps` steps, we freeze the RVQ part.
476
+ if self.vqvae_model.optimized_steps >= n_vqvae_training_steps:
477
+ self.vqvae_model.discretized = torch.tensor(True)
478
+ self.vqvae_model.vq_layer.freeze_codebook = torch.tensor(True)
479
+ print("Finished discretizing action data!")
480
+ self.vqvae_model.eval()
481
+ for param in self.vqvae_model.vq_layer.parameters():
482
+ param.requires_grad = False
483
+ return loss, n_different_codes, n_different_combinations, recon_l1_error
484
+
485
+ def forward(self, x, **kwargs) -> dict:
486
+ # N is the batch size, and T is number of action query tokens, which are process through same GPT
487
+ N, T, _ = x.shape
488
+ # we calculate N and T side parallelly. Thus, the dimensions would be
489
+ # (batch size * number of action query tokens, action chunk size, action dimension)
490
+ x = einops.rearrange(x, "N T WA -> (N T) WA")
491
+
492
+ # sample offsets
493
+ cbet_offsets = self.map_to_cbet_preds_offset(x)
494
+ cbet_offsets = einops.rearrange(
495
+ cbet_offsets,
496
+ "(NT) (G C WA) -> (NT) G C WA",
497
+ G=self.vqvae_model.vqvae_num_layers,
498
+ C=self.config.vqvae_n_embed,
499
+ )
500
+ # if self.config.sequentially_select is True, bin prediction head first sample the primary code, and then sample secondary code
501
+ if self.config.sequentially_select:
502
+ cbet_primary_logits = self.map_to_cbet_preds_primary_bin(x)
503
+
504
+ # select primary bin first
505
+ cbet_primary_probs = torch.softmax(
506
+ cbet_primary_logits / self.config.bet_softmax_temperature, dim=-1
507
+ )
508
+ NT, choices = cbet_primary_probs.shape
509
+ sampled_primary_centers = einops.rearrange(
510
+ torch.multinomial(cbet_primary_probs.view(-1, choices), num_samples=1),
511
+ "(NT) 1 -> NT",
512
+ NT=NT,
513
+ )
514
+
515
+ cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin(
516
+ torch.cat(
517
+ (x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)),
518
+ axis=1,
519
+ )
520
+ )
521
+ cbet_secondary_probs = torch.softmax(
522
+ cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1
523
+ )
524
+ sampled_secondary_centers = einops.rearrange(
525
+ torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1),
526
+ "(NT) 1 -> NT",
527
+ NT=NT,
528
+ )
529
+ sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1)
530
+ cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1)
531
+ # if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once.
532
+ else:
533
+ cbet_logits = self.map_to_cbet_preds_bin(x)
534
+ cbet_logits = einops.rearrange(
535
+ cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
536
+ )
537
+ cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
538
+ NT, G, choices = cbet_probs.shape
539
+ sampled_centers = einops.rearrange(
540
+ torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
541
+ "(NT G) 1 -> NT G",
542
+ NT=NT,
543
+ )
544
+
545
+ device = get_device_from_parameters(self)
546
+ indices = (
547
+ torch.arange(NT, device=device).unsqueeze(1),
548
+ torch.arange(self.vqvae_model.vqvae_num_layers, device=device).unsqueeze(0),
549
+ sampled_centers,
550
+ )
551
+ # Use advanced indexing to sample the values (Extract the only offsets corresponding to the sampled codes.)
552
+ sampled_offsets = cbet_offsets[indices]
553
+ # Then, sum the offsets over the RVQ layers to get a net offset for the bin prediction
554
+ sampled_offsets = sampled_offsets.sum(dim=1)
555
+ with torch.no_grad():
556
+ # Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder
557
+ return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach()
558
+ # pass the centroids through decoder to get actions.
559
+ decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach()
560
+ # reshaped extracted offset to match with decoded centroids
561
+ sampled_offsets = einops.rearrange(
562
+ sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
563
+ )
564
+ # add offset and decoded centroids
565
+ predicted_action = decoded_action + sampled_offsets
566
+ predicted_action = einops.rearrange(
567
+ predicted_action,
568
+ "(N T) W A -> N T (W A)",
569
+ N=N,
570
+ T=T,
571
+ W=self.config.action_chunk_size,
572
+ )
573
+
574
+ return {
575
+ "cbet_logits": cbet_logits,
576
+ "predicted_action": predicted_action,
577
+ "sampled_centers": sampled_centers,
578
+ "decoded_action": decoded_action,
579
+ }
580
+
581
+ def loss_fn(self, pred, target, **kwargs):
582
+ """
583
+ for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.
584
+
585
+ predicted_action: predicted action chunk (offset + decoded centroids)
586
+ sampled_centers: sampled centroids (code of RVQ)
587
+ decoded_action: decoded action, which is produced by passing sampled_centers through RVQ decoder
588
+ NT: batch size * T
589
+ T: number of action query tokens, which are process through same GPT
590
+ cbet_logits: probability of all codes in each layer
591
+ """
592
+ action_seq = target
593
+ predicted_action = pred["predicted_action"]
594
+ sampled_centers = pred["sampled_centers"]
595
+ decoded_action = pred["decoded_action"]
596
+ NT = predicted_action.shape[0] * predicted_action.shape[1]
597
+
598
+ cbet_logits = pred["cbet_logits"]
599
+
600
+ predicted_action = einops.rearrange(
601
+ predicted_action, "N T (W A) -> (N T) W A", W=self.config.action_chunk_size
602
+ )
603
+
604
+ action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
605
+ # Figure out the loss for the actions.
606
+ # First, we need to find the closest cluster center for each ground truth action.
607
+ with torch.no_grad():
608
+ state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
609
+
610
+ # Now we can compute the loss.
611
+
612
+ # offset loss is L1 distance between the predicted action and ground truth action
613
+ offset_loss = F.l1_loss(action_seq, predicted_action)
614
+
615
+ # calculate primary code prediction loss
616
+ cbet_loss1 = self._focal_loss_fn(
617
+ cbet_logits[:, 0, :],
618
+ action_bins[:, 0],
619
+ )
620
+ # calculate secondary code prediction loss
621
+ cbet_loss2 = self._focal_loss_fn(
622
+ cbet_logits[:, 1, :],
623
+ action_bins[:, 1],
624
+ )
625
+ # add all the prediction loss
626
+ cbet_loss = (
627
+ cbet_loss1 * self.config.primary_code_loss_weight
628
+ + cbet_loss2 * self.config.secondary_code_loss_weight
629
+ )
630
+
631
+ equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT)
632
+ equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT)
633
+
634
+ action_mse_error = torch.mean((action_seq - predicted_action) ** 2)
635
+ vq_action_error = torch.mean(torch.abs(action_seq - decoded_action))
636
+ offset_action_error = torch.mean(torch.abs(action_seq - predicted_action))
637
+ action_error_max = torch.max(torch.abs(action_seq - predicted_action))
638
+
639
+ loss = cbet_loss + self.config.offset_loss_weight * offset_loss
640
+
641
+ loss_dict = {
642
+ "loss": loss,
643
+ "classification_loss": cbet_loss.detach().cpu().item(),
644
+ "offset_loss": offset_loss.detach().cpu().item(),
645
+ "equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(),
646
+ "equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(),
647
+ "vq_action_error": vq_action_error.detach().cpu().item(),
648
+ "offset_action_error": offset_action_error.detach().cpu().item(),
649
+ "action_error_max": action_error_max.detach().cpu().item(),
650
+ "action_mse_error": action_mse_error.detach().cpu().item(),
651
+ }
652
+ return loss_dict
653
+
654
+
655
+ class VQBeTRgbEncoder(nn.Module):
656
+ """Encode an RGB image into a 1D feature vector.
657
+
658
+ Includes the ability to normalize and crop the image first.
659
+
660
+ Same with DiffusionRgbEncoder from modeling_diffusion.py
661
+ """
662
+
663
+ def __init__(self, config: VQBeTConfig):
664
+ super().__init__()
665
+ # Set up optional preprocessing.
666
+ if config.crop_shape is not None:
667
+ self.do_crop = True
668
+ # Always use center crop for eval
669
+ self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
670
+ if config.crop_is_random:
671
+ self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
672
+ else:
673
+ self.maybe_random_crop = self.center_crop
674
+ else:
675
+ self.do_crop = False
676
+
677
+ # Set up backbone.
678
+ backbone_model = getattr(torchvision.models, config.vision_backbone)(
679
+ weights=config.pretrained_backbone_weights
680
+ )
681
+ # Note: This assumes that the layer4 feature map is children()[-3]
682
+ # TODO(alexander-soare): Use a safer alternative.
683
+ self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
684
+ if config.use_group_norm:
685
+ if config.pretrained_backbone_weights:
686
+ raise ValueError(
687
+ "You can't replace BatchNorm in a pretrained model without ruining the weights!"
688
+ )
689
+ self.backbone = _replace_submodules(
690
+ root_module=self.backbone,
691
+ predicate=lambda x: isinstance(x, nn.BatchNorm2d),
692
+ func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
693
+ )
694
+
695
+ # Set up pooling and final layers.
696
+ # Use a dry run to get the feature map shape.
697
+ # The dummy input should take the number of image channels from `config.image_features` and it should
698
+ # use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
699
+ # height and width from `config.image_features`.
700
+
701
+ images_shape = next(iter(config.image_features.values())).shape
702
+ dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
703
+ dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
704
+ feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
705
+
706
+ self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
707
+ self.feature_dim = config.spatial_softmax_num_keypoints * 2
708
+ self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
709
+ self.relu = nn.ReLU()
710
+
711
+ def forward(self, x: Tensor) -> Tensor:
712
+ """
713
+ Args:
714
+ x: (B, C, H, W) image tensor with pixel values in [0, 1].
715
+ Returns:
716
+ (B, D) image feature.
717
+ """
718
+ # Preprocess: maybe crop (if it was set up in the __init__).
719
+ if self.do_crop:
720
+ if self.training: # noqa: SIM108
721
+ x = self.maybe_random_crop(x)
722
+ else:
723
+ # Always use center crop for eval.
724
+ x = self.center_crop(x)
725
+ # Extract backbone feature.
726
+ x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
727
+ # Final linear layer with non-linearity.
728
+ x = self.relu(self.out(x))
729
+ return x
730
+
731
+
732
+ def _replace_submodules(
733
+ root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
734
+ ) -> nn.Module:
735
+ """
736
+ Args:
737
+ root_module: The module for which the submodules need to be replaced
738
+ predicate: Takes a module as an argument and must return True if the that module is to be replaced.
739
+ func: Takes a module as an argument and returns a new module to replace it with.
740
+ Returns:
741
+ The root module with its submodules replaced.
742
+ """
743
+ if predicate(root_module):
744
+ return func(root_module)
745
+
746
+ replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
747
+ for *parents, k in replace_list:
748
+ parent_module = root_module
749
+ if len(parents) > 0:
750
+ parent_module = root_module.get_submodule(".".join(parents))
751
+ if isinstance(parent_module, nn.Sequential):
752
+ src_module = parent_module[int(k)]
753
+ else:
754
+ src_module = getattr(parent_module, k)
755
+ tgt_module = func(src_module)
756
+ if isinstance(parent_module, nn.Sequential):
757
+ parent_module[int(k)] = tgt_module
758
+ else:
759
+ setattr(parent_module, k, tgt_module)
760
+ # verify that all BN are replaced
761
+ assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
762
+ return root_module
763
+
764
+
765
+ class VqVae(nn.Module):
766
+ def __init__(
767
+ self,
768
+ config: VQBeTConfig,
769
+ ):
770
+ """
771
+ VQ-VAE is composed of three parts: encoder, vq_layer, and decoder.
772
+ Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively.
773
+ The vq_layer uses residual VQs.
774
+
775
+ This class contains functions for training the encoder and decoder along with the residual VQ layer (for training phase 1),
776
+ as well as functions to help BeT training part in training phase 2.
777
+ """
778
+
779
+ super().__init__()
780
+ self.config = config
781
+ # 'discretized' indicates whether the Residual VQ part is trained or not. (After finishing the training, we set discretized=True)
782
+ self.register_buffer("discretized", torch.tensor(False))
783
+ self.optimized_steps = 0
784
+ # we use the fixed number of layers for Residual VQ across all environments.
785
+ self.vqvae_num_layers = 2
786
+
787
+ self.vq_layer = ResidualVQ(
788
+ dim=config.vqvae_embedding_dim,
789
+ num_quantizers=self.vqvae_num_layers,
790
+ codebook_size=config.vqvae_n_embed,
791
+ )
792
+
793
+ self.encoder = MLP(
794
+ in_channels=self.config.action_feature.shape[0] * self.config.action_chunk_size,
795
+ hidden_channels=[
796
+ config.vqvae_enc_hidden_dim,
797
+ config.vqvae_enc_hidden_dim,
798
+ config.vqvae_embedding_dim,
799
+ ],
800
+ )
801
+ self.decoder = MLP(
802
+ in_channels=config.vqvae_embedding_dim,
803
+ hidden_channels=[
804
+ config.vqvae_enc_hidden_dim,
805
+ config.vqvae_enc_hidden_dim,
806
+ self.config.action_feature.shape[0] * self.config.action_chunk_size,
807
+ ],
808
+ )
809
+
810
+ def get_embeddings_from_code(self, encoding_indices):
811
+ # This function gets code indices as inputs, and outputs embedding vectors corresponding to the code indices.
812
+ with torch.no_grad():
813
+ z_embed = self.vq_layer.get_codebook_vector_from_indices(encoding_indices)
814
+ # since the RVQ has multiple layers, it adds the vectors in the axis of layers to provide a vector for that code combination.
815
+ z_embed = z_embed.sum(dim=0)
816
+ return z_embed
817
+
818
+ def get_action_from_latent(self, latent):
819
+ # given latent vector, this function outputs the decoded action.
820
+ output = self.decoder(latent)
821
+ if self.config.action_chunk_size == 1:
822
+ return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
823
+ else:
824
+ return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
825
+
826
+ def get_code(self, state):
827
+ # in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)
828
+ # this function outputs the `GT code` of given action using frozen encoder and quantization layers. (please refer to Figure 2. in the paper https://arxiv.org/pdf/2403.03181)
829
+ state = einops.rearrange(state, "N T A -> N (T A)")
830
+ with torch.no_grad():
831
+ state_rep = self.encoder(state)
832
+ state_rep_shape = state_rep.shape[:-1]
833
+ state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1))
834
+ state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat)
835
+ state_vq = state_rep_flat.view(*state_rep_shape, -1)
836
+ vq_code = vq_code.view(*state_rep_shape, -1)
837
+ vq_loss_state = torch.sum(vq_loss_state)
838
+ return state_vq, vq_code
839
+
840
+ def vqvae_forward(self, state):
841
+ # This function passes the given data through Residual VQ with Encoder and Decoder. Please refer to section 3.2 in the paper https://arxiv.org/pdf/2403.03181).
842
+ state = einops.rearrange(state, "N T A -> N (T A)")
843
+ # We start with passing action (or action chunk) at:t+n through the encoder ϕ.
844
+ state_rep = self.encoder(state)
845
+ state_rep_shape = state_rep.shape[:-1]
846
+ state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1))
847
+ # The resulting latent embedding vector x = ϕ(at:t+n) is then mapped to an embedding vector in the codebook of the RVQ layers by the nearest neighbor look-up.
848
+ state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat)
849
+ state_vq = state_rep_flat.view(*state_rep_shape, -1)
850
+ vq_code = vq_code.view(*state_rep_shape, -1)
851
+ # since the RVQ has multiple layers, it adds the vectors in the axis of layers to provide a vector for that code combination.
852
+ vq_loss_state = torch.sum(vq_loss_state)
853
+ # Then, the discretized vector zq(x) is reconstructed as ψ(zq(x)) by passing through the decoder ψ.
854
+ dec_out = self.decoder(state_vq)
855
+ # Calculate L1 reconstruction loss
856
+ encoder_loss = (state - dec_out).abs().mean()
857
+ # add encoder reconstruction loss and commitment loss
858
+ rep_loss = encoder_loss + vq_loss_state * 5
859
+
860
+ metric = (
861
+ encoder_loss.clone().detach(),
862
+ vq_loss_state.clone().detach(),
863
+ vq_code,
864
+ rep_loss.item(),
865
+ )
866
+ return rep_loss, metric
867
+
868
+
869
+ class FocalLoss(nn.Module):
870
+ """
871
+ From https://github.com/notmahi/miniBET/blob/main/behavior_transformer/bet.py
872
+ """
873
+
874
+ def __init__(self, gamma: float = 0, size_average: bool = True):
875
+ super().__init__()
876
+ self.gamma = gamma
877
+ self.size_average = size_average
878
+
879
+ def forward(self, input, target):
880
+ if len(input.shape) == 3:
881
+ N, T, _ = input.shape
882
+ logpt = F.log_softmax(input, dim=-1)
883
+ logpt = logpt.gather(-1, target.view(N, T, 1)).view(N, T)
884
+ elif len(input.shape) == 2:
885
+ logpt = F.log_softmax(input, dim=-1)
886
+ logpt = logpt.gather(-1, target.view(-1, 1)).view(-1)
887
+ pt = logpt.exp()
888
+
889
+ loss = -1 * (1 - pt) ** self.gamma * logpt
890
+ if self.size_average:
891
+ return loss.mean()
892
+ else:
893
+ return loss.sum()
894
+
895
+
896
+ class MLP(torch.nn.Sequential):
897
+ def __init__(
898
+ self,
899
+ in_channels: int,
900
+ hidden_channels: List[int],
901
+ ):
902
+ layers = []
903
+ in_dim = in_channels
904
+ for hidden_dim in hidden_channels[:-1]:
905
+ layers.append(torch.nn.Linear(in_dim, hidden_dim))
906
+ layers.append(torch.nn.ReLU())
907
+ in_dim = hidden_dim
908
+
909
+ layers.append(torch.nn.Linear(in_dim, hidden_channels[-1]))
910
+
911
+ super().__init__(*layers)
lerobot/common/robot_devices/cameras/configs.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import abc
16
+ from dataclasses import dataclass
17
+
18
+ import draccus
19
+
20
+
21
+ @dataclass
22
+ class CameraConfig(draccus.ChoiceRegistry, abc.ABC):
23
+ @property
24
+ def type(self) -> str:
25
+ return self.get_choice_name(self.__class__)
26
+
27
+
28
+ @CameraConfig.register_subclass("opencv")
29
+ @dataclass
30
+ class OpenCVCameraConfig(CameraConfig):
31
+ """
32
+ Example of tested options for Intel Real Sense D405:
33
+
34
+ ```python
35
+ OpenCVCameraConfig(0, 30, 640, 480)
36
+ OpenCVCameraConfig(0, 60, 640, 480)
37
+ OpenCVCameraConfig(0, 90, 640, 480)
38
+ OpenCVCameraConfig(0, 30, 1280, 720)
39
+ ```
40
+ """
41
+
42
+ camera_index: int
43
+ fps: int | None = None
44
+ width: int | None = None
45
+ height: int | None = None
46
+ color_mode: str = "rgb"
47
+ channels: int | None = None
48
+ rotation: int | None = None
49
+ mock: bool = False
50
+
51
+ def __post_init__(self):
52
+ if self.color_mode not in ["rgb", "bgr"]:
53
+ raise ValueError(
54
+ f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
55
+ )
56
+
57
+ self.channels = 3
58
+
59
+ if self.rotation not in [-90, None, 90, 180]:
60
+ raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
61
+
62
+
63
+ @CameraConfig.register_subclass("intelrealsense")
64
+ @dataclass
65
+ class IntelRealSenseCameraConfig(CameraConfig):
66
+ """
67
+ Example of tested options for Intel Real Sense D405:
68
+
69
+ ```python
70
+ IntelRealSenseCameraConfig(128422271347, 30, 640, 480)
71
+ IntelRealSenseCameraConfig(128422271347, 60, 640, 480)
72
+ IntelRealSenseCameraConfig(128422271347, 90, 640, 480)
73
+ IntelRealSenseCameraConfig(128422271347, 30, 1280, 720)
74
+ IntelRealSenseCameraConfig(128422271347, 30, 640, 480, use_depth=True)
75
+ IntelRealSenseCameraConfig(128422271347, 30, 640, 480, rotation=90)
76
+ ```
77
+ """
78
+
79
+ name: str | None = None
80
+ serial_number: int | None = None
81
+ fps: int | None = None
82
+ width: int | None = None
83
+ height: int | None = None
84
+ color_mode: str = "rgb"
85
+ channels: int | None = None
86
+ use_depth: bool = False
87
+ force_hardware_reset: bool = True
88
+ rotation: int | None = None
89
+ mock: bool = False
90
+
91
+ def __post_init__(self):
92
+ # bool is stronger than is None, since it works with empty strings
93
+ if bool(self.name) and bool(self.serial_number):
94
+ raise ValueError(
95
+ f"One of them must be set: name or serial_number, but {self.name=} and {self.serial_number=} provided."
96
+ )
97
+
98
+ if self.color_mode not in ["rgb", "bgr"]:
99
+ raise ValueError(
100
+ f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
101
+ )
102
+
103
+ self.channels = 3
104
+
105
+ at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
106
+ at_least_one_is_none = self.fps is None or self.width is None or self.height is None
107
+ if at_least_one_is_not_none and at_least_one_is_none:
108
+ raise ValueError(
109
+ "For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
110
+ f"but {self.fps=}, {self.width=}, {self.height=} were provided."
111
+ )
112
+
113
+ if self.rotation not in [-90, None, 90, 180]:
114
+ raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
lerobot/common/robot_devices/cameras/intelrealsense.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ This file contains utilities for recording frames from Intel Realsense cameras.
17
+ """
18
+
19
+ import argparse
20
+ import concurrent.futures
21
+ import logging
22
+ import math
23
+ import shutil
24
+ import threading
25
+ import time
26
+ import traceback
27
+ from collections import Counter
28
+ from pathlib import Path
29
+ from threading import Thread
30
+
31
+ import numpy as np
32
+ from PIL import Image
33
+
34
+ from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig
35
+ from lerobot.common.robot_devices.utils import (
36
+ RobotDeviceAlreadyConnectedError,
37
+ RobotDeviceNotConnectedError,
38
+ busy_wait,
39
+ )
40
+ from lerobot.common.utils.utils import capture_timestamp_utc
41
+
42
+ SERIAL_NUMBER_INDEX = 1
43
+
44
+
45
+ def find_cameras(raise_when_empty=True, mock=False) -> list[dict]:
46
+ """
47
+ Find the names and the serial numbers of the Intel RealSense cameras
48
+ connected to the computer.
49
+ """
50
+ if mock:
51
+ import tests.cameras.mock_pyrealsense2 as rs
52
+ else:
53
+ import pyrealsense2 as rs
54
+
55
+ cameras = []
56
+ for device in rs.context().query_devices():
57
+ serial_number = int(device.get_info(rs.camera_info(SERIAL_NUMBER_INDEX)))
58
+ name = device.get_info(rs.camera_info.name)
59
+ cameras.append(
60
+ {
61
+ "serial_number": serial_number,
62
+ "name": name,
63
+ }
64
+ )
65
+
66
+ if raise_when_empty and len(cameras) == 0:
67
+ raise OSError(
68
+ "Not a single camera was detected. Try re-plugging, or re-installing `librealsense` and its python wrapper `pyrealsense2`, or updating the firmware."
69
+ )
70
+
71
+ return cameras
72
+
73
+
74
+ def save_image(img_array, serial_number, frame_index, images_dir):
75
+ try:
76
+ img = Image.fromarray(img_array)
77
+ path = images_dir / f"camera_{serial_number}_frame_{frame_index:06d}.png"
78
+ path.parent.mkdir(parents=True, exist_ok=True)
79
+ img.save(str(path), quality=100)
80
+ logging.info(f"Saved image: {path}")
81
+ except Exception as e:
82
+ logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}")
83
+
84
+
85
+ def save_images_from_cameras(
86
+ images_dir: Path,
87
+ serial_numbers: list[int] | None = None,
88
+ fps=None,
89
+ width=None,
90
+ height=None,
91
+ record_time_s=2,
92
+ mock=False,
93
+ ):
94
+ """
95
+ Initializes all the cameras and saves images to the directory. Useful to visually identify the camera
96
+ associated to a given serial number.
97
+ """
98
+ if serial_numbers is None or len(serial_numbers) == 0:
99
+ camera_infos = find_cameras(mock=mock)
100
+ serial_numbers = [cam["serial_number"] for cam in camera_infos]
101
+
102
+ if mock:
103
+ import tests.cameras.mock_cv2 as cv2
104
+ else:
105
+ import cv2
106
+
107
+ print("Connecting cameras")
108
+ cameras = []
109
+ for cam_sn in serial_numbers:
110
+ print(f"{cam_sn=}")
111
+ config = IntelRealSenseCameraConfig(
112
+ serial_number=cam_sn, fps=fps, width=width, height=height, mock=mock
113
+ )
114
+ camera = IntelRealSenseCamera(config)
115
+ camera.connect()
116
+ print(
117
+ f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.capture_width}, height={camera.capture_height}, color_mode={camera.color_mode})"
118
+ )
119
+ cameras.append(camera)
120
+
121
+ images_dir = Path(images_dir)
122
+ if images_dir.exists():
123
+ shutil.rmtree(
124
+ images_dir,
125
+ )
126
+ images_dir.mkdir(parents=True, exist_ok=True)
127
+
128
+ print(f"Saving images to {images_dir}")
129
+ frame_index = 0
130
+ start_time = time.perf_counter()
131
+ try:
132
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
133
+ while True:
134
+ now = time.perf_counter()
135
+
136
+ for camera in cameras:
137
+ # If we use async_read when fps is None, the loop will go full speed, and we will end up
138
+ # saving the same images from the cameras multiple times until the RAM/disk is full.
139
+ image = camera.read() if fps is None else camera.async_read()
140
+ if image is None:
141
+ print("No Frame")
142
+
143
+ bgr_converted_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
144
+
145
+ executor.submit(
146
+ save_image,
147
+ bgr_converted_image,
148
+ camera.serial_number,
149
+ frame_index,
150
+ images_dir,
151
+ )
152
+
153
+ if fps is not None:
154
+ dt_s = time.perf_counter() - now
155
+ busy_wait(1 / fps - dt_s)
156
+
157
+ if time.perf_counter() - start_time > record_time_s:
158
+ break
159
+
160
+ print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
161
+
162
+ frame_index += 1
163
+ finally:
164
+ print(f"Images have been saved to {images_dir}")
165
+ for camera in cameras:
166
+ camera.disconnect()
167
+
168
+
169
+ class IntelRealSenseCamera:
170
+ """
171
+ The IntelRealSenseCamera class is similar to OpenCVCamera class but adds additional features for Intel Real Sense cameras:
172
+ - is instantiated with the serial number of the camera - won't randomly change as it can be the case of OpenCVCamera for Linux,
173
+ - can also be instantiated with the camera's name — if it's unique — using IntelRealSenseCamera.init_from_name(),
174
+ - depth map can be returned.
175
+
176
+ To find the camera indices of your cameras, you can run our utility script that will save a few frames for each camera:
177
+ ```bash
178
+ python lerobot/common/robot_devices/cameras/intelrealsense.py --images-dir outputs/images_from_intelrealsense_cameras
179
+ ```
180
+
181
+ When an IntelRealSenseCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
182
+ of the given camera will be used.
183
+
184
+ Example of instantiating with a serial number:
185
+ ```python
186
+ from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig
187
+
188
+ config = IntelRealSenseCameraConfig(serial_number=128422271347)
189
+ camera = IntelRealSenseCamera(config)
190
+ camera.connect()
191
+ color_image = camera.read()
192
+ # when done using the camera, consider disconnecting
193
+ camera.disconnect()
194
+ ```
195
+
196
+ Example of instantiating with a name if it's unique:
197
+ ```
198
+ config = IntelRealSenseCameraConfig(name="Intel RealSense D405")
199
+ ```
200
+
201
+ Example of changing default fps, width, height and color_mode:
202
+ ```python
203
+ config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=30, width=1280, height=720)
204
+ config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480)
205
+ config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480, color_mode="bgr")
206
+ # Note: might error out upon `camera.connect()` if these settings are not compatible with the camera
207
+ ```
208
+
209
+ Example of returning depth:
210
+ ```python
211
+ config = IntelRealSenseCameraConfig(serial_number=128422271347, use_depth=True)
212
+ camera = IntelRealSenseCamera(config)
213
+ camera.connect()
214
+ color_image, depth_map = camera.read()
215
+ ```
216
+ """
217
+
218
+ def __init__(
219
+ self,
220
+ config: IntelRealSenseCameraConfig,
221
+ ):
222
+ self.config = config
223
+ if config.name is not None:
224
+ self.serial_number = self.find_serial_number_from_name(config.name)
225
+ else:
226
+ self.serial_number = config.serial_number
227
+
228
+ # Store the raw (capture) resolution from the config.
229
+ self.capture_width = config.width
230
+ self.capture_height = config.height
231
+
232
+ # If rotated by ±90, swap width and height.
233
+ if config.rotation in [-90, 90]:
234
+ self.width = config.height
235
+ self.height = config.width
236
+ else:
237
+ self.width = config.width
238
+ self.height = config.height
239
+
240
+ self.fps = config.fps
241
+ self.channels = config.channels
242
+ self.color_mode = config.color_mode
243
+ self.use_depth = config.use_depth
244
+ self.force_hardware_reset = config.force_hardware_reset
245
+ self.mock = config.mock
246
+
247
+ self.camera = None
248
+ self.is_connected = False
249
+ self.thread = None
250
+ self.stop_event = None
251
+ self.color_image = None
252
+ self.depth_map = None
253
+ self.logs = {}
254
+
255
+ if self.mock:
256
+ import tests.cameras.mock_cv2 as cv2
257
+ else:
258
+ import cv2
259
+
260
+ self.rotation = None
261
+ if config.rotation == -90:
262
+ self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
263
+ elif config.rotation == 90:
264
+ self.rotation = cv2.ROTATE_90_CLOCKWISE
265
+ elif config.rotation == 180:
266
+ self.rotation = cv2.ROTATE_180
267
+
268
+ def find_serial_number_from_name(self, name):
269
+ camera_infos = find_cameras()
270
+ camera_names = [cam["name"] for cam in camera_infos]
271
+ this_name_count = Counter(camera_names)[name]
272
+ if this_name_count > 1:
273
+ # TODO(aliberts): Test this with multiple identical cameras (Aloha)
274
+ raise ValueError(
275
+ f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them."
276
+ )
277
+
278
+ name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos}
279
+ cam_sn = name_to_serial_dict[name]
280
+
281
+ return cam_sn
282
+
283
+ def connect(self):
284
+ if self.is_connected:
285
+ raise RobotDeviceAlreadyConnectedError(
286
+ f"IntelRealSenseCamera({self.serial_number}) is already connected."
287
+ )
288
+
289
+ if self.mock:
290
+ import tests.cameras.mock_pyrealsense2 as rs
291
+ else:
292
+ import pyrealsense2 as rs
293
+
294
+ config = rs.config()
295
+ config.enable_device(str(self.serial_number))
296
+
297
+ if self.fps and self.capture_width and self.capture_height:
298
+ # TODO(rcadene): can we set rgb8 directly?
299
+ config.enable_stream(
300
+ rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps
301
+ )
302
+ else:
303
+ config.enable_stream(rs.stream.color)
304
+
305
+ if self.use_depth:
306
+ if self.fps and self.capture_width and self.capture_height:
307
+ config.enable_stream(
308
+ rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps
309
+ )
310
+ else:
311
+ config.enable_stream(rs.stream.depth)
312
+
313
+ self.camera = rs.pipeline()
314
+ try:
315
+ profile = self.camera.start(config)
316
+ is_camera_open = True
317
+ except RuntimeError:
318
+ is_camera_open = False
319
+ traceback.print_exc()
320
+
321
+ # If the camera doesn't work, display the camera indices corresponding to
322
+ # valid cameras.
323
+ if not is_camera_open:
324
+ # Verify that the provided `serial_number` is valid before printing the traceback
325
+ camera_infos = find_cameras()
326
+ serial_numbers = [cam["serial_number"] for cam in camera_infos]
327
+ if self.serial_number not in serial_numbers:
328
+ raise ValueError(
329
+ f"`serial_number` is expected to be one of these available cameras {serial_numbers}, but {self.serial_number} is provided instead. "
330
+ "To find the serial number you should use, run `python lerobot/common/robot_devices/cameras/intelrealsense.py`."
331
+ )
332
+
333
+ raise OSError(f"Can't access IntelRealSenseCamera({self.serial_number}).")
334
+
335
+ color_stream = profile.get_stream(rs.stream.color)
336
+ color_profile = color_stream.as_video_stream_profile()
337
+ actual_fps = color_profile.fps()
338
+ actual_width = color_profile.width()
339
+ actual_height = color_profile.height()
340
+
341
+ # Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
342
+ if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
343
+ # Using `OSError` since it's a broad that encompasses issues related to device communication
344
+ raise OSError(
345
+ f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
346
+ )
347
+ if self.capture_width is not None and self.capture_width != actual_width:
348
+ raise OSError(
349
+ f"Can't set {self.capture_width=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_width}."
350
+ )
351
+ if self.capture_height is not None and self.capture_height != actual_height:
352
+ raise OSError(
353
+ f"Can't set {self.capture_height=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_height}."
354
+ )
355
+
356
+ self.fps = round(actual_fps)
357
+ self.capture_width = round(actual_width)
358
+ self.capture_height = round(actual_height)
359
+
360
+ self.is_connected = True
361
+
362
+ def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
363
+ """Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3)
364
+ of type `np.uint8`, contrarily to the pytorch format which is float channel first.
365
+
366
+ When `use_depth=True`, returns a tuple `(color_image, depth_map)` with a depth map in the format
367
+ height x width (e.g. 480 x 640) of type np.uint16.
368
+
369
+ Note: Reading a frame is done every `camera.fps` times per second, and it is blocking.
370
+ If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`.
371
+ """
372
+ if not self.is_connected:
373
+ raise RobotDeviceNotConnectedError(
374
+ f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
375
+ )
376
+
377
+ if self.mock:
378
+ import tests.cameras.mock_cv2 as cv2
379
+ else:
380
+ import cv2
381
+
382
+ start_time = time.perf_counter()
383
+
384
+ frame = self.camera.wait_for_frames(timeout_ms=5000)
385
+
386
+ color_frame = frame.get_color_frame()
387
+
388
+ if not color_frame:
389
+ raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).")
390
+
391
+ color_image = np.asanyarray(color_frame.get_data())
392
+
393
+ requested_color_mode = self.color_mode if temporary_color is None else temporary_color
394
+ if requested_color_mode not in ["rgb", "bgr"]:
395
+ raise ValueError(
396
+ f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided."
397
+ )
398
+
399
+ # IntelRealSense uses RGB format as default (red, green, blue).
400
+ if requested_color_mode == "bgr":
401
+ color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
402
+
403
+ h, w, _ = color_image.shape
404
+ if h != self.capture_height or w != self.capture_width:
405
+ raise OSError(
406
+ f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
407
+ )
408
+
409
+ if self.rotation is not None:
410
+ color_image = cv2.rotate(color_image, self.rotation)
411
+
412
+ # log the number of seconds it took to read the image
413
+ self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
414
+
415
+ # log the utc time at which the image was received
416
+ self.logs["timestamp_utc"] = capture_timestamp_utc()
417
+
418
+ if self.use_depth:
419
+ depth_frame = frame.get_depth_frame()
420
+ if not depth_frame:
421
+ raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).")
422
+
423
+ depth_map = np.asanyarray(depth_frame.get_data())
424
+
425
+ h, w = depth_map.shape
426
+ if h != self.capture_height or w != self.capture_width:
427
+ raise OSError(
428
+ f"Can't capture depth map with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
429
+ )
430
+
431
+ if self.rotation is not None:
432
+ depth_map = cv2.rotate(depth_map, self.rotation)
433
+
434
+ return color_image, depth_map
435
+ else:
436
+ return color_image
437
+
438
+ def read_loop(self):
439
+ while not self.stop_event.is_set():
440
+ if self.use_depth:
441
+ self.color_image, self.depth_map = self.read()
442
+ else:
443
+ self.color_image = self.read()
444
+
445
+ def async_read(self):
446
+ """Access the latest color image"""
447
+ if not self.is_connected:
448
+ raise RobotDeviceNotConnectedError(
449
+ f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
450
+ )
451
+
452
+ if self.thread is None:
453
+ self.stop_event = threading.Event()
454
+ self.thread = Thread(target=self.read_loop, args=())
455
+ self.thread.daemon = True
456
+ self.thread.start()
457
+
458
+ num_tries = 0
459
+ while self.color_image is None:
460
+ # TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here
461
+ num_tries += 1
462
+ time.sleep(1 / self.fps)
463
+ if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
464
+ raise Exception(
465
+ "The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
466
+ )
467
+
468
+ if self.use_depth:
469
+ return self.color_image, self.depth_map
470
+ else:
471
+ return self.color_image
472
+
473
+ def disconnect(self):
474
+ if not self.is_connected:
475
+ raise RobotDeviceNotConnectedError(
476
+ f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
477
+ )
478
+
479
+ if self.thread is not None and self.thread.is_alive():
480
+ # wait for the thread to finish
481
+ self.stop_event.set()
482
+ self.thread.join()
483
+ self.thread = None
484
+ self.stop_event = None
485
+
486
+ self.camera.stop()
487
+ self.camera = None
488
+
489
+ self.is_connected = False
490
+
491
+ def __del__(self):
492
+ if getattr(self, "is_connected", False):
493
+ self.disconnect()
494
+
495
+
496
+ if __name__ == "__main__":
497
+ parser = argparse.ArgumentParser(
498
+ description="Save a few frames using `IntelRealSenseCamera` for all cameras connected to the computer, or a selected subset."
499
+ )
500
+ parser.add_argument(
501
+ "--serial-numbers",
502
+ type=int,
503
+ nargs="*",
504
+ default=None,
505
+ help="List of serial numbers used to instantiate the `IntelRealSenseCamera`. If not provided, find and use all available camera indices.",
506
+ )
507
+ parser.add_argument(
508
+ "--fps",
509
+ type=int,
510
+ default=30,
511
+ help="Set the number of frames recorded per seconds for all cameras. If not provided, use the default fps of each camera.",
512
+ )
513
+ parser.add_argument(
514
+ "--width",
515
+ type=str,
516
+ default=640,
517
+ help="Set the width for all cameras. If not provided, use the default width of each camera.",
518
+ )
519
+ parser.add_argument(
520
+ "--height",
521
+ type=str,
522
+ default=480,
523
+ help="Set the height for all cameras. If not provided, use the default height of each camera.",
524
+ )
525
+ parser.add_argument(
526
+ "--images-dir",
527
+ type=Path,
528
+ default="outputs/images_from_intelrealsense_cameras",
529
+ help="Set directory to save a few frames for each camera.",
530
+ )
531
+ parser.add_argument(
532
+ "--record-time-s",
533
+ type=float,
534
+ default=2.0,
535
+ help="Set the number of seconds used to record the frames. By default, 2 seconds.",
536
+ )
537
+ args = parser.parse_args()
538
+ save_images_from_cameras(**vars(args))
lerobot/common/robot_devices/cameras/opencv.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring.
17
+ """
18
+
19
+ import argparse
20
+ import concurrent.futures
21
+ import math
22
+ import platform
23
+ import shutil
24
+ import threading
25
+ import time
26
+ from pathlib import Path
27
+ from threading import Thread
28
+
29
+ import numpy as np
30
+ from PIL import Image
31
+
32
+ from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
33
+ from lerobot.common.robot_devices.utils import (
34
+ RobotDeviceAlreadyConnectedError,
35
+ RobotDeviceNotConnectedError,
36
+ busy_wait,
37
+ )
38
+ from lerobot.common.utils.utils import capture_timestamp_utc
39
+
40
+ # The maximum opencv device index depends on your operating system. For instance,
41
+ # if you have 3 cameras, they should be associated to index 0, 1, and 2. This is the case
42
+ # on MacOS. However, on Ubuntu, the indices are different like 6, 16, 23.
43
+ # When you change the USB port or reboot the computer, the operating system might
44
+ # treat the same cameras as new devices. Thus we select a higher bound to search indices.
45
+ MAX_OPENCV_INDEX = 60
46
+
47
+
48
+ def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
49
+ cameras = []
50
+ if platform.system() == "Linux":
51
+ print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
52
+ possible_ports = [str(port) for port in Path("/dev").glob("video*")]
53
+ ports = _find_cameras(possible_ports, mock=mock)
54
+ for port in ports:
55
+ cameras.append(
56
+ {
57
+ "port": port,
58
+ "index": int(port.removeprefix("/dev/video")),
59
+ }
60
+ )
61
+ else:
62
+ print(
63
+ "Mac or Windows detected. Finding available camera indices through "
64
+ f"scanning all indices from 0 to {MAX_OPENCV_INDEX}"
65
+ )
66
+ possible_indices = range(max_index_search_range)
67
+ indices = _find_cameras(possible_indices, mock=mock)
68
+ for index in indices:
69
+ cameras.append(
70
+ {
71
+ "port": None,
72
+ "index": index,
73
+ }
74
+ )
75
+
76
+ return cameras
77
+
78
+
79
+ def _find_cameras(
80
+ possible_camera_ids: list[int | str], raise_when_empty=False, mock=False
81
+ ) -> list[int | str]:
82
+ if mock:
83
+ import tests.cameras.mock_cv2 as cv2
84
+ else:
85
+ import cv2
86
+
87
+ camera_ids = []
88
+ for camera_idx in possible_camera_ids:
89
+ camera = cv2.VideoCapture(camera_idx)
90
+ is_open = camera.isOpened()
91
+ camera.release()
92
+
93
+ if is_open:
94
+ print(f"Camera found at index {camera_idx}")
95
+ camera_ids.append(camera_idx)
96
+
97
+ if raise_when_empty and len(camera_ids) == 0:
98
+ raise OSError(
99
+ "Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, "
100
+ "or your camera driver, or make sure your camera is compatible with opencv2."
101
+ )
102
+
103
+ return camera_ids
104
+
105
+
106
+ def is_valid_unix_path(path: str) -> bool:
107
+ """Note: if 'path' points to a symlink, this will return True only if the target exists"""
108
+ p = Path(path)
109
+ return p.is_absolute() and p.exists()
110
+
111
+
112
+ def get_camera_index_from_unix_port(port: Path) -> int:
113
+ return int(str(port.resolve()).removeprefix("/dev/video"))
114
+
115
+
116
+ def save_image(img_array, camera_index, frame_index, images_dir):
117
+ img = Image.fromarray(img_array)
118
+ path = images_dir / f"camera_{camera_index:02d}_frame_{frame_index:06d}.png"
119
+ path.parent.mkdir(parents=True, exist_ok=True)
120
+ img.save(str(path), quality=100)
121
+
122
+
123
+ def save_images_from_cameras(
124
+ images_dir: Path,
125
+ camera_ids: list | None = None,
126
+ fps=None,
127
+ width=None,
128
+ height=None,
129
+ record_time_s=2,
130
+ mock=False,
131
+ ):
132
+ """
133
+ Initializes all the cameras and saves images to the directory. Useful to visually identify the camera
134
+ associated to a given camera index.
135
+ """
136
+ if camera_ids is None or len(camera_ids) == 0:
137
+ camera_infos = find_cameras(mock=mock)
138
+ camera_ids = [cam["index"] for cam in camera_infos]
139
+
140
+ print("Connecting cameras")
141
+ cameras = []
142
+ for cam_idx in camera_ids:
143
+ config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock)
144
+ camera = OpenCVCamera(config)
145
+ camera.connect()
146
+ print(
147
+ f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.capture_width}, "
148
+ f"height={camera.capture_height}, color_mode={camera.color_mode})"
149
+ )
150
+ cameras.append(camera)
151
+
152
+ images_dir = Path(images_dir)
153
+ if images_dir.exists():
154
+ shutil.rmtree(
155
+ images_dir,
156
+ )
157
+ images_dir.mkdir(parents=True, exist_ok=True)
158
+
159
+ print(f"Saving images to {images_dir}")
160
+ frame_index = 0
161
+ start_time = time.perf_counter()
162
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
163
+ while True:
164
+ now = time.perf_counter()
165
+
166
+ for camera in cameras:
167
+ # If we use async_read when fps is None, the loop will go full speed, and we will endup
168
+ # saving the same images from the cameras multiple times until the RAM/disk is full.
169
+ image = camera.read() if fps is None else camera.async_read()
170
+
171
+ executor.submit(
172
+ save_image,
173
+ image,
174
+ camera.camera_index,
175
+ frame_index,
176
+ images_dir,
177
+ )
178
+
179
+ if fps is not None:
180
+ dt_s = time.perf_counter() - now
181
+ busy_wait(1 / fps - dt_s)
182
+
183
+ print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
184
+
185
+ if time.perf_counter() - start_time > record_time_s:
186
+ break
187
+
188
+ frame_index += 1
189
+
190
+ print(f"Images have been saved to {images_dir}")
191
+
192
+
193
+ class OpenCVCamera:
194
+ """
195
+ The OpenCVCamera class allows to efficiently record images from cameras. It relies on opencv2 to communicate
196
+ with the cameras. Most cameras are compatible. For more info, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
197
+
198
+ An OpenCVCamera instance requires a camera index (e.g. `OpenCVCamera(camera_index=0)`). When you only have one camera
199
+ like a webcam of a laptop, the camera index is expected to be 0, but it might also be very different, and the camera index
200
+ might change if you reboot your computer or re-plug your camera. This behavior depends on your operation system.
201
+
202
+ To find the camera indices of your cameras, you can run our utility script that will be save a few frames for each camera:
203
+ ```bash
204
+ python lerobot/common/robot_devices/cameras/opencv.py --images-dir outputs/images_from_opencv_cameras
205
+ ```
206
+
207
+ When an OpenCVCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
208
+ of the given camera will be used.
209
+
210
+ Example of usage:
211
+ ```python
212
+ from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
213
+
214
+ config = OpenCVCameraConfig(camera_index=0)
215
+ camera = OpenCVCamera(config)
216
+ camera.connect()
217
+ color_image = camera.read()
218
+ # when done using the camera, consider disconnecting
219
+ camera.disconnect()
220
+ ```
221
+
222
+ Example of changing default fps, width, height and color_mode:
223
+ ```python
224
+ config = OpenCVCameraConfig(camera_index=0, fps=30, width=1280, height=720)
225
+ config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480)
226
+ config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480, color_mode="bgr")
227
+ # Note: might error out open `camera.connect()` if these settings are not compatible with the camera
228
+ ```
229
+ """
230
+
231
+ def __init__(self, config: OpenCVCameraConfig):
232
+ self.config = config
233
+ self.camera_index = config.camera_index
234
+ self.port = None
235
+
236
+ # Linux uses ports for connecting to cameras
237
+ if platform.system() == "Linux":
238
+ if isinstance(self.camera_index, int):
239
+ self.port = Path(f"/dev/video{self.camera_index}")
240
+ elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index):
241
+ self.port = Path(self.camera_index)
242
+ # Retrieve the camera index from a potentially symlinked path
243
+ self.camera_index = get_camera_index_from_unix_port(self.port)
244
+ else:
245
+ raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
246
+
247
+ # Store the raw (capture) resolution from the config.
248
+ self.capture_width = config.width
249
+ self.capture_height = config.height
250
+
251
+ # If rotated by ±90, swap width and height.
252
+ if config.rotation in [-90, 90]:
253
+ self.width = config.height
254
+ self.height = config.width
255
+ else:
256
+ self.width = config.width
257
+ self.height = config.height
258
+
259
+ self.fps = config.fps
260
+ self.channels = config.channels
261
+ self.color_mode = config.color_mode
262
+ self.mock = config.mock
263
+
264
+ self.camera = None
265
+ self.is_connected = False
266
+ self.thread = None
267
+ self.stop_event = None
268
+ self.color_image = None
269
+ self.logs = {}
270
+
271
+ if self.mock:
272
+ import tests.cameras.mock_cv2 as cv2
273
+ else:
274
+ import cv2
275
+
276
+ self.rotation = None
277
+ if config.rotation == -90:
278
+ self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
279
+ elif config.rotation == 90:
280
+ self.rotation = cv2.ROTATE_90_CLOCKWISE
281
+ elif config.rotation == 180:
282
+ self.rotation = cv2.ROTATE_180
283
+
284
+ def connect(self):
285
+ if self.is_connected:
286
+ raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
287
+
288
+ if self.mock:
289
+ import tests.cameras.mock_cv2 as cv2
290
+ else:
291
+ import cv2
292
+
293
+ # Use 1 thread to avoid blocking the main thread. Especially useful during data collection
294
+ # when other threads are used to save the images.
295
+ cv2.setNumThreads(1)
296
+
297
+ backend = (
298
+ cv2.CAP_V4L2
299
+ if platform.system() == "Linux"
300
+ else cv2.CAP_DSHOW
301
+ if platform.system() == "Windows"
302
+ else cv2.CAP_AVFOUNDATION
303
+ if platform.system() == "Darwin"
304
+ else cv2.CAP_ANY
305
+ )
306
+
307
+ camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
308
+ # First create a temporary camera trying to access `camera_index`,
309
+ # and verify it is a valid camera by calling `isOpened`.
310
+ tmp_camera = cv2.VideoCapture(camera_idx, backend)
311
+ is_camera_open = tmp_camera.isOpened()
312
+ # Release camera to make it accessible for `find_camera_indices`
313
+ tmp_camera.release()
314
+ del tmp_camera
315
+
316
+ # If the camera doesn't work, display the camera indices corresponding to
317
+ # valid cameras.
318
+ if not is_camera_open:
319
+ # Verify that the provided `camera_index` is valid before printing the traceback
320
+ cameras_info = find_cameras()
321
+ available_cam_ids = [cam["index"] for cam in cameras_info]
322
+ if self.camera_index not in available_cam_ids:
323
+ raise ValueError(
324
+ f"`camera_index` is expected to be one of these available cameras {available_cam_ids}, but {self.camera_index} is provided instead. "
325
+ "To find the camera index you should use, run `python lerobot/common/robot_devices/cameras/opencv.py`."
326
+ )
327
+
328
+ raise OSError(f"Can't access OpenCVCamera({camera_idx}).")
329
+
330
+ # Secondly, create the camera that will be used downstream.
331
+ # Note: For some unknown reason, calling `isOpened` blocks the camera which then
332
+ # needs to be re-created.
333
+ self.camera = cv2.VideoCapture(camera_idx, backend)
334
+
335
+ if self.fps is not None:
336
+ self.camera.set(cv2.CAP_PROP_FPS, self.fps)
337
+ if self.capture_width is not None:
338
+ self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.capture_width)
339
+ if self.capture_height is not None:
340
+ self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.capture_height)
341
+
342
+ actual_fps = self.camera.get(cv2.CAP_PROP_FPS)
343
+ actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH)
344
+ actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
345
+
346
+ # Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
347
+ if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
348
+ # Using `OSError` since it's a broad that encompasses issues related to device communication
349
+ raise OSError(
350
+ f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
351
+ )
352
+ if self.capture_width is not None and not math.isclose(
353
+ self.capture_width, actual_width, rel_tol=1e-3
354
+ ):
355
+ raise OSError(
356
+ f"Can't set {self.capture_width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
357
+ )
358
+ if self.capture_height is not None and not math.isclose(
359
+ self.capture_height, actual_height, rel_tol=1e-3
360
+ ):
361
+ raise OSError(
362
+ f"Can't set {self.capture_height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
363
+ )
364
+
365
+ self.fps = round(actual_fps)
366
+ self.capture_width = round(actual_width)
367
+ self.capture_height = round(actual_height)
368
+ self.is_connected = True
369
+
370
+ def read(self, temporary_color_mode: str | None = None) -> np.ndarray:
371
+ """Read a frame from the camera returned in the format (height, width, channels)
372
+ (e.g. 480 x 640 x 3), contrarily to the pytorch format which is channel first.
373
+
374
+ Note: Reading a frame is done every `camera.fps` times per second, and it is blocking.
375
+ If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`.
376
+ """
377
+ if not self.is_connected:
378
+ raise RobotDeviceNotConnectedError(
379
+ f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
380
+ )
381
+
382
+ start_time = time.perf_counter()
383
+
384
+ ret, color_image = self.camera.read()
385
+
386
+ if not ret:
387
+ raise OSError(f"Can't capture color image from camera {self.camera_index}.")
388
+
389
+ requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode
390
+
391
+ if requested_color_mode not in ["rgb", "bgr"]:
392
+ raise ValueError(
393
+ f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided."
394
+ )
395
+
396
+ # OpenCV uses BGR format as default (blue, green, red) for all operations, including displaying images.
397
+ # However, Deep Learning framework such as LeRobot uses RGB format as default to train neural networks,
398
+ # so we convert the image color from BGR to RGB.
399
+ if requested_color_mode == "rgb":
400
+ if self.mock:
401
+ import tests.cameras.mock_cv2 as cv2
402
+ else:
403
+ import cv2
404
+
405
+ color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
406
+
407
+ h, w, _ = color_image.shape
408
+ if h != self.capture_height or w != self.capture_width:
409
+ raise OSError(
410
+ f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
411
+ )
412
+
413
+ if self.rotation is not None:
414
+ color_image = cv2.rotate(color_image, self.rotation)
415
+
416
+ # log the number of seconds it took to read the image
417
+ self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
418
+
419
+ # log the utc time at which the image was received
420
+ self.logs["timestamp_utc"] = capture_timestamp_utc()
421
+
422
+ self.color_image = color_image
423
+
424
+ return color_image
425
+
426
+ def read_loop(self):
427
+ while not self.stop_event.is_set():
428
+ try:
429
+ self.color_image = self.read()
430
+ except Exception as e:
431
+ print(f"Error reading in thread: {e}")
432
+
433
+ def async_read(self):
434
+ if not self.is_connected:
435
+ raise RobotDeviceNotConnectedError(
436
+ f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
437
+ )
438
+
439
+ if self.thread is None:
440
+ self.stop_event = threading.Event()
441
+ self.thread = Thread(target=self.read_loop, args=())
442
+ self.thread.daemon = True
443
+ self.thread.start()
444
+
445
+ num_tries = 0
446
+ while True:
447
+ if self.color_image is not None:
448
+ return self.color_image
449
+
450
+ time.sleep(1 / self.fps)
451
+ num_tries += 1
452
+ if num_tries > self.fps * 2:
453
+ raise TimeoutError("Timed out waiting for async_read() to start.")
454
+
455
+ def disconnect(self):
456
+ if not self.is_connected:
457
+ raise RobotDeviceNotConnectedError(
458
+ f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
459
+ )
460
+
461
+ if self.thread is not None:
462
+ self.stop_event.set()
463
+ self.thread.join() # wait for the thread to finish
464
+ self.thread = None
465
+ self.stop_event = None
466
+
467
+ self.camera.release()
468
+ self.camera = None
469
+ self.is_connected = False
470
+
471
+ def __del__(self):
472
+ if getattr(self, "is_connected", False):
473
+ self.disconnect()
474
+
475
+
476
+ if __name__ == "__main__":
477
+ parser = argparse.ArgumentParser(
478
+ description="Save a few frames using `OpenCVCamera` for all cameras connected to the computer, or a selected subset."
479
+ )
480
+ parser.add_argument(
481
+ "--camera-ids",
482
+ type=int,
483
+ nargs="*",
484
+ default=None,
485
+ help="List of camera indices used to instantiate the `OpenCVCamera`. If not provided, find and use all available camera indices.",
486
+ )
487
+ parser.add_argument(
488
+ "--fps",
489
+ type=int,
490
+ default=None,
491
+ help="Set the number of frames recorded per seconds for all cameras. If not provided, use the default fps of each camera.",
492
+ )
493
+ parser.add_argument(
494
+ "--width",
495
+ type=str,
496
+ default=None,
497
+ help="Set the width for all cameras. If not provided, use the default width of each camera.",
498
+ )
499
+ parser.add_argument(
500
+ "--height",
501
+ type=str,
502
+ default=None,
503
+ help="Set the height for all cameras. If not provided, use the default height of each camera.",
504
+ )
505
+ parser.add_argument(
506
+ "--images-dir",
507
+ type=Path,
508
+ default="outputs/images_from_opencv_cameras",
509
+ help="Set directory to save a few frames for each camera.",
510
+ )
511
+ parser.add_argument(
512
+ "--record-time-s",
513
+ type=float,
514
+ default=4.0,
515
+ help="Set the number of seconds used to record the frames. By default, 2 seconds.",
516
+ )
517
+ args = parser.parse_args()
518
+ save_images_from_cameras(**vars(args))
lerobot/common/robot_devices/cameras/utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Protocol
16
+
17
+ import numpy as np
18
+
19
+ from lerobot.common.robot_devices.cameras.configs import (
20
+ CameraConfig,
21
+ IntelRealSenseCameraConfig,
22
+ OpenCVCameraConfig,
23
+ )
24
+
25
+
26
+ # Defines a camera type
27
+ class Camera(Protocol):
28
+ def connect(self): ...
29
+ def read(self, temporary_color: str | None = None) -> np.ndarray: ...
30
+ def async_read(self) -> np.ndarray: ...
31
+ def disconnect(self): ...
32
+
33
+
34
+ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[Camera]:
35
+ cameras = {}
36
+
37
+ for key, cfg in camera_configs.items():
38
+ if cfg.type == "opencv":
39
+ from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
40
+
41
+ cameras[key] = OpenCVCamera(cfg)
42
+
43
+ elif cfg.type == "intelrealsense":
44
+ from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
45
+
46
+ cameras[key] = IntelRealSenseCamera(cfg)
47
+ else:
48
+ raise ValueError(f"The camera type '{cfg.type}' is not valid.")
49
+
50
+ return cameras
51
+
52
+
53
+ def make_camera(camera_type, **kwargs) -> Camera:
54
+ if camera_type == "opencv":
55
+ from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
56
+
57
+ config = OpenCVCameraConfig(**kwargs)
58
+ return OpenCVCamera(config)
59
+
60
+ elif camera_type == "intelrealsense":
61
+ from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
62
+
63
+ config = IntelRealSenseCameraConfig(**kwargs)
64
+ return IntelRealSenseCamera(config)
65
+
66
+ else:
67
+ raise ValueError(f"The camera type '{camera_type}' is not valid.")
lerobot/common/robot_devices/control_configs.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from pathlib import Path
17
+
18
+ import draccus
19
+
20
+ from lerobot.common.robot_devices.robots.configs import RobotConfig
21
+ from lerobot.configs import parser
22
+ from lerobot.configs.policies import PreTrainedConfig
23
+
24
+
25
+ @dataclass
26
+ class ControlConfig(draccus.ChoiceRegistry):
27
+ pass
28
+
29
+
30
+ @ControlConfig.register_subclass("calibrate")
31
+ @dataclass
32
+ class CalibrateControlConfig(ControlConfig):
33
+ # List of arms to calibrate (e.g. `--arms='["left_follower","right_follower"]' left_leader`)
34
+ arms: list[str] | None = None
35
+
36
+
37
+ @ControlConfig.register_subclass("teleoperate")
38
+ @dataclass
39
+ class TeleoperateControlConfig(ControlConfig):
40
+ # Limit the maximum frames per second. By default, no limit.
41
+ fps: int | None = None
42
+ teleop_time_s: float | None = None
43
+ # Display all cameras on screen
44
+ display_cameras: bool = True
45
+
46
+
47
+ @ControlConfig.register_subclass("record")
48
+ @dataclass
49
+ class RecordControlConfig(ControlConfig):
50
+ # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
51
+ repo_id: str
52
+ # A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
53
+ single_task: str
54
+ # Root directory where the dataset will be stored (e.g. 'dataset/path').
55
+ root: str | Path | None = None
56
+ policy: PreTrainedConfig | None = None
57
+ # Limit the frames per second. By default, uses the policy fps.
58
+ fps: int | None = None
59
+ # Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
60
+ warmup_time_s: int | float = 10
61
+ # Number of seconds for data recording for each episode.
62
+ episode_time_s: int | float = 60
63
+ # Number of seconds for resetting the environment after each episode.
64
+ reset_time_s: int | float = 60
65
+ # Number of episodes to record.
66
+ num_episodes: int = 50
67
+ # Encode frames in the dataset into video
68
+ video: bool = True
69
+ # Upload dataset to Hugging Face hub.
70
+ push_to_hub: bool = True
71
+ # Upload on private repository on the Hugging Face hub.
72
+ private: bool = False
73
+ # Add tags to your dataset on the hub.
74
+ tags: list[str] | None = None
75
+ # Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
76
+ # set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
77
+ # and threads depends on your system. We recommend 4 threads per camera with 0 processes.
78
+ # If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
79
+ num_image_writer_processes: int = 0
80
+ # Number of threads writing the frames as png images on disk, per camera.
81
+ # Too many threads might cause unstable teleoperation fps due to main thread being blocked.
82
+ # Not enough threads might cause low camera fps.
83
+ num_image_writer_threads_per_camera: int = 4
84
+ # Display all cameras on screen
85
+ display_cameras: bool = True
86
+ # Use vocal synthesis to read events.
87
+ play_sounds: bool = True
88
+ # Resume recording on an existing dataset.
89
+ resume: bool = False
90
+
91
+ def __post_init__(self):
92
+ # HACK: We parse again the cli args here to get the pretrained path if there was one.
93
+ policy_path = parser.get_path_arg("control.policy")
94
+ if policy_path:
95
+ cli_overrides = parser.get_cli_overrides("control.policy")
96
+ self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
97
+ self.policy.pretrained_path = policy_path
98
+
99
+
100
+ @ControlConfig.register_subclass("replay")
101
+ @dataclass
102
+ class ReplayControlConfig(ControlConfig):
103
+ # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
104
+ repo_id: str
105
+ # Index of the episode to replay.
106
+ episode: int
107
+ # Root directory where the dataset will be stored (e.g. 'dataset/path').
108
+ root: str | Path | None = None
109
+ # Limit the frames per second. By default, uses the dataset fps.
110
+ fps: int | None = None
111
+ # Use vocal synthesis to read events.
112
+ play_sounds: bool = True
113
+
114
+
115
+ @ControlConfig.register_subclass("remote_robot")
116
+ @dataclass
117
+ class RemoteRobotConfig(ControlConfig):
118
+ log_interval: int = 100
119
+
120
+
121
+ @dataclass
122
+ class ControlPipelineConfig:
123
+ robot: RobotConfig
124
+ control: ControlConfig
125
+
126
+ @classmethod
127
+ def __get_path_fields__(cls) -> list[str]:
128
+ """This enables the parser to load config from the policy using `--policy.path=local/dir`"""
129
+ return ["control.policy"]
lerobot/common/robot_devices/control_utils.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ ########################################################################################
16
+ # Utilities
17
+ ########################################################################################
18
+
19
+
20
+ import logging
21
+ import time
22
+ import traceback
23
+ from contextlib import nullcontext
24
+ from copy import copy
25
+ from functools import cache
26
+
27
+ import cv2
28
+ import torch
29
+ from deepdiff import DeepDiff
30
+ from termcolor import colored
31
+
32
+ from lerobot.common.datasets.image_writer import safe_stop_image_writer
33
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
34
+ from lerobot.common.datasets.utils import get_features_from_robot
35
+ from lerobot.common.policies.pretrained import PreTrainedPolicy
36
+ from lerobot.common.robot_devices.robots.utils import Robot
37
+ from lerobot.common.robot_devices.utils import busy_wait
38
+ from lerobot.common.utils.utils import get_safe_torch_device, has_method
39
+
40
+
41
+ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
42
+ log_items = []
43
+ if episode_index is not None:
44
+ log_items.append(f"ep:{episode_index}")
45
+ if frame_index is not None:
46
+ log_items.append(f"frame:{frame_index}")
47
+
48
+ def log_dt(shortname, dt_val_s):
49
+ nonlocal log_items, fps
50
+ info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
51
+ if fps is not None:
52
+ actual_fps = 1 / dt_val_s
53
+ if actual_fps < fps - 1:
54
+ info_str = colored(info_str, "yellow")
55
+ log_items.append(info_str)
56
+
57
+ # total step time displayed in milliseconds and its frequency
58
+ log_dt("dt", dt_s)
59
+
60
+ # TODO(aliberts): move robot-specific logs logic in robot.print_logs()
61
+ if not robot.robot_type.startswith("stretch"):
62
+ for name in robot.leader_arms:
63
+ key = f"read_leader_{name}_pos_dt_s"
64
+ if key in robot.logs:
65
+ log_dt("dtRlead", robot.logs[key])
66
+
67
+ for name in robot.follower_arms:
68
+ key = f"write_follower_{name}_goal_pos_dt_s"
69
+ if key in robot.logs:
70
+ log_dt("dtWfoll", robot.logs[key])
71
+
72
+ key = f"read_follower_{name}_pos_dt_s"
73
+ if key in robot.logs:
74
+ log_dt("dtRfoll", robot.logs[key])
75
+
76
+ for name in robot.cameras:
77
+ key = f"read_camera_{name}_dt_s"
78
+ if key in robot.logs:
79
+ log_dt(f"dtR{name}", robot.logs[key])
80
+
81
+ info_str = " ".join(log_items)
82
+ logging.info(info_str)
83
+
84
+
85
+ @cache
86
+ def is_headless():
87
+ """Detects if python is running without a monitor."""
88
+ try:
89
+ import pynput # noqa
90
+
91
+ return False
92
+ except Exception:
93
+ print(
94
+ "Error trying to import pynput. Switching to headless mode. "
95
+ "As a result, the video stream from the cameras won't be shown, "
96
+ "and you won't be able to change the control flow with keyboards. "
97
+ "For more info, see traceback below.\n"
98
+ )
99
+ traceback.print_exc()
100
+ print()
101
+ return True
102
+
103
+
104
+ def predict_action(observation, policy, device, use_amp):
105
+ observation = copy(observation)
106
+ with (
107
+ torch.inference_mode(),
108
+ torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
109
+ ):
110
+ # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
111
+ for name in observation:
112
+ if "image" in name:
113
+ observation[name] = observation[name].type(torch.float32) / 255
114
+ observation[name] = observation[name].permute(2, 0, 1).contiguous()
115
+ observation[name] = observation[name].unsqueeze(0)
116
+ observation[name] = observation[name].to(device)
117
+
118
+ # Compute the next action with the policy
119
+ # based on the current observation
120
+ action = policy.select_action(observation)
121
+
122
+ # Remove batch dimension
123
+ action = action.squeeze(0)
124
+
125
+ # Move to cpu, if not already the case
126
+ action = action.to("cpu")
127
+
128
+ return action
129
+
130
+
131
+ def init_keyboard_listener():
132
+ # Allow to exit early while recording an episode or resetting the environment,
133
+ # by tapping the right arrow key '->'. This might require a sudo permission
134
+ # to allow your terminal to monitor keyboard events.
135
+ events = {}
136
+ events["exit_early"] = False
137
+ events["rerecord_episode"] = False
138
+ events["stop_recording"] = False
139
+
140
+ if is_headless():
141
+ logging.warning(
142
+ "Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
143
+ )
144
+ listener = None
145
+ return listener, events
146
+
147
+ # Only import pynput if not in a headless environment
148
+ from pynput import keyboard
149
+
150
+ def on_press(key):
151
+ try:
152
+ if key == keyboard.Key.right:
153
+ print("Right arrow key pressed. Exiting loop...")
154
+ events["exit_early"] = True
155
+ elif key == keyboard.Key.left:
156
+ print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
157
+ events["rerecord_episode"] = True
158
+ events["exit_early"] = True
159
+ elif key == keyboard.Key.esc:
160
+ print("Escape key pressed. Stopping data recording...")
161
+ events["stop_recording"] = True
162
+ events["exit_early"] = True
163
+ except Exception as e:
164
+ print(f"Error handling key press: {e}")
165
+
166
+ listener = keyboard.Listener(on_press=on_press)
167
+ listener.start()
168
+
169
+ return listener, events
170
+
171
+
172
+ def warmup_record(
173
+ robot,
174
+ events,
175
+ enable_teleoperation,
176
+ warmup_time_s,
177
+ display_cameras,
178
+ fps,
179
+ ):
180
+ control_loop(
181
+ robot=robot,
182
+ control_time_s=warmup_time_s,
183
+ display_cameras=display_cameras,
184
+ events=events,
185
+ fps=fps,
186
+ teleoperate=enable_teleoperation,
187
+ )
188
+
189
+
190
+ def record_episode(
191
+ robot,
192
+ dataset,
193
+ events,
194
+ episode_time_s,
195
+ display_cameras,
196
+ policy,
197
+ fps,
198
+ single_task,
199
+ ):
200
+ control_loop(
201
+ robot=robot,
202
+ control_time_s=episode_time_s,
203
+ display_cameras=display_cameras,
204
+ dataset=dataset,
205
+ events=events,
206
+ policy=policy,
207
+ fps=fps,
208
+ teleoperate=policy is None,
209
+ single_task=single_task,
210
+ )
211
+
212
+
213
+ @safe_stop_image_writer
214
+ def control_loop(
215
+ robot,
216
+ control_time_s=None,
217
+ teleoperate=False,
218
+ display_cameras=False,
219
+ dataset: LeRobotDataset | None = None,
220
+ events=None,
221
+ policy: PreTrainedPolicy = None,
222
+ fps: int | None = None,
223
+ single_task: str | None = None,
224
+ ):
225
+ # TODO(rcadene): Add option to record logs
226
+ if not robot.is_connected:
227
+ robot.connect()
228
+
229
+ if events is None:
230
+ events = {"exit_early": False}
231
+
232
+ if control_time_s is None:
233
+ control_time_s = float("inf")
234
+
235
+ if teleoperate and policy is not None:
236
+ raise ValueError("When `teleoperate` is True, `policy` should be None.")
237
+
238
+ if dataset is not None and single_task is None:
239
+ raise ValueError("You need to provide a task as argument in `single_task`.")
240
+
241
+ if dataset is not None and fps is not None and dataset.fps != fps:
242
+ raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
243
+
244
+ timestamp = 0
245
+ start_episode_t = time.perf_counter()
246
+ while timestamp < control_time_s:
247
+ start_loop_t = time.perf_counter()
248
+
249
+ if teleoperate:
250
+ observation, action = robot.teleop_step(record_data=True)
251
+ else:
252
+ observation = robot.capture_observation()
253
+
254
+ if policy is not None:
255
+ pred_action = predict_action(
256
+ observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
257
+ )
258
+ # Action can eventually be clipped using `max_relative_target`,
259
+ # so action actually sent is saved in the dataset.
260
+ action = robot.send_action(pred_action)
261
+ action = {"action": action}
262
+
263
+ if dataset is not None:
264
+ frame = {**observation, **action, "task": single_task}
265
+ dataset.add_frame(frame)
266
+
267
+ if display_cameras and not is_headless():
268
+ image_keys = [key for key in observation if "image" in key]
269
+ for key in image_keys:
270
+ cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
271
+ cv2.waitKey(1)
272
+
273
+ if fps is not None:
274
+ dt_s = time.perf_counter() - start_loop_t
275
+ busy_wait(1 / fps - dt_s)
276
+
277
+ dt_s = time.perf_counter() - start_loop_t
278
+ log_control_info(robot, dt_s, fps=fps)
279
+
280
+ timestamp = time.perf_counter() - start_episode_t
281
+ if events["exit_early"]:
282
+ events["exit_early"] = False
283
+ break
284
+
285
+
286
+ def reset_environment(robot, events, reset_time_s, fps):
287
+ # TODO(rcadene): refactor warmup_record and reset_environment
288
+ if has_method(robot, "teleop_safety_stop"):
289
+ robot.teleop_safety_stop()
290
+
291
+ control_loop(
292
+ robot=robot,
293
+ control_time_s=reset_time_s,
294
+ events=events,
295
+ fps=fps,
296
+ teleoperate=True,
297
+ )
298
+
299
+
300
+ def stop_recording(robot, listener, display_cameras):
301
+ robot.disconnect()
302
+
303
+ if not is_headless():
304
+ if listener is not None:
305
+ listener.stop()
306
+
307
+ if display_cameras:
308
+ cv2.destroyAllWindows()
309
+
310
+
311
+ def sanity_check_dataset_name(repo_id, policy_cfg):
312
+ _, dataset_name = repo_id.split("/")
313
+ # either repo_id doesnt start with "eval_" and there is no policy
314
+ # or repo_id starts with "eval_" and there is a policy
315
+
316
+ # Check if dataset_name starts with "eval_" but policy is missing
317
+ if dataset_name.startswith("eval_") and policy_cfg is None:
318
+ raise ValueError(
319
+ f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
320
+ )
321
+
322
+ # Check if dataset_name does not start with "eval_" but policy is provided
323
+ if not dataset_name.startswith("eval_") and policy_cfg is not None:
324
+ raise ValueError(
325
+ f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy_cfg.type})."
326
+ )
327
+
328
+
329
+ def sanity_check_dataset_robot_compatibility(
330
+ dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
331
+ ) -> None:
332
+ fields = [
333
+ ("robot_type", dataset.meta.robot_type, robot.robot_type),
334
+ ("fps", dataset.fps, fps),
335
+ ("features", dataset.features, get_features_from_robot(robot, use_videos)),
336
+ ]
337
+
338
+ mismatches = []
339
+ for field, dataset_value, present_value in fields:
340
+ diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
341
+ if diff:
342
+ mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
343
+
344
+ if mismatches:
345
+ raise ValueError(
346
+ "Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
347
+ )
lerobot/common/robot_devices/motors/configs.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import abc
16
+ from dataclasses import dataclass
17
+
18
+ import draccus
19
+
20
+
21
+ @dataclass
22
+ class MotorsBusConfig(draccus.ChoiceRegistry, abc.ABC):
23
+ @property
24
+ def type(self) -> str:
25
+ return self.get_choice_name(self.__class__)
26
+
27
+
28
+ @MotorsBusConfig.register_subclass("dynamixel")
29
+ @dataclass
30
+ class DynamixelMotorsBusConfig(MotorsBusConfig):
31
+ port: str
32
+ motors: dict[str, tuple[int, str]]
33
+ mock: bool = False
34
+
35
+
36
+ @MotorsBusConfig.register_subclass("feetech")
37
+ @dataclass
38
+ class FeetechMotorsBusConfig(MotorsBusConfig):
39
+ port: str
40
+ motors: dict[str, tuple[int, str]]
41
+ mock: bool = False
lerobot/common/robot_devices/motors/dynamixel.py ADDED
@@ -0,0 +1,873 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import enum
16
+ import logging
17
+ import math
18
+ import time
19
+ import traceback
20
+ from copy import deepcopy
21
+
22
+ import numpy as np
23
+ import tqdm
24
+
25
+ from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
26
+ from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
27
+ from lerobot.common.utils.utils import capture_timestamp_utc
28
+
29
+ PROTOCOL_VERSION = 2.0
30
+ BAUDRATE = 1_000_000
31
+ TIMEOUT_MS = 1000
32
+
33
+ MAX_ID_RANGE = 252
34
+
35
+ # The following bounds define the lower and upper joints range (after calibration).
36
+ # For joints in degree (i.e. revolute joints), their nominal range is [-180, 180] degrees
37
+ # which corresponds to a half rotation on the left and half rotation on the right.
38
+ # Some joints might require higher range, so we allow up to [-270, 270] degrees until
39
+ # an error is raised.
40
+ LOWER_BOUND_DEGREE = -270
41
+ UPPER_BOUND_DEGREE = 270
42
+ # For joints in percentage (i.e. joints that move linearly like the prismatic joint of a gripper),
43
+ # their nominal range is [0, 100] %. For instance, for Aloha gripper, 0% is fully
44
+ # closed, and 100% is fully open. To account for slight calibration issue, we allow up to
45
+ # [-10, 110] until an error is raised.
46
+ LOWER_BOUND_LINEAR = -10
47
+ UPPER_BOUND_LINEAR = 110
48
+
49
+ HALF_TURN_DEGREE = 180
50
+
51
+ # https://emanual.robotis.com/docs/en/dxl/x/xl330-m077
52
+ # https://emanual.robotis.com/docs/en/dxl/x/xl330-m288
53
+ # https://emanual.robotis.com/docs/en/dxl/x/xl430-w250
54
+ # https://emanual.robotis.com/docs/en/dxl/x/xm430-w350
55
+ # https://emanual.robotis.com/docs/en/dxl/x/xm540-w270
56
+ # https://emanual.robotis.com/docs/en/dxl/x/xc430-w150
57
+
58
+ # data_name: (address, size_byte)
59
+ X_SERIES_CONTROL_TABLE = {
60
+ "Model_Number": (0, 2),
61
+ "Model_Information": (2, 4),
62
+ "Firmware_Version": (6, 1),
63
+ "ID": (7, 1),
64
+ "Baud_Rate": (8, 1),
65
+ "Return_Delay_Time": (9, 1),
66
+ "Drive_Mode": (10, 1),
67
+ "Operating_Mode": (11, 1),
68
+ "Secondary_ID": (12, 1),
69
+ "Protocol_Type": (13, 1),
70
+ "Homing_Offset": (20, 4),
71
+ "Moving_Threshold": (24, 4),
72
+ "Temperature_Limit": (31, 1),
73
+ "Max_Voltage_Limit": (32, 2),
74
+ "Min_Voltage_Limit": (34, 2),
75
+ "PWM_Limit": (36, 2),
76
+ "Current_Limit": (38, 2),
77
+ "Acceleration_Limit": (40, 4),
78
+ "Velocity_Limit": (44, 4),
79
+ "Max_Position_Limit": (48, 4),
80
+ "Min_Position_Limit": (52, 4),
81
+ "Shutdown": (63, 1),
82
+ "Torque_Enable": (64, 1),
83
+ "LED": (65, 1),
84
+ "Status_Return_Level": (68, 1),
85
+ "Registered_Instruction": (69, 1),
86
+ "Hardware_Error_Status": (70, 1),
87
+ "Velocity_I_Gain": (76, 2),
88
+ "Velocity_P_Gain": (78, 2),
89
+ "Position_D_Gain": (80, 2),
90
+ "Position_I_Gain": (82, 2),
91
+ "Position_P_Gain": (84, 2),
92
+ "Feedforward_2nd_Gain": (88, 2),
93
+ "Feedforward_1st_Gain": (90, 2),
94
+ "Bus_Watchdog": (98, 1),
95
+ "Goal_PWM": (100, 2),
96
+ "Goal_Current": (102, 2),
97
+ "Goal_Velocity": (104, 4),
98
+ "Profile_Acceleration": (108, 4),
99
+ "Profile_Velocity": (112, 4),
100
+ "Goal_Position": (116, 4),
101
+ "Realtime_Tick": (120, 2),
102
+ "Moving": (122, 1),
103
+ "Moving_Status": (123, 1),
104
+ "Present_PWM": (124, 2),
105
+ "Present_Current": (126, 2),
106
+ "Present_Velocity": (128, 4),
107
+ "Present_Position": (132, 4),
108
+ "Velocity_Trajectory": (136, 4),
109
+ "Position_Trajectory": (140, 4),
110
+ "Present_Input_Voltage": (144, 2),
111
+ "Present_Temperature": (146, 1),
112
+ }
113
+
114
+ X_SERIES_BAUDRATE_TABLE = {
115
+ 0: 9_600,
116
+ 1: 57_600,
117
+ 2: 115_200,
118
+ 3: 1_000_000,
119
+ 4: 2_000_000,
120
+ 5: 3_000_000,
121
+ 6: 4_000_000,
122
+ }
123
+
124
+ CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"]
125
+ CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
126
+
127
+ MODEL_CONTROL_TABLE = {
128
+ "x_series": X_SERIES_CONTROL_TABLE,
129
+ "xl330-m077": X_SERIES_CONTROL_TABLE,
130
+ "xl330-m288": X_SERIES_CONTROL_TABLE,
131
+ "xl430-w250": X_SERIES_CONTROL_TABLE,
132
+ "xm430-w350": X_SERIES_CONTROL_TABLE,
133
+ "xm540-w270": X_SERIES_CONTROL_TABLE,
134
+ "xc430-w150": X_SERIES_CONTROL_TABLE,
135
+ }
136
+
137
+ MODEL_RESOLUTION = {
138
+ "x_series": 4096,
139
+ "xl330-m077": 4096,
140
+ "xl330-m288": 4096,
141
+ "xl430-w250": 4096,
142
+ "xm430-w350": 4096,
143
+ "xm540-w270": 4096,
144
+ "xc430-w150": 4096,
145
+ }
146
+
147
+ MODEL_BAUDRATE_TABLE = {
148
+ "x_series": X_SERIES_BAUDRATE_TABLE,
149
+ "xl330-m077": X_SERIES_BAUDRATE_TABLE,
150
+ "xl330-m288": X_SERIES_BAUDRATE_TABLE,
151
+ "xl430-w250": X_SERIES_BAUDRATE_TABLE,
152
+ "xm430-w350": X_SERIES_BAUDRATE_TABLE,
153
+ "xm540-w270": X_SERIES_BAUDRATE_TABLE,
154
+ "xc430-w150": X_SERIES_BAUDRATE_TABLE,
155
+ }
156
+
157
+ NUM_READ_RETRY = 10
158
+ NUM_WRITE_RETRY = 10
159
+
160
+
161
+ def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
162
+ """This function converts the degree range to the step range for indicating motors rotation.
163
+ It assumes a motor achieves a full rotation by going from -180 degree position to +180.
164
+ The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
165
+ """
166
+ resolutions = [MODEL_RESOLUTION[model] for model in models]
167
+ steps = degrees / 180 * np.array(resolutions) / 2
168
+ steps = steps.astype(int)
169
+ return steps
170
+
171
+
172
+ def convert_to_bytes(value, bytes, mock=False):
173
+ if mock:
174
+ return value
175
+
176
+ import dynamixel_sdk as dxl
177
+
178
+ # Note: No need to convert back into unsigned int, since this byte preprocessing
179
+ # already handles it for us.
180
+ if bytes == 1:
181
+ data = [
182
+ dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
183
+ ]
184
+ elif bytes == 2:
185
+ data = [
186
+ dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
187
+ dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
188
+ ]
189
+ elif bytes == 4:
190
+ data = [
191
+ dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
192
+ dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
193
+ dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)),
194
+ dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)),
195
+ ]
196
+ else:
197
+ raise NotImplementedError(
198
+ f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
199
+ f"{bytes} is provided instead."
200
+ )
201
+ return data
202
+
203
+
204
+ def get_group_sync_key(data_name, motor_names):
205
+ group_key = f"{data_name}_" + "_".join(motor_names)
206
+ return group_key
207
+
208
+
209
+ def get_result_name(fn_name, data_name, motor_names):
210
+ group_key = get_group_sync_key(data_name, motor_names)
211
+ rslt_name = f"{fn_name}_{group_key}"
212
+ return rslt_name
213
+
214
+
215
+ def get_queue_name(fn_name, data_name, motor_names):
216
+ group_key = get_group_sync_key(data_name, motor_names)
217
+ queue_name = f"{fn_name}_{group_key}"
218
+ return queue_name
219
+
220
+
221
+ def get_log_name(var_name, fn_name, data_name, motor_names):
222
+ group_key = get_group_sync_key(data_name, motor_names)
223
+ log_name = f"{var_name}_{fn_name}_{group_key}"
224
+ return log_name
225
+
226
+
227
+ def assert_same_address(model_ctrl_table, motor_models, data_name):
228
+ all_addr = []
229
+ all_bytes = []
230
+ for model in motor_models:
231
+ addr, bytes = model_ctrl_table[model][data_name]
232
+ all_addr.append(addr)
233
+ all_bytes.append(bytes)
234
+
235
+ if len(set(all_addr)) != 1:
236
+ raise NotImplementedError(
237
+ f"At least two motor models use a different address for `data_name`='{data_name}' ({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer."
238
+ )
239
+
240
+ if len(set(all_bytes)) != 1:
241
+ raise NotImplementedError(
242
+ f"At least two motor models use a different bytes representation for `data_name`='{data_name}' ({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer."
243
+ )
244
+
245
+
246
+ class TorqueMode(enum.Enum):
247
+ ENABLED = 1
248
+ DISABLED = 0
249
+
250
+
251
+ class DriveMode(enum.Enum):
252
+ NON_INVERTED = 0
253
+ INVERTED = 1
254
+
255
+
256
+ class CalibrationMode(enum.Enum):
257
+ # Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
258
+ DEGREE = 0
259
+ # Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
260
+ LINEAR = 1
261
+
262
+
263
+ class JointOutOfRangeError(Exception):
264
+ def __init__(self, message="Joint is out of range"):
265
+ self.message = message
266
+ super().__init__(self.message)
267
+
268
+
269
+ class DynamixelMotorsBus:
270
+ """
271
+ The DynamixelMotorsBus class allows to efficiently read and write to the attached motors. It relies on
272
+ the python dynamixel sdk to communicate with the motors. For more info, see the [Dynamixel SDK Documentation](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20).
273
+
274
+ A DynamixelMotorsBus instance requires a port (e.g. `DynamixelMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
275
+ To find the port, you can run our utility script:
276
+ ```bash
277
+ python lerobot/scripts/find_motors_bus_port.py
278
+ >>> Finding all available ports for the MotorBus.
279
+ >>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
280
+ >>> Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
281
+ >>> The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751.
282
+ >>> Reconnect the usb cable.
283
+ ```
284
+
285
+ Example of usage for 1 motor connected to the bus:
286
+ ```python
287
+ motor_name = "gripper"
288
+ motor_index = 6
289
+ motor_model = "xl330-m288"
290
+
291
+ config = DynamixelMotorsBusConfig(
292
+ port="/dev/tty.usbmodem575E0031751",
293
+ motors={motor_name: (motor_index, motor_model)},
294
+ )
295
+ motors_bus = DynamixelMotorsBus(config)
296
+ motors_bus.connect()
297
+
298
+ position = motors_bus.read("Present_Position")
299
+
300
+ # move from a few motor steps as an example
301
+ few_steps = 30
302
+ motors_bus.write("Goal_Position", position + few_steps)
303
+
304
+ # when done, consider disconnecting
305
+ motors_bus.disconnect()
306
+ ```
307
+ """
308
+
309
+ def __init__(
310
+ self,
311
+ config: DynamixelMotorsBusConfig,
312
+ ):
313
+ self.port = config.port
314
+ self.motors = config.motors
315
+ self.mock = config.mock
316
+
317
+ self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
318
+ self.model_resolution = deepcopy(MODEL_RESOLUTION)
319
+
320
+ self.port_handler = None
321
+ self.packet_handler = None
322
+ self.calibration = None
323
+ self.is_connected = False
324
+ self.group_readers = {}
325
+ self.group_writers = {}
326
+ self.logs = {}
327
+
328
+ def connect(self):
329
+ if self.is_connected:
330
+ raise RobotDeviceAlreadyConnectedError(
331
+ f"DynamixelMotorsBus({self.port}) is already connected. Do not call `motors_bus.connect()` twice."
332
+ )
333
+
334
+ if self.mock:
335
+ import tests.motors.mock_dynamixel_sdk as dxl
336
+ else:
337
+ import dynamixel_sdk as dxl
338
+
339
+ self.port_handler = dxl.PortHandler(self.port)
340
+ self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION)
341
+
342
+ try:
343
+ if not self.port_handler.openPort():
344
+ raise OSError(f"Failed to open port '{self.port}'.")
345
+ except Exception:
346
+ traceback.print_exc()
347
+ print(
348
+ "\nTry running `python lerobot/scripts/find_motors_bus_port.py` to make sure you are using the correct port.\n"
349
+ )
350
+ raise
351
+
352
+ # Allow to read and write
353
+ self.is_connected = True
354
+
355
+ self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
356
+
357
+ def reconnect(self):
358
+ if self.mock:
359
+ import tests.motors.mock_dynamixel_sdk as dxl
360
+ else:
361
+ import dynamixel_sdk as dxl
362
+
363
+ self.port_handler = dxl.PortHandler(self.port)
364
+ self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION)
365
+
366
+ if not self.port_handler.openPort():
367
+ raise OSError(f"Failed to open port '{self.port}'.")
368
+
369
+ self.is_connected = True
370
+
371
+ def are_motors_configured(self):
372
+ # Only check the motor indices and not baudrate, since if the motor baudrates are incorrect,
373
+ # a ConnectionError will be raised anyway.
374
+ try:
375
+ return (self.motor_indices == self.read("ID")).all()
376
+ except ConnectionError as e:
377
+ print(e)
378
+ return False
379
+
380
+ def find_motor_indices(self, possible_ids=None, num_retry=2):
381
+ if possible_ids is None:
382
+ possible_ids = range(MAX_ID_RANGE)
383
+
384
+ indices = []
385
+ for idx in tqdm.tqdm(possible_ids):
386
+ try:
387
+ present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
388
+ except ConnectionError:
389
+ continue
390
+
391
+ if idx != present_idx:
392
+ # sanity check
393
+ raise OSError(
394
+ "Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged."
395
+ )
396
+ indices.append(idx)
397
+
398
+ return indices
399
+
400
+ def set_bus_baudrate(self, baudrate):
401
+ present_bus_baudrate = self.port_handler.getBaudRate()
402
+ if present_bus_baudrate != baudrate:
403
+ print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
404
+ self.port_handler.setBaudRate(baudrate)
405
+
406
+ if self.port_handler.getBaudRate() != baudrate:
407
+ raise OSError("Failed to write bus baud rate.")
408
+
409
+ @property
410
+ def motor_names(self) -> list[str]:
411
+ return list(self.motors.keys())
412
+
413
+ @property
414
+ def motor_models(self) -> list[str]:
415
+ return [model for _, model in self.motors.values()]
416
+
417
+ @property
418
+ def motor_indices(self) -> list[int]:
419
+ return [idx for idx, _ in self.motors.values()]
420
+
421
+ def set_calibration(self, calibration: dict[str, list]):
422
+ self.calibration = calibration
423
+
424
+ def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
425
+ """This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct.
426
+
427
+ For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
428
+ """
429
+ try:
430
+ values = self.apply_calibration(values, motor_names)
431
+ except JointOutOfRangeError as e:
432
+ print(e)
433
+ self.autocorrect_calibration(values, motor_names)
434
+ values = self.apply_calibration(values, motor_names)
435
+ return values
436
+
437
+ def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
438
+ """Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
439
+ a "zero position" at 0 degree.
440
+
441
+ Note: We say "nominal degree range" since the motors can take values outside this range. For instance, 190 degrees, if the motor
442
+ rotate more than a half a turn from the zero position. However, most motors can't rotate more than 180 degrees and will stay in this range.
443
+
444
+ Joints values are original in [0, 2**32[ (unsigned int32). Each motor are expected to complete a full rotation
445
+ when given a goal position that is + or - their resolution. For instance, dynamixel xl330-m077 have a resolution of 4096, and
446
+ at any position in their original range, let's say the position 56734, they complete a full rotation clockwise by moving to 60830,
447
+ or anticlockwise by moving to 52638. The position in the original range is arbitrary and might change a lot between each motor.
448
+ To harmonize between motors of the same model, different robots, or even models of different brands, we propose to work
449
+ in the centered nominal degree range ]-180, 180[.
450
+ """
451
+ if motor_names is None:
452
+ motor_names = self.motor_names
453
+
454
+ # Convert from unsigned int32 original range [0, 2**32] to signed float32 range
455
+ values = values.astype(np.float32)
456
+
457
+ for i, name in enumerate(motor_names):
458
+ calib_idx = self.calibration["motor_names"].index(name)
459
+ calib_mode = self.calibration["calib_mode"][calib_idx]
460
+
461
+ if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
462
+ drive_mode = self.calibration["drive_mode"][calib_idx]
463
+ homing_offset = self.calibration["homing_offset"][calib_idx]
464
+ _, model = self.motors[name]
465
+ resolution = self.model_resolution[model]
466
+
467
+ # Update direction of rotation of the motor to match between leader and follower.
468
+ # In fact, the motor of the leader for a given joint can be assembled in an
469
+ # opposite direction in term of rotation than the motor of the follower on the same joint.
470
+ if drive_mode:
471
+ values[i] *= -1
472
+
473
+ # Convert from range [-2**31, 2**31] to
474
+ # nominal range [-resolution//2, resolution//2] (e.g. [-2048, 2048])
475
+ values[i] += homing_offset
476
+
477
+ # Convert from range [-resolution//2, resolution//2] to
478
+ # universal float32 centered degree range [-180, 180]
479
+ # (e.g. 2048 / (4096 // 2) * 180 = 180)
480
+ values[i] = values[i] / (resolution // 2) * HALF_TURN_DEGREE
481
+
482
+ if (values[i] < LOWER_BOUND_DEGREE) or (values[i] > UPPER_BOUND_DEGREE):
483
+ raise JointOutOfRangeError(
484
+ f"Wrong motor position range detected for {name}. "
485
+ f"Expected to be in nominal range of [-{HALF_TURN_DEGREE}, {HALF_TURN_DEGREE}] degrees (a full rotation), "
486
+ f"with a maximum range of [{LOWER_BOUND_DEGREE}, {UPPER_BOUND_DEGREE}] degrees to account for joints that can rotate a bit more, "
487
+ f"but present value is {values[i]} degree. "
488
+ "This might be due to a cable connection issue creating an artificial 360 degrees jump in motor values. "
489
+ "You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
490
+ )
491
+
492
+ elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
493
+ start_pos = self.calibration["start_pos"][calib_idx]
494
+ end_pos = self.calibration["end_pos"][calib_idx]
495
+
496
+ # Rescale the present position to a nominal range [0, 100] %,
497
+ # useful for joints with linear motions like Aloha gripper
498
+ values[i] = (values[i] - start_pos) / (end_pos - start_pos) * 100
499
+
500
+ if (values[i] < LOWER_BOUND_LINEAR) or (values[i] > UPPER_BOUND_LINEAR):
501
+ raise JointOutOfRangeError(
502
+ f"Wrong motor position range detected for {name}. "
503
+ f"Expected to be in nominal range of [0, 100] % (a full linear translation), "
504
+ f"with a maximum range of [{LOWER_BOUND_LINEAR}, {UPPER_BOUND_LINEAR}] % to account for some imprecision during calibration, "
505
+ f"but present value is {values[i]} %. "
506
+ "This might be due to a cable connection issue creating an artificial jump in motor values. "
507
+ "You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
508
+ )
509
+
510
+ return values
511
+
512
+ def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
513
+ """This function automatically detects issues with values of motors after calibration, and correct for these issues.
514
+
515
+ Some motors might have values outside of expected maximum bounds after calibration.
516
+ For instance, for a joint in degree, its value can be outside [-270, 270] degrees, which is totally unexpected given
517
+ a nominal range of [-180, 180] degrees, which represents half a turn to the left or right starting from zero position.
518
+
519
+ Known issues:
520
+ #1: Motor value randomly shifts of a full turn, caused by hardware/connection errors.
521
+ #2: Motor internal homing offset is shifted by a full turn, caused by using default calibration (e.g Aloha).
522
+ #3: motor internal homing offset is shifted by less or more than a full turn, caused by using default calibration
523
+ or by human error during manual calibration.
524
+
525
+ Issues #1 and #2 can be solved by shifting the calibration homing offset by a full turn.
526
+ Issue #3 will be visually detected by user and potentially captured by the safety feature `max_relative_target`,
527
+ that will slow down the motor, raise an error asking to recalibrate. Manual recalibrating will solve the issue.
528
+
529
+ Note: A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
530
+ """
531
+ if motor_names is None:
532
+ motor_names = self.motor_names
533
+
534
+ # Convert from unsigned int32 original range [0, 2**32] to signed float32 range
535
+ values = values.astype(np.float32)
536
+
537
+ for i, name in enumerate(motor_names):
538
+ calib_idx = self.calibration["motor_names"].index(name)
539
+ calib_mode = self.calibration["calib_mode"][calib_idx]
540
+
541
+ if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
542
+ drive_mode = self.calibration["drive_mode"][calib_idx]
543
+ homing_offset = self.calibration["homing_offset"][calib_idx]
544
+ _, model = self.motors[name]
545
+ resolution = self.model_resolution[model]
546
+
547
+ # Update direction of rotation of the motor to match between leader and follower.
548
+ # In fact, the motor of the leader for a given joint can be assembled in an
549
+ # opposite direction in term of rotation than the motor of the follower on the same joint.
550
+ if drive_mode:
551
+ values[i] *= -1
552
+
553
+ # Convert from initial range to range [-180, 180] degrees
554
+ calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
555
+ in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
556
+
557
+ # Solve this inequality to find the factor to shift the range into [-180, 180] degrees
558
+ # values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
559
+ # - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
560
+ # (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution
561
+ low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution
562
+ upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution
563
+
564
+ elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
565
+ start_pos = self.calibration["start_pos"][calib_idx]
566
+ end_pos = self.calibration["end_pos"][calib_idx]
567
+
568
+ # Convert from initial range to range [0, 100] in %
569
+ calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
570
+ in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
571
+
572
+ # Solve this inequality to find the factor to shift the range into [0, 100] %
573
+ # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
574
+ # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100
575
+ # 0 <= (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100 <= 100
576
+ # (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
577
+ low_factor = (start_pos - values[i]) / resolution
578
+ upp_factor = (end_pos - values[i]) / resolution
579
+
580
+ if not in_range:
581
+ # Get first integer between the two bounds
582
+ if low_factor < upp_factor:
583
+ factor = math.ceil(low_factor)
584
+
585
+ if factor > upp_factor:
586
+ raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
587
+ else:
588
+ factor = math.ceil(upp_factor)
589
+
590
+ if factor > low_factor:
591
+ raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
592
+
593
+ if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
594
+ out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
595
+ in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
596
+ elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
597
+ out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
598
+ in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
599
+
600
+ logging.warning(
601
+ f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
602
+ f"from '{out_of_range_str}' to '{in_range_str}'."
603
+ )
604
+
605
+ # A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
606
+ self.calibration["homing_offset"][calib_idx] += resolution * factor
607
+
608
+ def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
609
+ """Inverse of `apply_calibration`."""
610
+ if motor_names is None:
611
+ motor_names = self.motor_names
612
+
613
+ for i, name in enumerate(motor_names):
614
+ calib_idx = self.calibration["motor_names"].index(name)
615
+ calib_mode = self.calibration["calib_mode"][calib_idx]
616
+
617
+ if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
618
+ drive_mode = self.calibration["drive_mode"][calib_idx]
619
+ homing_offset = self.calibration["homing_offset"][calib_idx]
620
+ _, model = self.motors[name]
621
+ resolution = self.model_resolution[model]
622
+
623
+ # Convert from nominal 0-centered degree range [-180, 180] to
624
+ # 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
625
+ values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
626
+
627
+ # Subtract the homing offsets to come back to actual motor range of values
628
+ # which can be arbitrary.
629
+ values[i] -= homing_offset
630
+
631
+ # Remove drive mode, which is the rotation direction of the motor, to come back to
632
+ # actual motor rotation direction which can be arbitrary.
633
+ if drive_mode:
634
+ values[i] *= -1
635
+
636
+ elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
637
+ start_pos = self.calibration["start_pos"][calib_idx]
638
+ end_pos = self.calibration["end_pos"][calib_idx]
639
+
640
+ # Convert from nominal lnear range of [0, 100] % to
641
+ # actual motor range of values which can be arbitrary.
642
+ values[i] = values[i] / 100 * (end_pos - start_pos) + start_pos
643
+
644
+ values = np.round(values).astype(np.int32)
645
+ return values
646
+
647
+ def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
648
+ if self.mock:
649
+ import tests.motors.mock_dynamixel_sdk as dxl
650
+ else:
651
+ import dynamixel_sdk as dxl
652
+
653
+ return_list = True
654
+ if not isinstance(motor_ids, list):
655
+ return_list = False
656
+ motor_ids = [motor_ids]
657
+
658
+ assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
659
+ addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
660
+ group = dxl.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
661
+ for idx in motor_ids:
662
+ group.addParam(idx)
663
+
664
+ for _ in range(num_retry):
665
+ comm = group.txRxPacket()
666
+ if comm == dxl.COMM_SUCCESS:
667
+ break
668
+
669
+ if comm != dxl.COMM_SUCCESS:
670
+ raise ConnectionError(
671
+ f"Read failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
672
+ f"{self.packet_handler.getTxRxResult(comm)}"
673
+ )
674
+
675
+ values = []
676
+ for idx in motor_ids:
677
+ value = group.getData(idx, addr, bytes)
678
+ values.append(value)
679
+
680
+ if return_list:
681
+ return values
682
+ else:
683
+ return values[0]
684
+
685
+ def read(self, data_name, motor_names: str | list[str] | None = None):
686
+ if not self.is_connected:
687
+ raise RobotDeviceNotConnectedError(
688
+ f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
689
+ )
690
+
691
+ start_time = time.perf_counter()
692
+
693
+ if self.mock:
694
+ import tests.motors.mock_dynamixel_sdk as dxl
695
+ else:
696
+ import dynamixel_sdk as dxl
697
+
698
+ if motor_names is None:
699
+ motor_names = self.motor_names
700
+
701
+ if isinstance(motor_names, str):
702
+ motor_names = [motor_names]
703
+
704
+ motor_ids = []
705
+ models = []
706
+ for name in motor_names:
707
+ motor_idx, model = self.motors[name]
708
+ motor_ids.append(motor_idx)
709
+ models.append(model)
710
+
711
+ assert_same_address(self.model_ctrl_table, models, data_name)
712
+ addr, bytes = self.model_ctrl_table[model][data_name]
713
+ group_key = get_group_sync_key(data_name, motor_names)
714
+
715
+ if data_name not in self.group_readers:
716
+ # create new group reader
717
+ self.group_readers[group_key] = dxl.GroupSyncRead(
718
+ self.port_handler, self.packet_handler, addr, bytes
719
+ )
720
+ for idx in motor_ids:
721
+ self.group_readers[group_key].addParam(idx)
722
+
723
+ for _ in range(NUM_READ_RETRY):
724
+ comm = self.group_readers[group_key].txRxPacket()
725
+ if comm == dxl.COMM_SUCCESS:
726
+ break
727
+
728
+ if comm != dxl.COMM_SUCCESS:
729
+ raise ConnectionError(
730
+ f"Read failed due to communication error on port {self.port} for group_key {group_key}: "
731
+ f"{self.packet_handler.getTxRxResult(comm)}"
732
+ )
733
+
734
+ values = []
735
+ for idx in motor_ids:
736
+ value = self.group_readers[group_key].getData(idx, addr, bytes)
737
+ values.append(value)
738
+
739
+ values = np.array(values)
740
+
741
+ # Convert to signed int to use range [-2048, 2048] for our motor positions.
742
+ if data_name in CONVERT_UINT32_TO_INT32_REQUIRED:
743
+ values = values.astype(np.int32)
744
+
745
+ if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
746
+ values = self.apply_calibration_autocorrect(values, motor_names)
747
+
748
+ # log the number of seconds it took to read the data from the motors
749
+ delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
750
+ self.logs[delta_ts_name] = time.perf_counter() - start_time
751
+
752
+ # log the utc time at which the data was received
753
+ ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names)
754
+ self.logs[ts_utc_name] = capture_timestamp_utc()
755
+
756
+ return values
757
+
758
+ def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
759
+ if self.mock:
760
+ import tests.motors.mock_dynamixel_sdk as dxl
761
+ else:
762
+ import dynamixel_sdk as dxl
763
+
764
+ if not isinstance(motor_ids, list):
765
+ motor_ids = [motor_ids]
766
+ if not isinstance(values, list):
767
+ values = [values]
768
+
769
+ assert_same_address(self.model_ctrl_table, motor_models, data_name)
770
+ addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
771
+ group = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
772
+ for idx, value in zip(motor_ids, values, strict=True):
773
+ data = convert_to_bytes(value, bytes, self.mock)
774
+ group.addParam(idx, data)
775
+
776
+ for _ in range(num_retry):
777
+ comm = group.txPacket()
778
+ if comm == dxl.COMM_SUCCESS:
779
+ break
780
+
781
+ if comm != dxl.COMM_SUCCESS:
782
+ raise ConnectionError(
783
+ f"Write failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
784
+ f"{self.packet_handler.getTxRxResult(comm)}"
785
+ )
786
+
787
+ def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
788
+ if not self.is_connected:
789
+ raise RobotDeviceNotConnectedError(
790
+ f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
791
+ )
792
+
793
+ start_time = time.perf_counter()
794
+
795
+ if self.mock:
796
+ import tests.motors.mock_dynamixel_sdk as dxl
797
+ else:
798
+ import dynamixel_sdk as dxl
799
+
800
+ if motor_names is None:
801
+ motor_names = self.motor_names
802
+
803
+ if isinstance(motor_names, str):
804
+ motor_names = [motor_names]
805
+
806
+ if isinstance(values, (int, float, np.integer)):
807
+ values = [int(values)] * len(motor_names)
808
+
809
+ values = np.array(values)
810
+
811
+ motor_ids = []
812
+ models = []
813
+ for name in motor_names:
814
+ motor_idx, model = self.motors[name]
815
+ motor_ids.append(motor_idx)
816
+ models.append(model)
817
+
818
+ if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
819
+ values = self.revert_calibration(values, motor_names)
820
+
821
+ values = values.tolist()
822
+
823
+ assert_same_address(self.model_ctrl_table, models, data_name)
824
+ addr, bytes = self.model_ctrl_table[model][data_name]
825
+ group_key = get_group_sync_key(data_name, motor_names)
826
+
827
+ init_group = data_name not in self.group_readers
828
+ if init_group:
829
+ self.group_writers[group_key] = dxl.GroupSyncWrite(
830
+ self.port_handler, self.packet_handler, addr, bytes
831
+ )
832
+
833
+ for idx, value in zip(motor_ids, values, strict=True):
834
+ data = convert_to_bytes(value, bytes, self.mock)
835
+ if init_group:
836
+ self.group_writers[group_key].addParam(idx, data)
837
+ else:
838
+ self.group_writers[group_key].changeParam(idx, data)
839
+
840
+ comm = self.group_writers[group_key].txPacket()
841
+ if comm != dxl.COMM_SUCCESS:
842
+ raise ConnectionError(
843
+ f"Write failed due to communication error on port {self.port} for group_key {group_key}: "
844
+ f"{self.packet_handler.getTxRxResult(comm)}"
845
+ )
846
+
847
+ # log the number of seconds it took to write the data to the motors
848
+ delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
849
+ self.logs[delta_ts_name] = time.perf_counter() - start_time
850
+
851
+ # TODO(rcadene): should we log the time before sending the write command?
852
+ # log the utc time when the write has been completed
853
+ ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names)
854
+ self.logs[ts_utc_name] = capture_timestamp_utc()
855
+
856
+ def disconnect(self):
857
+ if not self.is_connected:
858
+ raise RobotDeviceNotConnectedError(
859
+ f"DynamixelMotorsBus({self.port}) is not connected. Try running `motors_bus.connect()` first."
860
+ )
861
+
862
+ if self.port_handler is not None:
863
+ self.port_handler.closePort()
864
+ self.port_handler = None
865
+
866
+ self.packet_handler = None
867
+ self.group_readers = {}
868
+ self.group_writers = {}
869
+ self.is_connected = False
870
+
871
+ def __del__(self):
872
+ if getattr(self, "is_connected", False):
873
+ self.disconnect()
lerobot/common/robot_devices/motors/feetech.py ADDED
@@ -0,0 +1,898 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import enum
16
+ import logging
17
+ import math
18
+ import time
19
+ import traceback
20
+ from copy import deepcopy
21
+
22
+ import numpy as np
23
+ import tqdm
24
+
25
+ from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig
26
+ from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
27
+ from lerobot.common.utils.utils import capture_timestamp_utc
28
+
29
+ PROTOCOL_VERSION = 0
30
+ BAUDRATE = 1_000_000
31
+ TIMEOUT_MS = 1000
32
+
33
+ MAX_ID_RANGE = 252
34
+
35
+ # The following bounds define the lower and upper joints range (after calibration).
36
+ # For joints in degree (i.e. revolute joints), their nominal range is [-180, 180] degrees
37
+ # which corresponds to a half rotation on the left and half rotation on the right.
38
+ # Some joints might require higher range, so we allow up to [-270, 270] degrees until
39
+ # an error is raised.
40
+ LOWER_BOUND_DEGREE = -270
41
+ UPPER_BOUND_DEGREE = 270
42
+ # For joints in percentage (i.e. joints that move linearly like the prismatic joint of a gripper),
43
+ # their nominal range is [0, 100] %. For instance, for Aloha gripper, 0% is fully
44
+ # closed, and 100% is fully open. To account for slight calibration issue, we allow up to
45
+ # [-10, 110] until an error is raised.
46
+ LOWER_BOUND_LINEAR = -10
47
+ UPPER_BOUND_LINEAR = 110
48
+
49
+ HALF_TURN_DEGREE = 180
50
+
51
+
52
+ # See this link for STS3215 Memory Table:
53
+ # https://docs.google.com/spreadsheets/d/1GVs7W1VS1PqdhA1nW-abeyAHhTUxKUdR/edit?usp=sharing&ouid=116566590112741600240&rtpof=true&sd=true
54
+ # data_name: (address, size_byte)
55
+ SCS_SERIES_CONTROL_TABLE = {
56
+ "Model": (3, 2),
57
+ "ID": (5, 1),
58
+ "Baud_Rate": (6, 1),
59
+ "Return_Delay": (7, 1),
60
+ "Response_Status_Level": (8, 1),
61
+ "Min_Angle_Limit": (9, 2),
62
+ "Max_Angle_Limit": (11, 2),
63
+ "Max_Temperature_Limit": (13, 1),
64
+ "Max_Voltage_Limit": (14, 1),
65
+ "Min_Voltage_Limit": (15, 1),
66
+ "Max_Torque_Limit": (16, 2),
67
+ "Phase": (18, 1),
68
+ "Unloading_Condition": (19, 1),
69
+ "LED_Alarm_Condition": (20, 1),
70
+ "P_Coefficient": (21, 1),
71
+ "D_Coefficient": (22, 1),
72
+ "I_Coefficient": (23, 1),
73
+ "Minimum_Startup_Force": (24, 2),
74
+ "CW_Dead_Zone": (26, 1),
75
+ "CCW_Dead_Zone": (27, 1),
76
+ "Protection_Current": (28, 2),
77
+ "Angular_Resolution": (30, 1),
78
+ "Offset": (31, 2),
79
+ "Mode": (33, 1),
80
+ "Protective_Torque": (34, 1),
81
+ "Protection_Time": (35, 1),
82
+ "Overload_Torque": (36, 1),
83
+ "Speed_closed_loop_P_proportional_coefficient": (37, 1),
84
+ "Over_Current_Protection_Time": (38, 1),
85
+ "Velocity_closed_loop_I_integral_coefficient": (39, 1),
86
+ "Torque_Enable": (40, 1),
87
+ "Acceleration": (41, 1),
88
+ "Goal_Position": (42, 2),
89
+ "Goal_Time": (44, 2),
90
+ "Goal_Speed": (46, 2),
91
+ "Torque_Limit": (48, 2),
92
+ "Lock": (55, 1),
93
+ "Present_Position": (56, 2),
94
+ "Present_Speed": (58, 2),
95
+ "Present_Load": (60, 2),
96
+ "Present_Voltage": (62, 1),
97
+ "Present_Temperature": (63, 1),
98
+ "Status": (65, 1),
99
+ "Moving": (66, 1),
100
+ "Present_Current": (69, 2),
101
+ # Not in the Memory Table
102
+ "Maximum_Acceleration": (85, 2),
103
+ }
104
+
105
+ SCS_SERIES_BAUDRATE_TABLE = {
106
+ 0: 1_000_000,
107
+ 1: 500_000,
108
+ 2: 250_000,
109
+ 3: 128_000,
110
+ 4: 115_200,
111
+ 5: 57_600,
112
+ 6: 38_400,
113
+ 7: 19_200,
114
+ }
115
+
116
+ CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"]
117
+ CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
118
+
119
+
120
+ MODEL_CONTROL_TABLE = {
121
+ "scs_series": SCS_SERIES_CONTROL_TABLE,
122
+ "sts3215": SCS_SERIES_CONTROL_TABLE,
123
+ }
124
+
125
+ MODEL_RESOLUTION = {
126
+ "scs_series": 4096,
127
+ "sts3215": 4096,
128
+ }
129
+
130
+ MODEL_BAUDRATE_TABLE = {
131
+ "scs_series": SCS_SERIES_BAUDRATE_TABLE,
132
+ "sts3215": SCS_SERIES_BAUDRATE_TABLE,
133
+ }
134
+
135
+ # High number of retries is needed for feetech compared to dynamixel motors.
136
+ NUM_READ_RETRY = 20
137
+ NUM_WRITE_RETRY = 20
138
+
139
+
140
+ def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
141
+ """This function converts the degree range to the step range for indicating motors rotation.
142
+ It assumes a motor achieves a full rotation by going from -180 degree position to +180.
143
+ The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
144
+ """
145
+ resolutions = [MODEL_RESOLUTION[model] for model in models]
146
+ steps = degrees / 180 * np.array(resolutions) / 2
147
+ steps = steps.astype(int)
148
+ return steps
149
+
150
+
151
+ def convert_to_bytes(value, bytes, mock=False):
152
+ if mock:
153
+ return value
154
+
155
+ import scservo_sdk as scs
156
+
157
+ # Note: No need to convert back into unsigned int, since this byte preprocessing
158
+ # already handles it for us.
159
+ if bytes == 1:
160
+ data = [
161
+ scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
162
+ ]
163
+ elif bytes == 2:
164
+ data = [
165
+ scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
166
+ scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
167
+ ]
168
+ elif bytes == 4:
169
+ data = [
170
+ scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
171
+ scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
172
+ scs.SCS_LOBYTE(scs.SCS_HIWORD(value)),
173
+ scs.SCS_HIBYTE(scs.SCS_HIWORD(value)),
174
+ ]
175
+ else:
176
+ raise NotImplementedError(
177
+ f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
178
+ f"{bytes} is provided instead."
179
+ )
180
+ return data
181
+
182
+
183
+ def get_group_sync_key(data_name, motor_names):
184
+ group_key = f"{data_name}_" + "_".join(motor_names)
185
+ return group_key
186
+
187
+
188
+ def get_result_name(fn_name, data_name, motor_names):
189
+ group_key = get_group_sync_key(data_name, motor_names)
190
+ rslt_name = f"{fn_name}_{group_key}"
191
+ return rslt_name
192
+
193
+
194
+ def get_queue_name(fn_name, data_name, motor_names):
195
+ group_key = get_group_sync_key(data_name, motor_names)
196
+ queue_name = f"{fn_name}_{group_key}"
197
+ return queue_name
198
+
199
+
200
+ def get_log_name(var_name, fn_name, data_name, motor_names):
201
+ group_key = get_group_sync_key(data_name, motor_names)
202
+ log_name = f"{var_name}_{fn_name}_{group_key}"
203
+ return log_name
204
+
205
+
206
+ def assert_same_address(model_ctrl_table, motor_models, data_name):
207
+ all_addr = []
208
+ all_bytes = []
209
+ for model in motor_models:
210
+ addr, bytes = model_ctrl_table[model][data_name]
211
+ all_addr.append(addr)
212
+ all_bytes.append(bytes)
213
+
214
+ if len(set(all_addr)) != 1:
215
+ raise NotImplementedError(
216
+ f"At least two motor models use a different address for `data_name`='{data_name}' ({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer."
217
+ )
218
+
219
+ if len(set(all_bytes)) != 1:
220
+ raise NotImplementedError(
221
+ f"At least two motor models use a different bytes representation for `data_name`='{data_name}' ({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer."
222
+ )
223
+
224
+
225
+ class TorqueMode(enum.Enum):
226
+ ENABLED = 1
227
+ DISABLED = 0
228
+
229
+
230
+ class DriveMode(enum.Enum):
231
+ NON_INVERTED = 0
232
+ INVERTED = 1
233
+
234
+
235
+ class CalibrationMode(enum.Enum):
236
+ # Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
237
+ DEGREE = 0
238
+ # Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
239
+ LINEAR = 1
240
+
241
+
242
+ class JointOutOfRangeError(Exception):
243
+ def __init__(self, message="Joint is out of range"):
244
+ self.message = message
245
+ super().__init__(self.message)
246
+
247
+
248
+ class FeetechMotorsBus:
249
+ """
250
+ The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on
251
+ the python feetech sdk to communicate with the motors. For more info, see the [feetech SDK Documentation](https://emanual.robotis.com/docs/en/software/feetech/feetech_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20).
252
+
253
+ A FeetechMotorsBus instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
254
+ To find the port, you can run our utility script:
255
+ ```bash
256
+ python lerobot/scripts/find_motors_bus_port.py
257
+ >>> Finding all available ports for the MotorsBus.
258
+ >>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
259
+ >>> Remove the usb cable from your FeetechMotorsBus and press Enter when done.
260
+ >>> The port of this FeetechMotorsBus is /dev/tty.usbmodem575E0031751.
261
+ >>> Reconnect the usb cable.
262
+ ```
263
+
264
+ Example of usage for 1 motor connected to the bus:
265
+ ```python
266
+ motor_name = "gripper"
267
+ motor_index = 6
268
+ motor_model = "sts3215"
269
+
270
+ config = FeetechMotorsBusConfig(
271
+ port="/dev/tty.usbmodem575E0031751",
272
+ motors={motor_name: (motor_index, motor_model)},
273
+ )
274
+ motors_bus = FeetechMotorsBus(config)
275
+ motors_bus.connect()
276
+
277
+ position = motors_bus.read("Present_Position")
278
+
279
+ # move from a few motor steps as an example
280
+ few_steps = 30
281
+ motors_bus.write("Goal_Position", position + few_steps)
282
+
283
+ # when done, consider disconnecting
284
+ motors_bus.disconnect()
285
+ ```
286
+ """
287
+
288
+ def __init__(
289
+ self,
290
+ config: FeetechMotorsBusConfig,
291
+ ):
292
+ self.port = config.port
293
+ self.motors = config.motors
294
+ self.mock = config.mock
295
+
296
+ self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
297
+ self.model_resolution = deepcopy(MODEL_RESOLUTION)
298
+
299
+ self.port_handler = None
300
+ self.packet_handler = None
301
+ self.calibration = None
302
+ self.is_connected = False
303
+ self.group_readers = {}
304
+ self.group_writers = {}
305
+ self.logs = {}
306
+
307
+ self.track_positions = {}
308
+
309
+ def connect(self):
310
+ if self.is_connected:
311
+ raise RobotDeviceAlreadyConnectedError(
312
+ f"FeetechMotorsBus({self.port}) is already connected. Do not call `motors_bus.connect()` twice."
313
+ )
314
+
315
+ if self.mock:
316
+ import tests.motors.mock_scservo_sdk as scs
317
+ else:
318
+ import scservo_sdk as scs
319
+
320
+ self.port_handler = scs.PortHandler(self.port)
321
+ self.packet_handler = scs.PacketHandler(PROTOCOL_VERSION)
322
+
323
+ try:
324
+ if not self.port_handler.openPort():
325
+ raise OSError(f"Failed to open port '{self.port}'.")
326
+ except Exception:
327
+ traceback.print_exc()
328
+ print(
329
+ "\nTry running `python lerobot/scripts/find_motors_bus_port.py` to make sure you are using the correct port.\n"
330
+ )
331
+ raise
332
+
333
+ # Allow to read and write
334
+ self.is_connected = True
335
+
336
+ self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
337
+
338
+ def reconnect(self):
339
+ if self.mock:
340
+ import tests.motors.mock_scservo_sdk as scs
341
+ else:
342
+ import scservo_sdk as scs
343
+
344
+ self.port_handler = scs.PortHandler(self.port)
345
+ self.packet_handler = scs.PacketHandler(PROTOCOL_VERSION)
346
+
347
+ if not self.port_handler.openPort():
348
+ raise OSError(f"Failed to open port '{self.port}'.")
349
+
350
+ self.is_connected = True
351
+
352
+ def are_motors_configured(self):
353
+ # Only check the motor indices and not baudrate, since if the motor baudrates are incorrect,
354
+ # a ConnectionError will be raised anyway.
355
+ try:
356
+ return (self.motor_indices == self.read("ID")).all()
357
+ except ConnectionError as e:
358
+ print(e)
359
+ return False
360
+
361
+ def find_motor_indices(self, possible_ids=None, num_retry=2):
362
+ if possible_ids is None:
363
+ possible_ids = range(MAX_ID_RANGE)
364
+
365
+ indices = []
366
+ for idx in tqdm.tqdm(possible_ids):
367
+ try:
368
+ present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
369
+ except ConnectionError:
370
+ continue
371
+
372
+ if idx != present_idx:
373
+ # sanity check
374
+ raise OSError(
375
+ "Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged."
376
+ )
377
+ indices.append(idx)
378
+
379
+ return indices
380
+
381
+ def set_bus_baudrate(self, baudrate):
382
+ present_bus_baudrate = self.port_handler.getBaudRate()
383
+ if present_bus_baudrate != baudrate:
384
+ print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
385
+ self.port_handler.setBaudRate(baudrate)
386
+
387
+ if self.port_handler.getBaudRate() != baudrate:
388
+ raise OSError("Failed to write bus baud rate.")
389
+
390
+ @property
391
+ def motor_names(self) -> list[str]:
392
+ return list(self.motors.keys())
393
+
394
+ @property
395
+ def motor_models(self) -> list[str]:
396
+ return [model for _, model in self.motors.values()]
397
+
398
+ @property
399
+ def motor_indices(self) -> list[int]:
400
+ return [idx for idx, _ in self.motors.values()]
401
+
402
+ def set_calibration(self, calibration: dict[str, list]):
403
+ self.calibration = calibration
404
+
405
+ def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
406
+ """This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct.
407
+
408
+ For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
409
+ """
410
+ try:
411
+ values = self.apply_calibration(values, motor_names)
412
+ except JointOutOfRangeError as e:
413
+ print(e)
414
+ self.autocorrect_calibration(values, motor_names)
415
+ values = self.apply_calibration(values, motor_names)
416
+ return values
417
+
418
+ def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
419
+ """Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
420
+ a "zero position" at 0 degree.
421
+
422
+ Note: We say "nominal degree range" since the motors can take values outside this range. For instance, 190 degrees, if the motor
423
+ rotate more than a half a turn from the zero position. However, most motors can't rotate more than 180 degrees and will stay in this range.
424
+
425
+ Joints values are original in [0, 2**32[ (unsigned int32). Each motor are expected to complete a full rotation
426
+ when given a goal position that is + or - their resolution. For instance, feetech xl330-m077 have a resolution of 4096, and
427
+ at any position in their original range, let's say the position 56734, they complete a full rotation clockwise by moving to 60830,
428
+ or anticlockwise by moving to 52638. The position in the original range is arbitrary and might change a lot between each motor.
429
+ To harmonize between motors of the same model, different robots, or even models of different brands, we propose to work
430
+ in the centered nominal degree range ]-180, 180[.
431
+ """
432
+ if motor_names is None:
433
+ motor_names = self.motor_names
434
+
435
+ # Convert from unsigned int32 original range [0, 2**32] to signed float32 range
436
+ values = values.astype(np.float32)
437
+
438
+ for i, name in enumerate(motor_names):
439
+ calib_idx = self.calibration["motor_names"].index(name)
440
+ calib_mode = self.calibration["calib_mode"][calib_idx]
441
+
442
+ if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
443
+ drive_mode = self.calibration["drive_mode"][calib_idx]
444
+ homing_offset = self.calibration["homing_offset"][calib_idx]
445
+ _, model = self.motors[name]
446
+ resolution = self.model_resolution[model]
447
+
448
+ # Update direction of rotation of the motor to match between leader and follower.
449
+ # In fact, the motor of the leader for a given joint can be assembled in an
450
+ # opposite direction in term of rotation than the motor of the follower on the same joint.
451
+ if drive_mode:
452
+ values[i] *= -1
453
+
454
+ # Convert from range [-2**31, 2**31[ to
455
+ # nominal range ]-resolution, resolution[ (e.g. ]-2048, 2048[)
456
+ values[i] += homing_offset
457
+
458
+ # Convert from range ]-resolution, resolution[ to
459
+ # universal float32 centered degree range ]-180, 180[
460
+ values[i] = values[i] / (resolution // 2) * HALF_TURN_DEGREE
461
+
462
+ if (values[i] < LOWER_BOUND_DEGREE) or (values[i] > UPPER_BOUND_DEGREE):
463
+ raise JointOutOfRangeError(
464
+ f"Wrong motor position range detected for {name}. "
465
+ f"Expected to be in nominal range of [-{HALF_TURN_DEGREE}, {HALF_TURN_DEGREE}] degrees (a full rotation), "
466
+ f"with a maximum range of [{LOWER_BOUND_DEGREE}, {UPPER_BOUND_DEGREE}] degrees to account for joints that can rotate a bit more, "
467
+ f"but present value is {values[i]} degree. "
468
+ "This might be due to a cable connection issue creating an artificial 360 degrees jump in motor values. "
469
+ "You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
470
+ )
471
+
472
+ elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
473
+ start_pos = self.calibration["start_pos"][calib_idx]
474
+ end_pos = self.calibration["end_pos"][calib_idx]
475
+
476
+ # Rescale the present position to a nominal range [0, 100] %,
477
+ # useful for joints with linear motions like Aloha gripper
478
+ values[i] = (values[i] - start_pos) / (end_pos - start_pos) * 100
479
+
480
+ if (values[i] < LOWER_BOUND_LINEAR) or (values[i] > UPPER_BOUND_LINEAR):
481
+ raise JointOutOfRangeError(
482
+ f"Wrong motor position range detected for {name}. "
483
+ f"Expected to be in nominal range of [0, 100] % (a full linear translation), "
484
+ f"with a maximum range of [{LOWER_BOUND_LINEAR}, {UPPER_BOUND_LINEAR}] % to account for some imprecision during calibration, "
485
+ f"but present value is {values[i]} %. "
486
+ "This might be due to a cable connection issue creating an artificial jump in motor values. "
487
+ "You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
488
+ )
489
+
490
+ return values
491
+
492
+ def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
493
+ """This function automatically detects issues with values of motors after calibration, and correct for these issues.
494
+
495
+ Some motors might have values outside of expected maximum bounds after calibration.
496
+ For instance, for a joint in degree, its value can be outside [-270, 270] degrees, which is totally unexpected given
497
+ a nominal range of [-180, 180] degrees, which represents half a turn to the left or right starting from zero position.
498
+
499
+ Known issues:
500
+ #1: Motor value randomly shifts of a full turn, caused by hardware/connection errors.
501
+ #2: Motor internal homing offset is shifted of a full turn, caused by using default calibration (e.g Aloha).
502
+ #3: motor internal homing offset is shifted of less or more than a full turn, caused by using default calibration
503
+ or by human error during manual calibration.
504
+
505
+ Issues #1 and #2 can be solved by shifting the calibration homing offset by a full turn.
506
+ Issue #3 will be visually detected by user and potentially captured by the safety feature `max_relative_target`,
507
+ that will slow down the motor, raise an error asking to recalibrate. Manual recalibrating will solve the issue.
508
+
509
+ Note: A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
510
+ """
511
+ if motor_names is None:
512
+ motor_names = self.motor_names
513
+
514
+ # Convert from unsigned int32 original range [0, 2**32] to signed float32 range
515
+ values = values.astype(np.float32)
516
+
517
+ for i, name in enumerate(motor_names):
518
+ calib_idx = self.calibration["motor_names"].index(name)
519
+ calib_mode = self.calibration["calib_mode"][calib_idx]
520
+
521
+ if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
522
+ drive_mode = self.calibration["drive_mode"][calib_idx]
523
+ homing_offset = self.calibration["homing_offset"][calib_idx]
524
+ _, model = self.motors[name]
525
+ resolution = self.model_resolution[model]
526
+
527
+ if drive_mode:
528
+ values[i] *= -1
529
+
530
+ # Convert from initial range to range [-180, 180] degrees
531
+ calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
532
+ in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
533
+
534
+ # Solve this inequality to find the factor to shift the range into [-180, 180] degrees
535
+ # values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
536
+ # - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
537
+ # (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution
538
+ low_factor = (
539
+ -HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
540
+ ) / resolution
541
+ upp_factor = (
542
+ HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
543
+ ) / resolution
544
+
545
+ elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
546
+ start_pos = self.calibration["start_pos"][calib_idx]
547
+ end_pos = self.calibration["end_pos"][calib_idx]
548
+
549
+ # Convert from initial range to range [0, 100] in %
550
+ calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
551
+ in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
552
+
553
+ # Solve this inequality to find the factor to shift the range into [0, 100] %
554
+ # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
555
+ # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100
556
+ # 0 <= (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100 <= 100
557
+ # (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
558
+ low_factor = (start_pos - values[i]) / resolution
559
+ upp_factor = (end_pos - values[i]) / resolution
560
+
561
+ if not in_range:
562
+ # Get first integer between the two bounds
563
+ if low_factor < upp_factor:
564
+ factor = math.ceil(low_factor)
565
+
566
+ if factor > upp_factor:
567
+ raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
568
+ else:
569
+ factor = math.ceil(upp_factor)
570
+
571
+ if factor > low_factor:
572
+ raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
573
+
574
+ if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
575
+ out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
576
+ in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
577
+ elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
578
+ out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
579
+ in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
580
+
581
+ logging.warning(
582
+ f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
583
+ f"from '{out_of_range_str}' to '{in_range_str}'."
584
+ )
585
+
586
+ # A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
587
+ self.calibration["homing_offset"][calib_idx] += resolution * factor
588
+
589
+ def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
590
+ """Inverse of `apply_calibration`."""
591
+ if motor_names is None:
592
+ motor_names = self.motor_names
593
+
594
+ for i, name in enumerate(motor_names):
595
+ calib_idx = self.calibration["motor_names"].index(name)
596
+ calib_mode = self.calibration["calib_mode"][calib_idx]
597
+
598
+ if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
599
+ drive_mode = self.calibration["drive_mode"][calib_idx]
600
+ homing_offset = self.calibration["homing_offset"][calib_idx]
601
+ _, model = self.motors[name]
602
+ resolution = self.model_resolution[model]
603
+
604
+ # Convert from nominal 0-centered degree range [-180, 180] to
605
+ # 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
606
+ values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
607
+
608
+ # Subtract the homing offsets to come back to actual motor range of values
609
+ # which can be arbitrary.
610
+ values[i] -= homing_offset
611
+
612
+ # Remove drive mode, which is the rotation direction of the motor, to come back to
613
+ # actual motor rotation direction which can be arbitrary.
614
+ if drive_mode:
615
+ values[i] *= -1
616
+
617
+ elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
618
+ start_pos = self.calibration["start_pos"][calib_idx]
619
+ end_pos = self.calibration["end_pos"][calib_idx]
620
+
621
+ # Convert from nominal lnear range of [0, 100] % to
622
+ # actual motor range of values which can be arbitrary.
623
+ values[i] = values[i] / 100 * (end_pos - start_pos) + start_pos
624
+
625
+ values = np.round(values).astype(np.int32)
626
+ return values
627
+
628
+ def avoid_rotation_reset(self, values, motor_names, data_name):
629
+ if data_name not in self.track_positions:
630
+ self.track_positions[data_name] = {
631
+ "prev": [None] * len(self.motor_names),
632
+ # Assume False at initialization
633
+ "below_zero": [False] * len(self.motor_names),
634
+ "above_max": [False] * len(self.motor_names),
635
+ }
636
+
637
+ track = self.track_positions[data_name]
638
+
639
+ if motor_names is None:
640
+ motor_names = self.motor_names
641
+
642
+ for i, name in enumerate(motor_names):
643
+ idx = self.motor_names.index(name)
644
+
645
+ if track["prev"][idx] is None:
646
+ track["prev"][idx] = values[i]
647
+ continue
648
+
649
+ # Detect a full rotation occurred
650
+ if abs(track["prev"][idx] - values[i]) > 2048:
651
+ # Position went below 0 and got reset to 4095
652
+ if track["prev"][idx] < values[i]:
653
+ # So we set negative value by adding a full rotation
654
+ values[i] -= 4096
655
+
656
+ # Position went above 4095 and got reset to 0
657
+ elif track["prev"][idx] > values[i]:
658
+ # So we add a full rotation
659
+ values[i] += 4096
660
+
661
+ track["prev"][idx] = values[i]
662
+
663
+ return values
664
+
665
+ def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
666
+ if self.mock:
667
+ import tests.motors.mock_scservo_sdk as scs
668
+ else:
669
+ import scservo_sdk as scs
670
+
671
+ return_list = True
672
+ if not isinstance(motor_ids, list):
673
+ return_list = False
674
+ motor_ids = [motor_ids]
675
+
676
+ assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
677
+ addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
678
+ group = scs.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
679
+ for idx in motor_ids:
680
+ group.addParam(idx)
681
+
682
+ for _ in range(num_retry):
683
+ comm = group.txRxPacket()
684
+ if comm == scs.COMM_SUCCESS:
685
+ break
686
+
687
+ if comm != scs.COMM_SUCCESS:
688
+ raise ConnectionError(
689
+ f"Read failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
690
+ f"{self.packet_handler.getTxRxResult(comm)}"
691
+ )
692
+
693
+ values = []
694
+ for idx in motor_ids:
695
+ value = group.getData(idx, addr, bytes)
696
+ values.append(value)
697
+
698
+ if return_list:
699
+ return values
700
+ else:
701
+ return values[0]
702
+
703
+ def read(self, data_name, motor_names: str | list[str] | None = None):
704
+ if self.mock:
705
+ import tests.motors.mock_scservo_sdk as scs
706
+ else:
707
+ import scservo_sdk as scs
708
+
709
+ if not self.is_connected:
710
+ raise RobotDeviceNotConnectedError(
711
+ f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
712
+ )
713
+
714
+ start_time = time.perf_counter()
715
+
716
+ if motor_names is None:
717
+ motor_names = self.motor_names
718
+
719
+ if isinstance(motor_names, str):
720
+ motor_names = [motor_names]
721
+
722
+ motor_ids = []
723
+ models = []
724
+ for name in motor_names:
725
+ motor_idx, model = self.motors[name]
726
+ motor_ids.append(motor_idx)
727
+ models.append(model)
728
+
729
+ assert_same_address(self.model_ctrl_table, models, data_name)
730
+ addr, bytes = self.model_ctrl_table[model][data_name]
731
+ group_key = get_group_sync_key(data_name, motor_names)
732
+
733
+ if data_name not in self.group_readers:
734
+ # Very Important to flush the buffer!
735
+ self.port_handler.ser.reset_output_buffer()
736
+ self.port_handler.ser.reset_input_buffer()
737
+
738
+ # create new group reader
739
+ self.group_readers[group_key] = scs.GroupSyncRead(
740
+ self.port_handler, self.packet_handler, addr, bytes
741
+ )
742
+ for idx in motor_ids:
743
+ self.group_readers[group_key].addParam(idx)
744
+
745
+ for _ in range(NUM_READ_RETRY):
746
+ comm = self.group_readers[group_key].txRxPacket()
747
+ if comm == scs.COMM_SUCCESS:
748
+ break
749
+
750
+ if comm != scs.COMM_SUCCESS:
751
+ raise ConnectionError(
752
+ f"Read failed due to communication error on port {self.port} for group_key {group_key}: "
753
+ f"{self.packet_handler.getTxRxResult(comm)}"
754
+ )
755
+
756
+ values = []
757
+ for idx in motor_ids:
758
+ value = self.group_readers[group_key].getData(idx, addr, bytes)
759
+ values.append(value)
760
+
761
+ values = np.array(values)
762
+
763
+ # Convert to signed int to use range [-2048, 2048] for our motor positions.
764
+ if data_name in CONVERT_UINT32_TO_INT32_REQUIRED:
765
+ values = values.astype(np.int32)
766
+
767
+ if data_name in CALIBRATION_REQUIRED:
768
+ values = self.avoid_rotation_reset(values, motor_names, data_name)
769
+
770
+ if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
771
+ values = self.apply_calibration_autocorrect(values, motor_names)
772
+
773
+ # log the number of seconds it took to read the data from the motors
774
+ delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
775
+ self.logs[delta_ts_name] = time.perf_counter() - start_time
776
+
777
+ # log the utc time at which the data was received
778
+ ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names)
779
+ self.logs[ts_utc_name] = capture_timestamp_utc()
780
+
781
+ return values
782
+
783
+ def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
784
+ if self.mock:
785
+ import tests.motors.mock_scservo_sdk as scs
786
+ else:
787
+ import scservo_sdk as scs
788
+
789
+ if not isinstance(motor_ids, list):
790
+ motor_ids = [motor_ids]
791
+ if not isinstance(values, list):
792
+ values = [values]
793
+
794
+ assert_same_address(self.model_ctrl_table, motor_models, data_name)
795
+ addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
796
+ group = scs.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
797
+ for idx, value in zip(motor_ids, values, strict=True):
798
+ data = convert_to_bytes(value, bytes, self.mock)
799
+ group.addParam(idx, data)
800
+
801
+ for _ in range(num_retry):
802
+ comm = group.txPacket()
803
+ if comm == scs.COMM_SUCCESS:
804
+ break
805
+
806
+ if comm != scs.COMM_SUCCESS:
807
+ raise ConnectionError(
808
+ f"Write failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
809
+ f"{self.packet_handler.getTxRxResult(comm)}"
810
+ )
811
+
812
+ def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
813
+ if not self.is_connected:
814
+ raise RobotDeviceNotConnectedError(
815
+ f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
816
+ )
817
+
818
+ start_time = time.perf_counter()
819
+
820
+ if self.mock:
821
+ import tests.motors.mock_scservo_sdk as scs
822
+ else:
823
+ import scservo_sdk as scs
824
+
825
+ if motor_names is None:
826
+ motor_names = self.motor_names
827
+
828
+ if isinstance(motor_names, str):
829
+ motor_names = [motor_names]
830
+
831
+ if isinstance(values, (int, float, np.integer)):
832
+ values = [int(values)] * len(motor_names)
833
+
834
+ values = np.array(values)
835
+
836
+ motor_ids = []
837
+ models = []
838
+ for name in motor_names:
839
+ motor_idx, model = self.motors[name]
840
+ motor_ids.append(motor_idx)
841
+ models.append(model)
842
+
843
+ if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
844
+ values = self.revert_calibration(values, motor_names)
845
+
846
+ values = values.tolist()
847
+
848
+ assert_same_address(self.model_ctrl_table, models, data_name)
849
+ addr, bytes = self.model_ctrl_table[model][data_name]
850
+ group_key = get_group_sync_key(data_name, motor_names)
851
+
852
+ init_group = data_name not in self.group_readers
853
+ if init_group:
854
+ self.group_writers[group_key] = scs.GroupSyncWrite(
855
+ self.port_handler, self.packet_handler, addr, bytes
856
+ )
857
+
858
+ for idx, value in zip(motor_ids, values, strict=True):
859
+ data = convert_to_bytes(value, bytes, self.mock)
860
+ if init_group:
861
+ self.group_writers[group_key].addParam(idx, data)
862
+ else:
863
+ self.group_writers[group_key].changeParam(idx, data)
864
+
865
+ comm = self.group_writers[group_key].txPacket()
866
+ if comm != scs.COMM_SUCCESS:
867
+ raise ConnectionError(
868
+ f"Write failed due to communication error on port {self.port} for group_key {group_key}: "
869
+ f"{self.packet_handler.getTxRxResult(comm)}"
870
+ )
871
+
872
+ # log the number of seconds it took to write the data to the motors
873
+ delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
874
+ self.logs[delta_ts_name] = time.perf_counter() - start_time
875
+
876
+ # TODO(rcadene): should we log the time before sending the write command?
877
+ # log the utc time when the write has been completed
878
+ ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names)
879
+ self.logs[ts_utc_name] = capture_timestamp_utc()
880
+
881
+ def disconnect(self):
882
+ if not self.is_connected:
883
+ raise RobotDeviceNotConnectedError(
884
+ f"FeetechMotorsBus({self.port}) is not connected. Try running `motors_bus.connect()` first."
885
+ )
886
+
887
+ if self.port_handler is not None:
888
+ self.port_handler.closePort()
889
+ self.port_handler = None
890
+
891
+ self.packet_handler = None
892
+ self.group_readers = {}
893
+ self.group_writers = {}
894
+ self.is_connected = False
895
+
896
+ def __del__(self):
897
+ if getattr(self, "is_connected", False):
898
+ self.disconnect()
lerobot/common/robot_devices/motors/utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Protocol
16
+
17
+ from lerobot.common.robot_devices.motors.configs import (
18
+ DynamixelMotorsBusConfig,
19
+ FeetechMotorsBusConfig,
20
+ MotorsBusConfig,
21
+ )
22
+
23
+
24
+ class MotorsBus(Protocol):
25
+ def motor_names(self): ...
26
+ def set_calibration(self): ...
27
+ def apply_calibration(self): ...
28
+ def revert_calibration(self): ...
29
+ def read(self): ...
30
+ def write(self): ...
31
+
32
+
33
+ def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig]) -> list[MotorsBus]:
34
+ motors_buses = {}
35
+
36
+ for key, cfg in motors_bus_configs.items():
37
+ if cfg.type == "dynamixel":
38
+ from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
39
+
40
+ motors_buses[key] = DynamixelMotorsBus(cfg)
41
+
42
+ elif cfg.type == "feetech":
43
+ from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
44
+
45
+ motors_buses[key] = FeetechMotorsBus(cfg)
46
+
47
+ else:
48
+ raise ValueError(f"The motor type '{cfg.type}' is not valid.")
49
+
50
+ return motors_buses
51
+
52
+
53
+ def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
54
+ if motor_type == "dynamixel":
55
+ from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
56
+
57
+ config = DynamixelMotorsBusConfig(**kwargs)
58
+ return DynamixelMotorsBus(config)
59
+
60
+ elif motor_type == "feetech":
61
+ from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
62
+
63
+ config = FeetechMotorsBusConfig(**kwargs)
64
+ return FeetechMotorsBus(config)
65
+
66
+ else:
67
+ raise ValueError(f"The motor type '{motor_type}' is not valid.")
lerobot/common/robot_devices/robots/configs.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import abc
16
+ from dataclasses import dataclass, field
17
+ from typing import Sequence
18
+
19
+ import draccus
20
+
21
+ from lerobot.common.robot_devices.cameras.configs import (
22
+ CameraConfig,
23
+ IntelRealSenseCameraConfig,
24
+ OpenCVCameraConfig,
25
+ )
26
+ from lerobot.common.robot_devices.motors.configs import (
27
+ DynamixelMotorsBusConfig,
28
+ FeetechMotorsBusConfig,
29
+ MotorsBusConfig,
30
+ )
31
+
32
+
33
+ @dataclass
34
+ class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
35
+ @property
36
+ def type(self) -> str:
37
+ return self.get_choice_name(self.__class__)
38
+
39
+
40
+ # TODO(rcadene, aliberts): remove ManipulatorRobotConfig abstraction
41
+ @dataclass
42
+ class ManipulatorRobotConfig(RobotConfig):
43
+ leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
44
+ follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
45
+ cameras: dict[str, CameraConfig] = field(default_factory=lambda: {})
46
+
47
+ # Optionally limit the magnitude of the relative positional target vector for safety purposes.
48
+ # Set this to a positive scalar to have the same value for all motors, or a list that is the same length
49
+ # as the number of motors in your follower arms (assumes all follower arms have the same number of
50
+ # motors).
51
+ max_relative_target: list[float] | float | None = None
52
+
53
+ # Optionally set the leader arm in torque mode with the gripper motor set to this angle. This makes it
54
+ # possible to squeeze the gripper and have it spring back to an open position on its own. If None, the
55
+ # gripper is not put in torque mode.
56
+ gripper_open_degree: float | None = None
57
+
58
+ mock: bool = False
59
+
60
+ def __post_init__(self):
61
+ if self.mock:
62
+ for arm in self.leader_arms.values():
63
+ if not arm.mock:
64
+ arm.mock = True
65
+ for arm in self.follower_arms.values():
66
+ if not arm.mock:
67
+ arm.mock = True
68
+ for cam in self.cameras.values():
69
+ if not cam.mock:
70
+ cam.mock = True
71
+
72
+ if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
73
+ for name in self.follower_arms:
74
+ if len(self.follower_arms[name].motors) != len(self.max_relative_target):
75
+ raise ValueError(
76
+ f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has "
77
+ f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "
78
+ f"`max_relative_target` list has as many parameters as there are motors per arm. "
79
+ "Note: This feature does not yet work with robots where different follower arms have "
80
+ "different numbers of motors."
81
+ )
82
+
83
+
84
+ @RobotConfig.register_subclass("aloha")
85
+ @dataclass
86
+ class AlohaRobotConfig(ManipulatorRobotConfig):
87
+ # Specific to Aloha, LeRobot comes with default calibration files. Assuming the motors have been
88
+ # properly assembled, no manual calibration step is expected. If you need to run manual calibration,
89
+ # simply update this path to ".cache/calibration/aloha"
90
+ calibration_dir: str = ".cache/calibration/aloha_default"
91
+
92
+ # /!\ FOR SAFETY, READ THIS /!\
93
+ # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
94
+ # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
95
+ # the number of motors in your follower arms.
96
+ # For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default.
97
+ # When you feel more confident with teleoperation or running the policy, you can extend
98
+ # this safety limit and even removing it by setting it to `null`.
99
+ # Also, everything is expected to work safely out-of-the-box, but we highly advise to
100
+ # first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml),
101
+ # then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully
102
+ max_relative_target: int | None = 5
103
+
104
+ leader_arms: dict[str, MotorsBusConfig] = field(
105
+ default_factory=lambda: {
106
+ "left": DynamixelMotorsBusConfig(
107
+ # window_x
108
+ port="/dev/ttyDXL_leader_left",
109
+ motors={
110
+ # name: (index, model)
111
+ "waist": [1, "xm430-w350"],
112
+ "shoulder": [2, "xm430-w350"],
113
+ "shoulder_shadow": [3, "xm430-w350"],
114
+ "elbow": [4, "xm430-w350"],
115
+ "elbow_shadow": [5, "xm430-w350"],
116
+ "forearm_roll": [6, "xm430-w350"],
117
+ "wrist_angle": [7, "xm430-w350"],
118
+ "wrist_rotate": [8, "xl430-w250"],
119
+ "gripper": [9, "xc430-w150"],
120
+ },
121
+ ),
122
+ "right": DynamixelMotorsBusConfig(
123
+ # window_x
124
+ port="/dev/ttyDXL_leader_right",
125
+ motors={
126
+ # name: (index, model)
127
+ "waist": [1, "xm430-w350"],
128
+ "shoulder": [2, "xm430-w350"],
129
+ "shoulder_shadow": [3, "xm430-w350"],
130
+ "elbow": [4, "xm430-w350"],
131
+ "elbow_shadow": [5, "xm430-w350"],
132
+ "forearm_roll": [6, "xm430-w350"],
133
+ "wrist_angle": [7, "xm430-w350"],
134
+ "wrist_rotate": [8, "xl430-w250"],
135
+ "gripper": [9, "xc430-w150"],
136
+ },
137
+ ),
138
+ }
139
+ )
140
+
141
+ follower_arms: dict[str, MotorsBusConfig] = field(
142
+ default_factory=lambda: {
143
+ "left": DynamixelMotorsBusConfig(
144
+ port="/dev/ttyDXL_follower_left",
145
+ motors={
146
+ # name: (index, model)
147
+ "waist": [1, "xm540-w270"],
148
+ "shoulder": [2, "xm540-w270"],
149
+ "shoulder_shadow": [3, "xm540-w270"],
150
+ "elbow": [4, "xm540-w270"],
151
+ "elbow_shadow": [5, "xm540-w270"],
152
+ "forearm_roll": [6, "xm540-w270"],
153
+ "wrist_angle": [7, "xm540-w270"],
154
+ "wrist_rotate": [8, "xm430-w350"],
155
+ "gripper": [9, "xm430-w350"],
156
+ },
157
+ ),
158
+ "right": DynamixelMotorsBusConfig(
159
+ port="/dev/ttyDXL_follower_right",
160
+ motors={
161
+ # name: (index, model)
162
+ "waist": [1, "xm540-w270"],
163
+ "shoulder": [2, "xm540-w270"],
164
+ "shoulder_shadow": [3, "xm540-w270"],
165
+ "elbow": [4, "xm540-w270"],
166
+ "elbow_shadow": [5, "xm540-w270"],
167
+ "forearm_roll": [6, "xm540-w270"],
168
+ "wrist_angle": [7, "xm540-w270"],
169
+ "wrist_rotate": [8, "xm430-w350"],
170
+ "gripper": [9, "xm430-w350"],
171
+ },
172
+ ),
173
+ }
174
+ )
175
+
176
+ # Troubleshooting: If one of your IntelRealSense cameras freeze during
177
+ # data recording due to bandwidth limit, you might need to plug the camera
178
+ # on another USB hub or PCIe card.
179
+ cameras: dict[str, CameraConfig] = field(
180
+ default_factory=lambda: {
181
+ "cam_high": IntelRealSenseCameraConfig(
182
+ serial_number=128422271347,
183
+ fps=30,
184
+ width=640,
185
+ height=480,
186
+ ),
187
+ "cam_low": IntelRealSenseCameraConfig(
188
+ serial_number=130322270656,
189
+ fps=30,
190
+ width=640,
191
+ height=480,
192
+ ),
193
+ "cam_left_wrist": IntelRealSenseCameraConfig(
194
+ serial_number=218622272670,
195
+ fps=30,
196
+ width=640,
197
+ height=480,
198
+ ),
199
+ "cam_right_wrist": IntelRealSenseCameraConfig(
200
+ serial_number=130322272300,
201
+ fps=30,
202
+ width=640,
203
+ height=480,
204
+ ),
205
+ }
206
+ )
207
+
208
+ mock: bool = False
209
+
210
+
211
+ @RobotConfig.register_subclass("koch")
212
+ @dataclass
213
+ class KochRobotConfig(ManipulatorRobotConfig):
214
+ calibration_dir: str = ".cache/calibration/koch"
215
+ # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
216
+ # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
217
+ # the number of motors in your follower arms.
218
+ max_relative_target: int | None = None
219
+
220
+ leader_arms: dict[str, MotorsBusConfig] = field(
221
+ default_factory=lambda: {
222
+ "main": DynamixelMotorsBusConfig(
223
+ port="/dev/tty.usbmodem585A0085511",
224
+ motors={
225
+ # name: (index, model)
226
+ "shoulder_pan": [1, "xl330-m077"],
227
+ "shoulder_lift": [2, "xl330-m077"],
228
+ "elbow_flex": [3, "xl330-m077"],
229
+ "wrist_flex": [4, "xl330-m077"],
230
+ "wrist_roll": [5, "xl330-m077"],
231
+ "gripper": [6, "xl330-m077"],
232
+ },
233
+ ),
234
+ }
235
+ )
236
+
237
+ follower_arms: dict[str, MotorsBusConfig] = field(
238
+ default_factory=lambda: {
239
+ "main": DynamixelMotorsBusConfig(
240
+ port="/dev/tty.usbmodem585A0076891",
241
+ motors={
242
+ # name: (index, model)
243
+ "shoulder_pan": [1, "xl430-w250"],
244
+ "shoulder_lift": [2, "xl430-w250"],
245
+ "elbow_flex": [3, "xl330-m288"],
246
+ "wrist_flex": [4, "xl330-m288"],
247
+ "wrist_roll": [5, "xl330-m288"],
248
+ "gripper": [6, "xl330-m288"],
249
+ },
250
+ ),
251
+ }
252
+ )
253
+
254
+ cameras: dict[str, CameraConfig] = field(
255
+ default_factory=lambda: {
256
+ "laptop": OpenCVCameraConfig(
257
+ camera_index=0,
258
+ fps=30,
259
+ width=640,
260
+ height=480,
261
+ ),
262
+ "phone": OpenCVCameraConfig(
263
+ camera_index=1,
264
+ fps=30,
265
+ width=640,
266
+ height=480,
267
+ ),
268
+ }
269
+ )
270
+
271
+ # ~ Koch specific settings ~
272
+ # Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible
273
+ # to squeeze the gripper and have it spring back to an open position on its own.
274
+ gripper_open_degree: float = 35.156
275
+
276
+ mock: bool = False
277
+
278
+
279
+ @RobotConfig.register_subclass("koch_bimanual")
280
+ @dataclass
281
+ class KochBimanualRobotConfig(ManipulatorRobotConfig):
282
+ calibration_dir: str = ".cache/calibration/koch_bimanual"
283
+ # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
284
+ # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
285
+ # the number of motors in your follower arms.
286
+ max_relative_target: int | None = None
287
+
288
+ leader_arms: dict[str, MotorsBusConfig] = field(
289
+ default_factory=lambda: {
290
+ "left": DynamixelMotorsBusConfig(
291
+ port="/dev/tty.usbmodem585A0085511",
292
+ motors={
293
+ # name: (index, model)
294
+ "shoulder_pan": [1, "xl330-m077"],
295
+ "shoulder_lift": [2, "xl330-m077"],
296
+ "elbow_flex": [3, "xl330-m077"],
297
+ "wrist_flex": [4, "xl330-m077"],
298
+ "wrist_roll": [5, "xl330-m077"],
299
+ "gripper": [6, "xl330-m077"],
300
+ },
301
+ ),
302
+ "right": DynamixelMotorsBusConfig(
303
+ port="/dev/tty.usbmodem575E0031751",
304
+ motors={
305
+ # name: (index, model)
306
+ "shoulder_pan": [1, "xl330-m077"],
307
+ "shoulder_lift": [2, "xl330-m077"],
308
+ "elbow_flex": [3, "xl330-m077"],
309
+ "wrist_flex": [4, "xl330-m077"],
310
+ "wrist_roll": [5, "xl330-m077"],
311
+ "gripper": [6, "xl330-m077"],
312
+ },
313
+ ),
314
+ }
315
+ )
316
+
317
+ follower_arms: dict[str, MotorsBusConfig] = field(
318
+ default_factory=lambda: {
319
+ "left": DynamixelMotorsBusConfig(
320
+ port="/dev/tty.usbmodem585A0076891",
321
+ motors={
322
+ # name: (index, model)
323
+ "shoulder_pan": [1, "xl430-w250"],
324
+ "shoulder_lift": [2, "xl430-w250"],
325
+ "elbow_flex": [3, "xl330-m288"],
326
+ "wrist_flex": [4, "xl330-m288"],
327
+ "wrist_roll": [5, "xl330-m288"],
328
+ "gripper": [6, "xl330-m288"],
329
+ },
330
+ ),
331
+ "right": DynamixelMotorsBusConfig(
332
+ port="/dev/tty.usbmodem575E0032081",
333
+ motors={
334
+ # name: (index, model)
335
+ "shoulder_pan": [1, "xl430-w250"],
336
+ "shoulder_lift": [2, "xl430-w250"],
337
+ "elbow_flex": [3, "xl330-m288"],
338
+ "wrist_flex": [4, "xl330-m288"],
339
+ "wrist_roll": [5, "xl330-m288"],
340
+ "gripper": [6, "xl330-m288"],
341
+ },
342
+ ),
343
+ }
344
+ )
345
+
346
+ cameras: dict[str, CameraConfig] = field(
347
+ default_factory=lambda: {
348
+ "laptop": OpenCVCameraConfig(
349
+ camera_index=0,
350
+ fps=30,
351
+ width=640,
352
+ height=480,
353
+ ),
354
+ "phone": OpenCVCameraConfig(
355
+ camera_index=1,
356
+ fps=30,
357
+ width=640,
358
+ height=480,
359
+ ),
360
+ }
361
+ )
362
+
363
+ # ~ Koch specific settings ~
364
+ # Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible
365
+ # to squeeze the gripper and have it spring back to an open position on its own.
366
+ gripper_open_degree: float = 35.156
367
+
368
+ mock: bool = False
369
+
370
+
371
+ @RobotConfig.register_subclass("moss")
372
+ @dataclass
373
+ class MossRobotConfig(ManipulatorRobotConfig):
374
+ calibration_dir: str = ".cache/calibration/moss"
375
+ # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
376
+ # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
377
+ # the number of motors in your follower arms.
378
+ max_relative_target: int | None = None
379
+
380
+ leader_arms: dict[str, MotorsBusConfig] = field(
381
+ default_factory=lambda: {
382
+ "main": FeetechMotorsBusConfig(
383
+ port="/dev/tty.usbmodem58760431091",
384
+ motors={
385
+ # name: (index, model)
386
+ "shoulder_pan": [1, "sts3215"],
387
+ "shoulder_lift": [2, "sts3215"],
388
+ "elbow_flex": [3, "sts3215"],
389
+ "wrist_flex": [4, "sts3215"],
390
+ "wrist_roll": [5, "sts3215"],
391
+ "gripper": [6, "sts3215"],
392
+ },
393
+ ),
394
+ }
395
+ )
396
+
397
+ follower_arms: dict[str, MotorsBusConfig] = field(
398
+ default_factory=lambda: {
399
+ "main": FeetechMotorsBusConfig(
400
+ port="/dev/tty.usbmodem585A0076891",
401
+ motors={
402
+ # name: (index, model)
403
+ "shoulder_pan": [1, "sts3215"],
404
+ "shoulder_lift": [2, "sts3215"],
405
+ "elbow_flex": [3, "sts3215"],
406
+ "wrist_flex": [4, "sts3215"],
407
+ "wrist_roll": [5, "sts3215"],
408
+ "gripper": [6, "sts3215"],
409
+ },
410
+ ),
411
+ }
412
+ )
413
+
414
+ cameras: dict[str, CameraConfig] = field(
415
+ default_factory=lambda: {
416
+ "laptop": OpenCVCameraConfig(
417
+ camera_index=0,
418
+ fps=30,
419
+ width=640,
420
+ height=480,
421
+ ),
422
+ "phone": OpenCVCameraConfig(
423
+ camera_index=1,
424
+ fps=30,
425
+ width=640,
426
+ height=480,
427
+ ),
428
+ }
429
+ )
430
+
431
+ mock: bool = False
432
+
433
+
434
+ @RobotConfig.register_subclass("so100")
435
+ @dataclass
436
+ class So100RobotConfig(ManipulatorRobotConfig):
437
+ calibration_dir: str = ".cache/calibration/so100"
438
+ # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
439
+ # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
440
+ # the number of motors in your follower arms.
441
+ max_relative_target: int | None = None
442
+
443
+ leader_arms: dict[str, MotorsBusConfig] = field(
444
+ default_factory=lambda: {
445
+ "main": FeetechMotorsBusConfig(
446
+ port="/dev/tty.usbmodem58760431091",
447
+ motors={
448
+ # name: (index, model)
449
+ "shoulder_pan": [1, "sts3215"],
450
+ "shoulder_lift": [2, "sts3215"],
451
+ "elbow_flex": [3, "sts3215"],
452
+ "wrist_flex": [4, "sts3215"],
453
+ "wrist_roll": [5, "sts3215"],
454
+ "gripper": [6, "sts3215"],
455
+ },
456
+ ),
457
+ }
458
+ )
459
+
460
+ follower_arms: dict[str, MotorsBusConfig] = field(
461
+ default_factory=lambda: {
462
+ "main": FeetechMotorsBusConfig(
463
+ port="/dev/tty.usbmodem585A0076891",
464
+ motors={
465
+ # name: (index, model)
466
+ "shoulder_pan": [1, "sts3215"],
467
+ "shoulder_lift": [2, "sts3215"],
468
+ "elbow_flex": [3, "sts3215"],
469
+ "wrist_flex": [4, "sts3215"],
470
+ "wrist_roll": [5, "sts3215"],
471
+ "gripper": [6, "sts3215"],
472
+ },
473
+ ),
474
+ }
475
+ )
476
+
477
+ cameras: dict[str, CameraConfig] = field(
478
+ default_factory=lambda: {
479
+ "laptop": OpenCVCameraConfig(
480
+ camera_index=0,
481
+ fps=30,
482
+ width=640,
483
+ height=480,
484
+ ),
485
+ "phone": OpenCVCameraConfig(
486
+ camera_index=1,
487
+ fps=30,
488
+ width=640,
489
+ height=480,
490
+ ),
491
+ }
492
+ )
493
+
494
+ mock: bool = False
495
+
496
+
497
+ @RobotConfig.register_subclass("stretch")
498
+ @dataclass
499
+ class StretchRobotConfig(RobotConfig):
500
+ # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
501
+ # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
502
+ # the number of motors in your follower arms.
503
+ max_relative_target: int | None = None
504
+
505
+ cameras: dict[str, CameraConfig] = field(
506
+ default_factory=lambda: {
507
+ "navigation": OpenCVCameraConfig(
508
+ camera_index="/dev/hello-nav-head-camera",
509
+ fps=10,
510
+ width=1280,
511
+ height=720,
512
+ rotation=-90,
513
+ ),
514
+ "head": IntelRealSenseCameraConfig(
515
+ name="Intel RealSense D435I",
516
+ fps=30,
517
+ width=640,
518
+ height=480,
519
+ rotation=90,
520
+ ),
521
+ "wrist": IntelRealSenseCameraConfig(
522
+ name="Intel RealSense D405",
523
+ fps=30,
524
+ width=640,
525
+ height=480,
526
+ ),
527
+ }
528
+ )
529
+
530
+ mock: bool = False
531
+
532
+
533
+ @RobotConfig.register_subclass("lekiwi")
534
+ @dataclass
535
+ class LeKiwiRobotConfig(RobotConfig):
536
+ # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
537
+ # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
538
+ # the number of motors in your follower arms.
539
+ max_relative_target: int | None = None
540
+
541
+ # Network Configuration
542
+ ip: str = "192.168.0.193"
543
+ port: int = 5555
544
+ video_port: int = 5556
545
+
546
+ cameras: dict[str, CameraConfig] = field(
547
+ default_factory=lambda: {
548
+ "front": OpenCVCameraConfig(
549
+ camera_index="/dev/video0", fps=30, width=640, height=480, rotation=90
550
+ ),
551
+ "wrist": OpenCVCameraConfig(
552
+ camera_index="/dev/video2", fps=30, width=640, height=480, rotation=180
553
+ ),
554
+ }
555
+ )
556
+
557
+ calibration_dir: str = ".cache/calibration/lekiwi"
558
+
559
+ leader_arms: dict[str, MotorsBusConfig] = field(
560
+ default_factory=lambda: {
561
+ "main": FeetechMotorsBusConfig(
562
+ port="/dev/tty.usbmodem585A0077581",
563
+ motors={
564
+ # name: (index, model)
565
+ "shoulder_pan": [1, "sts3215"],
566
+ "shoulder_lift": [2, "sts3215"],
567
+ "elbow_flex": [3, "sts3215"],
568
+ "wrist_flex": [4, "sts3215"],
569
+ "wrist_roll": [5, "sts3215"],
570
+ "gripper": [6, "sts3215"],
571
+ },
572
+ ),
573
+ }
574
+ )
575
+
576
+ follower_arms: dict[str, MotorsBusConfig] = field(
577
+ default_factory=lambda: {
578
+ "main": FeetechMotorsBusConfig(
579
+ port="/dev/ttyACM0",
580
+ motors={
581
+ # name: (index, model)
582
+ "shoulder_pan": [1, "sts3215"],
583
+ "shoulder_lift": [2, "sts3215"],
584
+ "elbow_flex": [3, "sts3215"],
585
+ "wrist_flex": [4, "sts3215"],
586
+ "wrist_roll": [5, "sts3215"],
587
+ "gripper": [6, "sts3215"],
588
+ "left_wheel": (7, "sts3215"),
589
+ "back_wheel": (8, "sts3215"),
590
+ "right_wheel": (9, "sts3215"),
591
+ },
592
+ ),
593
+ }
594
+ )
595
+
596
+ teleop_keys: dict[str, str] = field(
597
+ default_factory=lambda: {
598
+ # Movement
599
+ "forward": "w",
600
+ "backward": "s",
601
+ "left": "a",
602
+ "right": "d",
603
+ "rotate_left": "z",
604
+ "rotate_right": "x",
605
+ # Speed control
606
+ "speed_up": "r",
607
+ "speed_down": "f",
608
+ # quit teleop
609
+ "quit": "q",
610
+ }
611
+ )
612
+
613
+ mock: bool = False
lerobot/common/robot_devices/robots/dynamixel_calibration.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Logic to calibrate a robot arm built with dynamixel motors"""
16
+ # TODO(rcadene, aliberts): move this logic into the robot code when refactoring
17
+
18
+ import numpy as np
19
+
20
+ from lerobot.common.robot_devices.motors.dynamixel import (
21
+ CalibrationMode,
22
+ TorqueMode,
23
+ convert_degrees_to_steps,
24
+ )
25
+ from lerobot.common.robot_devices.motors.utils import MotorsBus
26
+
27
+ URL_TEMPLATE = (
28
+ "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
29
+ )
30
+
31
+ # The following positions are provided in nominal degree range ]-180, +180[
32
+ # For more info on these constants, see comments in the code where they get used.
33
+ ZERO_POSITION_DEGREE = 0
34
+ ROTATED_POSITION_DEGREE = 90
35
+
36
+
37
+ def assert_drive_mode(drive_mode):
38
+ # `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
39
+ if not np.all(np.isin(drive_mode, [0, 1])):
40
+ raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
41
+
42
+
43
+ def apply_drive_mode(position, drive_mode):
44
+ assert_drive_mode(drive_mode)
45
+ # Convert `drive_mode` from [0, 1] with 0 indicates original rotation direction and 1 inverted,
46
+ # to [-1, 1] with 1 indicates original rotation direction and -1 inverted.
47
+ signed_drive_mode = -(drive_mode * 2 - 1)
48
+ position *= signed_drive_mode
49
+ return position
50
+
51
+
52
+ def compute_nearest_rounded_position(position, models):
53
+ delta_turn = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, models)
54
+ nearest_pos = np.round(position.astype(float) / delta_turn) * delta_turn
55
+ return nearest_pos.astype(position.dtype)
56
+
57
+
58
+ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
59
+ """This function ensures that a neural network trained on data collected on a given robot
60
+ can work on another robot. For instance before calibration, setting a same goal position
61
+ for each motor of two different robots will get two very different positions. But after calibration,
62
+ the two robots will move to the same position.To this end, this function computes the homing offset
63
+ and the drive mode for each motor of a given robot.
64
+
65
+ Homing offset is used to shift the motor position to a ]-2048, +2048[ nominal range (when the motor uses 2048 steps
66
+ to complete a half a turn). This range is set around an arbitrary "zero position" corresponding to all motor positions
67
+ being 0. During the calibration process, you will need to manually move the robot to this "zero position".
68
+
69
+ Drive mode is used to invert the rotation direction of the motor. This is useful when some motors have been assembled
70
+ in the opposite orientation for some robots. During the calibration process, you will need to manually move the robot
71
+ to the "rotated position".
72
+
73
+ After calibration, the homing offsets and drive modes are stored in a cache.
74
+
75
+ Example of usage:
76
+ ```python
77
+ run_arm_calibration(arm, "koch", "left", "follower")
78
+ ```
79
+ """
80
+ if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
81
+ raise ValueError("To run calibration, the torque must be disabled on all motors.")
82
+
83
+ print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
84
+
85
+ print("\nMove arm to zero position")
86
+ print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
87
+ input("Press Enter to continue...")
88
+
89
+ # We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
90
+ # It is easy to identify and all motors are in a "quarter turn" position. Once calibration is done, this position will
91
+ # correspond to every motor angle being 0. If you set all 0 as Goal Position, the arm will move in this position.
92
+ zero_target_pos = convert_degrees_to_steps(ZERO_POSITION_DEGREE, arm.motor_models)
93
+
94
+ # Compute homing offset so that `present_position + homing_offset ~= target_position`.
95
+ zero_pos = arm.read("Present_Position")
96
+ zero_nearest_pos = compute_nearest_rounded_position(zero_pos, arm.motor_models)
97
+ homing_offset = zero_target_pos - zero_nearest_pos
98
+
99
+ # The rotated target position corresponds to a rotation of a quarter turn from the zero position.
100
+ # This allows to identify the rotation direction of each motor.
101
+ # For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
102
+ # is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
103
+ # Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
104
+ # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
105
+ # of the previous motor in the kinetic chain.
106
+ print("\nMove arm to rotated target position")
107
+ print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
108
+ input("Press Enter to continue...")
109
+
110
+ rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
111
+
112
+ # Find drive mode by rotating each motor by a quarter of a turn.
113
+ # Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
114
+ rotated_pos = arm.read("Present_Position")
115
+ drive_mode = (rotated_pos < zero_pos).astype(np.int32)
116
+
117
+ # Re-compute homing offset to take into account drive mode
118
+ rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode)
119
+ rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models)
120
+ homing_offset = rotated_target_pos - rotated_nearest_pos
121
+
122
+ print("\nMove arm to rest position")
123
+ print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
124
+ input("Press Enter to continue...")
125
+ print()
126
+
127
+ # Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
128
+ calib_mode = [CalibrationMode.DEGREE.name] * len(arm.motor_names)
129
+
130
+ # TODO(rcadene): make type of joints (DEGREE or LINEAR) configurable from yaml?
131
+ if robot_type in ["aloha"] and "gripper" in arm.motor_names:
132
+ # Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
133
+ calib_idx = arm.motor_names.index("gripper")
134
+ calib_mode[calib_idx] = CalibrationMode.LINEAR.name
135
+
136
+ calib_data = {
137
+ "homing_offset": homing_offset.tolist(),
138
+ "drive_mode": drive_mode.tolist(),
139
+ "start_pos": zero_pos.tolist(),
140
+ "end_pos": rotated_pos.tolist(),
141
+ "calib_mode": calib_mode,
142
+ "motor_names": arm.motor_names,
143
+ }
144
+ return calib_data
lerobot/common/robot_devices/robots/feetech_calibration.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Logic to calibrate a robot arm built with feetech motors"""
16
+ # TODO(rcadene, aliberts): move this logic into the robot code when refactoring
17
+
18
+ import time
19
+
20
+ import numpy as np
21
+
22
+ from lerobot.common.robot_devices.motors.feetech import (
23
+ CalibrationMode,
24
+ TorqueMode,
25
+ convert_degrees_to_steps,
26
+ )
27
+ from lerobot.common.robot_devices.motors.utils import MotorsBus
28
+
29
+ URL_TEMPLATE = (
30
+ "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
31
+ )
32
+
33
+ # The following positions are provided in nominal degree range ]-180, +180[
34
+ # For more info on these constants, see comments in the code where they get used.
35
+ ZERO_POSITION_DEGREE = 0
36
+ ROTATED_POSITION_DEGREE = 90
37
+
38
+
39
+ def assert_drive_mode(drive_mode):
40
+ # `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
41
+ if not np.all(np.isin(drive_mode, [0, 1])):
42
+ raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
43
+
44
+
45
+ def apply_drive_mode(position, drive_mode):
46
+ assert_drive_mode(drive_mode)
47
+ # Convert `drive_mode` from [0, 1] with 0 indicates original rotation direction and 1 inverted,
48
+ # to [-1, 1] with 1 indicates original rotation direction and -1 inverted.
49
+ signed_drive_mode = -(drive_mode * 2 - 1)
50
+ position *= signed_drive_mode
51
+ return position
52
+
53
+
54
+ def move_until_block(arm, motor_name, positive_direction=True, while_move_hook=None):
55
+ count = 0
56
+ while True:
57
+ present_pos = arm.read("Present_Position", motor_name)
58
+ if positive_direction:
59
+ # Move +100 steps every time. Lower the steps to lower the speed at which the arm moves.
60
+ arm.write("Goal_Position", present_pos + 100, motor_name)
61
+ else:
62
+ arm.write("Goal_Position", present_pos - 100, motor_name)
63
+
64
+ if while_move_hook is not None:
65
+ while_move_hook()
66
+
67
+ present_pos = arm.read("Present_Position", motor_name).item()
68
+ present_speed = arm.read("Present_Speed", motor_name).item()
69
+ present_current = arm.read("Present_Current", motor_name).item()
70
+ # present_load = arm.read("Present_Load", motor_name).item()
71
+ # present_voltage = arm.read("Present_Voltage", motor_name).item()
72
+ # present_temperature = arm.read("Present_Temperature", motor_name).item()
73
+
74
+ # print(f"{present_pos=}")
75
+ # print(f"{present_speed=}")
76
+ # print(f"{present_current=}")
77
+ # print(f"{present_load=}")
78
+ # print(f"{present_voltage=}")
79
+ # print(f"{present_temperature=}")
80
+
81
+ if present_speed == 0 and present_current > 40:
82
+ count += 1
83
+ if count > 100 or present_current > 300:
84
+ return present_pos
85
+ else:
86
+ count = 0
87
+
88
+
89
+ def move_to_calibrate(
90
+ arm,
91
+ motor_name,
92
+ invert_drive_mode=False,
93
+ positive_first=True,
94
+ in_between_move_hook=None,
95
+ while_move_hook=None,
96
+ ):
97
+ initial_pos = arm.read("Present_Position", motor_name)
98
+
99
+ if positive_first:
100
+ p_present_pos = move_until_block(
101
+ arm, motor_name, positive_direction=True, while_move_hook=while_move_hook
102
+ )
103
+ else:
104
+ n_present_pos = move_until_block(
105
+ arm, motor_name, positive_direction=False, while_move_hook=while_move_hook
106
+ )
107
+
108
+ if in_between_move_hook is not None:
109
+ in_between_move_hook()
110
+
111
+ if positive_first:
112
+ n_present_pos = move_until_block(
113
+ arm, motor_name, positive_direction=False, while_move_hook=while_move_hook
114
+ )
115
+ else:
116
+ p_present_pos = move_until_block(
117
+ arm, motor_name, positive_direction=True, while_move_hook=while_move_hook
118
+ )
119
+
120
+ zero_pos = (n_present_pos + p_present_pos) / 2
121
+
122
+ calib_data = {
123
+ "initial_pos": initial_pos,
124
+ "homing_offset": zero_pos if invert_drive_mode else -zero_pos,
125
+ "invert_drive_mode": invert_drive_mode,
126
+ "drive_mode": -1 if invert_drive_mode else 0,
127
+ "zero_pos": zero_pos,
128
+ "start_pos": n_present_pos if invert_drive_mode else p_present_pos,
129
+ "end_pos": p_present_pos if invert_drive_mode else n_present_pos,
130
+ }
131
+ return calib_data
132
+
133
+
134
+ def apply_offset(calib, offset):
135
+ calib["zero_pos"] += offset
136
+ if calib["drive_mode"]:
137
+ calib["homing_offset"] += offset
138
+ else:
139
+ calib["homing_offset"] -= offset
140
+ return calib
141
+
142
+
143
+ def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
144
+ if robot_type == "so100":
145
+ return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type)
146
+ elif robot_type == "moss":
147
+ return run_arm_auto_calibration_moss(arm, robot_type, arm_name, arm_type)
148
+ else:
149
+ raise ValueError(robot_type)
150
+
151
+
152
+ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
153
+ """All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
154
+ if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
155
+ raise ValueError("To run calibration, the torque must be disabled on all motors.")
156
+
157
+ if not (robot_type == "so100" and arm_type == "follower"):
158
+ raise NotImplementedError("Auto calibration only supports the follower of so100 arms for now.")
159
+
160
+ print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
161
+
162
+ print("\nMove arm to initial position")
163
+ print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
164
+ input("Press Enter to continue...")
165
+
166
+ # Lower the acceleration of the motors (in [0,254])
167
+ initial_acceleration = arm.read("Acceleration")
168
+ arm.write("Lock", 0)
169
+ arm.write("Acceleration", 10)
170
+ time.sleep(1)
171
+
172
+ arm.write("Torque_Enable", TorqueMode.ENABLED.value)
173
+
174
+ print(f'{arm.read("Present_Position", "elbow_flex")=}')
175
+
176
+ calib = {}
177
+
178
+ init_wf_pos = arm.read("Present_Position", "wrist_flex")
179
+ init_sl_pos = arm.read("Present_Position", "shoulder_lift")
180
+ init_ef_pos = arm.read("Present_Position", "elbow_flex")
181
+ arm.write("Goal_Position", init_wf_pos - 800, "wrist_flex")
182
+ arm.write("Goal_Position", init_sl_pos + 150 + 1024, "shoulder_lift")
183
+ arm.write("Goal_Position", init_ef_pos - 2048, "elbow_flex")
184
+ time.sleep(2)
185
+
186
+ print("Calibrate shoulder_pan")
187
+ calib["shoulder_pan"] = move_to_calibrate(arm, "shoulder_pan")
188
+ arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
189
+ time.sleep(1)
190
+
191
+ print("Calibrate gripper")
192
+ calib["gripper"] = move_to_calibrate(arm, "gripper", invert_drive_mode=True)
193
+ time.sleep(1)
194
+
195
+ print("Calibrate wrist_flex")
196
+ calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex")
197
+ calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=80)
198
+
199
+ def in_between_move_hook():
200
+ nonlocal arm, calib
201
+ time.sleep(2)
202
+ ef_pos = arm.read("Present_Position", "elbow_flex")
203
+ sl_pos = arm.read("Present_Position", "shoulder_lift")
204
+ arm.write("Goal_Position", ef_pos + 1024, "elbow_flex")
205
+ arm.write("Goal_Position", sl_pos - 1024, "shoulder_lift")
206
+ time.sleep(2)
207
+
208
+ print("Calibrate elbow_flex")
209
+ calib["elbow_flex"] = move_to_calibrate(
210
+ arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook
211
+ )
212
+ calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
213
+
214
+ arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
215
+ time.sleep(1)
216
+
217
+ def in_between_move_hook():
218
+ nonlocal arm, calib
219
+ arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"], "elbow_flex")
220
+
221
+ print("Calibrate shoulder_lift")
222
+ calib["shoulder_lift"] = move_to_calibrate(
223
+ arm,
224
+ "shoulder_lift",
225
+ invert_drive_mode=True,
226
+ positive_first=False,
227
+ in_between_move_hook=in_between_move_hook,
228
+ )
229
+ # add an 30 steps as offset to align with body
230
+ calib["shoulder_lift"] = apply_offset(calib["shoulder_lift"], offset=1024 - 50)
231
+
232
+ def while_move_hook():
233
+ nonlocal arm, calib
234
+ positions = {
235
+ "shoulder_lift": round(calib["shoulder_lift"]["zero_pos"] - 1600),
236
+ "elbow_flex": round(calib["elbow_flex"]["zero_pos"] + 1700),
237
+ "wrist_flex": round(calib["wrist_flex"]["zero_pos"] + 800),
238
+ "gripper": round(calib["gripper"]["end_pos"]),
239
+ }
240
+ arm.write("Goal_Position", list(positions.values()), list(positions.keys()))
241
+
242
+ arm.write("Goal_Position", round(calib["shoulder_lift"]["zero_pos"] - 1600), "shoulder_lift")
243
+ time.sleep(2)
244
+ arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex")
245
+ time.sleep(2)
246
+ arm.write("Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex")
247
+ time.sleep(2)
248
+ arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper")
249
+ time.sleep(2)
250
+
251
+ print("Calibrate wrist_roll")
252
+ calib["wrist_roll"] = move_to_calibrate(
253
+ arm, "wrist_roll", invert_drive_mode=True, positive_first=False, while_move_hook=while_move_hook
254
+ )
255
+
256
+ arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll")
257
+ time.sleep(1)
258
+ arm.write("Goal_Position", calib["gripper"]["start_pos"], "gripper")
259
+ time.sleep(1)
260
+ arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex")
261
+ time.sleep(1)
262
+ arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex")
263
+ arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift")
264
+ time.sleep(1)
265
+ arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
266
+ time.sleep(1)
267
+
268
+ calib_modes = []
269
+ for name in arm.motor_names:
270
+ if name == "gripper":
271
+ calib_modes.append(CalibrationMode.LINEAR.name)
272
+ else:
273
+ calib_modes.append(CalibrationMode.DEGREE.name)
274
+
275
+ calib_dict = {
276
+ "homing_offset": [calib[name]["homing_offset"] for name in arm.motor_names],
277
+ "drive_mode": [calib[name]["drive_mode"] for name in arm.motor_names],
278
+ "start_pos": [calib[name]["start_pos"] for name in arm.motor_names],
279
+ "end_pos": [calib[name]["end_pos"] for name in arm.motor_names],
280
+ "calib_mode": calib_modes,
281
+ "motor_names": arm.motor_names,
282
+ }
283
+
284
+ # Re-enable original accerlation
285
+ arm.write("Lock", 0)
286
+ arm.write("Acceleration", initial_acceleration)
287
+ time.sleep(1)
288
+
289
+ return calib_dict
290
+
291
+
292
+ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
293
+ """All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
294
+ if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
295
+ raise ValueError("To run calibration, the torque must be disabled on all motors.")
296
+
297
+ if not (robot_type == "moss" and arm_type == "follower"):
298
+ raise NotImplementedError("Auto calibration only supports the follower of moss arms for now.")
299
+
300
+ print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
301
+
302
+ print("\nMove arm to initial position")
303
+ print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
304
+ input("Press Enter to continue...")
305
+
306
+ # Lower the acceleration of the motors (in [0,254])
307
+ initial_acceleration = arm.read("Acceleration")
308
+ arm.write("Lock", 0)
309
+ arm.write("Acceleration", 10)
310
+ time.sleep(1)
311
+
312
+ arm.write("Torque_Enable", TorqueMode.ENABLED.value)
313
+
314
+ sl_pos = arm.read("Present_Position", "shoulder_lift")
315
+ arm.write("Goal_Position", sl_pos - 1024 - 450, "shoulder_lift")
316
+ ef_pos = arm.read("Present_Position", "elbow_flex")
317
+ arm.write("Goal_Position", ef_pos + 1024 + 450, "elbow_flex")
318
+ time.sleep(2)
319
+
320
+ calib = {}
321
+
322
+ print("Calibrate shoulder_pan")
323
+ calib["shoulder_pan"] = move_to_calibrate(arm, "shoulder_pan")
324
+ arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
325
+ time.sleep(1)
326
+
327
+ print("Calibrate gripper")
328
+ calib["gripper"] = move_to_calibrate(arm, "gripper", invert_drive_mode=True)
329
+ time.sleep(1)
330
+
331
+ print("Calibrate wrist_flex")
332
+ calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex", invert_drive_mode=True)
333
+ calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=-210 + 1024)
334
+
335
+ wr_pos = arm.read("Present_Position", "wrist_roll")
336
+ arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
337
+ time.sleep(1)
338
+ arm.write("Goal_Position", wr_pos - 1024, "wrist_roll")
339
+ time.sleep(1)
340
+ arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 2048, "wrist_flex")
341
+ time.sleep(1)
342
+ arm.write("Goal_Position", calib["gripper"]["end_pos"], "gripper")
343
+ time.sleep(1)
344
+
345
+ print("Calibrate wrist_roll")
346
+ calib["wrist_roll"] = move_to_calibrate(arm, "wrist_roll", invert_drive_mode=True)
347
+ calib["wrist_roll"] = apply_offset(calib["wrist_roll"], offset=790)
348
+
349
+ arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"] - 1024, "wrist_roll")
350
+ arm.write("Goal_Position", calib["gripper"]["start_pos"], "gripper")
351
+ arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
352
+ time.sleep(1)
353
+ arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll")
354
+ arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 2048, "wrist_flex")
355
+
356
+ def in_between_move_elbow_flex_hook():
357
+ nonlocal arm, calib
358
+ arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex")
359
+
360
+ print("Calibrate elbow_flex")
361
+ calib["elbow_flex"] = move_to_calibrate(
362
+ arm,
363
+ "elbow_flex",
364
+ invert_drive_mode=True,
365
+ in_between_move_hook=in_between_move_elbow_flex_hook,
366
+ )
367
+ arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
368
+
369
+ def in_between_move_shoulder_lift_hook():
370
+ nonlocal arm, calib
371
+ sl = arm.read("Present_Position", "shoulder_lift")
372
+ arm.write("Goal_Position", sl - 1500, "shoulder_lift")
373
+ time.sleep(1)
374
+ arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1536, "elbow_flex")
375
+ time.sleep(1)
376
+ arm.write("Goal_Position", calib["wrist_flex"]["start_pos"], "wrist_flex")
377
+ time.sleep(1)
378
+
379
+ print("Calibrate shoulder_lift")
380
+ calib["shoulder_lift"] = move_to_calibrate(
381
+ arm, "shoulder_lift", in_between_move_hook=in_between_move_shoulder_lift_hook
382
+ )
383
+ calib["shoulder_lift"] = apply_offset(calib["shoulder_lift"], offset=-1024)
384
+
385
+ arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
386
+ time.sleep(1)
387
+ arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift")
388
+ arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex")
389
+ time.sleep(2)
390
+
391
+ calib_modes = []
392
+ for name in arm.motor_names:
393
+ if name == "gripper":
394
+ calib_modes.append(CalibrationMode.LINEAR.name)
395
+ else:
396
+ calib_modes.append(CalibrationMode.DEGREE.name)
397
+
398
+ calib_dict = {
399
+ "homing_offset": [calib[name]["homing_offset"] for name in arm.motor_names],
400
+ "drive_mode": [calib[name]["drive_mode"] for name in arm.motor_names],
401
+ "start_pos": [calib[name]["start_pos"] for name in arm.motor_names],
402
+ "end_pos": [calib[name]["end_pos"] for name in arm.motor_names],
403
+ "calib_mode": calib_modes,
404
+ "motor_names": arm.motor_names,
405
+ }
406
+
407
+ # Re-enable original accerlation
408
+ arm.write("Lock", 0)
409
+ arm.write("Acceleration", initial_acceleration)
410
+ time.sleep(1)
411
+
412
+ return calib_dict
413
+
414
+
415
+ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
416
+ """This function ensures that a neural network trained on data collected on a given robot
417
+ can work on another robot. For instance before calibration, setting a same goal position
418
+ for each motor of two different robots will get two very different positions. But after calibration,
419
+ the two robots will move to the same position.To this end, this function computes the homing offset
420
+ and the drive mode for each motor of a given robot.
421
+
422
+ Homing offset is used to shift the motor position to a ]-2048, +2048[ nominal range (when the motor uses 2048 steps
423
+ to complete a half a turn). This range is set around an arbitrary "zero position" corresponding to all motor positions
424
+ being 0. During the calibration process, you will need to manually move the robot to this "zero position".
425
+
426
+ Drive mode is used to invert the rotation direction of the motor. This is useful when some motors have been assembled
427
+ in the opposite orientation for some robots. During the calibration process, you will need to manually move the robot
428
+ to the "rotated position".
429
+
430
+ After calibration, the homing offsets and drive modes are stored in a cache.
431
+
432
+ Example of usage:
433
+ ```python
434
+ run_arm_calibration(arm, "so100", "left", "follower")
435
+ ```
436
+ """
437
+ if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
438
+ raise ValueError("To run calibration, the torque must be disabled on all motors.")
439
+
440
+ print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
441
+
442
+ print("\nMove arm to zero position")
443
+ print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
444
+ input("Press Enter to continue...")
445
+
446
+ # We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
447
+ # It is easy to identify and all motors are in a "quarter turn" position. Once calibration is done, this position will
448
+ # correspond to every motor angle being 0. If you set all 0 as Goal Position, the arm will move in this position.
449
+ zero_target_pos = convert_degrees_to_steps(ZERO_POSITION_DEGREE, arm.motor_models)
450
+
451
+ # Compute homing offset so that `present_position + homing_offset ~= target_position`.
452
+ zero_pos = arm.read("Present_Position")
453
+ homing_offset = zero_target_pos - zero_pos
454
+
455
+ # The rotated target position corresponds to a rotation of a quarter turn from the zero position.
456
+ # This allows to identify the rotation direction of each motor.
457
+ # For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
458
+ # is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
459
+ # Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
460
+ # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
461
+ # of the previous motor in the kinetic chain.
462
+ print("\nMove arm to rotated target position")
463
+ print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
464
+ input("Press Enter to continue...")
465
+
466
+ rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
467
+
468
+ # Find drive mode by rotating each motor by a quarter of a turn.
469
+ # Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
470
+ rotated_pos = arm.read("Present_Position")
471
+ drive_mode = (rotated_pos < zero_pos).astype(np.int32)
472
+
473
+ # Re-compute homing offset to take into account drive mode
474
+ rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode)
475
+ homing_offset = rotated_target_pos - rotated_drived_pos
476
+
477
+ print("\nMove arm to rest position")
478
+ print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
479
+ input("Press Enter to continue...")
480
+ print()
481
+
482
+ # Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
483
+ calib_modes = []
484
+ for name in arm.motor_names:
485
+ if name == "gripper":
486
+ calib_modes.append(CalibrationMode.LINEAR.name)
487
+ else:
488
+ calib_modes.append(CalibrationMode.DEGREE.name)
489
+
490
+ calib_dict = {
491
+ "homing_offset": homing_offset.tolist(),
492
+ "drive_mode": drive_mode.tolist(),
493
+ "start_pos": zero_pos.tolist(),
494
+ "end_pos": rotated_pos.tolist(),
495
+ "calib_mode": calib_modes,
496
+ "motor_names": arm.motor_names,
497
+ }
498
+ return calib_dict
lerobot/common/robot_devices/robots/lekiwi_remote.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import base64
16
+ import json
17
+ import threading
18
+ import time
19
+ from pathlib import Path
20
+
21
+ import cv2
22
+ import zmq
23
+
24
+ from lerobot.common.robot_devices.robots.mobile_manipulator import LeKiwi
25
+
26
+
27
+ def setup_zmq_sockets(config):
28
+ context = zmq.Context()
29
+ cmd_socket = context.socket(zmq.PULL)
30
+ cmd_socket.setsockopt(zmq.CONFLATE, 1)
31
+ cmd_socket.bind(f"tcp://*:{config.port}")
32
+
33
+ video_socket = context.socket(zmq.PUSH)
34
+ video_socket.setsockopt(zmq.CONFLATE, 1)
35
+ video_socket.bind(f"tcp://*:{config.video_port}")
36
+
37
+ return context, cmd_socket, video_socket
38
+
39
+
40
+ def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event):
41
+ while not stop_event.is_set():
42
+ local_dict = {}
43
+ for name, cam in cameras.items():
44
+ frame = cam.async_read()
45
+ ret, buffer = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
46
+ if ret:
47
+ local_dict[name] = base64.b64encode(buffer).decode("utf-8")
48
+ else:
49
+ local_dict[name] = ""
50
+ with images_lock:
51
+ latest_images_dict.update(local_dict)
52
+ time.sleep(0.01)
53
+
54
+
55
+ def calibrate_follower_arm(motors_bus, calib_dir_str):
56
+ """
57
+ Calibrates the follower arm. Attempts to load an existing calibration file;
58
+ if not found, runs manual calibration and saves the result.
59
+ """
60
+ calib_dir = Path(calib_dir_str)
61
+ calib_dir.mkdir(parents=True, exist_ok=True)
62
+ calib_file = calib_dir / "main_follower.json"
63
+ try:
64
+ from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration
65
+ except ImportError:
66
+ print("[WARNING] Calibration function not available. Skipping calibration.")
67
+ return
68
+
69
+ if calib_file.exists():
70
+ with open(calib_file) as f:
71
+ calibration = json.load(f)
72
+ print(f"[INFO] Loaded calibration from {calib_file}")
73
+ else:
74
+ print("[INFO] Calibration file not found. Running manual calibration...")
75
+ calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
76
+ print(f"[INFO] Calibration complete. Saving to {calib_file}")
77
+ with open(calib_file, "w") as f:
78
+ json.dump(calibration, f)
79
+ try:
80
+ motors_bus.set_calibration(calibration)
81
+ print("[INFO] Applied calibration for follower arm.")
82
+ except Exception as e:
83
+ print(f"[WARNING] Could not apply calibration: {e}")
84
+
85
+
86
+ def run_lekiwi(robot_config):
87
+ """
88
+ Runs the LeKiwi robot:
89
+ - Sets up cameras and connects them.
90
+ - Initializes the follower arm motors.
91
+ - Calibrates the follower arm if necessary.
92
+ - Creates ZeroMQ sockets for receiving commands and streaming observations.
93
+ - Processes incoming commands (arm and wheel commands) and sends back sensor and camera data.
94
+ """
95
+ # Import helper functions and classes
96
+ from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
97
+ from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus, TorqueMode
98
+
99
+ # Initialize cameras from the robot configuration.
100
+ cameras = make_cameras_from_configs(robot_config.cameras)
101
+ for cam in cameras.values():
102
+ cam.connect()
103
+
104
+ # Initialize the motors bus using the follower arm configuration.
105
+ motor_config = robot_config.follower_arms.get("main")
106
+ if motor_config is None:
107
+ print("[ERROR] Follower arm 'main' configuration not found.")
108
+ return
109
+ motors_bus = FeetechMotorsBus(motor_config)
110
+ motors_bus.connect()
111
+
112
+ # Calibrate the follower arm.
113
+ calibrate_follower_arm(motors_bus, robot_config.calibration_dir)
114
+
115
+ # Create the LeKiwi robot instance.
116
+ robot = LeKiwi(motors_bus)
117
+
118
+ # Define the expected arm motor IDs.
119
+ arm_motor_ids = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"]
120
+
121
+ # Disable torque for each arm motor.
122
+ for motor in arm_motor_ids:
123
+ motors_bus.write("Torque_Enable", TorqueMode.DISABLED.value, motor)
124
+
125
+ # Set up ZeroMQ sockets.
126
+ context, cmd_socket, video_socket = setup_zmq_sockets(robot_config)
127
+
128
+ # Start the camera capture thread.
129
+ latest_images_dict = {}
130
+ images_lock = threading.Lock()
131
+ stop_event = threading.Event()
132
+ cam_thread = threading.Thread(
133
+ target=run_camera_capture, args=(cameras, images_lock, latest_images_dict, stop_event), daemon=True
134
+ )
135
+ cam_thread.start()
136
+
137
+ last_cmd_time = time.time()
138
+ print("LeKiwi robot server started. Waiting for commands...")
139
+
140
+ try:
141
+ while True:
142
+ loop_start_time = time.time()
143
+
144
+ # Process incoming commands (non-blocking).
145
+ while True:
146
+ try:
147
+ msg = cmd_socket.recv_string(zmq.NOBLOCK)
148
+ except zmq.Again:
149
+ break
150
+ try:
151
+ data = json.loads(msg)
152
+ # Process arm position commands.
153
+ if "arm_positions" in data:
154
+ arm_positions = data["arm_positions"]
155
+ if not isinstance(arm_positions, list):
156
+ print(f"[ERROR] Invalid arm_positions: {arm_positions}")
157
+ elif len(arm_positions) < len(arm_motor_ids):
158
+ print(
159
+ f"[WARNING] Received {len(arm_positions)} arm positions, expected {len(arm_motor_ids)}"
160
+ )
161
+ else:
162
+ for motor, pos in zip(arm_motor_ids, arm_positions, strict=False):
163
+ motors_bus.write("Goal_Position", pos, motor)
164
+ # Process wheel (base) commands.
165
+ if "raw_velocity" in data:
166
+ raw_command = data["raw_velocity"]
167
+ # Expect keys: "left_wheel", "back_wheel", "right_wheel".
168
+ command_speeds = [
169
+ int(raw_command.get("left_wheel", 0)),
170
+ int(raw_command.get("back_wheel", 0)),
171
+ int(raw_command.get("right_wheel", 0)),
172
+ ]
173
+ robot.set_velocity(command_speeds)
174
+ last_cmd_time = time.time()
175
+ except Exception as e:
176
+ print(f"[ERROR] Parsing message failed: {e}")
177
+
178
+ # Watchdog: stop the robot if no command is received for over 0.5 seconds.
179
+ now = time.time()
180
+ if now - last_cmd_time > 0.5:
181
+ robot.stop()
182
+ last_cmd_time = now
183
+
184
+ # Read current wheel speeds from the robot.
185
+ current_velocity = robot.read_velocity()
186
+
187
+ # Read the follower arm state from the motors bus.
188
+ follower_arm_state = []
189
+ for motor in arm_motor_ids:
190
+ try:
191
+ pos = motors_bus.read("Present_Position", motor)
192
+ # Convert the position to a float (or use as is if already numeric).
193
+ follower_arm_state.append(float(pos) if not isinstance(pos, (int, float)) else pos)
194
+ except Exception as e:
195
+ print(f"[ERROR] Reading motor {motor} failed: {e}")
196
+
197
+ # Get the latest camera images.
198
+ with images_lock:
199
+ images_dict_copy = dict(latest_images_dict)
200
+
201
+ # Build the observation dictionary.
202
+ observation = {
203
+ "images": images_dict_copy,
204
+ "present_speed": current_velocity,
205
+ "follower_arm_state": follower_arm_state,
206
+ }
207
+ # Send the observation over the video socket.
208
+ video_socket.send_string(json.dumps(observation))
209
+
210
+ # Ensure a short sleep to avoid overloading the CPU.
211
+ elapsed = time.time() - loop_start_time
212
+ time.sleep(
213
+ max(0.033 - elapsed, 0)
214
+ ) # If robot jitters increase the sleep and monitor cpu load with `top` in cmd
215
+ except KeyboardInterrupt:
216
+ print("Shutting down LeKiwi server.")
217
+ finally:
218
+ stop_event.set()
219
+ cam_thread.join()
220
+ robot.stop()
221
+ motors_bus.disconnect()
222
+ cmd_socket.close()
223
+ video_socket.close()
224
+ context.term()
lerobot/common/robot_devices/robots/manipulator.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Contains logic to instantiate a robot, read information from its motors and cameras,
16
+ and send orders to its motors.
17
+ """
18
+ # TODO(rcadene, aliberts): reorganize the codebase into one file per robot, with the associated
19
+ # calibration procedure, to make it easy for people to add their own robot.
20
+
21
+ import json
22
+ import logging
23
+ import time
24
+ import warnings
25
+ from pathlib import Path
26
+
27
+ import numpy as np
28
+ import torch
29
+
30
+ from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
31
+ from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
32
+ from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
33
+ from lerobot.common.robot_devices.robots.utils import get_arm_id
34
+ from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
35
+
36
+
37
+ def ensure_safe_goal_position(
38
+ goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float]
39
+ ):
40
+ # Cap relative action target magnitude for safety.
41
+ diff = goal_pos - present_pos
42
+ max_relative_target = torch.tensor(max_relative_target)
43
+ safe_diff = torch.minimum(diff, max_relative_target)
44
+ safe_diff = torch.maximum(safe_diff, -max_relative_target)
45
+ safe_goal_pos = present_pos + safe_diff
46
+
47
+ if not torch.allclose(goal_pos, safe_goal_pos):
48
+ logging.warning(
49
+ "Relative goal position magnitude had to be clamped to be safe.\n"
50
+ f" requested relative goal position target: {diff}\n"
51
+ f" clamped relative goal position target: {safe_diff}"
52
+ )
53
+
54
+ return safe_goal_pos
55
+
56
+
57
+ class ManipulatorRobot:
58
+ # TODO(rcadene): Implement force feedback
59
+ """This class allows to control any manipulator robot of various number of motors.
60
+
61
+ Non exhaustive list of robots:
62
+ - [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow expansion, developed
63
+ by Alexander Koch from [Tau Robotics](https://tau-robotics.com)
64
+ - [Koch v1.1](https://github.com/jess-moss/koch-v1-1) developed by Jess Moss
65
+ - [Aloha](https://www.trossenrobotics.com/aloha-kits) developed by Trossen Robotics
66
+
67
+ Example of instantiation, a pre-defined robot config is required:
68
+ ```python
69
+ robot = ManipulatorRobot(KochRobotConfig())
70
+ ```
71
+
72
+ Example of overwriting motors during instantiation:
73
+ ```python
74
+ # Defines how to communicate with the motors of the leader and follower arms
75
+ leader_arms = {
76
+ "main": DynamixelMotorsBusConfig(
77
+ port="/dev/tty.usbmodem575E0031751",
78
+ motors={
79
+ # name: (index, model)
80
+ "shoulder_pan": (1, "xl330-m077"),
81
+ "shoulder_lift": (2, "xl330-m077"),
82
+ "elbow_flex": (3, "xl330-m077"),
83
+ "wrist_flex": (4, "xl330-m077"),
84
+ "wrist_roll": (5, "xl330-m077"),
85
+ "gripper": (6, "xl330-m077"),
86
+ },
87
+ ),
88
+ }
89
+ follower_arms = {
90
+ "main": DynamixelMotorsBusConfig(
91
+ port="/dev/tty.usbmodem575E0032081",
92
+ motors={
93
+ # name: (index, model)
94
+ "shoulder_pan": (1, "xl430-w250"),
95
+ "shoulder_lift": (2, "xl430-w250"),
96
+ "elbow_flex": (3, "xl330-m288"),
97
+ "wrist_flex": (4, "xl330-m288"),
98
+ "wrist_roll": (5, "xl330-m288"),
99
+ "gripper": (6, "xl330-m288"),
100
+ },
101
+ ),
102
+ }
103
+ robot_config = KochRobotConfig(leader_arms=leader_arms, follower_arms=follower_arms)
104
+ robot = ManipulatorRobot(robot_config)
105
+ ```
106
+
107
+ Example of overwriting cameras during instantiation:
108
+ ```python
109
+ # Defines how to communicate with 2 cameras connected to the computer.
110
+ # Here, the webcam of the laptop and the phone (connected in USB to the laptop)
111
+ # can be reached respectively using the camera indices 0 and 1. These indices can be
112
+ # arbitrary. See the documentation of `OpenCVCamera` to find your own camera indices.
113
+ cameras = {
114
+ "laptop": OpenCVCamera(camera_index=0, fps=30, width=640, height=480),
115
+ "phone": OpenCVCamera(camera_index=1, fps=30, width=640, height=480),
116
+ }
117
+ robot = ManipulatorRobot(KochRobotConfig(cameras=cameras))
118
+ ```
119
+
120
+ Once the robot is instantiated, connect motors buses and cameras if any (Required):
121
+ ```python
122
+ robot.connect()
123
+ ```
124
+
125
+ Example of highest frequency teleoperation, which doesn't require cameras:
126
+ ```python
127
+ while True:
128
+ robot.teleop_step()
129
+ ```
130
+
131
+ Example of highest frequency data collection from motors and cameras (if any):
132
+ ```python
133
+ while True:
134
+ observation, action = robot.teleop_step(record_data=True)
135
+ ```
136
+
137
+ Example of controlling the robot with a policy:
138
+ ```python
139
+ while True:
140
+ # Uses the follower arms and cameras to capture an observation
141
+ observation = robot.capture_observation()
142
+
143
+ # Assumes a policy has been instantiated
144
+ with torch.inference_mode():
145
+ action = policy.select_action(observation)
146
+
147
+ # Orders the robot to move
148
+ robot.send_action(action)
149
+ ```
150
+
151
+ Example of disconnecting which is not mandatory since we disconnect when the object is deleted:
152
+ ```python
153
+ robot.disconnect()
154
+ ```
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ config: ManipulatorRobotConfig,
160
+ ):
161
+ self.config = config
162
+ self.robot_type = self.config.type
163
+ self.calibration_dir = Path(self.config.calibration_dir)
164
+ self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
165
+ self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
166
+ self.cameras = make_cameras_from_configs(self.config.cameras)
167
+ self.is_connected = False
168
+ self.logs = {}
169
+
170
+ def get_motor_names(self, arm: dict[str, MotorsBus]) -> list:
171
+ return [f"{arm}_{motor}" for arm, bus in arm.items() for motor in bus.motors]
172
+
173
+ @property
174
+ def camera_features(self) -> dict:
175
+ cam_ft = {}
176
+ for cam_key, cam in self.cameras.items():
177
+ key = f"observation.images.{cam_key}"
178
+ cam_ft[key] = {
179
+ "shape": (cam.height, cam.width, cam.channels),
180
+ "names": ["height", "width", "channels"],
181
+ "info": None,
182
+ }
183
+ return cam_ft
184
+
185
+ @property
186
+ def motor_features(self) -> dict:
187
+ action_names = self.get_motor_names(self.leader_arms)
188
+ state_names = self.get_motor_names(self.leader_arms)
189
+ return {
190
+ "action": {
191
+ "dtype": "float32",
192
+ "shape": (len(action_names),),
193
+ "names": action_names,
194
+ },
195
+ "observation.state": {
196
+ "dtype": "float32",
197
+ "shape": (len(state_names),),
198
+ "names": state_names,
199
+ },
200
+ }
201
+
202
+ @property
203
+ def features(self):
204
+ return {**self.motor_features, **self.camera_features}
205
+
206
+ @property
207
+ def has_camera(self):
208
+ return len(self.cameras) > 0
209
+
210
+ @property
211
+ def num_cameras(self):
212
+ return len(self.cameras)
213
+
214
+ @property
215
+ def available_arms(self):
216
+ available_arms = []
217
+ for name in self.follower_arms:
218
+ arm_id = get_arm_id(name, "follower")
219
+ available_arms.append(arm_id)
220
+ for name in self.leader_arms:
221
+ arm_id = get_arm_id(name, "leader")
222
+ available_arms.append(arm_id)
223
+ return available_arms
224
+
225
+ def connect(self):
226
+ if self.is_connected:
227
+ raise RobotDeviceAlreadyConnectedError(
228
+ "ManipulatorRobot is already connected. Do not run `robot.connect()` twice."
229
+ )
230
+
231
+ if not self.leader_arms and not self.follower_arms and not self.cameras:
232
+ raise ValueError(
233
+ "ManipulatorRobot doesn't have any device to connect. See example of usage in docstring of the class."
234
+ )
235
+
236
+ # Connect the arms
237
+ for name in self.follower_arms:
238
+ print(f"Connecting {name} follower arm.")
239
+ self.follower_arms[name].connect()
240
+ for name in self.leader_arms:
241
+ print(f"Connecting {name} leader arm.")
242
+ self.leader_arms[name].connect()
243
+
244
+ if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
245
+ from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
246
+ elif self.robot_type in ["so100", "moss", "lekiwi"]:
247
+ from lerobot.common.robot_devices.motors.feetech import TorqueMode
248
+
249
+ # We assume that at connection time, arms are in a rest position, and torque can
250
+ # be safely disabled to run calibration and/or set robot preset configurations.
251
+ for name in self.follower_arms:
252
+ self.follower_arms[name].write("Torque_Enable", TorqueMode.DISABLED.value)
253
+ for name in self.leader_arms:
254
+ self.leader_arms[name].write("Torque_Enable", TorqueMode.DISABLED.value)
255
+
256
+ self.activate_calibration()
257
+
258
+ # Set robot preset (e.g. torque in leader gripper for Koch v1.1)
259
+ if self.robot_type in ["koch", "koch_bimanual"]:
260
+ self.set_koch_robot_preset()
261
+ elif self.robot_type == "aloha":
262
+ self.set_aloha_robot_preset()
263
+ elif self.robot_type in ["so100", "moss", "lekiwi"]:
264
+ self.set_so100_robot_preset()
265
+
266
+ # Enable torque on all motors of the follower arms
267
+ for name in self.follower_arms:
268
+ print(f"Activating torque on {name} follower arm.")
269
+ self.follower_arms[name].write("Torque_Enable", 1)
270
+
271
+ if self.config.gripper_open_degree is not None:
272
+ if self.robot_type not in ["koch", "koch_bimanual"]:
273
+ raise NotImplementedError(
274
+ f"{self.robot_type} does not support position AND current control in the handle, which is require to set the gripper open."
275
+ )
276
+ # Set the leader arm in torque mode with the gripper motor set to an angle. This makes it possible
277
+ # to squeeze the gripper and have it spring back to an open position on its own.
278
+ for name in self.leader_arms:
279
+ self.leader_arms[name].write("Torque_Enable", 1, "gripper")
280
+ self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
281
+
282
+ # Check both arms can be read
283
+ for name in self.follower_arms:
284
+ self.follower_arms[name].read("Present_Position")
285
+ for name in self.leader_arms:
286
+ self.leader_arms[name].read("Present_Position")
287
+
288
+ # Connect the cameras
289
+ for name in self.cameras:
290
+ self.cameras[name].connect()
291
+
292
+ self.is_connected = True
293
+
294
+ def activate_calibration(self):
295
+ """After calibration all motors function in human interpretable ranges.
296
+ Rotations are expressed in degrees in nominal range of [-180, 180],
297
+ and linear motions (like gripper of Aloha) in nominal range of [0, 100].
298
+ """
299
+
300
+ def load_or_run_calibration_(name, arm, arm_type):
301
+ arm_id = get_arm_id(name, arm_type)
302
+ arm_calib_path = self.calibration_dir / f"{arm_id}.json"
303
+
304
+ if arm_calib_path.exists():
305
+ with open(arm_calib_path) as f:
306
+ calibration = json.load(f)
307
+ else:
308
+ # TODO(rcadene): display a warning in __init__ if calibration file not available
309
+ print(f"Missing calibration file '{arm_calib_path}'")
310
+
311
+ if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
312
+ from lerobot.common.robot_devices.robots.dynamixel_calibration import run_arm_calibration
313
+
314
+ calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
315
+
316
+ elif self.robot_type in ["so100", "moss", "lekiwi"]:
317
+ from lerobot.common.robot_devices.robots.feetech_calibration import (
318
+ run_arm_manual_calibration,
319
+ )
320
+
321
+ calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
322
+
323
+ print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
324
+ arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
325
+ with open(arm_calib_path, "w") as f:
326
+ json.dump(calibration, f)
327
+
328
+ return calibration
329
+
330
+ for name, arm in self.follower_arms.items():
331
+ calibration = load_or_run_calibration_(name, arm, "follower")
332
+ arm.set_calibration(calibration)
333
+ for name, arm in self.leader_arms.items():
334
+ calibration = load_or_run_calibration_(name, arm, "leader")
335
+ arm.set_calibration(calibration)
336
+
337
+ def set_koch_robot_preset(self):
338
+ def set_operating_mode_(arm):
339
+ from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
340
+
341
+ if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
342
+ raise ValueError("To run set robot preset, the torque must be disabled on all motors.")
343
+
344
+ # Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
345
+ # rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
346
+ # you could end up with a servo with a position 0 or 4095 at a crucial point See [
347
+ # https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
348
+ all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"]
349
+ if len(all_motors_except_gripper) > 0:
350
+ # 4 corresponds to Extended Position on Koch motors
351
+ arm.write("Operating_Mode", 4, all_motors_except_gripper)
352
+
353
+ # Use 'position control current based' for gripper to be limited by the limit of the current.
354
+ # For the follower gripper, it means it can grasp an object without forcing too much even tho,
355
+ # it's goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
356
+ # For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger
357
+ # to make it move, and it will move back to its original target position when we release the force.
358
+ # 5 corresponds to Current Controlled Position on Koch gripper motors "xl330-m077, xl330-m288"
359
+ arm.write("Operating_Mode", 5, "gripper")
360
+
361
+ for name in self.follower_arms:
362
+ set_operating_mode_(self.follower_arms[name])
363
+
364
+ # Set better PID values to close the gap between recorded states and actions
365
+ # TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor
366
+ self.follower_arms[name].write("Position_P_Gain", 1500, "elbow_flex")
367
+ self.follower_arms[name].write("Position_I_Gain", 0, "elbow_flex")
368
+ self.follower_arms[name].write("Position_D_Gain", 600, "elbow_flex")
369
+
370
+ if self.config.gripper_open_degree is not None:
371
+ for name in self.leader_arms:
372
+ set_operating_mode_(self.leader_arms[name])
373
+
374
+ # Enable torque on the gripper of the leader arms, and move it to 45 degrees,
375
+ # so that we can use it as a trigger to close the gripper of the follower arms.
376
+ self.leader_arms[name].write("Torque_Enable", 1, "gripper")
377
+ self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
378
+
379
+ def set_aloha_robot_preset(self):
380
+ def set_shadow_(arm):
381
+ # Set secondary/shadow ID for shoulder and elbow. These joints have two motors.
382
+ # As a result, if only one of them is required to move to a certain position,
383
+ # the other will follow. This is to avoid breaking the motors.
384
+ if "shoulder_shadow" in arm.motor_names:
385
+ shoulder_idx = arm.read("ID", "shoulder")
386
+ arm.write("Secondary_ID", shoulder_idx, "shoulder_shadow")
387
+
388
+ if "elbow_shadow" in arm.motor_names:
389
+ elbow_idx = arm.read("ID", "elbow")
390
+ arm.write("Secondary_ID", elbow_idx, "elbow_shadow")
391
+
392
+ for name in self.follower_arms:
393
+ set_shadow_(self.follower_arms[name])
394
+
395
+ for name in self.leader_arms:
396
+ set_shadow_(self.leader_arms[name])
397
+
398
+ for name in self.follower_arms:
399
+ # Set a velocity limit of 131 as advised by Trossen Robotics
400
+ self.follower_arms[name].write("Velocity_Limit", 131)
401
+
402
+ # Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
403
+ # rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
404
+ # you could end up with a servo with a position 0 or 4095 at a crucial point See [
405
+ # https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
406
+ all_motors_except_gripper = [
407
+ name for name in self.follower_arms[name].motor_names if name != "gripper"
408
+ ]
409
+ if len(all_motors_except_gripper) > 0:
410
+ # 4 corresponds to Extended Position on Aloha motors
411
+ self.follower_arms[name].write("Operating_Mode", 4, all_motors_except_gripper)
412
+
413
+ # Use 'position control current based' for follower gripper to be limited by the limit of the current.
414
+ # It can grasp an object without forcing too much even tho,
415
+ # it's goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
416
+ # 5 corresponds to Current Controlled Position on Aloha gripper follower "xm430-w350"
417
+ self.follower_arms[name].write("Operating_Mode", 5, "gripper")
418
+
419
+ # Note: We can't enable torque on the leader gripper since "xc430-w150" doesn't have
420
+ # a Current Controlled Position mode.
421
+
422
+ if self.config.gripper_open_degree is not None:
423
+ warnings.warn(
424
+ f"`gripper_open_degree` is set to {self.config.gripper_open_degree}, but None is expected for Aloha instead",
425
+ stacklevel=1,
426
+ )
427
+
428
+ def set_so100_robot_preset(self):
429
+ for name in self.follower_arms:
430
+ # Mode=0 for Position Control
431
+ self.follower_arms[name].write("Mode", 0)
432
+ # Set P_Coefficient to lower value to avoid shakiness (Default is 32)
433
+ self.follower_arms[name].write("P_Coefficient", 16)
434
+ # Set I_Coefficient and D_Coefficient to default value 0 and 32
435
+ self.follower_arms[name].write("I_Coefficient", 0)
436
+ self.follower_arms[name].write("D_Coefficient", 32)
437
+ # Close the write lock so that Maximum_Acceleration gets written to EPROM address,
438
+ # which is mandatory for Maximum_Acceleration to take effect after rebooting.
439
+ self.follower_arms[name].write("Lock", 0)
440
+ # Set Maximum_Acceleration to 254 to speedup acceleration and deceleration of
441
+ # the motors. Note: this configuration is not in the official STS3215 Memory Table
442
+ self.follower_arms[name].write("Maximum_Acceleration", 254)
443
+ self.follower_arms[name].write("Acceleration", 254)
444
+
445
+ def teleop_step(
446
+ self, record_data=False
447
+ ) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
448
+ if not self.is_connected:
449
+ raise RobotDeviceNotConnectedError(
450
+ "ManipulatorRobot is not connected. You need to run `robot.connect()`."
451
+ )
452
+
453
+ # Prepare to assign the position of the leader to the follower
454
+ leader_pos = {}
455
+ for name in self.leader_arms:
456
+ before_lread_t = time.perf_counter()
457
+ leader_pos[name] = self.leader_arms[name].read("Present_Position")
458
+ leader_pos[name] = torch.from_numpy(leader_pos[name])
459
+ self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t
460
+
461
+ # Send goal position to the follower
462
+ follower_goal_pos = {}
463
+ for name in self.follower_arms:
464
+ before_fwrite_t = time.perf_counter()
465
+ goal_pos = leader_pos[name]
466
+
467
+ # Cap goal position when too far away from present position.
468
+ # Slower fps expected due to reading from the follower.
469
+ if self.config.max_relative_target is not None:
470
+ present_pos = self.follower_arms[name].read("Present_Position")
471
+ present_pos = torch.from_numpy(present_pos)
472
+ goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
473
+
474
+ # Used when record_data=True
475
+ follower_goal_pos[name] = goal_pos
476
+
477
+ goal_pos = goal_pos.numpy().astype(np.float32)
478
+ self.follower_arms[name].write("Goal_Position", goal_pos)
479
+ self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
480
+
481
+ # Early exit when recording data is not requested
482
+ if not record_data:
483
+ return
484
+
485
+ # TODO(rcadene): Add velocity and other info
486
+ # Read follower position
487
+ follower_pos = {}
488
+ for name in self.follower_arms:
489
+ before_fread_t = time.perf_counter()
490
+ follower_pos[name] = self.follower_arms[name].read("Present_Position")
491
+ follower_pos[name] = torch.from_numpy(follower_pos[name])
492
+ self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
493
+
494
+ # Create state by concatenating follower current position
495
+ state = []
496
+ for name in self.follower_arms:
497
+ if name in follower_pos:
498
+ state.append(follower_pos[name])
499
+ state = torch.cat(state)
500
+
501
+ # Create action by concatenating follower goal position
502
+ action = []
503
+ for name in self.follower_arms:
504
+ if name in follower_goal_pos:
505
+ action.append(follower_goal_pos[name])
506
+ action = torch.cat(action)
507
+
508
+ # Capture images from cameras
509
+ images = {}
510
+ for name in self.cameras:
511
+ before_camread_t = time.perf_counter()
512
+ images[name] = self.cameras[name].async_read()
513
+ images[name] = torch.from_numpy(images[name])
514
+ self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
515
+ self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
516
+
517
+ # Populate output dictionaries
518
+ obs_dict, action_dict = {}, {}
519
+ obs_dict["observation.state"] = state
520
+ action_dict["action"] = action
521
+ for name in self.cameras:
522
+ obs_dict[f"observation.images.{name}"] = images[name]
523
+
524
+ return obs_dict, action_dict
525
+
526
+ def capture_observation(self):
527
+ """The returned observations do not have a batch dimension."""
528
+ if not self.is_connected:
529
+ raise RobotDeviceNotConnectedError(
530
+ "ManipulatorRobot is not connected. You need to run `robot.connect()`."
531
+ )
532
+
533
+ # Read follower position
534
+ follower_pos = {}
535
+ for name in self.follower_arms:
536
+ before_fread_t = time.perf_counter()
537
+ follower_pos[name] = self.follower_arms[name].read("Present_Position")
538
+ follower_pos[name] = torch.from_numpy(follower_pos[name])
539
+ self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
540
+
541
+ # Create state by concatenating follower current position
542
+ state = []
543
+ for name in self.follower_arms:
544
+ if name in follower_pos:
545
+ state.append(follower_pos[name])
546
+ state = torch.cat(state)
547
+
548
+ # Capture images from cameras
549
+ images = {}
550
+ for name in self.cameras:
551
+ before_camread_t = time.perf_counter()
552
+ images[name] = self.cameras[name].async_read()
553
+ images[name] = torch.from_numpy(images[name])
554
+ self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
555
+ self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
556
+
557
+ # Populate output dictionaries and format to pytorch
558
+ obs_dict = {}
559
+ obs_dict["observation.state"] = state
560
+ for name in self.cameras:
561
+ obs_dict[f"observation.images.{name}"] = images[name]
562
+ return obs_dict
563
+
564
+ def send_action(self, action: torch.Tensor) -> torch.Tensor:
565
+ """Command the follower arms to move to a target joint configuration.
566
+
567
+ The relative action magnitude may be clipped depending on the configuration parameter
568
+ `max_relative_target`. In this case, the action sent differs from original action.
569
+ Thus, this function always returns the action actually sent.
570
+
571
+ Args:
572
+ action: tensor containing the concatenated goal positions for the follower arms.
573
+ """
574
+ if not self.is_connected:
575
+ raise RobotDeviceNotConnectedError(
576
+ "ManipulatorRobot is not connected. You need to run `robot.connect()`."
577
+ )
578
+
579
+ from_idx = 0
580
+ to_idx = 0
581
+ action_sent = []
582
+ for name in self.follower_arms:
583
+ # Get goal position of each follower arm by splitting the action vector
584
+ to_idx += len(self.follower_arms[name].motor_names)
585
+ goal_pos = action[from_idx:to_idx]
586
+ from_idx = to_idx
587
+
588
+ # Cap goal position when too far away from present position.
589
+ # Slower fps expected due to reading from the follower.
590
+ if self.config.max_relative_target is not None:
591
+ present_pos = self.follower_arms[name].read("Present_Position")
592
+ present_pos = torch.from_numpy(present_pos)
593
+ goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
594
+
595
+ # Save tensor to concat and return
596
+ action_sent.append(goal_pos)
597
+
598
+ # Send goal position to each follower
599
+ goal_pos = goal_pos.numpy().astype(np.float32)
600
+ self.follower_arms[name].write("Goal_Position", goal_pos)
601
+
602
+ return torch.cat(action_sent)
603
+
604
+ def print_logs(self):
605
+ pass
606
+ # TODO(aliberts): move robot-specific logs logic here
607
+
608
+ def disconnect(self):
609
+ if not self.is_connected:
610
+ raise RobotDeviceNotConnectedError(
611
+ "ManipulatorRobot is not connected. You need to run `robot.connect()` before disconnecting."
612
+ )
613
+
614
+ for name in self.follower_arms:
615
+ self.follower_arms[name].disconnect()
616
+
617
+ for name in self.leader_arms:
618
+ self.leader_arms[name].disconnect()
619
+
620
+ for name in self.cameras:
621
+ self.cameras[name].disconnect()
622
+
623
+ self.is_connected = False
624
+
625
+ def __del__(self):
626
+ if getattr(self, "is_connected", False):
627
+ self.disconnect()
lerobot/common/robot_devices/robots/mobile_manipulator.py ADDED
@@ -0,0 +1,703 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import base64
16
+ import json
17
+ import os
18
+ import sys
19
+ from pathlib import Path
20
+
21
+ import cv2
22
+ import numpy as np
23
+ import torch
24
+ import zmq
25
+
26
+ from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
27
+ from lerobot.common.robot_devices.motors.feetech import TorqueMode
28
+ from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
29
+ from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig
30
+ from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration
31
+ from lerobot.common.robot_devices.robots.utils import get_arm_id
32
+ from lerobot.common.robot_devices.utils import RobotDeviceNotConnectedError
33
+
34
+ PYNPUT_AVAILABLE = True
35
+ try:
36
+ # Only import if there's a valid X server or if we're not on a Pi
37
+ if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
38
+ print("No DISPLAY set. Skipping pynput import.")
39
+ raise ImportError("pynput blocked intentionally due to no display.")
40
+
41
+ from pynput import keyboard
42
+ except ImportError:
43
+ keyboard = None
44
+ PYNPUT_AVAILABLE = False
45
+ except Exception as e:
46
+ keyboard = None
47
+ PYNPUT_AVAILABLE = False
48
+ print(f"Could not import pynput: {e}")
49
+
50
+
51
+ class MobileManipulator:
52
+ """
53
+ MobileManipulator is a class for connecting to and controlling a remote mobile manipulator robot.
54
+ The robot includes a three omniwheel mobile base and a remote follower arm.
55
+ The leader arm is connected locally (on the laptop) and its joint positions are recorded and then
56
+ forwarded to the remote follower arm (after applying a safety clamp).
57
+ In parallel, keyboard teleoperation is used to generate raw velocity commands for the wheels.
58
+ """
59
+
60
+ def __init__(self, config: LeKiwiRobotConfig):
61
+ """
62
+ Expected keys in config:
63
+ - ip, port, video_port for the remote connection.
64
+ - calibration_dir, leader_arms, follower_arms, max_relative_target, etc.
65
+ """
66
+ self.robot_type = config.type
67
+ self.config = config
68
+ self.remote_ip = config.ip
69
+ self.remote_port = config.port
70
+ self.remote_port_video = config.video_port
71
+ self.calibration_dir = Path(self.config.calibration_dir)
72
+ self.logs = {}
73
+
74
+ self.teleop_keys = self.config.teleop_keys
75
+
76
+ # For teleoperation, the leader arm (local) is used to record the desired arm pose.
77
+ self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
78
+
79
+ self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
80
+
81
+ self.cameras = make_cameras_from_configs(self.config.cameras)
82
+
83
+ self.is_connected = False
84
+
85
+ self.last_frames = {}
86
+ self.last_present_speed = {}
87
+ self.last_remote_arm_state = torch.zeros(6, dtype=torch.float32)
88
+
89
+ # Define three speed levels and a current index
90
+ self.speed_levels = [
91
+ {"xy": 0.1, "theta": 30}, # slow
92
+ {"xy": 0.2, "theta": 60}, # medium
93
+ {"xy": 0.3, "theta": 90}, # fast
94
+ ]
95
+ self.speed_index = 0 # Start at slow
96
+
97
+ # ZeroMQ context and sockets.
98
+ self.context = None
99
+ self.cmd_socket = None
100
+ self.video_socket = None
101
+
102
+ # Keyboard state for base teleoperation.
103
+ self.running = True
104
+ self.pressed_keys = {
105
+ "forward": False,
106
+ "backward": False,
107
+ "left": False,
108
+ "right": False,
109
+ "rotate_left": False,
110
+ "rotate_right": False,
111
+ }
112
+
113
+ if PYNPUT_AVAILABLE:
114
+ print("pynput is available - enabling local keyboard listener.")
115
+ self.listener = keyboard.Listener(
116
+ on_press=self.on_press,
117
+ on_release=self.on_release,
118
+ )
119
+ self.listener.start()
120
+ else:
121
+ print("pynput not available - skipping local keyboard listener.")
122
+ self.listener = None
123
+
124
+ def get_motor_names(self, arms: dict[str, MotorsBus]) -> list:
125
+ return [f"{arm}_{motor}" for arm, bus in arms.items() for motor in bus.motors]
126
+
127
+ @property
128
+ def camera_features(self) -> dict:
129
+ cam_ft = {}
130
+ for cam_key, cam in self.cameras.items():
131
+ key = f"observation.images.{cam_key}"
132
+ cam_ft[key] = {
133
+ "shape": (cam.height, cam.width, cam.channels),
134
+ "names": ["height", "width", "channels"],
135
+ "info": None,
136
+ }
137
+ return cam_ft
138
+
139
+ @property
140
+ def motor_features(self) -> dict:
141
+ follower_arm_names = [
142
+ "shoulder_pan",
143
+ "shoulder_lift",
144
+ "elbow_flex",
145
+ "wrist_flex",
146
+ "wrist_roll",
147
+ "gripper",
148
+ ]
149
+ observations = ["x_mm", "y_mm", "theta"]
150
+ combined_names = follower_arm_names + observations
151
+ return {
152
+ "action": {
153
+ "dtype": "float32",
154
+ "shape": (len(combined_names),),
155
+ "names": combined_names,
156
+ },
157
+ "observation.state": {
158
+ "dtype": "float32",
159
+ "shape": (len(combined_names),),
160
+ "names": combined_names,
161
+ },
162
+ }
163
+
164
+ @property
165
+ def features(self):
166
+ return {**self.motor_features, **self.camera_features}
167
+
168
+ @property
169
+ def has_camera(self):
170
+ return len(self.cameras) > 0
171
+
172
+ @property
173
+ def num_cameras(self):
174
+ return len(self.cameras)
175
+
176
+ @property
177
+ def available_arms(self):
178
+ available = []
179
+ for name in self.leader_arms:
180
+ available.append(get_arm_id(name, "leader"))
181
+ for name in self.follower_arms:
182
+ available.append(get_arm_id(name, "follower"))
183
+ return available
184
+
185
+ def on_press(self, key):
186
+ try:
187
+ # Movement
188
+ if key.char == self.teleop_keys["forward"]:
189
+ self.pressed_keys["forward"] = True
190
+ elif key.char == self.teleop_keys["backward"]:
191
+ self.pressed_keys["backward"] = True
192
+ elif key.char == self.teleop_keys["left"]:
193
+ self.pressed_keys["left"] = True
194
+ elif key.char == self.teleop_keys["right"]:
195
+ self.pressed_keys["right"] = True
196
+ elif key.char == self.teleop_keys["rotate_left"]:
197
+ self.pressed_keys["rotate_left"] = True
198
+ elif key.char == self.teleop_keys["rotate_right"]:
199
+ self.pressed_keys["rotate_right"] = True
200
+
201
+ # Quit teleoperation
202
+ elif key.char == self.teleop_keys["quit"]:
203
+ self.running = False
204
+ return False
205
+
206
+ # Speed control
207
+ elif key.char == self.teleop_keys["speed_up"]:
208
+ self.speed_index = min(self.speed_index + 1, 2)
209
+ print(f"Speed index increased to {self.speed_index}")
210
+ elif key.char == self.teleop_keys["speed_down"]:
211
+ self.speed_index = max(self.speed_index - 1, 0)
212
+ print(f"Speed index decreased to {self.speed_index}")
213
+
214
+ except AttributeError:
215
+ # e.g., if key is special like Key.esc
216
+ if key == keyboard.Key.esc:
217
+ self.running = False
218
+ return False
219
+
220
+ def on_release(self, key):
221
+ try:
222
+ if hasattr(key, "char"):
223
+ if key.char == self.teleop_keys["forward"]:
224
+ self.pressed_keys["forward"] = False
225
+ elif key.char == self.teleop_keys["backward"]:
226
+ self.pressed_keys["backward"] = False
227
+ elif key.char == self.teleop_keys["left"]:
228
+ self.pressed_keys["left"] = False
229
+ elif key.char == self.teleop_keys["right"]:
230
+ self.pressed_keys["right"] = False
231
+ elif key.char == self.teleop_keys["rotate_left"]:
232
+ self.pressed_keys["rotate_left"] = False
233
+ elif key.char == self.teleop_keys["rotate_right"]:
234
+ self.pressed_keys["rotate_right"] = False
235
+ except AttributeError:
236
+ pass
237
+
238
+ def connect(self):
239
+ if not self.leader_arms:
240
+ raise ValueError("MobileManipulator has no leader arm to connect.")
241
+ for name in self.leader_arms:
242
+ print(f"Connecting {name} leader arm.")
243
+ self.calibrate_leader()
244
+
245
+ # Set up ZeroMQ sockets to communicate with the remote mobile robot.
246
+ self.context = zmq.Context()
247
+ self.cmd_socket = self.context.socket(zmq.PUSH)
248
+ connection_string = f"tcp://{self.remote_ip}:{self.remote_port}"
249
+ self.cmd_socket.connect(connection_string)
250
+ self.cmd_socket.setsockopt(zmq.CONFLATE, 1)
251
+ self.video_socket = self.context.socket(zmq.PULL)
252
+ video_connection = f"tcp://{self.remote_ip}:{self.remote_port_video}"
253
+ self.video_socket.connect(video_connection)
254
+ self.video_socket.setsockopt(zmq.CONFLATE, 1)
255
+ print(
256
+ f"[INFO] Connected to remote robot at {connection_string} and video stream at {video_connection}."
257
+ )
258
+ self.is_connected = True
259
+
260
+ def load_or_run_calibration_(self, name, arm, arm_type):
261
+ arm_id = get_arm_id(name, arm_type)
262
+ arm_calib_path = self.calibration_dir / f"{arm_id}.json"
263
+
264
+ if arm_calib_path.exists():
265
+ with open(arm_calib_path) as f:
266
+ calibration = json.load(f)
267
+ else:
268
+ print(f"Missing calibration file '{arm_calib_path}'")
269
+ calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
270
+ print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
271
+ arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
272
+ with open(arm_calib_path, "w") as f:
273
+ json.dump(calibration, f)
274
+
275
+ return calibration
276
+
277
+ def calibrate_leader(self):
278
+ for name, arm in self.leader_arms.items():
279
+ # Connect the bus
280
+ arm.connect()
281
+
282
+ # Disable torque on all motors
283
+ for motor_id in arm.motors:
284
+ arm.write("Torque_Enable", TorqueMode.DISABLED.value, motor_id)
285
+
286
+ # Now run calibration
287
+ calibration = self.load_or_run_calibration_(name, arm, "leader")
288
+ arm.set_calibration(calibration)
289
+
290
+ def calibrate_follower(self):
291
+ for name, bus in self.follower_arms.items():
292
+ bus.connect()
293
+
294
+ # Disable torque on all motors
295
+ for motor_id in bus.motors:
296
+ bus.write("Torque_Enable", 0, motor_id)
297
+
298
+ # Then filter out wheels
299
+ arm_only_dict = {k: v for k, v in bus.motors.items() if not k.startswith("wheel_")}
300
+ if not arm_only_dict:
301
+ continue
302
+
303
+ original_motors = bus.motors
304
+ bus.motors = arm_only_dict
305
+
306
+ calibration = self.load_or_run_calibration_(name, bus, "follower")
307
+ bus.set_calibration(calibration)
308
+
309
+ bus.motors = original_motors
310
+
311
+ def _get_data(self):
312
+ """
313
+ Polls the video socket for up to 15 ms. If data arrives, decode only
314
+ the *latest* message, returning frames, speed, and arm state. If
315
+ nothing arrives for any field, use the last known values.
316
+ """
317
+ frames = {}
318
+ present_speed = {}
319
+ remote_arm_state_tensor = torch.zeros(6, dtype=torch.float32)
320
+
321
+ # Poll up to 15 ms
322
+ poller = zmq.Poller()
323
+ poller.register(self.video_socket, zmq.POLLIN)
324
+ socks = dict(poller.poll(15))
325
+ if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN:
326
+ # No new data arrived → reuse ALL old data
327
+ return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
328
+
329
+ # Drain all messages, keep only the last
330
+ last_msg = None
331
+ while True:
332
+ try:
333
+ obs_string = self.video_socket.recv_string(zmq.NOBLOCK)
334
+ last_msg = obs_string
335
+ except zmq.Again:
336
+ break
337
+
338
+ if not last_msg:
339
+ # No new message → also reuse old
340
+ return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
341
+
342
+ # Decode only the final message
343
+ try:
344
+ observation = json.loads(last_msg)
345
+
346
+ images_dict = observation.get("images", {})
347
+ new_speed = observation.get("present_speed", {})
348
+ new_arm_state = observation.get("follower_arm_state", None)
349
+
350
+ # Convert images
351
+ for cam_name, image_b64 in images_dict.items():
352
+ if image_b64:
353
+ jpg_data = base64.b64decode(image_b64)
354
+ np_arr = np.frombuffer(jpg_data, dtype=np.uint8)
355
+ frame_candidate = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
356
+ if frame_candidate is not None:
357
+ frames[cam_name] = frame_candidate
358
+
359
+ # If remote_arm_state is None and frames is None there is no message then use the previous message
360
+ if new_arm_state is not None and frames is not None:
361
+ self.last_frames = frames
362
+
363
+ remote_arm_state_tensor = torch.tensor(new_arm_state, dtype=torch.float32)
364
+ self.last_remote_arm_state = remote_arm_state_tensor
365
+
366
+ present_speed = new_speed
367
+ self.last_present_speed = new_speed
368
+ else:
369
+ frames = self.last_frames
370
+
371
+ remote_arm_state_tensor = self.last_remote_arm_state
372
+
373
+ present_speed = self.last_present_speed
374
+
375
+ except Exception as e:
376
+ print(f"[DEBUG] Error decoding video message: {e}")
377
+ # If decode fails, fall back to old data
378
+ return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
379
+
380
+ return frames, present_speed, remote_arm_state_tensor
381
+
382
+ def _process_present_speed(self, present_speed: dict) -> torch.Tensor:
383
+ state_tensor = torch.zeros(3, dtype=torch.int32)
384
+ if present_speed:
385
+ decoded = {key: MobileManipulator.raw_to_degps(value) for key, value in present_speed.items()}
386
+ if "1" in decoded:
387
+ state_tensor[0] = decoded["1"]
388
+ if "2" in decoded:
389
+ state_tensor[1] = decoded["2"]
390
+ if "3" in decoded:
391
+ state_tensor[2] = decoded["3"]
392
+ return state_tensor
393
+
394
+ def teleop_step(
395
+ self, record_data: bool = False
396
+ ) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
397
+ if not self.is_connected:
398
+ raise RobotDeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.")
399
+
400
+ speed_setting = self.speed_levels[self.speed_index]
401
+ xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4
402
+ theta_speed = speed_setting["theta"] # e.g. 30, 60, or 90
403
+
404
+ # Prepare to assign the position of the leader to the follower
405
+ arm_positions = []
406
+ for name in self.leader_arms:
407
+ pos = self.leader_arms[name].read("Present_Position")
408
+ pos_tensor = torch.from_numpy(pos).float()
409
+ arm_positions.extend(pos_tensor.tolist())
410
+
411
+ y_cmd = 0.0 # m/s forward/backward
412
+ x_cmd = 0.0 # m/s lateral
413
+ theta_cmd = 0.0 # deg/s rotation
414
+ if self.pressed_keys["forward"]:
415
+ y_cmd += xy_speed
416
+ if self.pressed_keys["backward"]:
417
+ y_cmd -= xy_speed
418
+ if self.pressed_keys["left"]:
419
+ x_cmd += xy_speed
420
+ if self.pressed_keys["right"]:
421
+ x_cmd -= xy_speed
422
+ if self.pressed_keys["rotate_left"]:
423
+ theta_cmd += theta_speed
424
+ if self.pressed_keys["rotate_right"]:
425
+ theta_cmd -= theta_speed
426
+
427
+ wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
428
+
429
+ message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions}
430
+ self.cmd_socket.send_string(json.dumps(message))
431
+
432
+ if not record_data:
433
+ return
434
+
435
+ obs_dict = self.capture_observation()
436
+
437
+ arm_state_tensor = torch.tensor(arm_positions, dtype=torch.float32)
438
+
439
+ wheel_velocity_tuple = self.wheel_raw_to_body(wheel_commands)
440
+ wheel_velocity_mm = (
441
+ wheel_velocity_tuple[0] * 1000.0,
442
+ wheel_velocity_tuple[1] * 1000.0,
443
+ wheel_velocity_tuple[2],
444
+ )
445
+ wheel_tensor = torch.tensor(wheel_velocity_mm, dtype=torch.float32)
446
+ action_tensor = torch.cat([arm_state_tensor, wheel_tensor])
447
+ action_dict = {"action": action_tensor}
448
+
449
+ return obs_dict, action_dict
450
+
451
+ def capture_observation(self) -> dict:
452
+ """
453
+ Capture observations from the remote robot: current follower arm positions,
454
+ present wheel speeds (converted to body-frame velocities: x, y, theta),
455
+ and a camera frame.
456
+ """
457
+ if not self.is_connected:
458
+ raise RobotDeviceNotConnectedError("Not connected. Run `connect()` first.")
459
+
460
+ frames, present_speed, remote_arm_state_tensor = self._get_data()
461
+
462
+ body_state = self.wheel_raw_to_body(present_speed)
463
+
464
+ body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
465
+ wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
466
+ combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
467
+
468
+ obs_dict = {"observation.state": combined_state_tensor}
469
+
470
+ # Loop over each configured camera
471
+ for cam_name, cam in self.cameras.items():
472
+ frame = frames.get(cam_name, None)
473
+ if frame is None:
474
+ # Create a black image using the camera's configured width, height, and channels
475
+ frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8)
476
+ obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(frame)
477
+
478
+ return obs_dict
479
+
480
+ def send_action(self, action: torch.Tensor) -> torch.Tensor:
481
+ if not self.is_connected:
482
+ raise RobotDeviceNotConnectedError("Not connected. Run `connect()` first.")
483
+
484
+ # Ensure the action tensor has at least 9 elements:
485
+ # - First 6: arm positions.
486
+ # - Last 3: base commands.
487
+ if action.numel() < 9:
488
+ # Pad with zeros if there are not enough elements.
489
+ padded = torch.zeros(9, dtype=action.dtype)
490
+ padded[: action.numel()] = action
491
+ action = padded
492
+
493
+ # Extract arm and base actions.
494
+ arm_actions = action[:6].flatten()
495
+ base_actions = action[6:].flatten()
496
+
497
+ x_cmd_mm = base_actions[0].item() # mm/s
498
+ y_cmd_mm = base_actions[1].item() # mm/s
499
+ theta_cmd = base_actions[2].item() # deg/s
500
+
501
+ # Convert mm/s to m/s for the kinematics calculations.
502
+ x_cmd = x_cmd_mm / 1000.0 # m/s
503
+ y_cmd = y_cmd_mm / 1000.0 # m/s
504
+
505
+ # Compute wheel commands from body commands.
506
+ wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
507
+
508
+ arm_positions_list = arm_actions.tolist()
509
+
510
+ message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions_list}
511
+ self.cmd_socket.send_string(json.dumps(message))
512
+
513
+ return action
514
+
515
+ def print_logs(self):
516
+ pass
517
+
518
+ def disconnect(self):
519
+ if not self.is_connected:
520
+ raise RobotDeviceNotConnectedError("Not connected.")
521
+ if self.cmd_socket:
522
+ stop_cmd = {
523
+ "raw_velocity": {"left_wheel": 0, "back_wheel": 0, "right_wheel": 0},
524
+ "arm_positions": {},
525
+ }
526
+ self.cmd_socket.send_string(json.dumps(stop_cmd))
527
+ self.cmd_socket.close()
528
+ if self.video_socket:
529
+ self.video_socket.close()
530
+ if self.context:
531
+ self.context.term()
532
+ if PYNPUT_AVAILABLE:
533
+ self.listener.stop()
534
+ self.is_connected = False
535
+ print("[INFO] Disconnected from remote robot.")
536
+
537
+ def __del__(self):
538
+ if getattr(self, "is_connected", False):
539
+ self.disconnect()
540
+ if PYNPUT_AVAILABLE:
541
+ self.listener.stop()
542
+
543
+ @staticmethod
544
+ def degps_to_raw(degps: float) -> int:
545
+ steps_per_deg = 4096.0 / 360.0
546
+ speed_in_steps = abs(degps) * steps_per_deg
547
+ speed_int = int(round(speed_in_steps))
548
+ if speed_int > 0x7FFF:
549
+ speed_int = 0x7FFF
550
+ if degps < 0:
551
+ return speed_int | 0x8000
552
+ else:
553
+ return speed_int & 0x7FFF
554
+
555
+ @staticmethod
556
+ def raw_to_degps(raw_speed: int) -> float:
557
+ steps_per_deg = 4096.0 / 360.0
558
+ magnitude = raw_speed & 0x7FFF
559
+ degps = magnitude / steps_per_deg
560
+ if raw_speed & 0x8000:
561
+ degps = -degps
562
+ return degps
563
+
564
+ def body_to_wheel_raw(
565
+ self,
566
+ x_cmd: float,
567
+ y_cmd: float,
568
+ theta_cmd: float,
569
+ wheel_radius: float = 0.05,
570
+ base_radius: float = 0.125,
571
+ max_raw: int = 3000,
572
+ ) -> dict:
573
+ """
574
+ Convert desired body-frame velocities into wheel raw commands.
575
+
576
+ Parameters:
577
+ x_cmd : Linear velocity in x (m/s).
578
+ y_cmd : Linear velocity in y (m/s).
579
+ theta_cmd : Rotational velocity (deg/s).
580
+ wheel_radius: Radius of each wheel (meters).
581
+ base_radius : Distance from the center of rotation to each wheel (meters).
582
+ max_raw : Maximum allowed raw command (ticks) per wheel.
583
+
584
+ Returns:
585
+ A dictionary with wheel raw commands:
586
+ {"left_wheel": value, "back_wheel": value, "right_wheel": value}.
587
+
588
+ Notes:
589
+ - Internally, the method converts theta_cmd to rad/s for the kinematics.
590
+ - The raw command is computed from the wheels angular speed in deg/s
591
+ using degps_to_raw(). If any command exceeds max_raw, all commands
592
+ are scaled down proportionally.
593
+ """
594
+ # Convert rotational velocity from deg/s to rad/s.
595
+ theta_rad = theta_cmd * (np.pi / 180.0)
596
+ # Create the body velocity vector [x, y, theta_rad].
597
+ velocity_vector = np.array([x_cmd, y_cmd, theta_rad])
598
+
599
+ # Define the wheel mounting angles (defined from y axis cw)
600
+ angles = np.radians(np.array([300, 180, 60]))
601
+ # Build the kinematic matrix: each row maps body velocities to a wheel’s linear speed.
602
+ # The third column (base_radius) accounts for the effect of rotation.
603
+ m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
604
+
605
+ # Compute each wheel’s linear speed (m/s) and then its angular speed (rad/s).
606
+ wheel_linear_speeds = m.dot(velocity_vector)
607
+ wheel_angular_speeds = wheel_linear_speeds / wheel_radius
608
+
609
+ # Convert wheel angular speeds from rad/s to deg/s.
610
+ wheel_degps = wheel_angular_speeds * (180.0 / np.pi)
611
+
612
+ # Scaling
613
+ steps_per_deg = 4096.0 / 360.0
614
+ raw_floats = [abs(degps) * steps_per_deg for degps in wheel_degps]
615
+ max_raw_computed = max(raw_floats)
616
+ if max_raw_computed > max_raw:
617
+ scale = max_raw / max_raw_computed
618
+ wheel_degps = wheel_degps * scale
619
+
620
+ # Convert each wheel’s angular speed (deg/s) to a raw integer.
621
+ wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps]
622
+
623
+ return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]}
624
+
625
+ def wheel_raw_to_body(
626
+ self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125
627
+ ) -> tuple:
628
+ """
629
+ Convert wheel raw command feedback back into body-frame velocities.
630
+
631
+ Parameters:
632
+ wheel_raw : Dictionary with raw wheel commands (keys: "left_wheel", "back_wheel", "right_wheel").
633
+ wheel_radius: Radius of each wheel (meters).
634
+ base_radius : Distance from the robot center to each wheel (meters).
635
+
636
+ Returns:
637
+ A tuple (x_cmd, y_cmd, theta_cmd) where:
638
+ x_cmd : Linear velocity in x (m/s).
639
+ y_cmd : Linear velocity in y (m/s).
640
+ theta_cmd : Rotational velocity in deg/s.
641
+ """
642
+ # Extract the raw values in order.
643
+ raw_list = [
644
+ int(wheel_raw.get("left_wheel", 0)),
645
+ int(wheel_raw.get("back_wheel", 0)),
646
+ int(wheel_raw.get("right_wheel", 0)),
647
+ ]
648
+
649
+ # Convert each raw command back to an angular speed in deg/s.
650
+ wheel_degps = np.array([MobileManipulator.raw_to_degps(r) for r in raw_list])
651
+ # Convert from deg/s to rad/s.
652
+ wheel_radps = wheel_degps * (np.pi / 180.0)
653
+ # Compute each wheel’s linear speed (m/s) from its angular speed.
654
+ wheel_linear_speeds = wheel_radps * wheel_radius
655
+
656
+ # Define the wheel mounting angles (defined from y axis cw)
657
+ angles = np.radians(np.array([300, 180, 60]))
658
+ m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
659
+
660
+ # Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds.
661
+ m_inv = np.linalg.inv(m)
662
+ velocity_vector = m_inv.dot(wheel_linear_speeds)
663
+ x_cmd, y_cmd, theta_rad = velocity_vector
664
+ theta_cmd = theta_rad * (180.0 / np.pi)
665
+ return (x_cmd, y_cmd, theta_cmd)
666
+
667
+
668
+ class LeKiwi:
669
+ def __init__(self, motor_bus):
670
+ """
671
+ Initializes the LeKiwi with Feetech motors bus.
672
+ """
673
+ self.motor_bus = motor_bus
674
+ self.motor_ids = ["left_wheel", "back_wheel", "right_wheel"]
675
+
676
+ # Initialize motors in velocity mode.
677
+ self.motor_bus.write("Lock", 0)
678
+ self.motor_bus.write("Mode", [1, 1, 1], self.motor_ids)
679
+ self.motor_bus.write("Lock", 1)
680
+ print("Motors set to velocity mode.")
681
+
682
+ def read_velocity(self):
683
+ """
684
+ Reads the raw speeds for all wheels. Returns a dictionary with motor names:
685
+ """
686
+ raw_speeds = self.motor_bus.read("Present_Speed", self.motor_ids)
687
+ return {
688
+ "left_wheel": int(raw_speeds[0]),
689
+ "back_wheel": int(raw_speeds[1]),
690
+ "right_wheel": int(raw_speeds[2]),
691
+ }
692
+
693
+ def set_velocity(self, command_speeds):
694
+ """
695
+ Sends raw velocity commands (16-bit encoded values) directly to the motor bus.
696
+ The order of speeds must correspond to self.motor_ids.
697
+ """
698
+ self.motor_bus.write("Goal_Speed", command_speeds, self.motor_ids)
699
+
700
+ def stop(self):
701
+ """Stops the robot by setting all motor speeds to zero."""
702
+ self.motor_bus.write("Goal_Speed", [0, 0, 0], self.motor_ids)
703
+ print("Motors stopped.")
lerobot/common/robot_devices/robots/stretch.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import time
18
+ from dataclasses import replace
19
+
20
+ import torch
21
+ from stretch_body.gamepad_teleop import GamePadTeleop
22
+ from stretch_body.robot import Robot as StretchAPI
23
+ from stretch_body.robot_params import RobotParams
24
+
25
+ from lerobot.common.robot_devices.robots.configs import StretchRobotConfig
26
+
27
+
28
+ class StretchRobot(StretchAPI):
29
+ """Wrapper of stretch_body.robot.Robot"""
30
+
31
+ def __init__(self, config: StretchRobotConfig | None = None, **kwargs):
32
+ super().__init__()
33
+ if config is None:
34
+ self.config = StretchRobotConfig(**kwargs)
35
+ else:
36
+ # Overwrite config arguments using kwargs
37
+ self.config = replace(config, **kwargs)
38
+
39
+ self.robot_type = self.config.type
40
+ self.cameras = self.config.cameras
41
+ self.is_connected = False
42
+ self.teleop = None
43
+ self.logs = {}
44
+
45
+ # TODO(aliberts): test this
46
+ RobotParams.set_logging_level("WARNING")
47
+ RobotParams.set_logging_formatter("brief_console_formatter")
48
+
49
+ self.state_keys = None
50
+ self.action_keys = None
51
+
52
+ def connect(self) -> None:
53
+ self.is_connected = self.startup()
54
+ if not self.is_connected:
55
+ print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
56
+ raise ConnectionError()
57
+
58
+ for name in self.cameras:
59
+ self.cameras[name].connect()
60
+ self.is_connected = self.is_connected and self.cameras[name].is_connected
61
+
62
+ if not self.is_connected:
63
+ print("Could not connect to the cameras, check that all cameras are plugged-in.")
64
+ raise ConnectionError()
65
+
66
+ self.run_calibration()
67
+
68
+ def run_calibration(self) -> None:
69
+ if not self.is_homed():
70
+ self.home()
71
+
72
+ def teleop_step(
73
+ self, record_data=False
74
+ ) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
75
+ # TODO(aliberts): return ndarrays instead of torch.Tensors
76
+ if not self.is_connected:
77
+ raise ConnectionError()
78
+
79
+ if self.teleop is None:
80
+ self.teleop = GamePadTeleop(robot_instance=False)
81
+ self.teleop.startup(robot=self)
82
+
83
+ before_read_t = time.perf_counter()
84
+ state = self.get_state()
85
+ action = self.teleop.gamepad_controller.get_state()
86
+ self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
87
+
88
+ before_write_t = time.perf_counter()
89
+ self.teleop.do_motion(robot=self)
90
+ self.push_command()
91
+ self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
92
+
93
+ if self.state_keys is None:
94
+ self.state_keys = list(state)
95
+
96
+ if not record_data:
97
+ return
98
+
99
+ state = torch.as_tensor(list(state.values()))
100
+ action = torch.as_tensor(list(action.values()))
101
+
102
+ # Capture images from cameras
103
+ images = {}
104
+ for name in self.cameras:
105
+ before_camread_t = time.perf_counter()
106
+ images[name] = self.cameras[name].async_read()
107
+ images[name] = torch.from_numpy(images[name])
108
+ self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
109
+ self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
110
+
111
+ # Populate output dictionaries
112
+ obs_dict, action_dict = {}, {}
113
+ obs_dict["observation.state"] = state
114
+ action_dict["action"] = action
115
+ for name in self.cameras:
116
+ obs_dict[f"observation.images.{name}"] = images[name]
117
+
118
+ return obs_dict, action_dict
119
+
120
+ def get_state(self) -> dict:
121
+ status = self.get_status()
122
+ return {
123
+ "head_pan.pos": status["head"]["head_pan"]["pos"],
124
+ "head_tilt.pos": status["head"]["head_tilt"]["pos"],
125
+ "lift.pos": status["lift"]["pos"],
126
+ "arm.pos": status["arm"]["pos"],
127
+ "wrist_pitch.pos": status["end_of_arm"]["wrist_pitch"]["pos"],
128
+ "wrist_roll.pos": status["end_of_arm"]["wrist_roll"]["pos"],
129
+ "wrist_yaw.pos": status["end_of_arm"]["wrist_yaw"]["pos"],
130
+ "gripper.pos": status["end_of_arm"]["stretch_gripper"]["pos"],
131
+ "base_x.vel": status["base"]["x_vel"],
132
+ "base_y.vel": status["base"]["y_vel"],
133
+ "base_theta.vel": status["base"]["theta_vel"],
134
+ }
135
+
136
+ def capture_observation(self) -> dict:
137
+ # TODO(aliberts): return ndarrays instead of torch.Tensors
138
+ before_read_t = time.perf_counter()
139
+ state = self.get_state()
140
+ self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
141
+
142
+ if self.state_keys is None:
143
+ self.state_keys = list(state)
144
+
145
+ state = torch.as_tensor(list(state.values()))
146
+
147
+ # Capture images from cameras
148
+ images = {}
149
+ for name in self.cameras:
150
+ before_camread_t = time.perf_counter()
151
+ images[name] = self.cameras[name].async_read()
152
+ images[name] = torch.from_numpy(images[name])
153
+ self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
154
+ self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
155
+
156
+ # Populate output dictionaries
157
+ obs_dict = {}
158
+ obs_dict["observation.state"] = state
159
+ for name in self.cameras:
160
+ obs_dict[f"observation.images.{name}"] = images[name]
161
+
162
+ return obs_dict
163
+
164
+ def send_action(self, action: torch.Tensor) -> torch.Tensor:
165
+ # TODO(aliberts): return ndarrays instead of torch.Tensors
166
+ if not self.is_connected:
167
+ raise ConnectionError()
168
+
169
+ if self.teleop is None:
170
+ self.teleop = GamePadTeleop(robot_instance=False)
171
+ self.teleop.startup(robot=self)
172
+
173
+ if self.action_keys is None:
174
+ dummy_action = self.teleop.gamepad_controller.get_state()
175
+ self.action_keys = list(dummy_action.keys())
176
+
177
+ action_dict = dict(zip(self.action_keys, action.tolist(), strict=True))
178
+
179
+ before_write_t = time.perf_counter()
180
+ self.teleop.do_motion(state=action_dict, robot=self)
181
+ self.push_command()
182
+ self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
183
+
184
+ # TODO(aliberts): return action_sent when motion is limited
185
+ return action
186
+
187
+ def print_logs(self) -> None:
188
+ pass
189
+ # TODO(aliberts): move robot-specific logs logic here
190
+
191
+ def teleop_safety_stop(self) -> None:
192
+ if self.teleop is not None:
193
+ self.teleop._safety_stop(robot=self)
194
+
195
+ def disconnect(self) -> None:
196
+ self.stop()
197
+ if self.teleop is not None:
198
+ self.teleop.gamepad_controller.stop()
199
+ self.teleop.stop()
200
+
201
+ if len(self.cameras) > 0:
202
+ for cam in self.cameras.values():
203
+ cam.disconnect()
204
+
205
+ self.is_connected = False
206
+
207
+ def __del__(self):
208
+ self.disconnect()
lerobot/common/robot_devices/robots/utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Protocol
16
+
17
+ from lerobot.common.robot_devices.robots.configs import (
18
+ AlohaRobotConfig,
19
+ KochBimanualRobotConfig,
20
+ KochRobotConfig,
21
+ LeKiwiRobotConfig,
22
+ ManipulatorRobotConfig,
23
+ MossRobotConfig,
24
+ RobotConfig,
25
+ So100RobotConfig,
26
+ StretchRobotConfig,
27
+ )
28
+
29
+
30
+ def get_arm_id(name, arm_type):
31
+ """Returns the string identifier of a robot arm. For instance, for a bimanual manipulator
32
+ like Aloha, it could be left_follower, right_follower, left_leader, or right_leader.
33
+ """
34
+ return f"{name}_{arm_type}"
35
+
36
+
37
+ class Robot(Protocol):
38
+ # TODO(rcadene, aliberts): Add unit test checking the protocol is implemented in the corresponding classes
39
+ robot_type: str
40
+ features: dict
41
+
42
+ def connect(self): ...
43
+ def run_calibration(self): ...
44
+ def teleop_step(self, record_data=False): ...
45
+ def capture_observation(self): ...
46
+ def send_action(self, action): ...
47
+ def disconnect(self): ...
48
+
49
+
50
+ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
51
+ if robot_type == "aloha":
52
+ return AlohaRobotConfig(**kwargs)
53
+ elif robot_type == "koch":
54
+ return KochRobotConfig(**kwargs)
55
+ elif robot_type == "koch_bimanual":
56
+ return KochBimanualRobotConfig(**kwargs)
57
+ elif robot_type == "moss":
58
+ return MossRobotConfig(**kwargs)
59
+ elif robot_type == "so100":
60
+ return So100RobotConfig(**kwargs)
61
+ elif robot_type == "stretch":
62
+ return StretchRobotConfig(**kwargs)
63
+ elif robot_type == "lekiwi":
64
+ return LeKiwiRobotConfig(**kwargs)
65
+ else:
66
+ raise ValueError(f"Robot type '{robot_type}' is not available.")
67
+
68
+
69
+ def make_robot_from_config(config: RobotConfig):
70
+ if isinstance(config, ManipulatorRobotConfig):
71
+ from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
72
+
73
+ return ManipulatorRobot(config)
74
+ elif isinstance(config, LeKiwiRobotConfig):
75
+ from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator
76
+
77
+ return MobileManipulator(config)
78
+ else:
79
+ from lerobot.common.robot_devices.robots.stretch import StretchRobot
80
+
81
+ return StretchRobot(config)
82
+
83
+
84
+ def make_robot(robot_type: str, **kwargs) -> Robot:
85
+ config = make_robot_config(robot_type, **kwargs)
86
+ return make_robot_from_config(config)
lerobot/common/robot_devices/utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import platform
16
+ import time
17
+
18
+
19
+ def busy_wait(seconds):
20
+ if platform.system() == "Darwin":
21
+ # On Mac, `time.sleep` is not accurate and we need to use this while loop trick,
22
+ # but it consumes CPU cycles.
23
+ # TODO(rcadene): find an alternative: from python 11, time.sleep is precise
24
+ end_time = time.perf_counter() + seconds
25
+ while time.perf_counter() < end_time:
26
+ pass
27
+ else:
28
+ # On Linux time.sleep is accurate
29
+ if seconds > 0:
30
+ time.sleep(seconds)
31
+
32
+
33
+ def safe_disconnect(func):
34
+ # TODO(aliberts): Allow to pass custom exceptions
35
+ # (e.g. ThreadServiceExit, KeyboardInterrupt, SystemExit, UnpluggedError, DynamixelCommError)
36
+ def wrapper(robot, *args, **kwargs):
37
+ try:
38
+ return func(robot, *args, **kwargs)
39
+ except Exception as e:
40
+ if robot.is_connected:
41
+ robot.disconnect()
42
+ raise e
43
+
44
+ return wrapper
45
+
46
+
47
+ class RobotDeviceNotConnectedError(Exception):
48
+ """Exception raised when the robot device is not connected."""
49
+
50
+ def __init__(
51
+ self, message="This robot device is not connected. Try calling `robot_device.connect()` first."
52
+ ):
53
+ self.message = message
54
+ super().__init__(self.message)
55
+
56
+
57
+ class RobotDeviceAlreadyConnectedError(Exception):
58
+ """Exception raised when the robot device is already connected."""
59
+
60
+ def __init__(
61
+ self,
62
+ message="This robot device is already connected. Try not calling `robot_device.connect()` twice.",
63
+ ):
64
+ self.message = message
65
+ super().__init__(self.message)
lerobot/common/utils/benchmark.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import threading
17
+ import time
18
+ from contextlib import ContextDecorator
19
+
20
+
21
+ class TimeBenchmark(ContextDecorator):
22
+ """
23
+ Measures execution time using a context manager or decorator.
24
+
25
+ This class supports both context manager and decorator usage, and is thread-safe for multithreaded
26
+ environments.
27
+
28
+ Args:
29
+ print: If True, prints the elapsed time upon exiting the context or completing the function. Defaults
30
+ to False.
31
+
32
+ Examples:
33
+
34
+ Using as a context manager:
35
+
36
+ >>> benchmark = TimeBenchmark()
37
+ >>> with benchmark:
38
+ ... time.sleep(1)
39
+ >>> print(f"Block took {benchmark.result:.4f} seconds")
40
+ Block took approximately 1.0000 seconds
41
+
42
+ Using with multithreading:
43
+
44
+ ```python
45
+ import threading
46
+
47
+ benchmark = TimeBenchmark()
48
+
49
+ def context_manager_example():
50
+ with benchmark:
51
+ time.sleep(0.01)
52
+ print(f"Block took {benchmark.result_ms:.2f} milliseconds")
53
+
54
+ threads = []
55
+ for _ in range(3):
56
+ t1 = threading.Thread(target=context_manager_example)
57
+ threads.append(t1)
58
+
59
+ for t in threads:
60
+ t.start()
61
+
62
+ for t in threads:
63
+ t.join()
64
+ ```
65
+ Expected output:
66
+ Block took approximately 10.00 milliseconds
67
+ Block took approximately 10.00 milliseconds
68
+ Block took approximately 10.00 milliseconds
69
+ """
70
+
71
+ def __init__(self, print=False):
72
+ self.local = threading.local()
73
+ self.print_time = print
74
+
75
+ def __enter__(self):
76
+ self.local.start_time = time.perf_counter()
77
+ return self
78
+
79
+ def __exit__(self, *exc):
80
+ self.local.end_time = time.perf_counter()
81
+ self.local.elapsed_time = self.local.end_time - self.local.start_time
82
+ if self.print_time:
83
+ print(f"Elapsed time: {self.local.elapsed_time:.4f} seconds")
84
+ return False
85
+
86
+ @property
87
+ def result(self):
88
+ return getattr(self.local, "elapsed_time", None)
89
+
90
+ @property
91
+ def result_ms(self):
92
+ return self.result * 1e3
lerobot/common/utils/hub.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from pathlib import Path
16
+ from tempfile import TemporaryDirectory
17
+ from typing import Any, Type, TypeVar
18
+
19
+ from huggingface_hub import HfApi
20
+ from huggingface_hub.utils import validate_hf_hub_args
21
+
22
+ T = TypeVar("T", bound="HubMixin")
23
+
24
+
25
+ class HubMixin:
26
+ """
27
+ A Mixin containing the functionality to push an object to the hub.
28
+
29
+ This is similar to huggingface_hub.ModelHubMixin but is lighter and makes less assumptions about its
30
+ subclasses (in particular, the fact that it's not necessarily a model).
31
+
32
+ The inheriting classes must implement '_save_pretrained' and 'from_pretrained'.
33
+ """
34
+
35
+ def save_pretrained(
36
+ self,
37
+ save_directory: str | Path,
38
+ *,
39
+ repo_id: str | None = None,
40
+ push_to_hub: bool = False,
41
+ card_kwargs: dict[str, Any] | None = None,
42
+ **push_to_hub_kwargs,
43
+ ) -> str | None:
44
+ """
45
+ Save object in local directory.
46
+
47
+ Args:
48
+ save_directory (`str` or `Path`):
49
+ Path to directory in which the object will be saved.
50
+ push_to_hub (`bool`, *optional*, defaults to `False`):
51
+ Whether or not to push your object to the Huggingface Hub after saving it.
52
+ repo_id (`str`, *optional*):
53
+ ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
54
+ not provided.
55
+ card_kwargs (`Dict[str, Any]`, *optional*):
56
+ Additional arguments passed to the card template to customize the card.
57
+ push_to_hub_kwargs:
58
+ Additional key word arguments passed along to the [`~HubMixin.push_to_hub`] method.
59
+ Returns:
60
+ `str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
61
+ """
62
+ save_directory = Path(save_directory)
63
+ save_directory.mkdir(parents=True, exist_ok=True)
64
+
65
+ # save object (weights, files, etc.)
66
+ self._save_pretrained(save_directory)
67
+
68
+ # push to the Hub if required
69
+ if push_to_hub:
70
+ if repo_id is None:
71
+ repo_id = save_directory.name # Defaults to `save_directory` name
72
+ return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
73
+ return None
74
+
75
+ def _save_pretrained(self, save_directory: Path) -> None:
76
+ """
77
+ Overwrite this method in subclass to define how to save your object.
78
+
79
+ Args:
80
+ save_directory (`str` or `Path`):
81
+ Path to directory in which the object files will be saved.
82
+ """
83
+ raise NotImplementedError
84
+
85
+ @classmethod
86
+ @validate_hf_hub_args
87
+ def from_pretrained(
88
+ cls: Type[T],
89
+ pretrained_name_or_path: str | Path,
90
+ *,
91
+ force_download: bool = False,
92
+ resume_download: bool | None = None,
93
+ proxies: dict | None = None,
94
+ token: str | bool | None = None,
95
+ cache_dir: str | Path | None = None,
96
+ local_files_only: bool = False,
97
+ revision: str | None = None,
98
+ **kwargs,
99
+ ) -> T:
100
+ """
101
+ Download the object from the Huggingface Hub and instantiate it.
102
+
103
+ Args:
104
+ pretrained_name_or_path (`str`, `Path`):
105
+ - Either the `repo_id` (string) of the object hosted on the Hub, e.g. `lerobot/diffusion_pusht`.
106
+ - Or a path to a `directory` containing the object files saved using `.save_pretrained`,
107
+ e.g., `../path/to/my_model_directory/`.
108
+ revision (`str`, *optional*):
109
+ Revision on the Hub. Can be a branch name, a git tag or any commit id.
110
+ Defaults to the latest commit on `main` branch.
111
+ force_download (`bool`, *optional*, defaults to `False`):
112
+ Whether to force (re-)downloading the files from the Hub, overriding the existing cache.
113
+ proxies (`Dict[str, str]`, *optional*):
114
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
115
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
116
+ token (`str` or `bool`, *optional*):
117
+ The token to use as HTTP bearer authorization for remote files. By default, it will use the token
118
+ cached when running `huggingface-cli login`.
119
+ cache_dir (`str`, `Path`, *optional*):
120
+ Path to the folder where cached files are stored.
121
+ local_files_only (`bool`, *optional*, defaults to `False`):
122
+ If `True`, avoid downloading the file and return the path to the local cached file if it exists.
123
+ kwargs (`Dict`, *optional*):
124
+ Additional kwargs to pass to the object during initialization.
125
+ """
126
+ raise NotImplementedError
127
+
128
+ @validate_hf_hub_args
129
+ def push_to_hub(
130
+ self,
131
+ repo_id: str,
132
+ *,
133
+ commit_message: str | None = None,
134
+ private: bool | None = None,
135
+ token: str | None = None,
136
+ branch: str | None = None,
137
+ create_pr: bool | None = None,
138
+ allow_patterns: list[str] | str | None = None,
139
+ ignore_patterns: list[str] | str | None = None,
140
+ delete_patterns: list[str] | str | None = None,
141
+ card_kwargs: dict[str, Any] | None = None,
142
+ ) -> str:
143
+ """
144
+ Upload model checkpoint to the Hub.
145
+
146
+ Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
147
+ `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
148
+ details.
149
+
150
+ Args:
151
+ repo_id (`str`):
152
+ ID of the repository to push to (example: `"username/my-model"`).
153
+ commit_message (`str`, *optional*):
154
+ Message to commit while pushing.
155
+ private (`bool`, *optional*):
156
+ Whether the repository created should be private.
157
+ If `None` (default), the repo will be public unless the organization's default is private.
158
+ token (`str`, *optional*):
159
+ The token to use as HTTP bearer authorization for remote files. By default, it will use the token
160
+ cached when running `huggingface-cli login`.
161
+ branch (`str`, *optional*):
162
+ The git branch on which to push the model. This defaults to `"main"`.
163
+ create_pr (`boolean`, *optional*):
164
+ Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
165
+ allow_patterns (`List[str]` or `str`, *optional*):
166
+ If provided, only files matching at least one pattern are pushed.
167
+ ignore_patterns (`List[str]` or `str`, *optional*):
168
+ If provided, files matching any of the patterns are not pushed.
169
+ delete_patterns (`List[str]` or `str`, *optional*):
170
+ If provided, remote files matching any of the patterns will be deleted from the repo.
171
+ card_kwargs (`Dict[str, Any]`, *optional*):
172
+ Additional arguments passed to the card template to customize the card.
173
+
174
+ Returns:
175
+ The url of the commit of your object in the given repository.
176
+ """
177
+ api = HfApi(token=token)
178
+ repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
179
+
180
+ if commit_message is None:
181
+ if "Policy" in self.__class__.__name__:
182
+ commit_message = "Upload policy"
183
+ elif "Config" in self.__class__.__name__:
184
+ commit_message = "Upload config"
185
+ else:
186
+ commit_message = f"Upload {self.__class__.__name__}"
187
+
188
+ # Push the files to the repo in a single commit
189
+ with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
190
+ saved_path = Path(tmp) / repo_id
191
+ self.save_pretrained(saved_path, card_kwargs=card_kwargs)
192
+ return api.upload_folder(
193
+ repo_id=repo_id,
194
+ repo_type="model",
195
+ folder_path=saved_path,
196
+ commit_message=commit_message,
197
+ revision=branch,
198
+ create_pr=create_pr,
199
+ allow_patterns=allow_patterns,
200
+ ignore_patterns=ignore_patterns,
201
+ delete_patterns=delete_patterns,
202
+ )
lerobot/common/utils/import_utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import importlib
17
+ import logging
18
+
19
+
20
+ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
21
+ """Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
22
+ Check if the package spec exists and grab its version to avoid importing a local directory.
23
+ **Note:** this doesn't work for all packages.
24
+ """
25
+ package_exists = importlib.util.find_spec(pkg_name) is not None
26
+ package_version = "N/A"
27
+ if package_exists:
28
+ try:
29
+ # Primary method to get the package version
30
+ package_version = importlib.metadata.version(pkg_name)
31
+ except importlib.metadata.PackageNotFoundError:
32
+ # Fallback method: Only for "torch" and versions containing "dev"
33
+ if pkg_name == "torch":
34
+ try:
35
+ package = importlib.import_module(pkg_name)
36
+ temp_version = getattr(package, "__version__", "N/A")
37
+ # Check if the version contains "dev"
38
+ if "dev" in temp_version:
39
+ package_version = temp_version
40
+ package_exists = True
41
+ else:
42
+ package_exists = False
43
+ except ImportError:
44
+ # If the package can't be imported, it's not available
45
+ package_exists = False
46
+ else:
47
+ # For packages other than "torch", don't attempt the fallback and set as not available
48
+ package_exists = False
49
+ logging.debug(f"Detected {pkg_name} version: {package_version}")
50
+ if return_version:
51
+ return package_exists, package_version
52
+ else:
53
+ return package_exists
54
+
55
+
56
+ _torch_available, _torch_version = is_package_available("torch", return_version=True)
57
+ _gym_xarm_available = is_package_available("gym_xarm")
58
+ _gym_aloha_available = is_package_available("gym_aloha")
59
+ _gym_pusht_available = is_package_available("gym_pusht")
lerobot/common/utils/io_utils.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import json
17
+ import warnings
18
+ from pathlib import Path
19
+ from typing import TypeVar
20
+
21
+ import imageio
22
+
23
+ JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
24
+ T = TypeVar("T", bound=JsonLike)
25
+
26
+
27
+ def write_video(video_path, stacked_frames, fps):
28
+ # Filter out DeprecationWarnings raised from pkg_resources
29
+ with warnings.catch_warnings():
30
+ warnings.filterwarnings(
31
+ "ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
32
+ )
33
+ imageio.mimsave(video_path, stacked_frames, fps=fps)
34
+
35
+
36
+ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
37
+ """
38
+ Loads the JSON data from `fpath` and recursively fills `obj` with the
39
+ corresponding values (strictly matching structure and types).
40
+ Tuples in `obj` are expected to be lists in the JSON data, which will be
41
+ converted back into tuples.
42
+ """
43
+ with open(fpath, encoding="utf-8") as f:
44
+ data = json.load(f)
45
+
46
+ def _deserialize(target, source):
47
+ """
48
+ Recursively overwrite the structure in `target` with data from `source`,
49
+ performing strict checks on structure and type.
50
+ Returns the updated version of `target` (especially important for tuples).
51
+ """
52
+
53
+ # If the target is a dictionary, source must be a dictionary as well.
54
+ if isinstance(target, dict):
55
+ if not isinstance(source, dict):
56
+ raise TypeError(f"Type mismatch: expected dict, got {type(source)}")
57
+
58
+ # Check that they have exactly the same set of keys.
59
+ if target.keys() != source.keys():
60
+ raise ValueError(
61
+ f"Dictionary keys do not match.\nExpected: {target.keys()}, got: {source.keys()}"
62
+ )
63
+
64
+ # Recursively update each key.
65
+ for k in target:
66
+ target[k] = _deserialize(target[k], source[k])
67
+
68
+ return target
69
+
70
+ # If the target is a list, source must be a list as well.
71
+ elif isinstance(target, list):
72
+ if not isinstance(source, list):
73
+ raise TypeError(f"Type mismatch: expected list, got {type(source)}")
74
+
75
+ # Check length
76
+ if len(target) != len(source):
77
+ raise ValueError(f"List length mismatch: expected {len(target)}, got {len(source)}")
78
+
79
+ # Recursively update each element.
80
+ for i in range(len(target)):
81
+ target[i] = _deserialize(target[i], source[i])
82
+
83
+ return target
84
+
85
+ # If the target is a tuple, the source must be a list in JSON,
86
+ # which we'll convert back to a tuple.
87
+ elif isinstance(target, tuple):
88
+ if not isinstance(source, list):
89
+ raise TypeError(f"Type mismatch: expected list (for tuple), got {type(source)}")
90
+
91
+ if len(target) != len(source):
92
+ raise ValueError(f"Tuple length mismatch: expected {len(target)}, got {len(source)}")
93
+
94
+ # Convert each element, forming a new tuple.
95
+ converted_items = []
96
+ for t_item, s_item in zip(target, source, strict=False):
97
+ converted_items.append(_deserialize(t_item, s_item))
98
+
99
+ # Return a brand new tuple (tuples are immutable in Python).
100
+ return tuple(converted_items)
101
+
102
+ # Otherwise, we're dealing with a "primitive" (int, float, str, bool, None).
103
+ else:
104
+ # Check the exact type. If these must match 1:1, do:
105
+ if type(target) is not type(source):
106
+ raise TypeError(f"Type mismatch: expected {type(target)}, got {type(source)}")
107
+ return source
108
+
109
+ # Perform the in-place/recursive deserialization
110
+ updated_obj = _deserialize(obj, data)
111
+ return updated_obj
lerobot/common/utils/logging_utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from typing import Any
17
+
18
+ from lerobot.common.utils.utils import format_big_number
19
+
20
+
21
+ class AverageMeter:
22
+ """
23
+ Computes and stores the average and current value
24
+ Adapted from https://github.com/pytorch/examples/blob/main/imagenet/main.py
25
+ """
26
+
27
+ def __init__(self, name: str, fmt: str = ":f"):
28
+ self.name = name
29
+ self.fmt = fmt
30
+ self.reset()
31
+
32
+ def reset(self) -> None:
33
+ self.val = 0.0
34
+ self.avg = 0.0
35
+ self.sum = 0.0
36
+ self.count = 0.0
37
+
38
+ def update(self, val: float, n: int = 1) -> None:
39
+ self.val = val
40
+ self.sum += val * n
41
+ self.count += n
42
+ self.avg = self.sum / self.count
43
+
44
+ def __str__(self):
45
+ fmtstr = "{name}:{avg" + self.fmt + "}"
46
+ return fmtstr.format(**self.__dict__)
47
+
48
+
49
+ class MetricsTracker:
50
+ """
51
+ A helper class to track and log metrics over time.
52
+
53
+ Usage pattern:
54
+
55
+ ```python
56
+ # initialize, potentially with non-zero initial step (e.g. if resuming run)
57
+ metrics = {"loss": AverageMeter("loss", ":.3f")}
58
+ train_metrics = MetricsTracker(cfg, dataset, metrics, initial_step=step)
59
+
60
+ # update metrics derived from step (samples, episodes, epochs) at each training step
61
+ train_metrics.step()
62
+
63
+ # update various metrics
64
+ loss = policy.forward(batch)
65
+ train_metrics.loss = loss
66
+
67
+ # display current metrics
68
+ logging.info(train_metrics)
69
+
70
+ # export for wandb
71
+ wandb.log(train_metrics.to_dict())
72
+
73
+ # reset averages after logging
74
+ train_metrics.reset_averages()
75
+ ```
76
+ """
77
+
78
+ __keys__ = [
79
+ "_batch_size",
80
+ "_num_frames",
81
+ "_avg_samples_per_ep",
82
+ "metrics",
83
+ "steps",
84
+ "samples",
85
+ "episodes",
86
+ "epochs",
87
+ ]
88
+
89
+ def __init__(
90
+ self,
91
+ batch_size: int,
92
+ num_frames: int,
93
+ num_episodes: int,
94
+ metrics: dict[str, AverageMeter],
95
+ initial_step: int = 0,
96
+ ):
97
+ self.__dict__.update({k: None for k in self.__keys__})
98
+ self._batch_size = batch_size
99
+ self._num_frames = num_frames
100
+ self._avg_samples_per_ep = num_frames / num_episodes
101
+ self.metrics = metrics
102
+
103
+ self.steps = initial_step
104
+ # A sample is an (observation,action) pair, where observation and action
105
+ # can be on multiple timestamps. In a batch, we have `batch_size` number of samples.
106
+ self.samples = self.steps * self._batch_size
107
+ self.episodes = self.samples / self._avg_samples_per_ep
108
+ self.epochs = self.samples / self._num_frames
109
+
110
+ def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
111
+ if name in self.__dict__:
112
+ return self.__dict__[name]
113
+ elif name in self.metrics:
114
+ return self.metrics[name]
115
+ else:
116
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
117
+
118
+ def __setattr__(self, name: str, value: Any) -> None:
119
+ if name in self.__dict__:
120
+ super().__setattr__(name, value)
121
+ elif name in self.metrics:
122
+ self.metrics[name].update(value)
123
+ else:
124
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
125
+
126
+ def step(self) -> None:
127
+ """
128
+ Updates metrics that depend on 'step' for one step.
129
+ """
130
+ self.steps += 1
131
+ self.samples += self._batch_size
132
+ self.episodes = self.samples / self._avg_samples_per_ep
133
+ self.epochs = self.samples / self._num_frames
134
+
135
+ def __str__(self) -> str:
136
+ display_list = [
137
+ f"step:{format_big_number(self.steps)}",
138
+ # number of samples seen during training
139
+ f"smpl:{format_big_number(self.samples)}",
140
+ # number of episodes seen during training
141
+ f"ep:{format_big_number(self.episodes)}",
142
+ # number of time all unique samples are seen
143
+ f"epch:{self.epochs:.2f}",
144
+ *[str(m) for m in self.metrics.values()],
145
+ ]
146
+ return " ".join(display_list)
147
+
148
+ def to_dict(self, use_avg: bool = True) -> dict[str, int | float]:
149
+ """
150
+ Returns the current metric values (or averages if `use_avg=True`) as a dict.
151
+ """
152
+ return {
153
+ "steps": self.steps,
154
+ "samples": self.samples,
155
+ "episodes": self.episodes,
156
+ "epochs": self.epochs,
157
+ **{k: m.avg if use_avg else m.val for k, m in self.metrics.items()},
158
+ }
159
+
160
+ def reset_averages(self) -> None:
161
+ """Resets average meters."""
162
+ for m in self.metrics.values():
163
+ m.reset()
lerobot/common/utils/random_utils.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import random
17
+ from contextlib import contextmanager
18
+ from pathlib import Path
19
+ from typing import Any, Generator
20
+
21
+ import numpy as np
22
+ import torch
23
+ from safetensors.torch import load_file, save_file
24
+
25
+ from lerobot.common.constants import RNG_STATE
26
+ from lerobot.common.datasets.utils import flatten_dict, unflatten_dict
27
+
28
+
29
+ def serialize_python_rng_state() -> dict[str, torch.Tensor]:
30
+ """
31
+ Returns the rng state for `random` in the form of a flat dict[str, torch.Tensor] to be saved using
32
+ `safetensors.save_file()` or `torch.save()`.
33
+ """
34
+ py_state = random.getstate()
35
+ return {
36
+ "py_rng_version": torch.tensor([py_state[0]], dtype=torch.int64),
37
+ "py_rng_state": torch.tensor(py_state[1], dtype=torch.int64),
38
+ }
39
+
40
+
41
+ def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
42
+ """
43
+ Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`.
44
+ """
45
+ py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None)
46
+ random.setstate(py_state)
47
+
48
+
49
+ def serialize_numpy_rng_state() -> dict[str, torch.Tensor]:
50
+ """
51
+ Returns the rng state for `numpy` in the form of a flat dict[str, torch.Tensor] to be saved using
52
+ `safetensors.save_file()` or `torch.save()`.
53
+ """
54
+ np_state = np.random.get_state()
55
+ # Ensure no breaking changes from numpy
56
+ assert np_state[0] == "MT19937"
57
+ return {
58
+ "np_rng_state_values": torch.tensor(np_state[1], dtype=torch.int64),
59
+ "np_rng_state_index": torch.tensor([np_state[2]], dtype=torch.int64),
60
+ "np_rng_has_gauss": torch.tensor([np_state[3]], dtype=torch.int64),
61
+ "np_rng_cached_gaussian": torch.tensor([np_state[4]], dtype=torch.float32),
62
+ }
63
+
64
+
65
+ def deserialize_numpy_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
66
+ """
67
+ Restores the rng state for `numpy` from a dictionary produced by `serialize_numpy_rng_state()`.
68
+ """
69
+ np_state = (
70
+ "MT19937",
71
+ rng_state_dict["np_rng_state_values"].numpy(),
72
+ rng_state_dict["np_rng_state_index"].item(),
73
+ rng_state_dict["np_rng_has_gauss"].item(),
74
+ rng_state_dict["np_rng_cached_gaussian"].item(),
75
+ )
76
+ np.random.set_state(np_state)
77
+
78
+
79
+ def serialize_torch_rng_state() -> dict[str, torch.Tensor]:
80
+ """
81
+ Returns the rng state for `torch` in the form of a flat dict[str, torch.Tensor] to be saved using
82
+ `safetensors.save_file()` or `torch.save()`.
83
+ """
84
+ torch_rng_state_dict = {"torch_rng_state": torch.get_rng_state()}
85
+ if torch.cuda.is_available():
86
+ torch_rng_state_dict["torch_cuda_rng_state"] = torch.cuda.get_rng_state()
87
+ return torch_rng_state_dict
88
+
89
+
90
+ def deserialize_torch_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
91
+ """
92
+ Restores the rng state for `torch` from a dictionary produced by `serialize_torch_rng_state()`.
93
+ """
94
+ torch.set_rng_state(rng_state_dict["torch_rng_state"])
95
+ if torch.cuda.is_available() and "torch_cuda_rng_state" in rng_state_dict:
96
+ torch.cuda.set_rng_state(rng_state_dict["torch_cuda_rng_state"])
97
+
98
+
99
+ def serialize_rng_state() -> dict[str, torch.Tensor]:
100
+ """
101
+ Returns the rng state for `random`, `numpy`, and `torch`, in the form of a flat
102
+ dict[str, torch.Tensor] to be saved using `safetensors.save_file()` `torch.save()`.
103
+ """
104
+ py_rng_state_dict = serialize_python_rng_state()
105
+ np_rng_state_dict = serialize_numpy_rng_state()
106
+ torch_rng_state_dict = serialize_torch_rng_state()
107
+
108
+ return {
109
+ **py_rng_state_dict,
110
+ **np_rng_state_dict,
111
+ **torch_rng_state_dict,
112
+ }
113
+
114
+
115
+ def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
116
+ """
117
+ Restores the rng state for `random`, `numpy`, and `torch` from a dictionary produced by
118
+ `serialize_rng_state()`.
119
+ """
120
+ py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")}
121
+ np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")}
122
+ torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")}
123
+
124
+ deserialize_python_rng_state(py_rng_state_dict)
125
+ deserialize_numpy_rng_state(np_rng_state_dict)
126
+ deserialize_torch_rng_state(torch_rng_state_dict)
127
+
128
+
129
+ def save_rng_state(save_dir: Path) -> None:
130
+ rng_state_dict = serialize_rng_state()
131
+ flat_rng_state_dict = flatten_dict(rng_state_dict)
132
+ save_file(flat_rng_state_dict, save_dir / RNG_STATE)
133
+
134
+
135
+ def load_rng_state(save_dir: Path) -> None:
136
+ flat_rng_state_dict = load_file(save_dir / RNG_STATE)
137
+ rng_state_dict = unflatten_dict(flat_rng_state_dict)
138
+ deserialize_rng_state(rng_state_dict)
139
+
140
+
141
+ def get_rng_state() -> dict[str, Any]:
142
+ """Get the random state for `random`, `numpy`, and `torch`."""
143
+ random_state_dict = {
144
+ "random_state": random.getstate(),
145
+ "numpy_random_state": np.random.get_state(),
146
+ "torch_random_state": torch.random.get_rng_state(),
147
+ }
148
+ if torch.cuda.is_available():
149
+ random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
150
+ return random_state_dict
151
+
152
+
153
+ def set_rng_state(random_state_dict: dict[str, Any]):
154
+ """Set the random state for `random`, `numpy`, and `torch`.
155
+
156
+ Args:
157
+ random_state_dict: A dictionary of the form returned by `get_rng_state`.
158
+ """
159
+ random.setstate(random_state_dict["random_state"])
160
+ np.random.set_state(random_state_dict["numpy_random_state"])
161
+ torch.random.set_rng_state(random_state_dict["torch_random_state"])
162
+ if torch.cuda.is_available():
163
+ torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
164
+
165
+
166
+ def set_seed(seed) -> None:
167
+ """Set seed for reproducibility."""
168
+ random.seed(seed)
169
+ np.random.seed(seed)
170
+ torch.manual_seed(seed)
171
+ if torch.cuda.is_available():
172
+ torch.cuda.manual_seed_all(seed)
173
+
174
+
175
+ @contextmanager
176
+ def seeded_context(seed: int) -> Generator[None, None, None]:
177
+ """Set the seed when entering a context, and restore the prior random state at exit.
178
+
179
+ Example usage:
180
+
181
+ ```
182
+ a = random.random() # produces some random number
183
+ with seeded_context(1337):
184
+ b = random.random() # produces some other random number
185
+ c = random.random() # produces yet another random number, but the same it would have if we never made `b`
186
+ ```
187
+ """
188
+ random_state_dict = get_rng_state()
189
+ set_seed(seed)
190
+ yield None
191
+ set_rng_state(random_state_dict)
lerobot/common/utils/train_utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import logging
17
+ from pathlib import Path
18
+
19
+ from termcolor import colored
20
+ from torch.optim import Optimizer
21
+ from torch.optim.lr_scheduler import LRScheduler
22
+
23
+ from lerobot.common.constants import (
24
+ CHECKPOINTS_DIR,
25
+ LAST_CHECKPOINT_LINK,
26
+ PRETRAINED_MODEL_DIR,
27
+ TRAINING_STATE_DIR,
28
+ TRAINING_STEP,
29
+ )
30
+ from lerobot.common.datasets.utils import load_json, write_json
31
+ from lerobot.common.optim.optimizers import load_optimizer_state, save_optimizer_state
32
+ from lerobot.common.optim.schedulers import load_scheduler_state, save_scheduler_state
33
+ from lerobot.common.policies.pretrained import PreTrainedPolicy
34
+ from lerobot.common.utils.random_utils import load_rng_state, save_rng_state
35
+ from lerobot.configs.train import TrainPipelineConfig
36
+
37
+
38
+ def log_output_dir(out_dir):
39
+ logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
40
+
41
+
42
+ def get_step_identifier(step: int, total_steps: int) -> str:
43
+ num_digits = max(6, len(str(total_steps)))
44
+ return f"{step:0{num_digits}d}"
45
+
46
+
47
+ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Path:
48
+ """Returns the checkpoint sub-directory corresponding to the step number."""
49
+ step_identifier = get_step_identifier(step, total_steps)
50
+ return output_dir / CHECKPOINTS_DIR / step_identifier
51
+
52
+
53
+ def save_training_step(step: int, save_dir: Path) -> None:
54
+ write_json({"step": step}, save_dir / TRAINING_STEP)
55
+
56
+
57
+ def load_training_step(save_dir: Path) -> int:
58
+ training_step = load_json(save_dir / TRAINING_STEP)
59
+ return training_step["step"]
60
+
61
+
62
+ def update_last_checkpoint(checkpoint_dir: Path) -> Path:
63
+ last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
64
+ if last_checkpoint_dir.is_symlink():
65
+ last_checkpoint_dir.unlink()
66
+ relative_target = checkpoint_dir.relative_to(checkpoint_dir.parent)
67
+ last_checkpoint_dir.symlink_to(relative_target)
68
+
69
+
70
+ def save_checkpoint(
71
+ checkpoint_dir: Path,
72
+ step: int,
73
+ cfg: TrainPipelineConfig,
74
+ policy: PreTrainedPolicy,
75
+ optimizer: Optimizer,
76
+ scheduler: LRScheduler | None = None,
77
+ ) -> None:
78
+ """This function creates the following directory structure:
79
+
80
+ 005000/ # training step at checkpoint
81
+ ├── pretrained_model/
82
+ │ ├── config.json # policy config
83
+ │ ├── model.safetensors # policy weights
84
+ │ └── train_config.json # train config
85
+ └── training_state/
86
+ ├── optimizer_param_groups.json # optimizer param groups
87
+ ├── optimizer_state.safetensors # optimizer state
88
+ ├── rng_state.safetensors # rng states
89
+ ├── scheduler_state.json # scheduler state
90
+ └── training_step.json # training step
91
+
92
+ Args:
93
+ cfg (TrainPipelineConfig): The training config used for this run.
94
+ step (int): The training step at that checkpoint.
95
+ policy (PreTrainedPolicy): The policy to save.
96
+ optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
97
+ scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
98
+ """
99
+ pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
100
+ policy.save_pretrained(pretrained_dir)
101
+ cfg.save_pretrained(pretrained_dir)
102
+ save_training_state(checkpoint_dir, step, optimizer, scheduler)
103
+
104
+
105
+ def save_training_state(
106
+ checkpoint_dir: Path,
107
+ train_step: int,
108
+ optimizer: Optimizer | None = None,
109
+ scheduler: LRScheduler | None = None,
110
+ ) -> None:
111
+ """
112
+ Saves the training step, optimizer state, scheduler state, and rng state.
113
+
114
+ Args:
115
+ save_dir (Path): The directory to save artifacts to.
116
+ train_step (int): Current training step.
117
+ optimizer (Optimizer | None, optional): The optimizer from which to save the state_dict.
118
+ Defaults to None.
119
+ scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
120
+ Defaults to None.
121
+ """
122
+ save_dir = checkpoint_dir / TRAINING_STATE_DIR
123
+ save_dir.mkdir(parents=True, exist_ok=True)
124
+ save_training_step(train_step, save_dir)
125
+ save_rng_state(save_dir)
126
+ if optimizer is not None:
127
+ save_optimizer_state(optimizer, save_dir)
128
+ if scheduler is not None:
129
+ save_scheduler_state(scheduler, save_dir)
130
+
131
+
132
+ def load_training_state(
133
+ checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None
134
+ ) -> tuple[int, Optimizer, LRScheduler | None]:
135
+ """
136
+ Loads the training step, optimizer state, scheduler state, and rng state.
137
+ This is used to resume a training run.
138
+
139
+ Args:
140
+ checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir.
141
+ optimizer (Optimizer): The optimizer to load the state_dict to.
142
+ scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None).
143
+
144
+ Raises:
145
+ NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir
146
+
147
+ Returns:
148
+ tuple[int, Optimizer, LRScheduler | None]: training step, optimizer and scheduler with their
149
+ state_dict loaded.
150
+ """
151
+ training_state_dir = checkpoint_dir / TRAINING_STATE_DIR
152
+ if not training_state_dir.is_dir():
153
+ raise NotADirectoryError(training_state_dir)
154
+
155
+ load_rng_state(training_state_dir)
156
+ step = load_training_step(training_state_dir)
157
+ optimizer = load_optimizer_state(optimizer, training_state_dir)
158
+ if scheduler is not None:
159
+ scheduler = load_scheduler_state(scheduler, training_state_dir)
160
+
161
+ return step, optimizer, scheduler
lerobot/common/utils/utils.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import logging
17
+ import os
18
+ import os.path as osp
19
+ import platform
20
+ import subprocess
21
+ from copy import copy
22
+ from datetime import datetime, timezone
23
+ from pathlib import Path
24
+
25
+ import numpy as np
26
+ import torch
27
+
28
+
29
+ def none_or_int(value):
30
+ if value == "None":
31
+ return None
32
+ return int(value)
33
+
34
+
35
+ def inside_slurm():
36
+ """Check whether the python process was launched through slurm"""
37
+ # TODO(rcadene): return False for interactive mode `--pty bash`
38
+ return "SLURM_JOB_ID" in os.environ
39
+
40
+
41
+ def auto_select_torch_device() -> torch.device:
42
+ """Tries to select automatically a torch device."""
43
+ if torch.cuda.is_available():
44
+ logging.info("Cuda backend detected, using cuda.")
45
+ return torch.device("cuda")
46
+ elif torch.backends.mps.is_available():
47
+ logging.info("Metal backend detected, using cuda.")
48
+ return torch.device("mps")
49
+ else:
50
+ logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
51
+ return torch.device("cpu")
52
+
53
+
54
+ # TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
55
+ def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
56
+ """Given a string, return a torch.device with checks on whether the device is available."""
57
+ try_device = str(try_device)
58
+ match try_device:
59
+ case "cuda":
60
+ assert torch.cuda.is_available()
61
+ device = torch.device("cuda")
62
+ case "mps":
63
+ assert torch.backends.mps.is_available()
64
+ device = torch.device("mps")
65
+ case "cpu":
66
+ device = torch.device("cpu")
67
+ if log:
68
+ logging.warning("Using CPU, this will be slow.")
69
+ case _:
70
+ device = torch.device(try_device)
71
+ if log:
72
+ logging.warning(f"Using custom {try_device} device.")
73
+
74
+ return device
75
+
76
+
77
+ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
78
+ """
79
+ mps is currently not compatible with float64
80
+ """
81
+ if isinstance(device, torch.device):
82
+ device = device.type
83
+ if device == "mps" and dtype == torch.float64:
84
+ return torch.float32
85
+ else:
86
+ return dtype
87
+
88
+
89
+ def is_torch_device_available(try_device: str) -> bool:
90
+ try_device = str(try_device) # Ensure try_device is a string
91
+ if try_device == "cuda":
92
+ return torch.cuda.is_available()
93
+ elif try_device == "mps":
94
+ return torch.backends.mps.is_available()
95
+ elif try_device == "cpu":
96
+ return True
97
+ else:
98
+ raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
99
+
100
+
101
+ def is_amp_available(device: str):
102
+ if device in ["cuda", "cpu"]:
103
+ return True
104
+ elif device == "mps":
105
+ return False
106
+ else:
107
+ raise ValueError(f"Unknown device '{device}.")
108
+
109
+
110
+ def init_logging():
111
+ def custom_format(record):
112
+ dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
113
+ fnameline = f"{record.pathname}:{record.lineno}"
114
+ message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
115
+ return message
116
+
117
+ logging.basicConfig(level=logging.INFO)
118
+
119
+ for handler in logging.root.handlers[:]:
120
+ logging.root.removeHandler(handler)
121
+
122
+ formatter = logging.Formatter()
123
+ formatter.format = custom_format
124
+ console_handler = logging.StreamHandler()
125
+ console_handler.setFormatter(formatter)
126
+ logging.getLogger().addHandler(console_handler)
127
+
128
+
129
+ def format_big_number(num, precision=0):
130
+ suffixes = ["", "K", "M", "B", "T", "Q"]
131
+ divisor = 1000.0
132
+
133
+ for suffix in suffixes:
134
+ if abs(num) < divisor:
135
+ return f"{num:.{precision}f}{suffix}"
136
+ num /= divisor
137
+
138
+ return num
139
+
140
+
141
+ def _relative_path_between(path1: Path, path2: Path) -> Path:
142
+ """Returns path1 relative to path2."""
143
+ path1 = path1.absolute()
144
+ path2 = path2.absolute()
145
+ try:
146
+ return path1.relative_to(path2)
147
+ except ValueError: # most likely because path1 is not a subpath of path2
148
+ common_parts = Path(osp.commonpath([path1, path2])).parts
149
+ return Path(
150
+ "/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
151
+ )
152
+
153
+
154
+ def print_cuda_memory_usage():
155
+ """Use this function to locate and debug memory leak."""
156
+ import gc
157
+
158
+ gc.collect()
159
+ # Also clear the cache if you want to fully release the memory
160
+ torch.cuda.empty_cache()
161
+ print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
162
+ print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
163
+ print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
164
+ print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
165
+
166
+
167
+ def capture_timestamp_utc():
168
+ return datetime.now(timezone.utc)
169
+
170
+
171
+ def say(text, blocking=False):
172
+ system = platform.system()
173
+
174
+ if system == "Darwin":
175
+ cmd = ["say", text]
176
+
177
+ elif system == "Linux":
178
+ cmd = ["spd-say", text]
179
+ if blocking:
180
+ cmd.append("--wait")
181
+
182
+ elif system == "Windows":
183
+ cmd = [
184
+ "PowerShell",
185
+ "-Command",
186
+ "Add-Type -AssemblyName System.Speech; "
187
+ f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')",
188
+ ]
189
+
190
+ else:
191
+ raise RuntimeError("Unsupported operating system for text-to-speech.")
192
+
193
+ if blocking:
194
+ subprocess.run(cmd, check=True)
195
+ else:
196
+ subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
197
+
198
+
199
+ def log_say(text, play_sounds, blocking=False):
200
+ logging.info(text)
201
+
202
+ if play_sounds:
203
+ say(text, blocking)
204
+
205
+
206
+ def get_channel_first_image_shape(image_shape: tuple) -> tuple:
207
+ shape = copy(image_shape)
208
+ if shape[2] < shape[0] and shape[2] < shape[1]: # (h, w, c) -> (c, h, w)
209
+ shape = (shape[2], shape[0], shape[1])
210
+ elif not (shape[0] < shape[1] and shape[0] < shape[2]):
211
+ raise ValueError(image_shape)
212
+
213
+ return shape
214
+
215
+
216
+ def has_method(cls: object, method_name: str) -> bool:
217
+ return hasattr(cls, method_name) and callable(getattr(cls, method_name))
218
+
219
+
220
+ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
221
+ """
222
+ Return True if a given string can be converted to a numpy dtype.
223
+ """
224
+ try:
225
+ # Attempt to convert the string to a numpy dtype
226
+ np.dtype(dtype_str)
227
+ return True
228
+ except TypeError:
229
+ # If a TypeError is raised, the string is not a valid dtype
230
+ return False
lerobot/common/utils/wandb_utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import logging
17
+ import os
18
+ import re
19
+ from glob import glob
20
+ from pathlib import Path
21
+
22
+ from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
23
+ from termcolor import colored
24
+
25
+ from lerobot.common.constants import PRETRAINED_MODEL_DIR
26
+ from lerobot.configs.train import TrainPipelineConfig
27
+
28
+
29
+ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
30
+ """Return a group name for logging. Optionally returns group name as list."""
31
+ lst = [
32
+ f"policy:{cfg.policy.type}",
33
+ f"dataset:{cfg.dataset.repo_id}",
34
+ f"seed:{cfg.seed}",
35
+ ]
36
+ if cfg.env is not None:
37
+ lst.append(f"env:{cfg.env.type}")
38
+ return lst if return_list else "-".join(lst)
39
+
40
+
41
+ def get_wandb_run_id_from_filesystem(log_dir: Path) -> str:
42
+ # Get the WandB run ID.
43
+ paths = glob(str(log_dir / "wandb/latest-run/run-*"))
44
+ if len(paths) != 1:
45
+ raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
46
+ match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
47
+ if match is None:
48
+ raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
49
+ wandb_run_id = match.groups(0)[0]
50
+ return wandb_run_id
51
+
52
+
53
+ def get_safe_wandb_artifact_name(name: str):
54
+ """WandB artifacts don't accept ":" or "/" in their name."""
55
+ return name.replace(":", "_").replace("/", "_")
56
+
57
+
58
+ class WandBLogger:
59
+ """A helper class to log object using wandb."""
60
+
61
+ def __init__(self, cfg: TrainPipelineConfig):
62
+ self.cfg = cfg.wandb
63
+ self.log_dir = cfg.output_dir
64
+ self.job_name = cfg.job_name
65
+ self.env_fps = cfg.env.fps if cfg.env else None
66
+ self._group = cfg_to_group(cfg)
67
+
68
+ # Set up WandB.
69
+ os.environ["WANDB_SILENT"] = "True"
70
+ import wandb
71
+
72
+ wandb_run_id = (
73
+ cfg.wandb.run_id
74
+ if cfg.wandb.run_id
75
+ else get_wandb_run_id_from_filesystem(self.log_dir)
76
+ if cfg.resume
77
+ else None
78
+ )
79
+ wandb.init(
80
+ id=wandb_run_id,
81
+ project=self.cfg.project,
82
+ entity=self.cfg.entity,
83
+ name=self.job_name,
84
+ notes=self.cfg.notes,
85
+ tags=cfg_to_group(cfg, return_list=True),
86
+ dir=self.log_dir,
87
+ config=cfg.to_dict(),
88
+ # TODO(rcadene): try set to True
89
+ save_code=False,
90
+ # TODO(rcadene): split train and eval, and run async eval with job_type="eval"
91
+ job_type="train_eval",
92
+ resume="must" if cfg.resume else None,
93
+ )
94
+ print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
95
+ logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
96
+ self._wandb = wandb
97
+
98
+ def log_policy(self, checkpoint_dir: Path):
99
+ """Checkpoints the policy to wandb."""
100
+ if self.cfg.disable_artifact:
101
+ return
102
+
103
+ step_id = checkpoint_dir.name
104
+ artifact_name = f"{self._group}-{step_id}"
105
+ artifact_name = get_safe_wandb_artifact_name(artifact_name)
106
+ artifact = self._wandb.Artifact(artifact_name, type="model")
107
+ artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
108
+ self._wandb.log_artifact(artifact)
109
+
110
+ def log_dict(self, d: dict, step: int, mode: str = "train"):
111
+ if mode not in {"train", "eval"}:
112
+ raise ValueError(mode)
113
+
114
+ for k, v in d.items():
115
+ if not isinstance(v, (int, float, str)):
116
+ logging.warning(
117
+ f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
118
+ )
119
+ continue
120
+ self._wandb.log({f"{mode}/{k}": v}, step=step)
121
+
122
+ def log_video(self, video_path: str, step: int, mode: str = "train"):
123
+ if mode not in {"train", "eval"}:
124
+ raise ValueError(mode)
125
+
126
+ wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")
127
+ self._wandb.log({f"{mode}/video": wandb_video}, step=step)
lerobot/configs/default.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from dataclasses import dataclass, field
18
+
19
+ from lerobot.common import (
20
+ policies, # noqa: F401
21
+ )
22
+ from lerobot.common.datasets.transforms import ImageTransformsConfig
23
+ from lerobot.common.datasets.video_utils import get_safe_default_codec
24
+
25
+
26
+ @dataclass
27
+ class DatasetConfig:
28
+ # You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
29
+ # keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
30
+ # "dataset_index" into the returned item. The index mapping is made according to the order in which the
31
+ # datasets are provided.
32
+ repo_id: str
33
+ # Root directory where the dataset will be stored (e.g. 'dataset/path').
34
+ root: str | None = None
35
+ episodes: list[int] | None = None
36
+ image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
37
+ revision: str | None = None
38
+ use_imagenet_stats: bool = True
39
+ video_backend: str = field(default_factory=get_safe_default_codec)
40
+
41
+
42
+ @dataclass
43
+ class WandBConfig:
44
+ enable: bool = False
45
+ # Set to true to disable saving an artifact despite training.save_checkpoint=True
46
+ disable_artifact: bool = False
47
+ project: str = "lerobot"
48
+ entity: str | None = None
49
+ notes: str | None = None
50
+ run_id: str | None = None
51
+
52
+
53
+ @dataclass
54
+ class EvalConfig:
55
+ n_episodes: int = 50
56
+ # `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
57
+ batch_size: int = 50
58
+ # `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
59
+ use_async_envs: bool = False
60
+
61
+ def __post_init__(self):
62
+ if self.batch_size > self.n_episodes:
63
+ raise ValueError(
64
+ "The eval batch size is greater than the number of eval episodes "
65
+ f"({self.batch_size} > {self.n_episodes}). As a result, {self.batch_size} "
66
+ f"eval environments will be instantiated, but only {self.n_episodes} will be used. "
67
+ "This might significantly slow down evaluation. To fix this, you should update your command "
68
+ f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={self.batch_size}`), "
69
+ f"or lower the batch size (e.g. `eval.batch_size={self.n_episodes}`)."
70
+ )
lerobot/configs/eval.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import datetime as dt
16
+ import logging
17
+ from dataclasses import dataclass, field
18
+ from pathlib import Path
19
+
20
+ from lerobot.common import envs, policies # noqa: F401
21
+ from lerobot.configs import parser
22
+ from lerobot.configs.default import EvalConfig
23
+ from lerobot.configs.policies import PreTrainedConfig
24
+
25
+
26
+ @dataclass
27
+ class EvalPipelineConfig:
28
+ # Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
29
+ # saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch
30
+ # (useful for debugging). This argument is mutually exclusive with `--config`.
31
+ env: envs.EnvConfig
32
+ eval: EvalConfig = field(default_factory=EvalConfig)
33
+ policy: PreTrainedConfig | None = None
34
+ output_dir: Path | None = None
35
+ job_name: str | None = None
36
+ seed: int | None = 1000
37
+
38
+ def __post_init__(self):
39
+ # HACK: We parse again the cli args here to get the pretrained path if there was one.
40
+ policy_path = parser.get_path_arg("policy")
41
+ if policy_path:
42
+ cli_overrides = parser.get_cli_overrides("policy")
43
+ self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
44
+ self.policy.pretrained_path = policy_path
45
+
46
+ else:
47
+ logging.warning(
48
+ "No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
49
+ )
50
+
51
+ if not self.job_name:
52
+ if self.env is None:
53
+ self.job_name = f"{self.policy.type}"
54
+ else:
55
+ self.job_name = f"{self.env.type}_{self.policy.type}"
56
+
57
+ if not self.output_dir:
58
+ now = dt.datetime.now()
59
+ eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
60
+ self.output_dir = Path("outputs/eval") / eval_dir
61
+
62
+ @classmethod
63
+ def __get_path_fields__(cls) -> list[str]:
64
+ """This enables the parser to load config from the policy using `--policy.path=local/dir`"""
65
+ return ["policy"]
lerobot/configs/parser.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import importlib
15
+ import inspect
16
+ import pkgutil
17
+ import sys
18
+ from argparse import ArgumentError
19
+ from functools import wraps
20
+ from pathlib import Path
21
+ from typing import Sequence
22
+
23
+ import draccus
24
+
25
+ from lerobot.common.utils.utils import has_method
26
+
27
+ PATH_KEY = "path"
28
+ PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
29
+ draccus.set_config_type("json")
30
+
31
+
32
+ def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
33
+ """Parses arguments from cli at a given nested attribute level.
34
+
35
+ For example, supposing the main script was called with:
36
+ python myscript.py --arg1=1 --arg2.subarg1=abc --arg2.subarg2=some/path
37
+
38
+ If called during execution of myscript.py, get_cli_overrides("arg2") will return:
39
+ ["--subarg1=abc" "--subarg2=some/path"]
40
+ """
41
+ if args is None:
42
+ args = sys.argv[1:]
43
+ attr_level_args = []
44
+ detect_string = f"--{field_name}."
45
+ exclude_strings = (f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", f"--{field_name}.{PATH_KEY}=")
46
+ for arg in args:
47
+ if arg.startswith(detect_string) and not arg.startswith(exclude_strings):
48
+ denested_arg = f"--{arg.removeprefix(detect_string)}"
49
+ attr_level_args.append(denested_arg)
50
+
51
+ return attr_level_args
52
+
53
+
54
+ def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
55
+ if args is None:
56
+ args = sys.argv[1:]
57
+ prefix = f"--{arg_name}="
58
+ for arg in args:
59
+ if arg.startswith(prefix):
60
+ return arg[len(prefix) :]
61
+ return None
62
+
63
+
64
+ def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
65
+ """Parse plugin-related arguments from command-line arguments.
66
+
67
+ This function extracts arguments from command-line arguments that match a specified suffix pattern.
68
+ It processes arguments in the format '--key=value' and returns them as a dictionary.
69
+
70
+ Args:
71
+ plugin_arg_suffix (str): The suffix to identify plugin-related arguments.
72
+ cli_args (Sequence[str]): A sequence of command-line arguments to parse.
73
+
74
+ Returns:
75
+ dict: A dictionary containing the parsed plugin arguments where:
76
+ - Keys are the argument names (with '--' prefix removed if present)
77
+ - Values are the corresponding argument values
78
+
79
+ Example:
80
+ >>> args = ['--env.discover_packages_path=my_package',
81
+ ... '--other_arg=value']
82
+ >>> parse_plugin_args('discover_packages_path', args)
83
+ {'env.discover_packages_path': 'my_package'}
84
+ """
85
+ plugin_args = {}
86
+ for arg in args:
87
+ if "=" in arg and plugin_arg_suffix in arg:
88
+ key, value = arg.split("=", 1)
89
+ # Remove leading '--' if present
90
+ if key.startswith("--"):
91
+ key = key[2:]
92
+ plugin_args[key] = value
93
+ return plugin_args
94
+
95
+
96
+ class PluginLoadError(Exception):
97
+ """Raised when a plugin fails to load."""
98
+
99
+
100
+ def load_plugin(plugin_path: str) -> None:
101
+ """Load and initialize a plugin from a given Python package path.
102
+
103
+ This function attempts to load a plugin by importing its package and any submodules.
104
+ Plugin registration is expected to happen during package initialization, i.e. when
105
+ the package is imported the gym environment should be registered and the config classes
106
+ registered with their parents using the `register_subclass` decorator.
107
+
108
+ Args:
109
+ plugin_path (str): The Python package path to the plugin (e.g. "mypackage.plugins.myplugin")
110
+
111
+ Raises:
112
+ PluginLoadError: If the plugin cannot be loaded due to import errors or if the package path is invalid.
113
+
114
+ Examples:
115
+ >>> load_plugin("external_plugin.core") # Loads plugin from external package
116
+
117
+ Notes:
118
+ - The plugin package should handle its own registration during import
119
+ - All submodules in the plugin package will be imported
120
+ - Implementation follows the plugin discovery pattern from Python packaging guidelines
121
+
122
+ See Also:
123
+ https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/
124
+ """
125
+ try:
126
+ package_module = importlib.import_module(plugin_path, __package__)
127
+ except (ImportError, ModuleNotFoundError) as e:
128
+ raise PluginLoadError(
129
+ f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
130
+ ) from e
131
+
132
+ def iter_namespace(ns_pkg):
133
+ return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
134
+
135
+ try:
136
+ for _finder, pkg_name, _ispkg in iter_namespace(package_module):
137
+ importlib.import_module(pkg_name)
138
+ except ImportError as e:
139
+ raise PluginLoadError(
140
+ f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
141
+ ) from e
142
+
143
+
144
+ def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
145
+ return parse_arg(f"{field_name}.{PATH_KEY}", args)
146
+
147
+
148
+ def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
149
+ return parse_arg(f"{field_name}.{draccus.CHOICE_TYPE_KEY}", args)
150
+
151
+
152
+ def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
153
+ return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
154
+
155
+
156
+ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]:
157
+ """
158
+ Filters command-line arguments related to fields with specific path arguments.
159
+
160
+ Args:
161
+ fields_to_filter (str | list[str]): A single str or a list of str whose arguments need to be filtered.
162
+ args (Sequence[str] | None): The sequence of command-line arguments to be filtered.
163
+ Defaults to None.
164
+
165
+ Returns:
166
+ list[str]: A filtered list of arguments, with arguments related to the specified
167
+ fields removed.
168
+
169
+ Raises:
170
+ ArgumentError: If both a path argument (e.g., `--field_name.path`) and a type
171
+ argument (e.g., `--field_name.type`) are specified for the same field.
172
+ """
173
+ if isinstance(fields_to_filter, str):
174
+ fields_to_filter = [fields_to_filter]
175
+
176
+ filtered_args = args
177
+ for field in fields_to_filter:
178
+ if get_path_arg(field, args):
179
+ if get_type_arg(field, args):
180
+ raise ArgumentError(
181
+ argument=None,
182
+ message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}",
183
+ )
184
+ filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")]
185
+
186
+ return filtered_args
187
+
188
+
189
+ def wrap(config_path: Path | None = None):
190
+ """
191
+ HACK: Similar to draccus.wrap but does three additional things:
192
+ - Will remove '.path' arguments from CLI in order to process them later on.
193
+ - If a 'config_path' is passed and the main config class has a 'from_pretrained' method, will
194
+ initialize it from there to allow to fetch configs from the hub directly
195
+ - Will load plugins specified in the CLI arguments. These plugins will typically register
196
+ their own subclasses of config classes, so that draccus can find the right class to instantiate
197
+ from the CLI '.type' arguments
198
+ """
199
+
200
+ def wrapper_outer(fn):
201
+ @wraps(fn)
202
+ def wrapper_inner(*args, **kwargs):
203
+ argspec = inspect.getfullargspec(fn)
204
+ argtype = argspec.annotations[argspec.args[0]]
205
+ if len(args) > 0 and type(args[0]) is argtype:
206
+ cfg = args[0]
207
+ args = args[1:]
208
+ else:
209
+ cli_args = sys.argv[1:]
210
+ plugin_args = parse_plugin_args(PLUGIN_DISCOVERY_SUFFIX, cli_args)
211
+ for plugin_cli_arg, plugin_path in plugin_args.items():
212
+ try:
213
+ load_plugin(plugin_path)
214
+ except PluginLoadError as e:
215
+ # add the relevant CLI arg to the error message
216
+ raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e
217
+ cli_args = filter_arg(plugin_cli_arg, cli_args)
218
+ config_path_cli = parse_arg("config_path", cli_args)
219
+ if has_method(argtype, "__get_path_fields__"):
220
+ path_fields = argtype.__get_path_fields__()
221
+ cli_args = filter_path_args(path_fields, cli_args)
222
+ if has_method(argtype, "from_pretrained") and config_path_cli:
223
+ cli_args = filter_arg("config_path", cli_args)
224
+ cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
225
+ else:
226
+ cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
227
+ response = fn(cfg, *args, **kwargs)
228
+ return response
229
+
230
+ return wrapper_inner
231
+
232
+ return wrapper_outer
lerobot/configs/policies.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import abc
15
+ import logging
16
+ import os
17
+ from dataclasses import dataclass, field
18
+ from pathlib import Path
19
+ from typing import Type, TypeVar
20
+
21
+ import draccus
22
+ from huggingface_hub import hf_hub_download
23
+ from huggingface_hub.constants import CONFIG_NAME
24
+ from huggingface_hub.errors import HfHubHTTPError
25
+
26
+ from lerobot.common.optim.optimizers import OptimizerConfig
27
+ from lerobot.common.optim.schedulers import LRSchedulerConfig
28
+ from lerobot.common.utils.hub import HubMixin
29
+ from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
30
+ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
31
+
32
+ # Generic variable that is either PreTrainedConfig or a subclass thereof
33
+ T = TypeVar("T", bound="PreTrainedConfig")
34
+
35
+
36
+ @dataclass
37
+ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
38
+ """
39
+ Base configuration class for policy models.
40
+
41
+ Args:
42
+ n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
43
+ current step and additional steps going back).
44
+ input_shapes: A dictionary defining the shapes of the input data for the policy.
45
+ output_shapes: A dictionary defining the shapes of the output data for the policy.
46
+ input_normalization_modes: A dictionary with key representing the modality and the value specifies the
47
+ normalization mode to apply.
48
+ output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
49
+ the original scale.
50
+ """
51
+
52
+ n_obs_steps: int = 1
53
+ normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
54
+
55
+ input_features: dict[str, PolicyFeature] = field(default_factory=dict)
56
+ output_features: dict[str, PolicyFeature] = field(default_factory=dict)
57
+
58
+ device: str | None = None # cuda | cpu | mp
59
+ # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
60
+ # automatic gradient scaling is used.
61
+ use_amp: bool = False
62
+
63
+ def __post_init__(self):
64
+ self.pretrained_path = None
65
+ if not self.device or not is_torch_device_available(self.device):
66
+ auto_device = auto_select_torch_device()
67
+ logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
68
+ self.device = auto_device.type
69
+
70
+ # Automatically deactivate AMP if necessary
71
+ if self.use_amp and not is_amp_available(self.device):
72
+ logging.warning(
73
+ f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
74
+ )
75
+ self.use_amp = False
76
+
77
+ @property
78
+ def type(self) -> str:
79
+ return self.get_choice_name(self.__class__)
80
+
81
+ @abc.abstractproperty
82
+ def observation_delta_indices(self) -> list | None:
83
+ raise NotImplementedError
84
+
85
+ @abc.abstractproperty
86
+ def action_delta_indices(self) -> list | None:
87
+ raise NotImplementedError
88
+
89
+ @abc.abstractproperty
90
+ def reward_delta_indices(self) -> list | None:
91
+ raise NotImplementedError
92
+
93
+ @abc.abstractmethod
94
+ def get_optimizer_preset(self) -> OptimizerConfig:
95
+ raise NotImplementedError
96
+
97
+ @abc.abstractmethod
98
+ def get_scheduler_preset(self) -> LRSchedulerConfig | None:
99
+ raise NotImplementedError
100
+
101
+ @abc.abstractmethod
102
+ def validate_features(self) -> None:
103
+ raise NotImplementedError
104
+
105
+ @property
106
+ def robot_state_feature(self) -> PolicyFeature | None:
107
+ for _, ft in self.input_features.items():
108
+ if ft.type is FeatureType.STATE:
109
+ return ft
110
+ return None
111
+
112
+ @property
113
+ def env_state_feature(self) -> PolicyFeature | None:
114
+ for _, ft in self.input_features.items():
115
+ if ft.type is FeatureType.ENV:
116
+ return ft
117
+ return None
118
+
119
+ @property
120
+ def image_features(self) -> dict[str, PolicyFeature]:
121
+ return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
122
+
123
+ @property
124
+ def action_feature(self) -> PolicyFeature | None:
125
+ for _, ft in self.output_features.items():
126
+ if ft.type is FeatureType.ACTION:
127
+ return ft
128
+ return None
129
+
130
+ def _save_pretrained(self, save_directory: Path) -> None:
131
+ with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
132
+ draccus.dump(self, f, indent=4)
133
+
134
+ @classmethod
135
+ def from_pretrained(
136
+ cls: Type[T],
137
+ pretrained_name_or_path: str | Path,
138
+ *,
139
+ force_download: bool = False,
140
+ resume_download: bool = None,
141
+ proxies: dict | None = None,
142
+ token: str | bool | None = None,
143
+ cache_dir: str | Path | None = None,
144
+ local_files_only: bool = False,
145
+ revision: str | None = None,
146
+ **policy_kwargs,
147
+ ) -> T:
148
+ model_id = str(pretrained_name_or_path)
149
+ config_file: str | None = None
150
+ if Path(model_id).is_dir():
151
+ if CONFIG_NAME in os.listdir(model_id):
152
+ config_file = os.path.join(model_id, CONFIG_NAME)
153
+ else:
154
+ print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
155
+ else:
156
+ try:
157
+ config_file = hf_hub_download(
158
+ repo_id=model_id,
159
+ filename=CONFIG_NAME,
160
+ revision=revision,
161
+ cache_dir=cache_dir,
162
+ force_download=force_download,
163
+ proxies=proxies,
164
+ resume_download=resume_download,
165
+ token=token,
166
+ local_files_only=local_files_only,
167
+ )
168
+ except HfHubHTTPError as e:
169
+ raise FileNotFoundError(
170
+ f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
171
+ ) from e
172
+
173
+ # HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus
174
+ # something like --policy.path (in addition to --policy.type)
175
+ cli_overrides = policy_kwargs.pop("cli_overrides", [])
176
+ return draccus.parse(cls, config_file, args=cli_overrides)
lerobot/configs/train.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import datetime as dt
15
+ import os
16
+ from dataclasses import dataclass, field
17
+ from pathlib import Path
18
+ from typing import Type
19
+
20
+ import draccus
21
+ from huggingface_hub import hf_hub_download
22
+ from huggingface_hub.errors import HfHubHTTPError
23
+
24
+ from lerobot.common import envs
25
+ from lerobot.common.optim import OptimizerConfig
26
+ from lerobot.common.optim.schedulers import LRSchedulerConfig
27
+ from lerobot.common.utils.hub import HubMixin
28
+ from lerobot.configs import parser
29
+ from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
30
+ from lerobot.configs.policies import PreTrainedConfig
31
+
32
+ TRAIN_CONFIG_NAME = "train_config.json"
33
+
34
+
35
+ @dataclass
36
+ class TrainPipelineConfig(HubMixin):
37
+ dataset: DatasetConfig
38
+ env: envs.EnvConfig | None = None
39
+ policy: PreTrainedConfig | None = None
40
+ # Set `dir` to where you would like to save all of the run outputs. If you run another training session
41
+ # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
42
+ output_dir: Path | None = None
43
+ job_name: str | None = None
44
+ # Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
45
+ # `dir` is the directory of an existing run with at least one checkpoint in it.
46
+ # Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
47
+ # regardless of what's provided with the training command at the time of resumption.
48
+ resume: bool = False
49
+ # `seed` is used for training (eg: model initialization, dataset shuffling)
50
+ # AND for the evaluation environments.
51
+ seed: int | None = 1000
52
+ # Number of workers for the dataloader.
53
+ num_workers: int = 4
54
+ batch_size: int = 8
55
+ steps: int = 100_000
56
+ eval_freq: int = 20_000
57
+ log_freq: int = 200
58
+ save_checkpoint: bool = True
59
+ # Checkpoint is saved every `save_freq` training iterations and after the last training step.
60
+ save_freq: int = 20_000
61
+ use_policy_training_preset: bool = True
62
+ optimizer: OptimizerConfig | None = None
63
+ scheduler: LRSchedulerConfig | None = None
64
+ eval: EvalConfig = field(default_factory=EvalConfig)
65
+ wandb: WandBConfig = field(default_factory=WandBConfig)
66
+
67
+ def __post_init__(self):
68
+ self.checkpoint_path = None
69
+
70
+ def validate(self):
71
+ # HACK: We parse again the cli args here to get the pretrained paths if there was some.
72
+ policy_path = parser.get_path_arg("policy")
73
+ if policy_path:
74
+ # Only load the policy config
75
+ cli_overrides = parser.get_cli_overrides("policy")
76
+ self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
77
+ self.policy.pretrained_path = policy_path
78
+ elif self.resume:
79
+ # The entire train config is already loaded, we just need to get the checkpoint dir
80
+ config_path = parser.parse_arg("config_path")
81
+ if not config_path:
82
+ raise ValueError(
83
+ f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
84
+ )
85
+ if not Path(config_path).resolve().exists():
86
+ raise NotADirectoryError(
87
+ f"{config_path=} is expected to be a local path. "
88
+ "Resuming from the hub is not supported for now."
89
+ )
90
+ policy_path = Path(config_path).parent
91
+ self.policy.pretrained_path = policy_path
92
+ self.checkpoint_path = policy_path.parent
93
+
94
+ if not self.job_name:
95
+ if self.env is None:
96
+ self.job_name = f"{self.policy.type}"
97
+ else:
98
+ self.job_name = f"{self.env.type}_{self.policy.type}"
99
+
100
+ if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
101
+ raise FileExistsError(
102
+ f"Output directory {self.output_dir} already exists and resume is {self.resume}. "
103
+ f"Please change your output directory so that {self.output_dir} is not overwritten."
104
+ )
105
+ elif not self.output_dir:
106
+ now = dt.datetime.now()
107
+ train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
108
+ self.output_dir = Path("outputs/train") / train_dir
109
+
110
+ if isinstance(self.dataset.repo_id, list):
111
+ raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
112
+
113
+ if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
114
+ raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
115
+ elif self.use_policy_training_preset and not self.resume:
116
+ self.optimizer = self.policy.get_optimizer_preset()
117
+ self.scheduler = self.policy.get_scheduler_preset()
118
+
119
+ @classmethod
120
+ def __get_path_fields__(cls) -> list[str]:
121
+ """This enables the parser to load config from the policy using `--policy.path=local/dir`"""
122
+ return ["policy"]
123
+
124
+ def to_dict(self) -> dict:
125
+ return draccus.encode(self)
126
+
127
+ def _save_pretrained(self, save_directory: Path) -> None:
128
+ with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
129
+ draccus.dump(self, f, indent=4)
130
+
131
+ @classmethod
132
+ def from_pretrained(
133
+ cls: Type["TrainPipelineConfig"],
134
+ pretrained_name_or_path: str | Path,
135
+ *,
136
+ force_download: bool = False,
137
+ resume_download: bool = None,
138
+ proxies: dict | None = None,
139
+ token: str | bool | None = None,
140
+ cache_dir: str | Path | None = None,
141
+ local_files_only: bool = False,
142
+ revision: str | None = None,
143
+ **kwargs,
144
+ ) -> "TrainPipelineConfig":
145
+ model_id = str(pretrained_name_or_path)
146
+ config_file: str | None = None
147
+ if Path(model_id).is_dir():
148
+ if TRAIN_CONFIG_NAME in os.listdir(model_id):
149
+ config_file = os.path.join(model_id, TRAIN_CONFIG_NAME)
150
+ else:
151
+ print(f"{TRAIN_CONFIG_NAME} not found in {Path(model_id).resolve()}")
152
+ elif Path(model_id).is_file():
153
+ config_file = model_id
154
+ else:
155
+ try:
156
+ config_file = hf_hub_download(
157
+ repo_id=model_id,
158
+ filename=TRAIN_CONFIG_NAME,
159
+ revision=revision,
160
+ cache_dir=cache_dir,
161
+ force_download=force_download,
162
+ proxies=proxies,
163
+ resume_download=resume_download,
164
+ token=token,
165
+ local_files_only=local_files_only,
166
+ )
167
+ except HfHubHTTPError as e:
168
+ raise FileNotFoundError(
169
+ f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
170
+ ) from e
171
+
172
+ cli_args = kwargs.pop("cli_args", [])
173
+ cfg = draccus.parse(cls, config_file, args=cli_args)
174
+
175
+ return cfg
lerobot/configs/types.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Note: We subclass str so that serialization is straightforward
15
+ # https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
16
+ from dataclasses import dataclass
17
+ from enum import Enum
18
+ from typing import Any, Protocol
19
+
20
+
21
+ class FeatureType(str, Enum):
22
+ STATE = "STATE"
23
+ VISUAL = "VISUAL"
24
+ ENV = "ENV"
25
+ ACTION = "ACTION"
26
+
27
+
28
+ class NormalizationMode(str, Enum):
29
+ MIN_MAX = "MIN_MAX"
30
+ MEAN_STD = "MEAN_STD"
31
+ IDENTITY = "IDENTITY"
32
+
33
+
34
+ class DictLike(Protocol):
35
+ def __getitem__(self, key: Any) -> Any: ...
36
+
37
+
38
+ @dataclass
39
+ class PolicyFeature:
40
+ type: FeatureType
41
+ shape: tuple
lerobot/scripts/configure_motor.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ This script configure a single motor at a time to a given ID and baudrate.
16
+
17
+ Example of usage:
18
+ ```bash
19
+ python lerobot/scripts/configure_motor.py \
20
+ --port /dev/tty.usbmodem585A0080521 \
21
+ --brand feetech \
22
+ --model sts3215 \
23
+ --baudrate 1000000 \
24
+ --ID 1
25
+ ```
26
+ """
27
+
28
+ import argparse
29
+ import time
30
+
31
+
32
+ def get_motor_bus_cls(brand: str) -> tuple:
33
+ if brand == "feetech":
34
+ from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig
35
+ from lerobot.common.robot_devices.motors.feetech import (
36
+ MODEL_BAUDRATE_TABLE,
37
+ SCS_SERIES_BAUDRATE_TABLE,
38
+ FeetechMotorsBus,
39
+ )
40
+
41
+ return FeetechMotorsBusConfig, FeetechMotorsBus, MODEL_BAUDRATE_TABLE, SCS_SERIES_BAUDRATE_TABLE
42
+
43
+ elif brand == "dynamixel":
44
+ from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
45
+ from lerobot.common.robot_devices.motors.dynamixel import (
46
+ MODEL_BAUDRATE_TABLE,
47
+ X_SERIES_BAUDRATE_TABLE,
48
+ DynamixelMotorsBus,
49
+ )
50
+
51
+ return DynamixelMotorsBusConfig, DynamixelMotorsBus, MODEL_BAUDRATE_TABLE, X_SERIES_BAUDRATE_TABLE
52
+
53
+ else:
54
+ raise ValueError(
55
+ f"Currently we do not support this motor brand: {brand}. We currently support feetech and dynamixel motors."
56
+ )
57
+
58
+
59
+ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
60
+ motor_bus_config_cls, motor_bus_cls, model_baudrate_table, series_baudrate_table = get_motor_bus_cls(
61
+ brand
62
+ )
63
+
64
+ # Check if the provided model exists in the model_baud_rate_table
65
+ if model not in model_baudrate_table:
66
+ raise ValueError(
67
+ f"Invalid model '{model}' for brand '{brand}'. Supported models: {list(model_baudrate_table.keys())}"
68
+ )
69
+
70
+ # Setup motor names, indices, and models
71
+ motor_name = "motor"
72
+ motor_index_arbitrary = motor_idx_des # Use the motor ID passed via argument
73
+ motor_model = model # Use the motor model passed via argument
74
+
75
+ config = motor_bus_config_cls(port=port, motors={motor_name: (motor_index_arbitrary, motor_model)})
76
+
77
+ # Initialize the MotorBus with the correct port and motor configurations
78
+ motor_bus = motor_bus_cls(config=config)
79
+
80
+ # Try to connect to the motor bus and handle any connection-specific errors
81
+ try:
82
+ motor_bus.connect()
83
+ print(f"Connected on port {motor_bus.port}")
84
+ except OSError as e:
85
+ print(f"Error occurred when connecting to the motor bus: {e}")
86
+ return
87
+
88
+ # Motor bus is connected, proceed with the rest of the operations
89
+ try:
90
+ print("Scanning all baudrates and motor indices")
91
+ all_baudrates = set(series_baudrate_table.values())
92
+ motor_index = -1 # Set the motor index to an out-of-range value.
93
+
94
+ for baudrate in all_baudrates:
95
+ motor_bus.set_bus_baudrate(baudrate)
96
+ present_ids = motor_bus.find_motor_indices(list(range(1, 10)))
97
+ if len(present_ids) > 1:
98
+ raise ValueError(
99
+ "Error: More than one motor ID detected. This script is designed to only handle one motor at a time. Please disconnect all but one motor."
100
+ )
101
+
102
+ if len(present_ids) == 1:
103
+ if motor_index != -1:
104
+ raise ValueError(
105
+ "Error: More than one motor ID detected. This script is designed to only handle one motor at a time. Please disconnect all but one motor."
106
+ )
107
+ motor_index = present_ids[0]
108
+ break
109
+
110
+ if motor_index == -1:
111
+ raise ValueError("No motors detected. Please ensure you have one motor connected.")
112
+
113
+ print(f"Motor index found at: {motor_index}")
114
+
115
+ if brand == "feetech":
116
+ # Allows ID and BAUDRATE to be written in memory
117
+ motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
118
+
119
+ if baudrate != baudrate_des:
120
+ print(f"Setting its baudrate to {baudrate_des}")
121
+ baudrate_idx = list(series_baudrate_table.values()).index(baudrate_des)
122
+
123
+ # The write can fail, so we allow retries
124
+ motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx)
125
+ time.sleep(0.5)
126
+ motor_bus.set_bus_baudrate(baudrate_des)
127
+ present_baudrate_idx = motor_bus.read_with_motor_ids(
128
+ motor_bus.motor_models, motor_index, "Baud_Rate", num_retry=2
129
+ )
130
+
131
+ if present_baudrate_idx != baudrate_idx:
132
+ raise OSError("Failed to write baudrate.")
133
+
134
+ print(f"Setting its index to desired index {motor_idx_des}")
135
+ if brand == "feetech":
136
+ motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
137
+ motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des)
138
+
139
+ present_idx = motor_bus.read_with_motor_ids(motor_bus.motor_models, motor_idx_des, "ID", num_retry=2)
140
+ if present_idx != motor_idx_des:
141
+ raise OSError("Failed to write index.")
142
+
143
+ if brand == "feetech":
144
+ # Set Maximum_Acceleration to 254 to speedup acceleration and deceleration of
145
+ # the motors. Note: this configuration is not in the official STS3215 Memory Table
146
+ motor_bus.write("Lock", 0)
147
+ motor_bus.write("Maximum_Acceleration", 254)
148
+
149
+ motor_bus.write("Goal_Position", 2048)
150
+ time.sleep(4)
151
+ print("Present Position", motor_bus.read("Present_Position"))
152
+
153
+ motor_bus.write("Offset", 0)
154
+ time.sleep(4)
155
+ print("Offset", motor_bus.read("Offset"))
156
+
157
+ except Exception as e:
158
+ print(f"Error occurred during motor configuration: {e}")
159
+
160
+ finally:
161
+ motor_bus.disconnect()
162
+ print("Disconnected from motor bus.")
163
+
164
+
165
+ if __name__ == "__main__":
166
+ parser = argparse.ArgumentParser()
167
+ parser.add_argument("--port", type=str, required=True, help="Motors bus port (e.g. dynamixel,feetech)")
168
+ parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)")
169
+ parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)")
170
+ parser.add_argument("--ID", type=int, required=True, help="Desired ID of the current motor (e.g. 1,2,3)")
171
+ parser.add_argument(
172
+ "--baudrate", type=int, default=1000000, help="Desired baudrate for the motor (default: 1000000)"
173
+ )
174
+ args = parser.parse_args()
175
+
176
+ configure_motor(args.port, args.brand, args.model, args.ID, args.baudrate)
lerobot/scripts/control_robot.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Utilities to control a robot.
16
+
17
+ Useful to record a dataset, replay a recorded episode, run the policy on your robot
18
+ and record an evaluation dataset, and to recalibrate your robot if needed.
19
+
20
+ Examples of usage:
21
+
22
+ - Recalibrate your robot:
23
+ ```bash
24
+ python lerobot/scripts/control_robot.py \
25
+ --robot.type=so100 \
26
+ --control.type=calibrate
27
+ ```
28
+
29
+ - Unlimited teleoperation at highest frequency (~200 Hz is expected), to exit with CTRL+C:
30
+ ```bash
31
+ python lerobot/scripts/control_robot.py \
32
+ --robot.type=so100 \
33
+ --robot.cameras='{}' \
34
+ --control.type=teleoperate
35
+
36
+ # Add the cameras from the robot definition to visualize them:
37
+ python lerobot/scripts/control_robot.py \
38
+ --robot.type=so100 \
39
+ --control.type=teleoperate
40
+ ```
41
+
42
+ - Unlimited teleoperation at a limited frequency of 30 Hz, to simulate data recording frequency:
43
+ ```bash
44
+ python lerobot/scripts/control_robot.py \
45
+ --robot.type=so100 \
46
+ --control.type=teleoperate \
47
+ --control.fps=30
48
+ ```
49
+
50
+ - Record one episode in order to test replay:
51
+ ```bash
52
+ python lerobot/scripts/control_robot.py \
53
+ --robot.type=so100 \
54
+ --control.type=record \
55
+ --control.fps=30 \
56
+ --control.single_task="Grasp a lego block and put it in the bin." \
57
+ --control.repo_id=$USER/koch_test \
58
+ --control.num_episodes=1 \
59
+ --control.push_to_hub=True
60
+ ```
61
+
62
+ - Visualize dataset:
63
+ ```bash
64
+ python lerobot/scripts/visualize_dataset.py \
65
+ --repo-id $USER/koch_test \
66
+ --episode-index 0
67
+ ```
68
+
69
+ - Replay this test episode:
70
+ ```bash
71
+ python lerobot/scripts/control_robot.py replay \
72
+ --robot.type=so100 \
73
+ --control.type=replay \
74
+ --control.fps=30 \
75
+ --control.repo_id=$USER/koch_test \
76
+ --control.episode=0
77
+ ```
78
+
79
+ - Record a full dataset in order to train a policy, with 2 seconds of warmup,
80
+ 30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes:
81
+ ```bash
82
+ python lerobot/scripts/control_robot.py record \
83
+ --robot.type=so100 \
84
+ --control.type=record \
85
+ --control.fps 30 \
86
+ --control.repo_id=$USER/koch_pick_place_lego \
87
+ --control.num_episodes=50 \
88
+ --control.warmup_time_s=2 \
89
+ --control.episode_time_s=30 \
90
+ --control.reset_time_s=10
91
+ ```
92
+
93
+ - For remote controlled robots like LeKiwi, run this script on the robot edge device (e.g. RaspBerryPi):
94
+ ```bash
95
+ python lerobot/scripts/control_robot.py \
96
+ --robot.type=lekiwi \
97
+ --control.type=remote_robot
98
+ ```
99
+
100
+ **NOTE**: You can use your keyboard to control data recording flow.
101
+ - Tap right arrow key '->' to early exit while recording an episode and go to resseting the environment.
102
+ - Tap right arrow key '->' to early exit while resetting the environment and got to recording the next episode.
103
+ - Tap left arrow key '<-' to early exit and re-record the current episode.
104
+ - Tap escape key 'esc' to stop the data recording.
105
+ This might require a sudo permission to allow your terminal to monitor keyboard events.
106
+
107
+ **NOTE**: You can resume/continue data recording by running the same data recording command and adding `--control.resume=true`.
108
+
109
+ - Train on this dataset with the ACT policy:
110
+ ```bash
111
+ python lerobot/scripts/train.py \
112
+ --dataset.repo_id=${HF_USER}/koch_pick_place_lego \
113
+ --policy.type=act \
114
+ --output_dir=outputs/train/act_koch_pick_place_lego \
115
+ --job_name=act_koch_pick_place_lego \
116
+ --device=cuda \
117
+ --wandb.enable=true
118
+ ```
119
+
120
+ - Run the pretrained policy on the robot:
121
+ ```bash
122
+ python lerobot/scripts/control_robot.py \
123
+ --robot.type=so100 \
124
+ --control.type=record \
125
+ --control.fps=30 \
126
+ --control.single_task="Grasp a lego block and put it in the bin." \
127
+ --control.repo_id=$USER/eval_act_koch_pick_place_lego \
128
+ --control.num_episodes=10 \
129
+ --control.warmup_time_s=2 \
130
+ --control.episode_time_s=30 \
131
+ --control.reset_time_s=10 \
132
+ --control.push_to_hub=true \
133
+ --control.policy.path=outputs/train/act_koch_pick_place_lego/checkpoints/080000/pretrained_model
134
+ ```
135
+ """
136
+
137
+ import logging
138
+ import time
139
+ from dataclasses import asdict
140
+ from pprint import pformat
141
+
142
+ # from safetensors.torch import load_file, save_file
143
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
144
+ from lerobot.common.policies.factory import make_policy
145
+ from lerobot.common.robot_devices.control_configs import (
146
+ CalibrateControlConfig,
147
+ ControlPipelineConfig,
148
+ RecordControlConfig,
149
+ RemoteRobotConfig,
150
+ ReplayControlConfig,
151
+ TeleoperateControlConfig,
152
+ )
153
+ from lerobot.common.robot_devices.control_utils import (
154
+ control_loop,
155
+ init_keyboard_listener,
156
+ log_control_info,
157
+ record_episode,
158
+ reset_environment,
159
+ sanity_check_dataset_name,
160
+ sanity_check_dataset_robot_compatibility,
161
+ stop_recording,
162
+ warmup_record,
163
+ )
164
+ from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config
165
+ from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
166
+ from lerobot.common.utils.utils import has_method, init_logging, log_say
167
+ from lerobot.configs import parser
168
+
169
+ ########################################################################################
170
+ # Control modes
171
+ ########################################################################################
172
+
173
+
174
+ @safe_disconnect
175
+ def calibrate(robot: Robot, cfg: CalibrateControlConfig):
176
+ # TODO(aliberts): move this code in robots' classes
177
+ if robot.robot_type.startswith("stretch"):
178
+ if not robot.is_connected:
179
+ robot.connect()
180
+ if not robot.is_homed():
181
+ robot.home()
182
+ return
183
+
184
+ arms = robot.available_arms if cfg.arms is None else cfg.arms
185
+ unknown_arms = [arm_id for arm_id in arms if arm_id not in robot.available_arms]
186
+ available_arms_str = " ".join(robot.available_arms)
187
+ unknown_arms_str = " ".join(unknown_arms)
188
+
189
+ if arms is None or len(arms) == 0:
190
+ raise ValueError(
191
+ "No arm provided. Use `--arms` as argument with one or more available arms.\n"
192
+ f"For instance, to recalibrate all arms add: `--arms {available_arms_str}`"
193
+ )
194
+
195
+ if len(unknown_arms) > 0:
196
+ raise ValueError(
197
+ f"Unknown arms provided ('{unknown_arms_str}'). Available arms are `{available_arms_str}`."
198
+ )
199
+
200
+ for arm_id in arms:
201
+ arm_calib_path = robot.calibration_dir / f"{arm_id}.json"
202
+ if arm_calib_path.exists():
203
+ print(f"Removing '{arm_calib_path}'")
204
+ arm_calib_path.unlink()
205
+ else:
206
+ print(f"Calibration file not found '{arm_calib_path}'")
207
+
208
+ if robot.is_connected:
209
+ robot.disconnect()
210
+
211
+ if robot.robot_type.startswith("lekiwi") and "main_follower" in arms:
212
+ print("Calibrating only the lekiwi follower arm 'main_follower'...")
213
+ robot.calibrate_follower()
214
+ return
215
+
216
+ if robot.robot_type.startswith("lekiwi") and "main_leader" in arms:
217
+ print("Calibrating only the lekiwi leader arm 'main_leader'...")
218
+ robot.calibrate_leader()
219
+ return
220
+
221
+ # Calling `connect` automatically runs calibration
222
+ # when the calibration file is missing
223
+ robot.connect()
224
+ robot.disconnect()
225
+ print("Calibration is done! You can now teleoperate and record datasets!")
226
+
227
+
228
+ @safe_disconnect
229
+ def teleoperate(robot: Robot, cfg: TeleoperateControlConfig):
230
+ control_loop(
231
+ robot,
232
+ control_time_s=cfg.teleop_time_s,
233
+ fps=cfg.fps,
234
+ teleoperate=True,
235
+ display_cameras=cfg.display_cameras,
236
+ )
237
+
238
+
239
+ @safe_disconnect
240
+ def record(
241
+ robot: Robot,
242
+ cfg: RecordControlConfig,
243
+ ) -> LeRobotDataset:
244
+ # TODO(rcadene): Add option to record logs
245
+ if cfg.resume:
246
+ dataset = LeRobotDataset(
247
+ cfg.repo_id,
248
+ root=cfg.root,
249
+ )
250
+ if len(robot.cameras) > 0:
251
+ dataset.start_image_writer(
252
+ num_processes=cfg.num_image_writer_processes,
253
+ num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
254
+ )
255
+ sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video)
256
+ else:
257
+ # Create empty dataset or load existing saved episodes
258
+ sanity_check_dataset_name(cfg.repo_id, cfg.policy)
259
+ dataset = LeRobotDataset.create(
260
+ cfg.repo_id,
261
+ cfg.fps,
262
+ root=cfg.root,
263
+ robot=robot,
264
+ use_videos=cfg.video,
265
+ image_writer_processes=cfg.num_image_writer_processes,
266
+ image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
267
+ )
268
+
269
+ # Load pretrained policy
270
+ policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
271
+
272
+ if not robot.is_connected:
273
+ robot.connect()
274
+
275
+ listener, events = init_keyboard_listener()
276
+
277
+ # Execute a few seconds without recording to:
278
+ # 1. teleoperate the robot to move it in starting position if no policy provided,
279
+ # 2. give times to the robot devices to connect and start synchronizing,
280
+ # 3. place the cameras windows on screen
281
+ enable_teleoperation = policy is None
282
+ log_say("Warmup record", cfg.play_sounds)
283
+ warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
284
+
285
+ if has_method(robot, "teleop_safety_stop"):
286
+ robot.teleop_safety_stop()
287
+
288
+ recorded_episodes = 0
289
+ while True:
290
+ if recorded_episodes >= cfg.num_episodes:
291
+ break
292
+
293
+ log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
294
+ record_episode(
295
+ robot=robot,
296
+ dataset=dataset,
297
+ events=events,
298
+ episode_time_s=cfg.episode_time_s,
299
+ display_cameras=cfg.display_cameras,
300
+ policy=policy,
301
+ fps=cfg.fps,
302
+ single_task=cfg.single_task,
303
+ )
304
+
305
+ # Execute a few seconds without recording to give time to manually reset the environment
306
+ # Current code logic doesn't allow to teleoperate during this time.
307
+ # TODO(rcadene): add an option to enable teleoperation during reset
308
+ # Skip reset for the last episode to be recorded
309
+ if not events["stop_recording"] and (
310
+ (recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"]
311
+ ):
312
+ log_say("Reset the environment", cfg.play_sounds)
313
+ reset_environment(robot, events, cfg.reset_time_s, cfg.fps)
314
+
315
+ if events["rerecord_episode"]:
316
+ log_say("Re-record episode", cfg.play_sounds)
317
+ events["rerecord_episode"] = False
318
+ events["exit_early"] = False
319
+ dataset.clear_episode_buffer()
320
+ continue
321
+
322
+ dataset.save_episode()
323
+ recorded_episodes += 1
324
+
325
+ if events["stop_recording"]:
326
+ break
327
+
328
+ log_say("Stop recording", cfg.play_sounds, blocking=True)
329
+ stop_recording(robot, listener, cfg.display_cameras)
330
+
331
+ if cfg.push_to_hub:
332
+ dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
333
+
334
+ log_say("Exiting", cfg.play_sounds)
335
+ return dataset
336
+
337
+
338
+ @safe_disconnect
339
+ def replay(
340
+ robot: Robot,
341
+ cfg: ReplayControlConfig,
342
+ ):
343
+ # TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
344
+ # TODO(rcadene): Add option to record logs
345
+
346
+ dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
347
+ actions = dataset.hf_dataset.select_columns("action")
348
+
349
+ if not robot.is_connected:
350
+ robot.connect()
351
+
352
+ log_say("Replaying episode", cfg.play_sounds, blocking=True)
353
+ for idx in range(dataset.num_frames):
354
+ start_episode_t = time.perf_counter()
355
+
356
+ action = actions[idx]["action"]
357
+ robot.send_action(action)
358
+
359
+ dt_s = time.perf_counter() - start_episode_t
360
+ busy_wait(1 / cfg.fps - dt_s)
361
+
362
+ dt_s = time.perf_counter() - start_episode_t
363
+ log_control_info(robot, dt_s, fps=cfg.fps)
364
+
365
+
366
+ @parser.wrap()
367
+ def control_robot(cfg: ControlPipelineConfig):
368
+ init_logging()
369
+ logging.info(pformat(asdict(cfg)))
370
+
371
+ robot = make_robot_from_config(cfg.robot)
372
+
373
+ if isinstance(cfg.control, CalibrateControlConfig):
374
+ calibrate(robot, cfg.control)
375
+ elif isinstance(cfg.control, TeleoperateControlConfig):
376
+ teleoperate(robot, cfg.control)
377
+ elif isinstance(cfg.control, RecordControlConfig):
378
+ record(robot, cfg.control)
379
+ elif isinstance(cfg.control, ReplayControlConfig):
380
+ replay(robot, cfg.control)
381
+ elif isinstance(cfg.control, RemoteRobotConfig):
382
+ from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi
383
+
384
+ run_lekiwi(cfg.robot)
385
+
386
+ if robot.is_connected:
387
+ # Disconnect manually to avoid a "Core dump" during process
388
+ # termination due to camera threads not properly exiting.
389
+ robot.disconnect()
390
+
391
+
392
+ if __name__ == "__main__":
393
+ control_robot()
lerobot/scripts/control_sim_robot.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Utilities to control a robot in simulation.
16
+
17
+ Useful to record a dataset, replay a recorded episode and record an evaluation dataset.
18
+
19
+ Examples of usage:
20
+
21
+
22
+ - Unlimited teleoperation at a limited frequency of 30 Hz, to simulate data recording frequency.
23
+ You can modify this value depending on how fast your simulation can run:
24
+ ```bash
25
+ python lerobot/scripts/control_robot.py teleoperate \
26
+ --fps 30 \
27
+ --robot-path lerobot/configs/robot/your_robot_config.yaml \
28
+ --sim-config lerobot/configs/env/your_sim_config.yaml
29
+ ```
30
+
31
+ - Record one episode in order to test replay:
32
+ ```bash
33
+ python lerobot/scripts/control_sim_robot.py record \
34
+ --robot-path lerobot/configs/robot/your_robot_config.yaml \
35
+ --sim-config lerobot/configs/env/your_sim_config.yaml \
36
+ --fps 30 \
37
+ --repo-id $USER/robot_sim_test \
38
+ --num-episodes 1 \
39
+ --run-compute-stats 0
40
+ ```
41
+
42
+ Enable the --push-to-hub 1 to push the recorded dataset to the huggingface hub.
43
+
44
+ - Visualize dataset:
45
+ ```bash
46
+ python lerobot/scripts/visualize_dataset.py \
47
+ --repo-id $USER/robot_sim_test \
48
+ --episode-index 0
49
+ ```
50
+
51
+ - Replay a sequence of test episodes:
52
+ ```bash
53
+ python lerobot/scripts/control_sim_robot.py replay \
54
+ --robot-path lerobot/configs/robot/your_robot_config.yaml \
55
+ --sim-config lerobot/configs/env/your_sim_config.yaml \
56
+ --fps 30 \
57
+ --repo-id $USER/robot_sim_test \
58
+ --episode 0
59
+ ```
60
+ Note: The seed is saved, therefore, during replay we can load the same environment state as the one during collection.
61
+
62
+ - Record a full dataset in order to train a policy,
63
+ 30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes:
64
+ ```bash
65
+ python lerobot/scripts/control_sim_robot.py record \
66
+ --robot-path lerobot/configs/robot/your_robot_config.yaml \
67
+ --sim-config lerobot/configs/env/your_sim_config.yaml \
68
+ --fps 30 \
69
+ --repo-id $USER/robot_sim_test \
70
+ --num-episodes 50 \
71
+ --episode-time-s 30 \
72
+ ```
73
+
74
+ **NOTE**: You can use your keyboard to control data recording flow.
75
+ - Tap right arrow key '->' to early exit while recording an episode and go to resetting the environment.
76
+ - Tap right arrow key '->' to early exit while resetting the environment and got to recording the next episode.
77
+ - Tap left arrow key '<-' to early exit and re-record the current episode.
78
+ - Tap escape key 'esc' to stop the data recording.
79
+ This might require a sudo permission to allow your terminal to monitor keyboard events.
80
+
81
+ **NOTE**: You can resume/continue data recording by running the same data recording command twice.
82
+ """
83
+
84
+ import argparse
85
+ import importlib
86
+ import logging
87
+ import time
88
+ from pathlib import Path
89
+
90
+ import cv2
91
+ import gymnasium as gym
92
+ import numpy as np
93
+ import torch
94
+
95
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
96
+ from lerobot.common.robot_devices.control_utils import (
97
+ init_keyboard_listener,
98
+ init_policy,
99
+ is_headless,
100
+ log_control_info,
101
+ predict_action,
102
+ sanity_check_dataset_name,
103
+ sanity_check_dataset_robot_compatibility,
104
+ stop_recording,
105
+ )
106
+ from lerobot.common.robot_devices.robots.utils import Robot, make_robot
107
+ from lerobot.common.robot_devices.utils import busy_wait
108
+ from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say
109
+
110
+ raise NotImplementedError("This script is currently deactivated")
111
+
112
+ DEFAULT_FEATURES = {
113
+ "next.reward": {
114
+ "dtype": "float32",
115
+ "shape": (1,),
116
+ "names": None,
117
+ },
118
+ "next.success": {
119
+ "dtype": "bool",
120
+ "shape": (1,),
121
+ "names": None,
122
+ },
123
+ "seed": {
124
+ "dtype": "int64",
125
+ "shape": (1,),
126
+ "names": None,
127
+ },
128
+ "timestamp": {
129
+ "dtype": "float32",
130
+ "shape": (1,),
131
+ "names": None,
132
+ },
133
+ }
134
+
135
+
136
+ ########################################################################################
137
+ # Utilities
138
+ ########################################################################################
139
+ def none_or_int(value):
140
+ if value == "None":
141
+ return None
142
+ return int(value)
143
+
144
+
145
+ def init_sim_calibration(robot, cfg):
146
+ # Constants necessary for transforming the joint pos of the real robot to the sim
147
+ # depending on the robot description used in that sim.
148
+ start_pos = np.array(robot.leader_arms.main.calibration["start_pos"])
149
+ axis_directions = np.array(cfg.get("axis_directions", [1]))
150
+ offsets = np.array(cfg.get("offsets", [0])) * np.pi
151
+
152
+ return {"start_pos": start_pos, "axis_directions": axis_directions, "offsets": offsets}
153
+
154
+
155
+ def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets):
156
+ """Counts - starting position -> radians -> align axes -> offset"""
157
+ return axis_directions * (real_positions - start_pos) * 2.0 * np.pi / 4096 + offsets
158
+
159
+
160
+ ########################################################################################
161
+ # Control modes
162
+ ########################################################################################
163
+
164
+
165
+ def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None):
166
+ env = env()
167
+ env.reset()
168
+ start_teleop_t = time.perf_counter()
169
+ while True:
170
+ leader_pos = robot.leader_arms.main.read("Present_Position")
171
+ action = process_action_fn(leader_pos)
172
+ env.step(np.expand_dims(action, 0))
173
+ if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
174
+ print("Teleoperation processes finished.")
175
+ break
176
+
177
+
178
+ def record(
179
+ env,
180
+ robot: Robot,
181
+ process_action_from_leader,
182
+ root: Path,
183
+ repo_id: str,
184
+ task: str,
185
+ fps: int | None = None,
186
+ tags: list[str] | None = None,
187
+ pretrained_policy_name_or_path: str = None,
188
+ policy_overrides: bool | None = None,
189
+ episode_time_s: int = 30,
190
+ num_episodes: int = 50,
191
+ video: bool = True,
192
+ push_to_hub: bool = True,
193
+ num_image_writer_processes: int = 0,
194
+ num_image_writer_threads_per_camera: int = 4,
195
+ display_cameras: bool = False,
196
+ play_sounds: bool = True,
197
+ resume: bool = False,
198
+ local_files_only: bool = False,
199
+ run_compute_stats: bool = True,
200
+ ) -> LeRobotDataset:
201
+ # Load pretrained policy
202
+ policy = None
203
+ if pretrained_policy_name_or_path is not None:
204
+ policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
205
+
206
+ if fps is None:
207
+ fps = policy_fps
208
+ logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
209
+
210
+ if policy is None and process_action_from_leader is None:
211
+ raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
212
+
213
+ # initialize listener before sim env
214
+ listener, events = init_keyboard_listener()
215
+
216
+ # create sim env
217
+ env = env()
218
+
219
+ # Create empty dataset or load existing saved episodes
220
+ num_cameras = sum([1 if "image" in key else 0 for key in env.observation_space])
221
+
222
+ # get image keys
223
+ image_keys = [key for key in env.observation_space if "image" in key]
224
+ state_keys_dict = env_cfg.state_keys
225
+
226
+ if resume:
227
+ dataset = LeRobotDataset(
228
+ repo_id,
229
+ root=root,
230
+ local_files_only=local_files_only,
231
+ )
232
+ dataset.start_image_writer(
233
+ num_processes=num_image_writer_processes,
234
+ num_threads=num_image_writer_threads_per_camera * num_cameras,
235
+ )
236
+ sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
237
+ else:
238
+ features = DEFAULT_FEATURES
239
+ # add image keys to features
240
+ for key in image_keys:
241
+ shape = env.observation_space[key].shape
242
+ if not key.startswith("observation.image."):
243
+ key = "observation.image." + key
244
+ features[key] = {"dtype": "video", "names": ["channels", "height", "width"], "shape": shape}
245
+
246
+ for key, obs_key in state_keys_dict.items():
247
+ features[key] = {
248
+ "dtype": "float32",
249
+ "names": None,
250
+ "shape": env.observation_space[obs_key].shape,
251
+ }
252
+
253
+ features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
254
+
255
+ # Create empty dataset or load existing saved episodes
256
+ sanity_check_dataset_name(repo_id, policy)
257
+ dataset = LeRobotDataset.create(
258
+ repo_id,
259
+ fps,
260
+ root=root,
261
+ features=features,
262
+ use_videos=video,
263
+ image_writer_processes=num_image_writer_processes,
264
+ image_writer_threads=num_image_writer_threads_per_camera * num_cameras,
265
+ )
266
+
267
+ recorded_episodes = 0
268
+ while True:
269
+ log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
270
+
271
+ if events is None:
272
+ events = {"exit_early": False}
273
+
274
+ if episode_time_s is None:
275
+ episode_time_s = float("inf")
276
+
277
+ timestamp = 0
278
+ start_episode_t = time.perf_counter()
279
+
280
+ seed = np.random.randint(0, 1e5)
281
+ observation, info = env.reset(seed=seed)
282
+
283
+ while timestamp < episode_time_s:
284
+ start_loop_t = time.perf_counter()
285
+
286
+ if policy is not None:
287
+ action = predict_action(observation, policy, device, use_amp)
288
+ else:
289
+ leader_pos = robot.leader_arms.main.read("Present_Position")
290
+ action = process_action_from_leader(leader_pos)
291
+
292
+ observation, reward, terminated, _, info = env.step(action)
293
+
294
+ success = info.get("is_success", False)
295
+ env_timestamp = info.get("timestamp", dataset.episode_buffer["size"] / fps)
296
+
297
+ frame = {
298
+ "action": torch.from_numpy(action),
299
+ "next.reward": reward,
300
+ "next.success": success,
301
+ "seed": seed,
302
+ "timestamp": env_timestamp,
303
+ }
304
+
305
+ for key in image_keys:
306
+ if not key.startswith("observation.image"):
307
+ frame["observation.image." + key] = observation[key]
308
+ else:
309
+ frame[key] = observation[key]
310
+
311
+ for key, obs_key in state_keys_dict.items():
312
+ frame[key] = torch.from_numpy(observation[obs_key])
313
+
314
+ dataset.add_frame(frame)
315
+
316
+ if display_cameras and not is_headless():
317
+ for key in image_keys:
318
+ cv2.imshow(key, cv2.cvtColor(observation[key], cv2.COLOR_RGB2BGR))
319
+ cv2.waitKey(1)
320
+
321
+ if fps is not None:
322
+ dt_s = time.perf_counter() - start_loop_t
323
+ busy_wait(1 / fps - dt_s)
324
+
325
+ dt_s = time.perf_counter() - start_loop_t
326
+ log_control_info(robot, dt_s, fps=fps)
327
+
328
+ timestamp = time.perf_counter() - start_episode_t
329
+ if events["exit_early"] or terminated:
330
+ events["exit_early"] = False
331
+ break
332
+
333
+ if events["rerecord_episode"]:
334
+ log_say("Re-record episode", play_sounds)
335
+ events["rerecord_episode"] = False
336
+ events["exit_early"] = False
337
+ dataset.clear_episode_buffer()
338
+ continue
339
+
340
+ dataset.save_episode(task=task)
341
+ recorded_episodes += 1
342
+
343
+ if events["stop_recording"] or recorded_episodes >= num_episodes:
344
+ break
345
+ else:
346
+ logging.info("Waiting for a few seconds before starting next episode recording...")
347
+ busy_wait(3)
348
+
349
+ log_say("Stop recording", play_sounds, blocking=True)
350
+ stop_recording(robot, listener, display_cameras)
351
+
352
+ if run_compute_stats:
353
+ logging.info("Computing dataset statistics")
354
+ dataset.consolidate(run_compute_stats)
355
+
356
+ if push_to_hub:
357
+ dataset.push_to_hub(tags=tags)
358
+
359
+ log_say("Exiting", play_sounds)
360
+ return dataset
361
+
362
+
363
+ def replay(
364
+ env, root: Path, repo_id: str, episode: int, fps: int | None = None, local_files_only: bool = True
365
+ ):
366
+ env = env()
367
+
368
+ local_dir = Path(root) / repo_id
369
+ if not local_dir.exists():
370
+ raise ValueError(local_dir)
371
+
372
+ dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
373
+ items = dataset.hf_dataset.select_columns("action")
374
+ seeds = dataset.hf_dataset.select_columns("seed")["seed"]
375
+
376
+ from_idx = dataset.episode_data_index["from"][episode].item()
377
+ to_idx = dataset.episode_data_index["to"][episode].item()
378
+ env.reset(seed=seeds[from_idx].item())
379
+ logging.info("Replaying episode")
380
+ log_say("Replaying episode", play_sounds=True)
381
+ for idx in range(from_idx, to_idx):
382
+ start_episode_t = time.perf_counter()
383
+ action = items[idx]["action"]
384
+ env.step(action.unsqueeze(0).numpy())
385
+ dt_s = time.perf_counter() - start_episode_t
386
+ busy_wait(1 / fps - dt_s)
387
+
388
+
389
+ if __name__ == "__main__":
390
+ parser = argparse.ArgumentParser()
391
+ subparsers = parser.add_subparsers(dest="mode", required=True)
392
+
393
+ # Set common options for all the subparsers
394
+ base_parser = argparse.ArgumentParser(add_help=False)
395
+ base_parser.add_argument(
396
+ "--robot-path",
397
+ type=str,
398
+ default="lerobot/configs/robot/koch.yaml",
399
+ help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
400
+ )
401
+
402
+ base_parser.add_argument(
403
+ "--sim-config",
404
+ help="Path to a yaml config you want to use for initializing a sim environment based on gym ",
405
+ )
406
+
407
+ parser_record = subparsers.add_parser("teleoperate", parents=[base_parser])
408
+
409
+ parser_record = subparsers.add_parser("record", parents=[base_parser])
410
+ parser_record.add_argument(
411
+ "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
412
+ )
413
+ parser_record.add_argument(
414
+ "--root",
415
+ type=Path,
416
+ default=None,
417
+ help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
418
+ )
419
+ parser_record.add_argument(
420
+ "--repo-id",
421
+ type=str,
422
+ default="lerobot/test",
423
+ help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
424
+ )
425
+ parser_record.add_argument(
426
+ "--episode-time-s",
427
+ type=int,
428
+ default=60,
429
+ help="Number of seconds for data recording for each episode.",
430
+ )
431
+ parser_record.add_argument(
432
+ "--task",
433
+ type=str,
434
+ required=True,
435
+ help="A description of the task preformed during recording that can be used as a language instruction.",
436
+ )
437
+ parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
438
+ parser_record.add_argument(
439
+ "--run-compute-stats",
440
+ type=int,
441
+ default=1,
442
+ help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
443
+ )
444
+ parser_record.add_argument(
445
+ "--push-to-hub",
446
+ type=int,
447
+ default=1,
448
+ help="Upload dataset to Hugging Face hub.",
449
+ )
450
+ parser_record.add_argument(
451
+ "--tags",
452
+ type=str,
453
+ nargs="*",
454
+ help="Add tags to your dataset on the hub.",
455
+ )
456
+ parser_record.add_argument(
457
+ "--num-image-writer-processes",
458
+ type=int,
459
+ default=0,
460
+ help=(
461
+ "Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only; "
462
+ "set to ≥1 to use subprocesses, each using threads to write images. The best number of processes "
463
+ "and threads depends on your system. We recommend 4 threads per camera with 0 processes. "
464
+ "If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses."
465
+ ),
466
+ )
467
+ parser_record.add_argument(
468
+ "--num-image-writer-threads-per-camera",
469
+ type=int,
470
+ default=4,
471
+ help=(
472
+ "Number of threads writing the frames as png images on disk, per camera. "
473
+ "Too much threads might cause unstable teleoperation fps due to main thread being blocked. "
474
+ "Not enough threads might cause low camera fps."
475
+ ),
476
+ )
477
+ parser_record.add_argument(
478
+ "--display-cameras",
479
+ type=int,
480
+ default=0,
481
+ help="Visualize image observations with opencv.",
482
+ )
483
+ parser_record.add_argument(
484
+ "--resume",
485
+ type=int,
486
+ default=0,
487
+ help="Resume recording on an existing dataset.",
488
+ )
489
+ parser_replay = subparsers.add_parser("replay", parents=[base_parser])
490
+ parser_replay.add_argument(
491
+ "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
492
+ )
493
+ parser_replay.add_argument(
494
+ "--root",
495
+ type=Path,
496
+ default=None,
497
+ help="Root directory where the dataset will be stored locally (e.g. 'data/hf_username/dataset_name'). By default, stored in cache folder.",
498
+ )
499
+ parser_replay.add_argument(
500
+ "--repo-id",
501
+ type=str,
502
+ default="lerobot/test",
503
+ help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
504
+ )
505
+ parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episodes to replay.")
506
+
507
+ args = parser.parse_args()
508
+
509
+ init_logging()
510
+
511
+ control_mode = args.mode
512
+ robot_path = args.robot_path
513
+ env_config_path = args.sim_config
514
+ kwargs = vars(args)
515
+ del kwargs["mode"]
516
+ del kwargs["robot_path"]
517
+ del kwargs["sim_config"]
518
+
519
+ # make gym env
520
+ env_cfg = init_hydra_config(env_config_path)
521
+ importlib.import_module(f"gym_{env_cfg.env.type}")
522
+
523
+ def env_constructor():
524
+ return gym.make(env_cfg.env.handle, disable_env_checker=True, **env_cfg.env.gym)
525
+
526
+ robot = None
527
+ process_leader_actions_fn = None
528
+
529
+ if control_mode in ["teleoperate", "record"]:
530
+ # make robot
531
+ robot_overrides = ["~cameras", "~follower_arms"]
532
+ # TODO(rcadene): remove
533
+ robot_cfg = init_hydra_config(robot_path, robot_overrides)
534
+ robot = make_robot(robot_cfg)
535
+ robot.connect()
536
+
537
+ calib_kwgs = init_sim_calibration(robot, env_cfg.calibration)
538
+
539
+ def process_leader_actions_fn(action):
540
+ return real_positions_to_sim(action, **calib_kwgs)
541
+
542
+ robot.leader_arms.main.calibration = None
543
+
544
+ if control_mode == "teleoperate":
545
+ teleoperate(env_constructor, robot, process_leader_actions_fn)
546
+
547
+ elif control_mode == "record":
548
+ record(env_constructor, robot, process_leader_actions_fn, **kwargs)
549
+
550
+ elif control_mode == "replay":
551
+ replay(env_constructor, **kwargs)
552
+
553
+ else:
554
+ raise ValueError(
555
+ f"Invalid control mode: '{control_mode}', only valid modes are teleoperate, record and replay."
556
+ )
557
+
558
+ if robot and robot.is_connected:
559
+ # Disconnect manually to avoid a "Core dump" during process
560
+ # termination due to camera threads not properly exiting.
561
+ robot.disconnect()
lerobot/scripts/display_sys_info.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """Use this script to get a quick summary of your system config.
18
+ It should be able to run without any of LeRobot's dependencies or LeRobot itself installed.
19
+ """
20
+
21
+ import platform
22
+
23
+ HAS_HF_HUB = True
24
+ HAS_HF_DATASETS = True
25
+ HAS_NP = True
26
+ HAS_TORCH = True
27
+ HAS_LEROBOT = True
28
+
29
+ try:
30
+ import huggingface_hub
31
+ except ImportError:
32
+ HAS_HF_HUB = False
33
+
34
+ try:
35
+ import datasets
36
+ except ImportError:
37
+ HAS_HF_DATASETS = False
38
+
39
+ try:
40
+ import numpy as np
41
+ except ImportError:
42
+ HAS_NP = False
43
+
44
+ try:
45
+ import torch
46
+ except ImportError:
47
+ HAS_TORCH = False
48
+
49
+ try:
50
+ import lerobot
51
+ except ImportError:
52
+ HAS_LEROBOT = False
53
+
54
+
55
+ lerobot_version = lerobot.__version__ if HAS_LEROBOT else "N/A"
56
+ hf_hub_version = huggingface_hub.__version__ if HAS_HF_HUB else "N/A"
57
+ hf_datasets_version = datasets.__version__ if HAS_HF_DATASETS else "N/A"
58
+ np_version = np.__version__ if HAS_NP else "N/A"
59
+
60
+ torch_version = torch.__version__ if HAS_TORCH else "N/A"
61
+ torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
62
+ cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
63
+
64
+
65
+ # TODO(aliberts): refactor into an actual command `lerobot env`
66
+ def display_sys_info() -> dict:
67
+ """Run this to get basic system info to help for tracking issues & bugs."""
68
+ info = {
69
+ "`lerobot` version": lerobot_version,
70
+ "Platform": platform.platform(),
71
+ "Python version": platform.python_version(),
72
+ "Huggingface_hub version": hf_hub_version,
73
+ "Dataset version": hf_datasets_version,
74
+ "Numpy version": np_version,
75
+ "PyTorch version (GPU?)": f"{torch_version} ({torch_cuda_available})",
76
+ "Cuda version": cuda_version,
77
+ "Using GPU in script?": "<fill in>",
78
+ # "Using distributed or parallel set-up in script?": "<fill in>",
79
+ }
80
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
81
+ print(format_dict(info))
82
+ return info
83
+
84
+
85
+ def format_dict(d: dict) -> str:
86
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
87
+
88
+
89
+ if __name__ == "__main__":
90
+ display_sys_info()
lerobot/scripts/eval.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """Evaluate a policy on an environment by running rollouts and computing metrics.
17
+
18
+ Usage examples:
19
+
20
+ You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/diffusion_pusht)
21
+ for 10 episodes.
22
+
23
+ ```
24
+ python lerobot/scripts/eval.py \
25
+ --policy.path=lerobot/diffusion_pusht \
26
+ --env.type=pusht \
27
+ --eval.batch_size=10 \
28
+ --eval.n_episodes=10 \
29
+ --use_amp=false \
30
+ --device=cuda
31
+ ```
32
+
33
+ OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes.
34
+ ```
35
+ python lerobot/scripts/eval.py \
36
+ --policy.path=outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \
37
+ --env.type=pusht \
38
+ --eval.batch_size=10 \
39
+ --eval.n_episodes=10 \
40
+ --use_amp=false \
41
+ --device=cuda
42
+ ```
43
+
44
+ Note that in both examples, the repo/folder should contain at least `config.json` and `model.safetensors` files.
45
+
46
+ You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
47
+ """
48
+
49
+ import json
50
+ import logging
51
+ import threading
52
+ import time
53
+ from contextlib import nullcontext
54
+ from copy import deepcopy
55
+ from dataclasses import asdict
56
+ from pathlib import Path
57
+ from pprint import pformat
58
+ from typing import Callable
59
+
60
+ import einops
61
+ import gymnasium as gym
62
+ import numpy as np
63
+ import torch
64
+ from termcolor import colored
65
+ from torch import Tensor, nn
66
+ from tqdm import trange
67
+
68
+ from lerobot.common.envs.factory import make_env
69
+ from lerobot.common.envs.utils import preprocess_observation
70
+ from lerobot.common.policies.factory import make_policy
71
+ from lerobot.common.policies.pretrained import PreTrainedPolicy
72
+ from lerobot.common.policies.utils import get_device_from_parameters
73
+ from lerobot.common.utils.io_utils import write_video
74
+ from lerobot.common.utils.random_utils import set_seed
75
+ from lerobot.common.utils.utils import (
76
+ get_safe_torch_device,
77
+ init_logging,
78
+ inside_slurm,
79
+ )
80
+ from lerobot.configs import parser
81
+ from lerobot.configs.eval import EvalPipelineConfig
82
+
83
+
84
+ def rollout(
85
+ env: gym.vector.VectorEnv,
86
+ policy: PreTrainedPolicy,
87
+ seeds: list[int] | None = None,
88
+ return_observations: bool = False,
89
+ render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
90
+ ) -> dict:
91
+ """Run a batched policy rollout once through a batch of environments.
92
+
93
+ Note that all environments in the batch are run until the last environment is done. This means some
94
+ data will probably need to be discarded (for environments that aren't the first one to be done).
95
+
96
+ The return dictionary contains:
97
+ (optional) "observation": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
98
+ keys. NOTE the that this has an extra sequence element relative to the other keys in the
99
+ dictionary. This is because an extra observation is included for after the environment is
100
+ terminated or truncated.
101
+ "action": A (batch, sequence, action_dim) tensor of actions applied based on the observations (not
102
+ including the last observations).
103
+ "reward": A (batch, sequence) tensor of rewards received for applying the actions.
104
+ "success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
105
+ environment termination/truncation).
106
+ "done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
107
+ the first True is followed by True's all the way till the end. This can be used for masking
108
+ extraneous elements from the sequences above.
109
+
110
+ Args:
111
+ env: The batch of environments.
112
+ policy: The policy. Must be a PyTorch nn module.
113
+ seeds: The environments are seeded once at the start of the rollout. If provided, this argument
114
+ specifies the seeds for each of the environments.
115
+ return_observations: Whether to include all observations in the returned rollout data. Observations
116
+ are returned optionally because they typically take more memory to cache. Defaults to False.
117
+ render_callback: Optional rendering callback to be used after the environments are reset, and after
118
+ every step.
119
+ Returns:
120
+ The dictionary described above.
121
+ """
122
+ assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
123
+ device = get_device_from_parameters(policy)
124
+
125
+ # Reset the policy and environments.
126
+ policy.reset()
127
+
128
+ observation, info = env.reset(seed=seeds)
129
+ if render_callback is not None:
130
+ render_callback(env)
131
+
132
+ all_observations = []
133
+ all_actions = []
134
+ all_rewards = []
135
+ all_successes = []
136
+ all_dones = []
137
+
138
+ step = 0
139
+ # Keep track of which environments are done.
140
+ done = np.array([False] * env.num_envs)
141
+ max_steps = env.call("_max_episode_steps")[0]
142
+ progbar = trange(
143
+ max_steps,
144
+ desc=f"Running rollout with at most {max_steps} steps",
145
+ disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs
146
+ leave=False,
147
+ )
148
+ while not np.all(done):
149
+ # Numpy array to tensor and changing dictionary keys to LeRobot policy format.
150
+ observation = preprocess_observation(observation)
151
+ if return_observations:
152
+ all_observations.append(deepcopy(observation))
153
+
154
+ observation = {
155
+ key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
156
+ }
157
+
158
+ with torch.inference_mode():
159
+ action = policy.select_action(observation)
160
+
161
+ # Convert to CPU / numpy.
162
+ action = action.to("cpu").numpy()
163
+ assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
164
+
165
+ # Apply the next action.
166
+ observation, reward, terminated, truncated, info = env.step(action)
167
+ if render_callback is not None:
168
+ render_callback(env)
169
+
170
+ # VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
171
+ # available of none of the envs finished.
172
+ if "final_info" in info:
173
+ successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
174
+ else:
175
+ successes = [False] * env.num_envs
176
+
177
+ # Keep track of which environments are done so far.
178
+ done = terminated | truncated | done
179
+
180
+ all_actions.append(torch.from_numpy(action))
181
+ all_rewards.append(torch.from_numpy(reward))
182
+ all_dones.append(torch.from_numpy(done))
183
+ all_successes.append(torch.tensor(successes))
184
+
185
+ step += 1
186
+ running_success_rate = (
187
+ einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
188
+ )
189
+ progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
190
+ progbar.update()
191
+
192
+ # Track the final observation.
193
+ if return_observations:
194
+ observation = preprocess_observation(observation)
195
+ all_observations.append(deepcopy(observation))
196
+
197
+ # Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
198
+ ret = {
199
+ "action": torch.stack(all_actions, dim=1),
200
+ "reward": torch.stack(all_rewards, dim=1),
201
+ "success": torch.stack(all_successes, dim=1),
202
+ "done": torch.stack(all_dones, dim=1),
203
+ }
204
+ if return_observations:
205
+ stacked_observations = {}
206
+ for key in all_observations[0]:
207
+ stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
208
+ ret["observation"] = stacked_observations
209
+
210
+ if hasattr(policy, "use_original_modules"):
211
+ policy.use_original_modules()
212
+
213
+ return ret
214
+
215
+
216
+ def eval_policy(
217
+ env: gym.vector.VectorEnv,
218
+ policy: PreTrainedPolicy,
219
+ n_episodes: int,
220
+ max_episodes_rendered: int = 0,
221
+ videos_dir: Path | None = None,
222
+ return_episode_data: bool = False,
223
+ start_seed: int | None = None,
224
+ ) -> dict:
225
+ """
226
+ Args:
227
+ env: The batch of environments.
228
+ policy: The policy.
229
+ n_episodes: The number of episodes to evaluate.
230
+ max_episodes_rendered: Maximum number of episodes to render into videos.
231
+ videos_dir: Where to save rendered videos.
232
+ return_episode_data: Whether to return episode data for online training. Incorporates the data into
233
+ the "episodes" key of the returned dictionary.
234
+ start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the
235
+ seed is incremented by 1. If not provided, the environments are not manually seeded.
236
+ Returns:
237
+ Dictionary with metrics and data regarding the rollouts.
238
+ """
239
+ if max_episodes_rendered > 0 and not videos_dir:
240
+ raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
241
+
242
+ if not isinstance(policy, PreTrainedPolicy):
243
+ raise ValueError(
244
+ f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided."
245
+ )
246
+
247
+ start = time.time()
248
+ policy.eval()
249
+
250
+ # Determine how many batched rollouts we need to get n_episodes. Note that if n_episodes is not evenly
251
+ # divisible by env.num_envs we end up discarding some data in the last batch.
252
+ n_batches = n_episodes // env.num_envs + int((n_episodes % env.num_envs) != 0)
253
+
254
+ # Keep track of some metrics.
255
+ sum_rewards = []
256
+ max_rewards = []
257
+ all_successes = []
258
+ all_seeds = []
259
+ threads = [] # for video saving threads
260
+ n_episodes_rendered = 0 # for saving the correct number of videos
261
+
262
+ # Callback for visualization.
263
+ def render_frame(env: gym.vector.VectorEnv):
264
+ # noqa: B023
265
+ if n_episodes_rendered >= max_episodes_rendered:
266
+ return
267
+ n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
268
+ if isinstance(env, gym.vector.SyncVectorEnv):
269
+ ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023
270
+ elif isinstance(env, gym.vector.AsyncVectorEnv):
271
+ # Here we must render all frames and discard any we don't need.
272
+ ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
273
+
274
+ if max_episodes_rendered > 0:
275
+ video_paths: list[str] = []
276
+
277
+ if return_episode_data:
278
+ episode_data: dict | None = None
279
+
280
+ # we dont want progress bar when we use slurm, since it clutters the logs
281
+ progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
282
+ for batch_ix in progbar:
283
+ # Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
284
+ # step.
285
+ if max_episodes_rendered > 0:
286
+ ep_frames: list[np.ndarray] = []
287
+
288
+ if start_seed is None:
289
+ seeds = None
290
+ else:
291
+ seeds = range(
292
+ start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
293
+ )
294
+ rollout_data = rollout(
295
+ env,
296
+ policy,
297
+ seeds=list(seeds) if seeds else None,
298
+ return_observations=return_episode_data,
299
+ render_callback=render_frame if max_episodes_rendered > 0 else None,
300
+ )
301
+
302
+ # Figure out where in each rollout sequence the first done condition was encountered (results after
303
+ # this won't be included).
304
+ n_steps = rollout_data["done"].shape[1]
305
+ # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
306
+ done_indices = torch.argmax(rollout_data["done"].to(int), dim=1)
307
+
308
+ # Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
309
+ # (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
310
+ mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
311
+ # Extend metrics.
312
+ batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum")
313
+ sum_rewards.extend(batch_sum_rewards.tolist())
314
+ batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max")
315
+ max_rewards.extend(batch_max_rewards.tolist())
316
+ batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
317
+ all_successes.extend(batch_successes.tolist())
318
+ if seeds:
319
+ all_seeds.extend(seeds)
320
+ else:
321
+ all_seeds.append(None)
322
+
323
+ # FIXME: episode_data is either None or it doesn't exist
324
+ if return_episode_data:
325
+ this_episode_data = _compile_episode_data(
326
+ rollout_data,
327
+ done_indices,
328
+ start_episode_index=batch_ix * env.num_envs,
329
+ start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
330
+ fps=env.unwrapped.metadata["render_fps"],
331
+ )
332
+ if episode_data is None:
333
+ episode_data = this_episode_data
334
+ else:
335
+ # Some sanity checks to make sure we are correctly compiling the data.
336
+ assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0]
337
+ assert episode_data["index"][-1] + 1 == this_episode_data["index"][0]
338
+ # Concatenate the episode data.
339
+ episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data}
340
+
341
+ # Maybe render video for visualization.
342
+ if max_episodes_rendered > 0 and len(ep_frames) > 0:
343
+ batch_stacked_frames = np.stack(ep_frames, axis=1) # (b, t, *)
344
+ for stacked_frames, done_index in zip(
345
+ batch_stacked_frames, done_indices.flatten().tolist(), strict=False
346
+ ):
347
+ if n_episodes_rendered >= max_episodes_rendered:
348
+ break
349
+
350
+ videos_dir.mkdir(parents=True, exist_ok=True)
351
+ video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4"
352
+ video_paths.append(str(video_path))
353
+ thread = threading.Thread(
354
+ target=write_video,
355
+ args=(
356
+ str(video_path),
357
+ stacked_frames[: done_index + 1], # + 1 to capture the last observation
358
+ env.unwrapped.metadata["render_fps"],
359
+ ),
360
+ )
361
+ thread.start()
362
+ threads.append(thread)
363
+ n_episodes_rendered += 1
364
+
365
+ progbar.set_postfix(
366
+ {"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
367
+ )
368
+
369
+ # Wait till all video rendering threads are done.
370
+ for thread in threads:
371
+ thread.join()
372
+
373
+ # Compile eval info.
374
+ info = {
375
+ "per_episode": [
376
+ {
377
+ "episode_ix": i,
378
+ "sum_reward": sum_reward,
379
+ "max_reward": max_reward,
380
+ "success": success,
381
+ "seed": seed,
382
+ }
383
+ for i, (sum_reward, max_reward, success, seed) in enumerate(
384
+ zip(
385
+ sum_rewards[:n_episodes],
386
+ max_rewards[:n_episodes],
387
+ all_successes[:n_episodes],
388
+ all_seeds[:n_episodes],
389
+ strict=True,
390
+ )
391
+ )
392
+ ],
393
+ "aggregated": {
394
+ "avg_sum_reward": float(np.nanmean(sum_rewards[:n_episodes])),
395
+ "avg_max_reward": float(np.nanmean(max_rewards[:n_episodes])),
396
+ "pc_success": float(np.nanmean(all_successes[:n_episodes]) * 100),
397
+ "eval_s": time.time() - start,
398
+ "eval_ep_s": (time.time() - start) / n_episodes,
399
+ },
400
+ }
401
+
402
+ if return_episode_data:
403
+ info["episodes"] = episode_data
404
+
405
+ if max_episodes_rendered > 0:
406
+ info["video_paths"] = video_paths
407
+
408
+ return info
409
+
410
+
411
+ def _compile_episode_data(
412
+ rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float
413
+ ) -> dict:
414
+ """Convenience function for `eval_policy(return_episode_data=True)`
415
+
416
+ Compiles all the rollout data into a Hugging Face dataset.
417
+
418
+ Similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`).
419
+ """
420
+ ep_dicts = []
421
+ total_frames = 0
422
+ for ep_ix in range(rollout_data["action"].shape[0]):
423
+ # + 2 to include the first done frame and the last observation frame.
424
+ num_frames = done_indices[ep_ix].item() + 2
425
+ total_frames += num_frames
426
+
427
+ # Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
428
+ ep_dict = {
429
+ "action": rollout_data["action"][ep_ix, : num_frames - 1],
430
+ "episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
431
+ "frame_index": torch.arange(0, num_frames - 1, 1),
432
+ "timestamp": torch.arange(0, num_frames - 1, 1) / fps,
433
+ "next.done": rollout_data["done"][ep_ix, : num_frames - 1],
434
+ "next.success": rollout_data["success"][ep_ix, : num_frames - 1],
435
+ "next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
436
+ }
437
+
438
+ # For the last observation frame, all other keys will just be copy padded.
439
+ for k in ep_dict:
440
+ ep_dict[k] = torch.cat([ep_dict[k], ep_dict[k][-1:]])
441
+
442
+ for key in rollout_data["observation"]:
443
+ ep_dict[key] = rollout_data["observation"][key][ep_ix, :num_frames]
444
+
445
+ ep_dicts.append(ep_dict)
446
+
447
+ data_dict = {}
448
+ for key in ep_dicts[0]:
449
+ data_dict[key] = torch.cat([x[key] for x in ep_dicts])
450
+
451
+ data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
452
+
453
+ return data_dict
454
+
455
+
456
+ @parser.wrap()
457
+ def eval_main(cfg: EvalPipelineConfig):
458
+ logging.info(pformat(asdict(cfg)))
459
+
460
+ # Check device is available
461
+ device = get_safe_torch_device(cfg.policy.device, log=True)
462
+
463
+ torch.backends.cudnn.benchmark = True
464
+ torch.backends.cuda.matmul.allow_tf32 = True
465
+ set_seed(cfg.seed)
466
+
467
+ logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
468
+
469
+ logging.info("Making environment.")
470
+ env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
471
+
472
+ logging.info("Making policy.")
473
+
474
+ policy = make_policy(
475
+ cfg=cfg.policy,
476
+ env_cfg=cfg.env,
477
+ )
478
+ policy.eval()
479
+
480
+ with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
481
+ info = eval_policy(
482
+ env,
483
+ policy,
484
+ cfg.eval.n_episodes,
485
+ max_episodes_rendered=10,
486
+ videos_dir=Path(cfg.output_dir) / "videos",
487
+ start_seed=cfg.seed,
488
+ )
489
+ print(info["aggregated"])
490
+
491
+ # Save info
492
+ with open(Path(cfg.output_dir) / "eval_info.json", "w") as f:
493
+ json.dump(info, f, indent=2)
494
+
495
+ env.close()
496
+
497
+ logging.info("End of eval")
498
+
499
+
500
+ if __name__ == "__main__":
501
+ init_logging()
502
+ eval_main()
lerobot/scripts/find_motors_bus_port.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import time
16
+ from pathlib import Path
17
+
18
+ from serial.tools import list_ports # Part of pyserial library
19
+
20
+
21
+ def find_available_ports():
22
+ if os.name == "nt": # Windows
23
+ # List COM ports using pyserial
24
+ ports = [port.device for port in list_ports.comports()]
25
+ else: # Linux/macOS
26
+ # List /dev/tty* ports for Unix-based systems
27
+ ports = [str(path) for path in Path("/dev").glob("tty*")]
28
+ return ports
29
+
30
+
31
+ def find_port():
32
+ print("Finding all available ports for the MotorsBus.")
33
+ ports_before = find_available_ports()
34
+ print("Ports before disconnecting:", ports_before)
35
+
36
+ print("Remove the USB cable from your MotorsBus and press Enter when done.")
37
+ input() # Wait for user to disconnect the device
38
+
39
+ time.sleep(0.5) # Allow some time for port to be released
40
+ ports_after = find_available_ports()
41
+ ports_diff = list(set(ports_before) - set(ports_after))
42
+
43
+ if len(ports_diff) == 1:
44
+ port = ports_diff[0]
45
+ print(f"The port of this MotorsBus is '{port}'")
46
+ print("Reconnect the USB cable.")
47
+ elif len(ports_diff) == 0:
48
+ raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).")
49
+ else:
50
+ raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).")
51
+
52
+
53
+ if __name__ == "__main__":
54
+ # Helper to find the USB port associated with your MotorsBus.
55
+ find_port()
lerobot/scripts/push_dataset_to_hub.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
18
+ or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
19
+ installation of neural net specific packages like pytorch, tensorflow, jax.
20
+
21
+ Example of how to download raw datasets, convert them into LeRobotDataset format, and push them to the hub:
22
+ ```
23
+ python lerobot/scripts/push_dataset_to_hub.py \
24
+ --raw-dir data/pusht_raw \
25
+ --raw-format pusht_zarr \
26
+ --repo-id lerobot/pusht
27
+
28
+ python lerobot/scripts/push_dataset_to_hub.py \
29
+ --raw-dir data/xarm_lift_medium_raw \
30
+ --raw-format xarm_pkl \
31
+ --repo-id lerobot/xarm_lift_medium
32
+
33
+ python lerobot/scripts/push_dataset_to_hub.py \
34
+ --raw-dir data/aloha_sim_insertion_scripted_raw \
35
+ --raw-format aloha_hdf5 \
36
+ --repo-id lerobot/aloha_sim_insertion_scripted
37
+
38
+ python lerobot/scripts/push_dataset_to_hub.py \
39
+ --raw-dir data/umi_cup_in_the_wild_raw \
40
+ --raw-format umi_zarr \
41
+ --repo-id lerobot/umi_cup_in_the_wild
42
+ ```
43
+ """
44
+
45
+ import argparse
46
+ import json
47
+ import shutil
48
+ import warnings
49
+ from pathlib import Path
50
+ from typing import Any
51
+
52
+ import torch
53
+ from huggingface_hub import HfApi
54
+ from safetensors.torch import save_file
55
+
56
+ from lerobot.common.datasets.compute_stats import compute_stats
57
+ from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
58
+ from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
59
+ from lerobot.common.datasets.utils import create_branch, create_lerobot_dataset_card, flatten_dict
60
+
61
+
62
+ def get_from_raw_to_lerobot_format_fn(raw_format: str):
63
+ if raw_format == "pusht_zarr":
64
+ from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
65
+ elif raw_format == "umi_zarr":
66
+ from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
67
+ elif raw_format == "aloha_hdf5":
68
+ from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
69
+ elif raw_format in ["rlds", "openx"]:
70
+ from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import from_raw_to_lerobot_format
71
+ elif raw_format == "dora_parquet":
72
+ from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
73
+ elif raw_format == "xarm_pkl":
74
+ from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
75
+ elif raw_format == "cam_png":
76
+ from lerobot.common.datasets.push_dataset_to_hub.cam_png_format import from_raw_to_lerobot_format
77
+ else:
78
+ raise ValueError(
79
+ f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
80
+ )
81
+
82
+ return from_raw_to_lerobot_format
83
+
84
+
85
+ def save_meta_data(
86
+ info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path
87
+ ):
88
+ meta_data_dir.mkdir(parents=True, exist_ok=True)
89
+
90
+ # save info
91
+ info_path = meta_data_dir / "info.json"
92
+ with open(str(info_path), "w") as f:
93
+ json.dump(info, f, indent=4)
94
+
95
+ # save stats
96
+ stats_path = meta_data_dir / "stats.safetensors"
97
+ save_file(flatten_dict(stats), stats_path)
98
+
99
+ # save episode_data_index
100
+ episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
101
+ ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
102
+ save_file(episode_data_index, ep_data_idx_path)
103
+
104
+
105
+ def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None):
106
+ """Expect all meta data files to be all stored in a single "meta_data" directory.
107
+ On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
108
+ """
109
+ api = HfApi()
110
+ api.upload_folder(
111
+ folder_path=meta_data_dir,
112
+ path_in_repo="meta_data",
113
+ repo_id=repo_id,
114
+ revision=revision,
115
+ repo_type="dataset",
116
+ )
117
+
118
+
119
+ def push_dataset_card_to_hub(
120
+ repo_id: str,
121
+ revision: str | None,
122
+ tags: list | None = None,
123
+ license: str = "apache-2.0",
124
+ **card_kwargs,
125
+ ):
126
+ """Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
127
+ card = create_lerobot_dataset_card(tags=tags, license=license, **card_kwargs)
128
+ card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
129
+
130
+
131
+ def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
132
+ """Expect mp4 files to be all stored in a single "videos" directory.
133
+ On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
134
+ """
135
+ api = HfApi()
136
+ api.upload_folder(
137
+ folder_path=videos_dir,
138
+ path_in_repo="videos",
139
+ repo_id=repo_id,
140
+ revision=revision,
141
+ repo_type="dataset",
142
+ allow_patterns="*.mp4",
143
+ )
144
+
145
+
146
+ def push_dataset_to_hub(
147
+ raw_dir: Path,
148
+ raw_format: str,
149
+ repo_id: str,
150
+ push_to_hub: bool = True,
151
+ local_dir: Path | None = None,
152
+ fps: int | None = None,
153
+ video: bool = True,
154
+ batch_size: int = 32,
155
+ num_workers: int = 8,
156
+ episodes: list[int] | None = None,
157
+ force_override: bool = False,
158
+ resume: bool = False,
159
+ cache_dir: Path = Path("/tmp"),
160
+ tests_data_dir: Path | None = None,
161
+ encoding: dict | None = None,
162
+ ):
163
+ check_repo_id(repo_id)
164
+ user_id, dataset_id = repo_id.split("/")
165
+
166
+ # Robustify when `raw_dir` is str instead of Path
167
+ raw_dir = Path(raw_dir)
168
+ if not raw_dir.exists():
169
+ raise NotADirectoryError(
170
+ f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub: "
171
+ f"`python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw`"
172
+ )
173
+
174
+ if local_dir:
175
+ # Robustify when `local_dir` is str instead of Path
176
+ local_dir = Path(local_dir)
177
+
178
+ # Send warning if local_dir isn't well formatted
179
+ if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id:
180
+ warnings.warn(
181
+ f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.",
182
+ stacklevel=1,
183
+ )
184
+
185
+ # Check we don't override an existing `local_dir` by mistake
186
+ if local_dir.exists():
187
+ if force_override:
188
+ shutil.rmtree(local_dir)
189
+ elif not resume:
190
+ raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
191
+
192
+ meta_data_dir = local_dir / "meta_data"
193
+ videos_dir = local_dir / "videos"
194
+ else:
195
+ # Temporary directory used to store images, videos, meta_data
196
+ meta_data_dir = Path(cache_dir) / "meta_data"
197
+ videos_dir = Path(cache_dir) / "videos"
198
+
199
+ if raw_format is None:
200
+ # TODO(rcadene, adilzouitine): implement auto_find_raw_format
201
+ raise NotImplementedError()
202
+ # raw_format = auto_find_raw_format(raw_dir)
203
+
204
+ # convert dataset from original raw format to LeRobot format
205
+ from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
206
+
207
+ hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
208
+ raw_dir,
209
+ videos_dir,
210
+ fps,
211
+ video,
212
+ episodes,
213
+ encoding,
214
+ )
215
+
216
+ lerobot_dataset = LeRobotDataset.from_preloaded(
217
+ repo_id=repo_id,
218
+ hf_dataset=hf_dataset,
219
+ episode_data_index=episode_data_index,
220
+ info=info,
221
+ videos_dir=videos_dir,
222
+ )
223
+ stats = compute_stats(lerobot_dataset, batch_size, num_workers)
224
+
225
+ if local_dir:
226
+ hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
227
+ hf_dataset.save_to_disk(str(local_dir / "train"))
228
+
229
+ if push_to_hub or local_dir:
230
+ # mandatory for upload
231
+ save_meta_data(info, stats, episode_data_index, meta_data_dir)
232
+
233
+ if push_to_hub:
234
+ hf_dataset.push_to_hub(repo_id, revision="main")
235
+ push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
236
+ push_dataset_card_to_hub(repo_id, revision="main")
237
+ if video:
238
+ push_videos_to_hub(repo_id, videos_dir, revision="main")
239
+ create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
240
+
241
+ if tests_data_dir:
242
+ # get the first episode
243
+ num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
244
+ test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
245
+ episode_data_index = {k: v[:1] for k, v in episode_data_index.items()}
246
+
247
+ test_hf_dataset = test_hf_dataset.with_format(None)
248
+ test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train"))
249
+
250
+ tests_meta_data = tests_data_dir / repo_id / "meta_data"
251
+ save_meta_data(info, stats, episode_data_index, tests_meta_data)
252
+
253
+ # copy videos of first episode to tests directory
254
+ episode_index = 0
255
+ tests_videos_dir = tests_data_dir / repo_id / "videos"
256
+ tests_videos_dir.mkdir(parents=True, exist_ok=True)
257
+ for key in lerobot_dataset.camera_keys:
258
+ fname = f"{key}_episode_{episode_index:06d}.mp4"
259
+ shutil.copy(videos_dir / fname, tests_videos_dir / fname)
260
+
261
+ if local_dir is None:
262
+ # clear cache
263
+ shutil.rmtree(meta_data_dir)
264
+ shutil.rmtree(videos_dir)
265
+
266
+ return lerobot_dataset
267
+
268
+
269
+ def main():
270
+ parser = argparse.ArgumentParser()
271
+
272
+ parser.add_argument(
273
+ "--raw-dir",
274
+ type=Path,
275
+ required=True,
276
+ help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
277
+ )
278
+ # TODO(rcadene): add automatic detection of the format
279
+ parser.add_argument(
280
+ "--raw-format",
281
+ type=str,
282
+ required=True,
283
+ help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`, `rlds`, `openx`).",
284
+ )
285
+ parser.add_argument(
286
+ "--repo-id",
287
+ type=str,
288
+ required=True,
289
+ help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
290
+ )
291
+ parser.add_argument(
292
+ "--local-dir",
293
+ type=Path,
294
+ help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).",
295
+ )
296
+ parser.add_argument(
297
+ "--push-to-hub",
298
+ type=int,
299
+ default=1,
300
+ help="Upload to hub.",
301
+ )
302
+ parser.add_argument(
303
+ "--fps",
304
+ type=int,
305
+ help="Frame rate used to collect videos. If not provided, use the default one specified in the code.",
306
+ )
307
+ parser.add_argument(
308
+ "--video",
309
+ type=int,
310
+ default=1,
311
+ help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
312
+ )
313
+ parser.add_argument(
314
+ "--batch-size",
315
+ type=int,
316
+ default=32,
317
+ help="Batch size loaded by DataLoader for computing the dataset statistics.",
318
+ )
319
+ parser.add_argument(
320
+ "--num-workers",
321
+ type=int,
322
+ default=8,
323
+ help="Number of processes of Dataloader for computing the dataset statistics.",
324
+ )
325
+ parser.add_argument(
326
+ "--episodes",
327
+ type=int,
328
+ nargs="*",
329
+ help="When provided, only converts the provided episodes (e.g `--episodes 2 3 4`). Useful to test the code on 1 episode.",
330
+ )
331
+ parser.add_argument(
332
+ "--force-override",
333
+ type=int,
334
+ default=0,
335
+ help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.",
336
+ )
337
+ parser.add_argument(
338
+ "--resume",
339
+ type=int,
340
+ default=0,
341
+ help="When set to 1, resumes a previous run.",
342
+ )
343
+ parser.add_argument(
344
+ "--cache-dir",
345
+ type=Path,
346
+ required=False,
347
+ default="/tmp",
348
+ help="Directory to store the temporary videos and images generated while creating the dataset.",
349
+ )
350
+ parser.add_argument(
351
+ "--tests-data-dir",
352
+ type=Path,
353
+ help=(
354
+ "When provided, save tests artifacts into the given directory "
355
+ "(e.g. `--tests-data-dir tests/data` will save to tests/data/{--repo-id})."
356
+ ),
357
+ )
358
+
359
+ args = parser.parse_args()
360
+ push_dataset_to_hub(**vars(args))
361
+
362
+
363
+ if __name__ == "__main__":
364
+ main()
lerobot/scripts/push_pretrained.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Once you have trained a policy with our training script (lerobot/scripts/train.py), use this script to push it
18
+ to the hub.
19
+
20
+ Example:
21
+
22
+ ```bash
23
+ python lerobot/scripts/push_pretrained.py \
24
+ --pretrained_path=outputs/train/act_aloha_sim_transfer_cube_human/checkpoints/last/pretrained_model \
25
+ --repo_id=lerobot/act_aloha_sim_transfer_cube_human
26
+ ```
27
+ """
28
+
29
+ from dataclasses import dataclass
30
+ from pathlib import Path
31
+
32
+ import draccus
33
+ from huggingface_hub import HfApi
34
+
35
+
36
+ @dataclass
37
+ class PushPreTrainedConfig:
38
+ pretrained_path: Path
39
+ repo_id: str
40
+ branch: str | None = None
41
+ private: bool = False
42
+ exist_ok: bool = False
43
+
44
+
45
+ @draccus.wrap()
46
+ def main(cfg: PushPreTrainedConfig):
47
+ hub_api = HfApi()
48
+ hub_api.create_repo(
49
+ repo_id=cfg.repo_id,
50
+ private=cfg.private,
51
+ repo_type="model",
52
+ exist_ok=cfg.exist_ok,
53
+ )
54
+ if cfg.branch:
55
+ hub_api.create_branch(
56
+ repo_id=cfg.repo_id,
57
+ branch=cfg.branch,
58
+ repo_type="model",
59
+ exist_ok=cfg.exist_ok,
60
+ )
61
+
62
+ hub_api.upload_folder(
63
+ repo_id=cfg.repo_id,
64
+ folder_path=cfg.pretrained_path,
65
+ repo_type="model",
66
+ revision=cfg.branch,
67
+ )
68
+
69
+
70
+ if __name__ == "__main__":
71
+ main()
lerobot/scripts/train.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import logging
17
+ import time
18
+ from contextlib import nullcontext
19
+ from pprint import pformat
20
+ from typing import Any
21
+
22
+ import torch
23
+ from termcolor import colored
24
+ from torch.amp import GradScaler
25
+ from torch.optim import Optimizer
26
+
27
+ from lerobot.common.datasets.factory import make_dataset
28
+ from lerobot.common.datasets.sampler import EpisodeAwareSampler
29
+ from lerobot.common.datasets.utils import cycle
30
+ from lerobot.common.envs.factory import make_env
31
+ from lerobot.common.optim.factory import make_optimizer_and_scheduler
32
+ from lerobot.common.policies.factory import make_policy
33
+ from lerobot.common.policies.pretrained import PreTrainedPolicy
34
+ from lerobot.common.policies.utils import get_device_from_parameters
35
+ from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
36
+ from lerobot.common.utils.random_utils import set_seed
37
+ from lerobot.common.utils.train_utils import (
38
+ get_step_checkpoint_dir,
39
+ get_step_identifier,
40
+ load_training_state,
41
+ save_checkpoint,
42
+ update_last_checkpoint,
43
+ )
44
+ from lerobot.common.utils.utils import (
45
+ format_big_number,
46
+ get_safe_torch_device,
47
+ has_method,
48
+ init_logging,
49
+ )
50
+ from lerobot.common.utils.wandb_utils import WandBLogger
51
+ from lerobot.configs import parser
52
+ from lerobot.configs.train import TrainPipelineConfig
53
+ from lerobot.scripts.eval import eval_policy
54
+
55
+
56
+ def update_policy(
57
+ train_metrics: MetricsTracker,
58
+ policy: PreTrainedPolicy,
59
+ batch: Any,
60
+ optimizer: Optimizer,
61
+ grad_clip_norm: float,
62
+ grad_scaler: GradScaler,
63
+ lr_scheduler=None,
64
+ use_amp: bool = False,
65
+ lock=None,
66
+ ) -> tuple[MetricsTracker, dict]:
67
+ start_time = time.perf_counter()
68
+ device = get_device_from_parameters(policy)
69
+ policy.train()
70
+ with torch.autocast(device_type=device.type) if use_amp else nullcontext():
71
+ loss, output_dict = policy.forward(batch)
72
+ # TODO(rcadene): policy.unnormalize_outputs(out_dict)
73
+ grad_scaler.scale(loss).backward()
74
+
75
+ # Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
76
+ grad_scaler.unscale_(optimizer)
77
+
78
+ grad_norm = torch.nn.utils.clip_grad_norm_(
79
+ policy.parameters(),
80
+ grad_clip_norm,
81
+ error_if_nonfinite=False,
82
+ )
83
+
84
+ # Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
85
+ # although it still skips optimizer.step() if the gradients contain infs or NaNs.
86
+ with lock if lock is not None else nullcontext():
87
+ grad_scaler.step(optimizer)
88
+ # Updates the scale for next iteration.
89
+ grad_scaler.update()
90
+
91
+ optimizer.zero_grad()
92
+
93
+ # Step through pytorch scheduler at every batch instead of epoch
94
+ if lr_scheduler is not None:
95
+ lr_scheduler.step()
96
+
97
+ if has_method(policy, "update"):
98
+ # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
99
+ policy.update()
100
+
101
+ train_metrics.loss = loss.item()
102
+ train_metrics.grad_norm = grad_norm.item()
103
+ train_metrics.lr = optimizer.param_groups[0]["lr"]
104
+ train_metrics.update_s = time.perf_counter() - start_time
105
+ return train_metrics, output_dict
106
+
107
+
108
+ @parser.wrap()
109
+ def train(cfg: TrainPipelineConfig):
110
+ cfg.validate()
111
+ logging.info(pformat(cfg.to_dict()))
112
+
113
+ if cfg.wandb.enable and cfg.wandb.project:
114
+ wandb_logger = WandBLogger(cfg)
115
+ else:
116
+ wandb_logger = None
117
+ logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
118
+
119
+ if cfg.seed is not None:
120
+ set_seed(cfg.seed)
121
+
122
+ # Check device is available
123
+ device = get_safe_torch_device(cfg.policy.device, log=True)
124
+ torch.backends.cudnn.benchmark = True
125
+ torch.backends.cuda.matmul.allow_tf32 = True
126
+
127
+ logging.info("Creating dataset")
128
+ dataset = make_dataset(cfg)
129
+
130
+ # Create environment used for evaluating checkpoints during training on simulation data.
131
+ # On real-world data, no need to create an environment as evaluations are done outside train.py,
132
+ # using the eval.py instead, with gym_dora environment and dora-rs.
133
+ eval_env = None
134
+ if cfg.eval_freq > 0 and cfg.env is not None:
135
+ logging.info("Creating env")
136
+ eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size)
137
+
138
+ logging.info("Creating policy")
139
+ policy = make_policy(
140
+ cfg=cfg.policy,
141
+ ds_meta=dataset.meta,
142
+ )
143
+
144
+ logging.info("Creating optimizer and scheduler")
145
+ optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
146
+ grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
147
+
148
+ step = 0 # number of policy updates (forward + backward + optim)
149
+
150
+ if cfg.resume:
151
+ step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
152
+
153
+ num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
154
+ num_total_params = sum(p.numel() for p in policy.parameters())
155
+
156
+ logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
157
+ if cfg.env is not None:
158
+ logging.info(f"{cfg.env.task=}")
159
+ logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
160
+ logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
161
+ logging.info(f"{dataset.num_episodes=}")
162
+ logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
163
+ logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
164
+
165
+ # create dataloader for offline training
166
+ if hasattr(cfg.policy, "drop_n_last_frames"):
167
+ shuffle = False
168
+ sampler = EpisodeAwareSampler(
169
+ dataset.episode_data_index,
170
+ drop_n_last_frames=cfg.policy.drop_n_last_frames,
171
+ shuffle=True,
172
+ )
173
+ else:
174
+ shuffle = True
175
+ sampler = None
176
+
177
+ dataloader = torch.utils.data.DataLoader(
178
+ dataset,
179
+ num_workers=cfg.num_workers,
180
+ batch_size=cfg.batch_size,
181
+ shuffle=shuffle,
182
+ sampler=sampler,
183
+ pin_memory=device.type != "cpu",
184
+ drop_last=False,
185
+ )
186
+ dl_iter = cycle(dataloader)
187
+
188
+ policy.train()
189
+
190
+ train_metrics = {
191
+ "loss": AverageMeter("loss", ":.3f"),
192
+ "grad_norm": AverageMeter("grdn", ":.3f"),
193
+ "lr": AverageMeter("lr", ":0.1e"),
194
+ "update_s": AverageMeter("updt_s", ":.3f"),
195
+ "dataloading_s": AverageMeter("data_s", ":.3f"),
196
+ }
197
+
198
+ train_tracker = MetricsTracker(
199
+ cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
200
+ )
201
+
202
+ logging.info("Start offline training on a fixed dataset")
203
+ for _ in range(step, cfg.steps):
204
+ start_time = time.perf_counter()
205
+ batch = next(dl_iter)
206
+ train_tracker.dataloading_s = time.perf_counter() - start_time
207
+
208
+ for key in batch:
209
+ if isinstance(batch[key], torch.Tensor):
210
+ batch[key] = batch[key].to(device, non_blocking=True)
211
+
212
+ train_tracker, output_dict = update_policy(
213
+ train_tracker,
214
+ policy,
215
+ batch,
216
+ optimizer,
217
+ cfg.optimizer.grad_clip_norm,
218
+ grad_scaler=grad_scaler,
219
+ lr_scheduler=lr_scheduler,
220
+ use_amp=cfg.policy.use_amp,
221
+ )
222
+
223
+ # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
224
+ # increment `step` here.
225
+ step += 1
226
+ train_tracker.step()
227
+ is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
228
+ is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
229
+ is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
230
+
231
+ if is_log_step:
232
+ logging.info(train_tracker)
233
+ if wandb_logger:
234
+ wandb_log_dict = train_tracker.to_dict()
235
+ if output_dict:
236
+ wandb_log_dict.update(output_dict)
237
+ wandb_logger.log_dict(wandb_log_dict, step)
238
+ train_tracker.reset_averages()
239
+
240
+ if cfg.save_checkpoint and is_saving_step:
241
+ logging.info(f"Checkpoint policy after step {step}")
242
+ checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
243
+ save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
244
+ update_last_checkpoint(checkpoint_dir)
245
+ if wandb_logger:
246
+ wandb_logger.log_policy(checkpoint_dir)
247
+
248
+ if cfg.env and is_eval_step:
249
+ step_id = get_step_identifier(step, cfg.steps)
250
+ logging.info(f"Eval policy at step {step}")
251
+ with (
252
+ torch.no_grad(),
253
+ torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
254
+ ):
255
+ eval_info = eval_policy(
256
+ eval_env,
257
+ policy,
258
+ cfg.eval.n_episodes,
259
+ videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
260
+ max_episodes_rendered=4,
261
+ start_seed=cfg.seed,
262
+ )
263
+
264
+ eval_metrics = {
265
+ "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
266
+ "pc_success": AverageMeter("success", ":.1f"),
267
+ "eval_s": AverageMeter("eval_s", ":.3f"),
268
+ }
269
+ eval_tracker = MetricsTracker(
270
+ cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
271
+ )
272
+ eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
273
+ eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
274
+ eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success")
275
+ logging.info(eval_tracker)
276
+ if wandb_logger:
277
+ wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
278
+ wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
279
+ wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
280
+
281
+ if eval_env:
282
+ eval_env.close()
283
+ logging.info("End of training")
284
+
285
+
286
+ if __name__ == "__main__":
287
+ init_logging()
288
+ train()
lerobot/scripts/visualize_dataset.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
17
+
18
+ Note: The last frame of the episode doesn't always correspond to a final state.
19
+ That's because our datasets are composed of transition from state to state up to
20
+ the antepenultimate state associated to the ultimate action to arrive in the final state.
21
+ However, there might not be a transition from a final state to another state.
22
+
23
+ Note: This script aims to visualize the data used to train the neural networks.
24
+ ~What you see is what you get~. When visualizing image modality, it is often expected to observe
25
+ lossy compression artifacts since these images have been decoded from compressed mp4 videos to
26
+ save disk space. The compression factor applied has been tuned to not affect success rate.
27
+
28
+ Examples:
29
+
30
+ - Visualize data stored on a local machine:
31
+ ```
32
+ local$ python lerobot/scripts/visualize_dataset.py \
33
+ --repo-id lerobot/pusht \
34
+ --episode-index 0
35
+ ```
36
+
37
+ - Visualize data stored on a distant machine with a local viewer:
38
+ ```
39
+ distant$ python lerobot/scripts/visualize_dataset.py \
40
+ --repo-id lerobot/pusht \
41
+ --episode-index 0 \
42
+ --save 1 \
43
+ --output-dir path/to/directory
44
+
45
+ local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd .
46
+ local$ rerun lerobot_pusht_episode_0.rrd
47
+ ```
48
+
49
+ - Visualize data stored on a distant machine through streaming:
50
+ (You need to forward the websocket port to the distant machine, with
51
+ `ssh -L 9087:localhost:9087 username@remote-host`)
52
+ ```
53
+ distant$ python lerobot/scripts/visualize_dataset.py \
54
+ --repo-id lerobot/pusht \
55
+ --episode-index 0 \
56
+ --mode distant \
57
+ --ws-port 9087
58
+
59
+ local$ rerun ws://localhost:9087
60
+ ```
61
+
62
+ """
63
+
64
+ import argparse
65
+ import gc
66
+ import logging
67
+ import time
68
+ from pathlib import Path
69
+ from typing import Iterator
70
+
71
+ import numpy as np
72
+ import rerun as rr
73
+ import torch
74
+ import torch.utils.data
75
+ import tqdm
76
+
77
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
78
+
79
+
80
+ class EpisodeSampler(torch.utils.data.Sampler):
81
+ def __init__(self, dataset: LeRobotDataset, episode_index: int):
82
+ from_idx = dataset.episode_data_index["from"][episode_index].item()
83
+ to_idx = dataset.episode_data_index["to"][episode_index].item()
84
+ self.frame_ids = range(from_idx, to_idx)
85
+
86
+ def __iter__(self) -> Iterator:
87
+ return iter(self.frame_ids)
88
+
89
+ def __len__(self) -> int:
90
+ return len(self.frame_ids)
91
+
92
+
93
+ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
94
+ assert chw_float32_torch.dtype == torch.float32
95
+ assert chw_float32_torch.ndim == 3
96
+ c, h, w = chw_float32_torch.shape
97
+ assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
98
+ hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
99
+ return hwc_uint8_numpy
100
+
101
+
102
+ def visualize_dataset(
103
+ dataset: LeRobotDataset,
104
+ episode_index: int,
105
+ batch_size: int = 32,
106
+ num_workers: int = 0,
107
+ mode: str = "local",
108
+ web_port: int = 9090,
109
+ ws_port: int = 9087,
110
+ save: bool = False,
111
+ output_dir: Path | None = None,
112
+ ) -> Path | None:
113
+ if save:
114
+ assert output_dir is not None, (
115
+ "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
116
+ )
117
+
118
+ repo_id = dataset.repo_id
119
+
120
+ logging.info("Loading dataloader")
121
+ episode_sampler = EpisodeSampler(dataset, episode_index)
122
+ dataloader = torch.utils.data.DataLoader(
123
+ dataset,
124
+ num_workers=num_workers,
125
+ batch_size=batch_size,
126
+ sampler=episode_sampler,
127
+ )
128
+
129
+ logging.info("Starting Rerun")
130
+
131
+ if mode not in ["local", "distant"]:
132
+ raise ValueError(mode)
133
+
134
+ spawn_local_viewer = mode == "local" and not save
135
+ rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
136
+
137
+ # Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
138
+ # when iterating on a dataloader with `num_workers` > 0
139
+ # TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
140
+ gc.collect()
141
+
142
+ if mode == "distant":
143
+ rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
144
+
145
+ logging.info("Logging to Rerun")
146
+
147
+ for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
148
+ # iterate over the batch
149
+ for i in range(len(batch["index"])):
150
+ rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
151
+ rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
152
+
153
+ # display each camera image
154
+ for key in dataset.meta.camera_keys:
155
+ # TODO(rcadene): add `.compress()`? is it lossless?
156
+ rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
157
+
158
+ # display each dimension of action space (e.g. actuators command)
159
+ if "action" in batch:
160
+ for dim_idx, val in enumerate(batch["action"][i]):
161
+ rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
162
+
163
+ # display each dimension of observed state space (e.g. agent position in joint space)
164
+ if "observation.state" in batch:
165
+ for dim_idx, val in enumerate(batch["observation.state"][i]):
166
+ rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
167
+
168
+ if "next.done" in batch:
169
+ rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
170
+
171
+ if "next.reward" in batch:
172
+ rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
173
+
174
+ if "next.success" in batch:
175
+ rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
176
+
177
+ if mode == "local" and save:
178
+ # save .rrd locally
179
+ output_dir = Path(output_dir)
180
+ output_dir.mkdir(parents=True, exist_ok=True)
181
+ repo_id_str = repo_id.replace("/", "_")
182
+ rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd"
183
+ rr.save(rrd_path)
184
+ return rrd_path
185
+
186
+ elif mode == "distant":
187
+ # stop the process from exiting since it is serving the websocket connection
188
+ try:
189
+ while True:
190
+ time.sleep(1)
191
+ except KeyboardInterrupt:
192
+ print("Ctrl-C received. Exiting.")
193
+
194
+
195
+ def main():
196
+ parser = argparse.ArgumentParser()
197
+
198
+ parser.add_argument(
199
+ "--repo-id",
200
+ type=str,
201
+ required=True,
202
+ help="Name of hugging face repository containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
203
+ )
204
+ parser.add_argument(
205
+ "--episode-index",
206
+ type=int,
207
+ required=True,
208
+ help="Episode to visualize.",
209
+ )
210
+ parser.add_argument(
211
+ "--root",
212
+ type=Path,
213
+ default=None,
214
+ help="Root directory for the dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
215
+ )
216
+ parser.add_argument(
217
+ "--output-dir",
218
+ type=Path,
219
+ default=None,
220
+ help="Directory path to write a .rrd file when `--save 1` is set.",
221
+ )
222
+ parser.add_argument(
223
+ "--batch-size",
224
+ type=int,
225
+ default=32,
226
+ help="Batch size loaded by DataLoader.",
227
+ )
228
+ parser.add_argument(
229
+ "--num-workers",
230
+ type=int,
231
+ default=4,
232
+ help="Number of processes of Dataloader for loading the data.",
233
+ )
234
+ parser.add_argument(
235
+ "--mode",
236
+ type=str,
237
+ default="local",
238
+ help=(
239
+ "Mode of viewing between 'local' or 'distant'. "
240
+ "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
241
+ "'distant' creates a server on the distant machine where the data is stored. "
242
+ "Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
243
+ ),
244
+ )
245
+ parser.add_argument(
246
+ "--web-port",
247
+ type=int,
248
+ default=9090,
249
+ help="Web port for rerun.io when `--mode distant` is set.",
250
+ )
251
+ parser.add_argument(
252
+ "--ws-port",
253
+ type=int,
254
+ default=9087,
255
+ help="Web socket port for rerun.io when `--mode distant` is set.",
256
+ )
257
+ parser.add_argument(
258
+ "--save",
259
+ type=int,
260
+ default=0,
261
+ help=(
262
+ "Save a .rrd file in the directory provided by `--output-dir`. "
263
+ "It also deactivates the spawning of a viewer. "
264
+ "Visualize the data by running `rerun path/to/file.rrd` on your local machine."
265
+ ),
266
+ )
267
+
268
+ parser.add_argument(
269
+ "--tolerance-s",
270
+ type=float,
271
+ default=1e-4,
272
+ help=(
273
+ "Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
274
+ "This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
275
+ "If not given, defaults to 1e-4."
276
+ ),
277
+ )
278
+
279
+ args = parser.parse_args()
280
+ kwargs = vars(args)
281
+ repo_id = kwargs.pop("repo_id")
282
+ root = kwargs.pop("root")
283
+ tolerance_s = kwargs.pop("tolerance_s")
284
+
285
+ logging.info("Loading dataset")
286
+ dataset = LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
287
+
288
+ visualize_dataset(dataset, **vars(args))
289
+
290
+
291
+ if __name__ == "__main__":
292
+ main()
lerobot/scripts/visualize_dataset_html.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
17
+
18
+ Note: The last frame of the episode doesnt always correspond to a final state.
19
+ That's because our datasets are composed of transition from state to state up to
20
+ the antepenultimate state associated to the ultimate action to arrive in the final state.
21
+ However, there might not be a transition from a final state to another state.
22
+
23
+ Note: This script aims to visualize the data used to train the neural networks.
24
+ ~What you see is what you get~. When visualizing image modality, it is often expected to observe
25
+ lossly compression artifacts since these images have been decoded from compressed mp4 videos to
26
+ save disk space. The compression factor applied has been tuned to not affect success rate.
27
+
28
+ Example of usage:
29
+
30
+ - Visualize data stored on a local machine:
31
+ ```bash
32
+ local$ python lerobot/scripts/visualize_dataset_html.py \
33
+ --repo-id lerobot/pusht
34
+
35
+ local$ open http://localhost:9090
36
+ ```
37
+
38
+ - Visualize data stored on a distant machine with a local viewer:
39
+ ```bash
40
+ distant$ python lerobot/scripts/visualize_dataset_html.py \
41
+ --repo-id lerobot/pusht
42
+
43
+ local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel
44
+ local$ open http://localhost:9090
45
+ ```
46
+
47
+ - Select episodes to visualize:
48
+ ```bash
49
+ python lerobot/scripts/visualize_dataset_html.py \
50
+ --repo-id lerobot/pusht \
51
+ --episodes 7 3 5 1 4
52
+ ```
53
+ """
54
+
55
+ import argparse
56
+ import csv
57
+ import json
58
+ import logging
59
+ import re
60
+ import shutil
61
+ import tempfile
62
+ from io import StringIO
63
+ from pathlib import Path
64
+
65
+ import numpy as np
66
+ import pandas as pd
67
+ import requests
68
+ from flask import Flask, redirect, render_template, request, url_for
69
+
70
+ from lerobot import available_datasets
71
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
72
+ from lerobot.common.datasets.utils import IterableNamespace
73
+ from lerobot.common.utils.utils import init_logging
74
+
75
+
76
+ def run_server(
77
+ dataset: LeRobotDataset | IterableNamespace | None,
78
+ episodes: list[int] | None,
79
+ host: str,
80
+ port: str,
81
+ static_folder: Path,
82
+ template_folder: Path,
83
+ ):
84
+ app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
85
+ app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
86
+
87
+ @app.route("/")
88
+ def hommepage(dataset=dataset):
89
+ if dataset:
90
+ dataset_namespace, dataset_name = dataset.repo_id.split("/")
91
+ return redirect(
92
+ url_for(
93
+ "show_episode",
94
+ dataset_namespace=dataset_namespace,
95
+ dataset_name=dataset_name,
96
+ episode_id=0,
97
+ )
98
+ )
99
+
100
+ dataset_param, episode_param = None, None
101
+ all_params = request.args
102
+ if "dataset" in all_params:
103
+ dataset_param = all_params["dataset"]
104
+ if "episode" in all_params:
105
+ episode_param = int(all_params["episode"])
106
+
107
+ if dataset_param:
108
+ dataset_namespace, dataset_name = dataset_param.split("/")
109
+ return redirect(
110
+ url_for(
111
+ "show_episode",
112
+ dataset_namespace=dataset_namespace,
113
+ dataset_name=dataset_name,
114
+ episode_id=episode_param if episode_param is not None else 0,
115
+ )
116
+ )
117
+
118
+ featured_datasets = [
119
+ "lerobot/aloha_static_cups_open",
120
+ "lerobot/columbia_cairlab_pusht_real",
121
+ "lerobot/taco_play",
122
+ ]
123
+ return render_template(
124
+ "visualize_dataset_homepage.html",
125
+ featured_datasets=featured_datasets,
126
+ lerobot_datasets=available_datasets,
127
+ )
128
+
129
+ @app.route("/<string:dataset_namespace>/<string:dataset_name>")
130
+ def show_first_episode(dataset_namespace, dataset_name):
131
+ first_episode_id = 0
132
+ return redirect(
133
+ url_for(
134
+ "show_episode",
135
+ dataset_namespace=dataset_namespace,
136
+ dataset_name=dataset_name,
137
+ episode_id=first_episode_id,
138
+ )
139
+ )
140
+
141
+ @app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
142
+ def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
143
+ repo_id = f"{dataset_namespace}/{dataset_name}"
144
+ try:
145
+ if dataset is None:
146
+ dataset = get_dataset_info(repo_id)
147
+ except FileNotFoundError:
148
+ return (
149
+ "Make sure to convert your LeRobotDataset to v2 & above. See how to convert your dataset at https://github.com/huggingface/lerobot/pull/461",
150
+ 400,
151
+ )
152
+ dataset_version = (
153
+ str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
154
+ )
155
+ match = re.search(r"v(\d+)\.", dataset_version)
156
+ if match:
157
+ major_version = int(match.group(1))
158
+ if major_version < 2:
159
+ return "Make sure to convert your LeRobotDataset to v2 & above."
160
+
161
+ episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
162
+ dataset_info = {
163
+ "repo_id": f"{dataset_namespace}/{dataset_name}",
164
+ "num_samples": dataset.num_frames
165
+ if isinstance(dataset, LeRobotDataset)
166
+ else dataset.total_frames,
167
+ "num_episodes": dataset.num_episodes
168
+ if isinstance(dataset, LeRobotDataset)
169
+ else dataset.total_episodes,
170
+ "fps": dataset.fps,
171
+ }
172
+ if isinstance(dataset, LeRobotDataset):
173
+ video_paths = [
174
+ dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
175
+ ]
176
+ videos_info = [
177
+ {"url": url_for("static", filename=video_path), "filename": video_path.parent.name}
178
+ for video_path in video_paths
179
+ ]
180
+ tasks = dataset.meta.episodes[episode_id]["tasks"]
181
+ else:
182
+ video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
183
+ videos_info = [
184
+ {
185
+ "url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
186
+ + dataset.video_path.format(
187
+ episode_chunk=int(episode_id) // dataset.chunks_size,
188
+ video_key=video_key,
189
+ episode_index=episode_id,
190
+ ),
191
+ "filename": video_key,
192
+ }
193
+ for video_key in video_keys
194
+ ]
195
+
196
+ response = requests.get(
197
+ f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
198
+ )
199
+ response.raise_for_status()
200
+ # Split into lines and parse each line as JSON
201
+ tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
202
+
203
+ filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
204
+ tasks = filtered_tasks_jsonl[0]["tasks"]
205
+
206
+ videos_info[0]["language_instruction"] = tasks
207
+
208
+ if episodes is None:
209
+ episodes = list(
210
+ range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
211
+ )
212
+
213
+ return render_template(
214
+ "visualize_dataset_template.html",
215
+ episode_id=episode_id,
216
+ episodes=episodes,
217
+ dataset_info=dataset_info,
218
+ videos_info=videos_info,
219
+ episode_data_csv_str=episode_data_csv_str,
220
+ columns=columns,
221
+ ignored_columns=ignored_columns,
222
+ )
223
+
224
+ app.run(host=host, port=port)
225
+
226
+
227
+ def get_ep_csv_fname(episode_id: int):
228
+ ep_csv_fname = f"episode_{episode_id}.csv"
229
+ return ep_csv_fname
230
+
231
+
232
+ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index):
233
+ """Get a csv str containing timeseries data of an episode (e.g. state and action).
234
+ This file will be loaded by Dygraph javascript to plot data in real time."""
235
+ columns = []
236
+
237
+ selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
238
+ selected_columns.remove("timestamp")
239
+
240
+ ignored_columns = []
241
+ for column_name in selected_columns:
242
+ shape = dataset.features[column_name]["shape"]
243
+ shape_dim = len(shape)
244
+ if shape_dim > 1:
245
+ selected_columns.remove(column_name)
246
+ ignored_columns.append(column_name)
247
+
248
+ # init header of csv with state and action names
249
+ header = ["timestamp"]
250
+
251
+ for column_name in selected_columns:
252
+ dim_state = (
253
+ dataset.meta.shapes[column_name][0]
254
+ if isinstance(dataset, LeRobotDataset)
255
+ else dataset.features[column_name].shape[0]
256
+ )
257
+
258
+ if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
259
+ column_names = dataset.features[column_name]["names"]
260
+ while not isinstance(column_names, list):
261
+ column_names = list(column_names.values())[0]
262
+ else:
263
+ column_names = [f"{column_name}_{i}" for i in range(dim_state)]
264
+ columns.append({"key": column_name, "value": column_names})
265
+
266
+ header += column_names
267
+
268
+ selected_columns.insert(0, "timestamp")
269
+
270
+ if isinstance(dataset, LeRobotDataset):
271
+ from_idx = dataset.episode_data_index["from"][episode_index]
272
+ to_idx = dataset.episode_data_index["to"][episode_index]
273
+ data = (
274
+ dataset.hf_dataset.select(range(from_idx, to_idx))
275
+ .select_columns(selected_columns)
276
+ .with_format("pandas")
277
+ )
278
+ else:
279
+ repo_id = dataset.repo_id
280
+
281
+ url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
282
+ episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
283
+ )
284
+ df = pd.read_parquet(url)
285
+ data = df[selected_columns] # Select specific columns
286
+
287
+ rows = np.hstack(
288
+ (
289
+ np.expand_dims(data["timestamp"], axis=1),
290
+ *[np.vstack(data[col]) for col in selected_columns[1:]],
291
+ )
292
+ ).tolist()
293
+
294
+ # Convert data to CSV string
295
+ csv_buffer = StringIO()
296
+ csv_writer = csv.writer(csv_buffer)
297
+ # Write header
298
+ csv_writer.writerow(header)
299
+ # Write data rows
300
+ csv_writer.writerows(rows)
301
+ csv_string = csv_buffer.getvalue()
302
+
303
+ return csv_string, columns, ignored_columns
304
+
305
+
306
+ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
307
+ # get first frame of episode (hack to get video_path of the episode)
308
+ first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
309
+ return [
310
+ dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
311
+ for key in dataset.meta.video_keys
312
+ ]
313
+
314
+
315
+ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
316
+ # check if the dataset has language instructions
317
+ if "language_instruction" not in dataset.features:
318
+ return None
319
+
320
+ # get first frame index
321
+ first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
322
+
323
+ language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
324
+ # TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
325
+ # with the tf.tensor appearing in the string
326
+ return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
327
+
328
+
329
+ def get_dataset_info(repo_id: str) -> IterableNamespace:
330
+ response = requests.get(
331
+ f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
332
+ )
333
+ response.raise_for_status() # Raises an HTTPError for bad responses
334
+ dataset_info = response.json()
335
+ dataset_info["repo_id"] = repo_id
336
+ return IterableNamespace(dataset_info)
337
+
338
+
339
+ def visualize_dataset_html(
340
+ dataset: LeRobotDataset | None,
341
+ episodes: list[int] | None = None,
342
+ output_dir: Path | None = None,
343
+ serve: bool = True,
344
+ host: str = "127.0.0.1",
345
+ port: int = 9090,
346
+ force_override: bool = False,
347
+ ) -> Path | None:
348
+ init_logging()
349
+
350
+ template_dir = Path(__file__).resolve().parent.parent / "templates"
351
+
352
+ if output_dir is None:
353
+ # Create a temporary directory that will be automatically cleaned up
354
+ output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_")
355
+
356
+ output_dir = Path(output_dir)
357
+ if output_dir.exists():
358
+ if force_override:
359
+ shutil.rmtree(output_dir)
360
+ else:
361
+ logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
362
+
363
+ output_dir.mkdir(parents=True, exist_ok=True)
364
+
365
+ static_dir = output_dir / "static"
366
+ static_dir.mkdir(parents=True, exist_ok=True)
367
+
368
+ if dataset is None:
369
+ if serve:
370
+ run_server(
371
+ dataset=None,
372
+ episodes=None,
373
+ host=host,
374
+ port=port,
375
+ static_folder=static_dir,
376
+ template_folder=template_dir,
377
+ )
378
+ else:
379
+ # Create a simlink from the dataset video folder containing mp4 files to the output directory
380
+ # so that the http server can get access to the mp4 files.
381
+ if isinstance(dataset, LeRobotDataset):
382
+ ln_videos_dir = static_dir / "videos"
383
+ if not ln_videos_dir.exists():
384
+ ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
385
+
386
+ if serve:
387
+ run_server(dataset, episodes, host, port, static_dir, template_dir)
388
+
389
+
390
+ def main():
391
+ parser = argparse.ArgumentParser()
392
+
393
+ parser.add_argument(
394
+ "--repo-id",
395
+ type=str,
396
+ default=None,
397
+ help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
398
+ )
399
+ parser.add_argument(
400
+ "--root",
401
+ type=Path,
402
+ default=None,
403
+ help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
404
+ )
405
+ parser.add_argument(
406
+ "--load-from-hf-hub",
407
+ type=int,
408
+ default=0,
409
+ help="Load videos and parquet files from HF Hub rather than local system.",
410
+ )
411
+ parser.add_argument(
412
+ "--episodes",
413
+ type=int,
414
+ nargs="*",
415
+ default=None,
416
+ help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
417
+ )
418
+ parser.add_argument(
419
+ "--output-dir",
420
+ type=Path,
421
+ default=None,
422
+ help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
423
+ )
424
+ parser.add_argument(
425
+ "--serve",
426
+ type=int,
427
+ default=1,
428
+ help="Launch web server.",
429
+ )
430
+ parser.add_argument(
431
+ "--host",
432
+ type=str,
433
+ default="127.0.0.1",
434
+ help="Web host used by the http server.",
435
+ )
436
+ parser.add_argument(
437
+ "--port",
438
+ type=int,
439
+ default=9090,
440
+ help="Web port used by the http server.",
441
+ )
442
+ parser.add_argument(
443
+ "--force-override",
444
+ type=int,
445
+ default=0,
446
+ help="Delete the output directory if it exists already.",
447
+ )
448
+
449
+ parser.add_argument(
450
+ "--tolerance-s",
451
+ type=float,
452
+ default=1e-4,
453
+ help=(
454
+ "Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
455
+ "This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
456
+ "If not given, defaults to 1e-4."
457
+ ),
458
+ )
459
+
460
+ args = parser.parse_args()
461
+ kwargs = vars(args)
462
+ repo_id = kwargs.pop("repo_id")
463
+ load_from_hf_hub = kwargs.pop("load_from_hf_hub")
464
+ root = kwargs.pop("root")
465
+ tolerance_s = kwargs.pop("tolerance_s")
466
+
467
+ dataset = None
468
+ if repo_id:
469
+ dataset = (
470
+ LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
471
+ if not load_from_hf_hub
472
+ else get_dataset_info(repo_id)
473
+ )
474
+
475
+ visualize_dataset_html(dataset, **vars(args))
476
+
477
+
478
+ if __name__ == "__main__":
479
+ main()
lerobot/scripts/visualize_image_transforms.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ Visualize effects of image transforms for a given configuration.
17
+
18
+ This script will generate examples of transformed images as they are output by LeRobot dataset.
19
+ Additionally, each individual transform can be visualized separately as well as examples of combined transforms
20
+
21
+ Example:
22
+ ```bash
23
+ python lerobot/scripts/visualize_image_transforms.py \
24
+ --repo_id=lerobot/pusht \
25
+ --episodes='[0]' \
26
+ --image_transforms.enable=True
27
+ ```
28
+ """
29
+
30
+ import logging
31
+ from copy import deepcopy
32
+ from dataclasses import replace
33
+ from pathlib import Path
34
+
35
+ import draccus
36
+ from torchvision.transforms import ToPILImage
37
+
38
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
39
+ from lerobot.common.datasets.transforms import (
40
+ ImageTransforms,
41
+ ImageTransformsConfig,
42
+ make_transform_from_config,
43
+ )
44
+ from lerobot.configs.default import DatasetConfig
45
+
46
+ OUTPUT_DIR = Path("outputs/image_transforms")
47
+ to_pil = ToPILImage()
48
+
49
+
50
+ def save_all_transforms(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
51
+ output_dir_all = output_dir / "all"
52
+ output_dir_all.mkdir(parents=True, exist_ok=True)
53
+
54
+ tfs = ImageTransforms(cfg)
55
+ for i in range(1, n_examples + 1):
56
+ transformed_frame = tfs(original_frame)
57
+ to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100)
58
+
59
+ print("Combined transforms examples saved to:")
60
+ print(f" {output_dir_all}")
61
+
62
+
63
+ def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
64
+ if not cfg.enable:
65
+ logging.warning(
66
+ "No single transforms will be saved, because `image_transforms.enable=False`. To enable, set `enable` to True in `ImageTransformsConfig` or in the command line with `--image_transforms.enable=True`."
67
+ )
68
+ return
69
+
70
+ print("Individual transforms examples saved to:")
71
+ for tf_name, tf_cfg in cfg.tfs.items():
72
+ # Apply a few transformation with random value in min_max range
73
+ output_dir_single = output_dir / tf_name
74
+ output_dir_single.mkdir(parents=True, exist_ok=True)
75
+
76
+ tf = make_transform_from_config(tf_cfg)
77
+ for i in range(1, n_examples + 1):
78
+ transformed_frame = tf(original_frame)
79
+ to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100)
80
+
81
+ # Apply min, max, average transformations
82
+ tf_cfg_kwgs_min = deepcopy(tf_cfg.kwargs)
83
+ tf_cfg_kwgs_max = deepcopy(tf_cfg.kwargs)
84
+ tf_cfg_kwgs_avg = deepcopy(tf_cfg.kwargs)
85
+
86
+ for key, (min_, max_) in tf_cfg.kwargs.items():
87
+ avg = (min_ + max_) / 2
88
+ tf_cfg_kwgs_min[key] = [min_, min_]
89
+ tf_cfg_kwgs_max[key] = [max_, max_]
90
+ tf_cfg_kwgs_avg[key] = [avg, avg]
91
+
92
+ tf_min = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min}))
93
+ tf_max = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max}))
94
+ tf_avg = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg}))
95
+
96
+ tf_frame_min = tf_min(original_frame)
97
+ tf_frame_max = tf_max(original_frame)
98
+ tf_frame_avg = tf_avg(original_frame)
99
+
100
+ to_pil(tf_frame_min).save(output_dir_single / "min.png", quality=100)
101
+ to_pil(tf_frame_max).save(output_dir_single / "max.png", quality=100)
102
+ to_pil(tf_frame_avg).save(output_dir_single / "mean.png", quality=100)
103
+
104
+ print(f" {output_dir_single}")
105
+
106
+
107
+ @draccus.wrap()
108
+ def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5):
109
+ dataset = LeRobotDataset(
110
+ repo_id=cfg.repo_id,
111
+ episodes=cfg.episodes,
112
+ revision=cfg.revision,
113
+ video_backend=cfg.video_backend,
114
+ )
115
+
116
+ output_dir = output_dir / cfg.repo_id.split("/")[-1]
117
+ output_dir.mkdir(parents=True, exist_ok=True)
118
+
119
+ # Get 1st frame from 1st camera of 1st episode
120
+ original_frame = dataset[0][dataset.meta.camera_keys[0]]
121
+ to_pil(original_frame).save(output_dir / "original_frame.png", quality=100)
122
+ print("\nOriginal frame saved to:")
123
+ print(f" {output_dir / 'original_frame.png'}.")
124
+
125
+ save_all_transforms(cfg.image_transforms, original_frame, output_dir, n_examples)
126
+ save_each_transform(cfg.image_transforms, original_frame, output_dir, n_examples)
127
+
128
+
129
+ if __name__ == "__main__":
130
+ visualize_image_transforms()
lerobot/templates/visualize_dataset_homepage.html ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Interactive Video Background Page</title>
7
+ <script src="https://cdn.tailwindcss.com"></script>
8
+ <script defer src="https://cdn.jsdelivr.net/npm/alpinejs@3.x.x/dist/cdn.min.js"></script>
9
+ </head>
10
+ <body class="h-screen overflow-hidden font-mono text-white" x-data="{
11
+ inputValue: '',
12
+ navigateToDataset() {
13
+ const trimmedValue = this.inputValue.trim();
14
+ if (trimmedValue) {
15
+ window.location.href = `/${trimmedValue}`;
16
+ }
17
+ }
18
+ }">
19
+ <div class="fixed inset-0 w-full h-full overflow-hidden">
20
+ <video class="absolute min-w-full min-h-full w-auto h-auto top-1/2 left-1/2 transform -translate-x-1/2 -translate-y-1/2" autoplay muted loop>
21
+ <source src="https://huggingface.co/datasets/cadene/koch_bimanual_folding/resolve/v1.6/videos/observation.images.phone_episode_000037.mp4" type="video/mp4">
22
+ Your browser does not support HTML5 video.
23
+ </video>
24
+ </div>
25
+ <div class="fixed inset-0 bg-black bg-opacity-80"></div>
26
+ <div class="relative z-10 flex flex-col items-center justify-center h-screen">
27
+ <div class="text-center mb-8">
28
+ <h1 class="text-4xl font-bold mb-4">LeRobot Dataset Visualizer</h1>
29
+
30
+ <a href="https://x.com/RemiCadene/status/1825455895561859185" target="_blank" rel="noopener noreferrer" class="underline">create & train your own robots</a>
31
+
32
+ <p class="text-xl mb-4"></p>
33
+ <div class="text-left inline-block">
34
+ <h3 class="font-semibold mb-2 mt-4">Example Datasets:</h3>
35
+ <ul class="list-disc list-inside">
36
+ {% for dataset in featured_datasets %}
37
+ <li><a href="/{{ dataset }}" class="text-blue-300 hover:text-blue-100 hover:underline">{{ dataset }}</a></li>
38
+ {% endfor %}
39
+ </ul>
40
+ </div>
41
+ </div>
42
+ <div class="flex w-full max-w-lg px-4 mb-4">
43
+ <input
44
+ type="text"
45
+ x-model="inputValue"
46
+ @keyup.enter="navigateToDataset"
47
+ placeholder="enter dataset id (ex: lerobot/droid_100)"
48
+ class="flex-grow px-4 py-2 rounded-l bg-white bg-opacity-20 text-white placeholder-gray-300 focus:outline-none focus:ring-2 focus:ring-blue-300"
49
+ >
50
+ <button
51
+ @click="navigateToDataset"
52
+ class="px-4 py-2 bg-blue-500 text-white rounded-r hover:bg-blue-600 focus:outline-none focus:ring-2 focus:ring-blue-300"
53
+ >
54
+ Go
55
+ </button>
56
+ </div>
57
+
58
+ <details class="mt-4 max-w-full px-4">
59
+ <summary>More example datasets</summary>
60
+ <ul class="list-disc list-inside max-h-28 overflow-y-auto break-all">
61
+ {% for dataset in lerobot_datasets %}
62
+ <li><a href="/{{ dataset }}" class="text-blue-300 hover:text-blue-100 hover:underline">{{ dataset }}</a></li>
63
+ {% endfor %}
64
+ </ul>
65
+ </details>
66
+ </div>
67
+ </body>
68
+ </html>