Source code for trifinger_rl_datasets.dataset_env

from copy import deepcopy
import hashlib
import os
from pathlib import Path
from threading import Thread
from typing import Union, Tuple, Dict, Optional, List, Any
import urllib.request

import cv2
import gymnasium as gym
import gymnasium.spaces as spaces
import numpy as np
from tqdm import tqdm
import yaml
import zarr

from .sim_env import SimTriFingerCubeEnv


class ImageLoader(Thread):
    """Thread for loading and processing images from the dataset.

    This thread is responsible for loading and processing every
    loader_id-th image. Processing includes decoding, reordering
    of pixels and debayering."""

    def __init__(
        self,
        loader_id,
        n_loaders,
        image_data,
        unique_images,
        n_unique_images,
        n_cameras,
        reorder_pixels,
        timestep_dimension,
    ):
        """
        Args:
            loader_id: ID of this loader.  This loader will load every
                loader_id-th image.
            n_loaders: Total number of loaders.
            image_data: Numpy array containing the image data.
            unique_images: Numpy array to which the images are written.
            n_unique_images: Number of unique images to load. If this
                number is not divisible by n_cameras,
                self.unique_images will be padded with zeros.
            n_cameras: Number of cameras.
            reorder_pixels: Whether to undo the reordering of the pixels
                which was done during creation of the dataset to improve
                the image compression.
            timestep_dimension: If True, the image data is expected to
                contain images from all cameras in a row and
                n_unique_images is expected to have shape
                (n_timesteps, n_cameras, height, width). If False, the
                shape is expected to be
                (n_unique_images, n_cameras, height, width)."""
        super().__init__()
        self.loader_id = loader_id
        self.n_loaders = n_loaders
        self.image_data = image_data
        self.unique_images = unique_images
        self.n_unique_images = n_unique_images
        self.n_cameras = n_cameras
        self.reorder_pixels = reorder_pixels
        self.timestep_dimension = timestep_dimension

    def _reorder_pixels(self, img: np.ndarray) -> np.ndarray:
        """Undo reordering of Bayer pattern."""
        new = np.empty_like(img)
        a = img.shape[0] // 2
        b = img.shape[1] // 2

        red = img[0:a, 0:b]
        blue = img[a:, 0:b]
        green1 = img[0:a, b:]
        green2 = img[a:, b:]

        new[0::2, 0::2] = red
        new[1::2, 1::2] = blue
        new[0::2, 1::2] = green1
        new[1::2, 0::2] = green2

        return new

    def _decode_image(self, image: np.ndarray) -> np.ndarray:
        """Decode image from numpy array of type void."""
        # convert numpy array of type V1 to use with cv2 imdecode
        image = np.frombuffer(image, dtype=np.uint8)
        # use cv2 to decode image
        image = cv2.imdecode(image, cv2.IMREAD_UNCHANGED)
        return image

    def run(self):
        # this thread is responsible for every loader_id-th image
        for i in range(self.loader_id, self.n_unique_images, self.n_loaders):
            if self.timestep_dimension:
                timestep, camera = divmod(i, self.n_cameras)
            compressed_image = self.image_data[i]
            # decode image
            image = self._decode_image(compressed_image)
            if self.reorder_pixels:
                # undo reordering of pixels
                image = self._reorder_pixels(image)
            # debayer image (output channels in RGB order)
            image = cv2.cvtColor(image, cv2.COLOR_BAYER_BG2RGB)
            # convert to channel first
            image = np.transpose(image, (2, 0, 1))
            if self.timestep_dimension:
                self.unique_images[timestep, camera, ...] = image
            else:
                self.unique_images[i, ...] = image


[docs]class TriFingerDatasetEnv(gym.Env): """TriFinger environment which can load an offline RL dataset from a file. Similar to D4RL's OfflineEnv but with different data loading and options for customization of observation space.""" _PRELOAD_VECTOR_KEYS = ["observations", "actions"] _PRELOAD_SCALAR_KEYS = ["rewards", "timeouts"] def __init__( self, name, dataset_url, ref_max_score, ref_min_score, trifinger_kwargs, real_robot=False, image_obs=False, visualization=False, obs_to_keep=None, flatten_obs=True, scale_obs=False, set_terminals=False, data_dir=None, **kwargs, ): """ Args: name (str): Name of the dataset. dataset_url (str): URL pointing to the dataset. ref_max_score (float): Maximum score (for score normalization) ref_min_score (float): Minimum score (for score normalization) trifinger_kwargs (dict): Keyword arguments for underlying SimTriFingerCubeEnv environment. real_robot (bool): Whether the data was collected on real robots. image_obs (bool): Whether observations contain camera images. visualization (bool): Enables rendering for simulated environment. obs_to_keep (dict): Dictionary with the same structure as the observation of SimTriFingerCubeEnv. The boolean value of each item indicates whether it should be included in the observation. If None, the SimTriFingerCubeEnv is used. flatten_obs (bool): Whether to flatten the observation. Can be combined with obs_to_keep. scale_obs (bool): Whether to scale all components of the observation to interval [-1, 1]. Only implemented for flattend observations. set_terminals (bool): Whether to set the terminals instead of the timeouts. data_dir (str or Path): Directory where the dataset is stored. If None, the default data directory (~/.trifinger_rl_datasets) is used. """ super().__init__(**kwargs) self.name = name self.dataset_url = dataset_url self.ref_max_score = ref_max_score self.ref_min_score = ref_min_score self.real_robot = real_robot self.image_obs = image_obs self.obs_to_keep = obs_to_keep self.flatten_obs = flatten_obs self.scale_obs = scale_obs self.set_terminals = set_terminals self._local_dataset_path = None if data_dir is None: data_dir = Path.home() / ".trifinger_rl_datasets" self.data_dir = Path(data_dir) self.t_kwargs = deepcopy(trifinger_kwargs) self.t_kwargs["image_obs"] = image_obs self.t_kwargs["visualization"] = visualization # underlying simulated TriFinger environment self.sim_env = SimTriFingerCubeEnv(**self.t_kwargs) # a copy of the original observation space which is used when # filtering the observations self._orig_obs_space = deepcopy(self.sim_env.observation_space) # the space used for unflattening the observations (images will # be removed from this space) self._unflattening_space = deepcopy(self.sim_env.observation_space) # remove camera observations from space used for flattening # and unflattening as images are treated separetely and not # flattened if self.image_obs: stripped_camera_observations = spaces.Dict( { k: v for k, v in self._orig_obs_space.spaces[ "camera_observation" ].spaces.items() if k != "images" } ) self._unflattening_space["camera_observation"] = stripped_camera_observations if self.flatten_obs: # if the observations are eventually flattened, they do not contain # images anymore self._orig_obs_space["camera_observation"] = stripped_camera_observations self._orig_flat_obs_space = spaces.flatten_space(self._orig_obs_space) self._flat_unflattening_space = spaces.flatten_space(self._unflattening_space) if scale_obs and not flatten_obs: raise NotImplementedError( "Scaling of observations only " "implemented for flattened observations, i.e., for " "flatten_obs=True." ) # action space self.action_space = self.sim_env.action_space # observation space # self._filtered_obs_space is the Dict observation space after # filtering if self.obs_to_keep is not None: # construct filtered observation space self._filtered_obs_space = self._filter_dict( keys_to_keep=self.obs_to_keep, d=self._orig_obs_space ) else: self._filtered_obs_space = self._orig_obs_space # self.observation_space is potentially also flattened if self.flatten_obs: # flat obs space self.observation_space = spaces.flatten_space(self._filtered_obs_space) if self.scale_obs: self._obs_unscaled_low = self.observation_space.low self._obs_unscaled_high = self.observation_space.high # scale observations to [-1, 1] self.observation_space = spaces.Box( low=-1.0, high=1.0, shape=self.observation_space.shape, dtype=self.observation_space.dtype, ) else: self.observation_space = self._filtered_obs_space def _download_dataset(self): """Download dataset files if not already present. `self.dataset_url` is expected to point to a YAML file with the following structure: ``` n_parts: <number of parts> md5_hash_parts: - <md5 hash of part 1> - <md5 hash of part 2> ... md5_hash_complete: <md5 hash of complete dataset> ``` The dataset is split into multiple parts to allow for continuing a download if it was interrupted. The complete dataset is then reconstructed by concatenating the parts.""" if self._local_dataset_path is None: dataset_dir = self.data_dir / (self.name + ".zarr") dataset_dir.mkdir(exist_ok=True, parents=True) local_path = dataset_dir / "data.mdb" if not local_path.exists(): print(f"Downloading dataset {self.name}.") # first download YAML file with info about dataset files with urllib.request.urlopen(self.dataset_url) as web_url: dataset_info = yaml.safe_load(web_url) # download dataset parts for i, part_hash in enumerate(tqdm(dataset_info["md5_hash_parts"])): part_path = dataset_dir / f"{self.name}_{i:03d}" if not part_path.exists(): # strip filename from url stripped_url = self.dataset_url.rsplit("/", 1)[0] part_url = stripped_url + f"/part_{i:03d}" urllib.request.urlretrieve(part_url, part_path) if not part_path.exists(): raise IOError( f"Failed to download part {i} of dataset from URL {part_url}." ) # check hash with open(part_path, "rb") as f: m = hashlib.md5() m.update(f.read()) if m.hexdigest() != part_hash: raise IOError( f"Hash of downloaded part {part_path} does not " f"match expected hash. Please delete " f"the file and try again." ) # combine parts with open(local_path, "wb") as f: print("Assembling dataset parts.") for i in tqdm(range(dataset_info["n_parts"])): part_path = dataset_dir / f"{self.name}_{i:03d}" with open(part_path, "rb") as part_file: f.write(part_file.read()) # delete part file part_path.unlink() if not local_path.exists(): raise IOError( f"Failed to assemble dataset {self.dataset_url} locally at {local_path}." ) self._local_dataset_path = dataset_dir return self._local_dataset_path def _filter_dict(self, keys_to_keep, d): """Keep only a subset of keys in dict. Applied recursively. Args: keys_to_keep (dict): (Nested) dictionary with values being either a dict or a bolean indicating whether to keep an item. d (dict or gymnasium.spaces.Dict): Dicitionary or Dict space that is to be filtered.""" filtered_dict = {} for k, v in keys_to_keep.items(): if isinstance(v, dict): subspace = self._filter_dict(v, d[k]) filtered_dict[k] = subspace elif isinstance(v, bool) and v: filtered_dict[k] = d[k] elif not isinstance(v, bool): raise TypeError( "Expected boolean to indicate whether item " "in observation space is to be kept." ) if isinstance(d, spaces.Dict): filtered_dict = spaces.Dict(spaces=filtered_dict) return filtered_dict def _scale_obs(self, obs: np.ndarray) -> np.ndarray: """Scale observation components to [-1, 1].""" interval = self._obs_unscaled_high.high - self._obs_unscaled_low.low a = (obs - self._obs_unscaled_low.low) / interval return a * 2.0 - 1.0 def _process_obs(self, obs: Union[np.ndarray, Dict]) -> np.ndarray: """Process obs according to params. Assumes that if `self.obs_to_keep` is not None, then the observations are provided as a dictionary. Args: obs: Dictionary or array containing the observations. Returns: Processed observations. If `self.flatten_obs` is False then as a dictionary. If `self.flatten_obs` is True then either as a 1D NumPy array (if no images are contained in obs) or as a tuple (if images are contained in the obs dictionary) consisting of * a 1D NumPy array containing all observations except the camera images, and * a NumPy array of shape (n_cameras, n_channels, height, width) containing the camera images.""" images = None if self.obs_to_keep is not None: # filter obs obs = self._filter_dict(self.obs_to_keep, obs) if self.flatten_obs and isinstance(obs, dict): if "images" in obs["camera_observation"]: # remove camera_observations/images from obs images = obs["camera_observation"].pop("images") # flatten obs obs = spaces.flatten(self._filtered_obs_space, obs) if self.scale_obs: # scale obs obs = self._scale_obs(obs) if images is not None: return obs, images else: return obs
[docs] def get_obs_indices(self) -> Tuple[Dict, Dict]: """Get index ranges that correspond to the different observation components. Also returns a dictionary containing the shapes of these observation components. Returns: - A dictionary with keys corresponding to the observation components and values being tuples of the form (start, end), where start and end are the indices at which the observation component starts and ends. The nested dictionary structure of the observation is preserved. - A dictionary of the same structure but with values being the shapes of the observation components.""" def _construct_dummy_obs(spaces_dict, counter=[0]): """Construct dummy observation which has an array repeating a different integer as the value of each component.""" dummy_obs = {} for i, (k, v) in enumerate(spaces_dict.items()): if isinstance(v, spaces.Dict): dummy_obs[k] = _construct_dummy_obs(v.spaces, counter) else: dummy_obs[k] = counter * np.ones(v.shape, dtype=np.int32) counter[0] += 1 return dummy_obs dummy_obs = _construct_dummy_obs(self._orig_obs_space.spaces) flat_dummy_obs = spaces.flatten(self._orig_obs_space, dummy_obs) def _get_indices_and_shape(dummy_obs, flat_dummy_obs): indices = {} shape = {} for k, v in dummy_obs.items(): if isinstance(v, dict): indices[k], shape[k] = _get_indices_and_shape(v, flat_dummy_obs) else: where = np.where(flat_dummy_obs == v.flatten()[0])[0] indices[k] = (int(where[0]), int(where[-1]) + 1) shape[k] = v.shape return indices, shape return _get_indices_and_shape(dummy_obs, flat_dummy_obs)
[docs] def get_dataset_stats(self, zarr_path: Union[str, os.PathLike] = None) -> Dict: """Get statistics of dataset such as number of timesteps. Args: zarr_path: Optional path to a Zarr directory containing the dataset, which will be used instead of the default. Returns: The statistics of the dataset as a dictionary with keys - n_timesteps: Number of timesteps in dataset. Corresponds to the number of observations, actions and rewards. - obs_size: Size of the observation vector. - action_size: Size of the action vector. """ if zarr_path is None: zarr_path = self._download_dataset() store = zarr.LMDBStore(zarr_path, readonly=True) with zarr.open(store=store) as root: dataset_stats = { "n_timesteps": root["observations"].shape[0], "obs_size": root["observations"].shape[1], "action_size": root["actions"].shape[1], } return dataset_stats
[docs] def get_image_stats(self, zarr_path: Union[str, os.PathLike] = None) -> Dict: """Get statistics of image data in dataset. Args: zarr_path: Optional path to a Zarr directory containing the dataset, which will be used instead of the default. Returns: The statistics of the image data as a dictionary with keys - n_images: Number of images in the dataset. - n_cameras: Number of cameras used to capture the images. - n_channels: Number of channels in the images. - image_shape: Shape of the images in the format (height, width). - reorder_pixels: Whether the pixels in the images have been reordered to have the pixels corresponding to one color in the Bayer pattern together in blocks (to improve image compression). """ if zarr_path is None: zarr_path = self._download_dataset() store = zarr.LMDBStore(zarr_path, readonly=True) with zarr.open(store=store) as root: image_stats = { "n_images": root["images"].shape[0], "n_cameras": root["images"].attrs["n_cameras"], "n_channels": root["images"].attrs["n_channels"], "image_shape": tuple(root["images"].attrs["image_shape"]), "reorder_pixels": root["images"].attrs["reorder_pixels"], } return image_stats
[docs] def get_image_data( self, rng: Optional[Tuple[int, int]] = None, indices: Optional[np.ndarray] = None, zarr_path: Union[str, os.PathLike] = None, timestep_dimension: bool = True, n_threads: Optional[int] = None, ) -> np.ndarray: """Get image data from dataset. Args: rng: Optional range of images to return. rng=(m,n) means that the images with indices m to n-1 are returned. indices: Optional array of image indices for which to load data. rng and indices are mutually exclusive, only one of them can be set. zarr_path: Optional path to a Zarr directory containing the dataset, which will be used instead of the default. timestep_dimension: Whether to include the timestep dimension in the returned array. This is useful if the given range of indices always contains `n_cameras` of image indices in a row which correspond to the camera images at one camera timestep. If this assumption is violated, the first dimension will not correspond to camera timesteps anymore. n_threads: Number of threads to use for processing the images. If None, the number of threads is set to the number of CPUs available to the process. Returns: The image data (or a part of it specified by rng or indices) as a numpy array. If `timestep_dimension` is True the shape will be (n_camera_timesteps, n_cameras, n_channels, height, width) else (n_images, n_channels, height, width). The channels are ordered as RGB. """ if rng is not None and indices is not None: raise ValueError("rng and indices cannot be specified at the same time.") if n_threads is None: n_threads = len(os.sched_getaffinity(0)) if zarr_path is None: zarr_path = self._download_dataset() store = zarr.LMDBStore(zarr_path, readonly=True) root = zarr.open(store=store) n_cameras = root["images"].attrs["n_cameras"] n_channels = root["images"].attrs["n_channels"] image_shape = tuple(root["images"].attrs["image_shape"]) reorder_pixels = root["images"].attrs["reorder_pixels"] compression = root["images"].attrs["compression"] assert compression == "image", "Only image compression is supported." # load only relevant image data if indices is not None: image_data = root["images"].get_orthogonal_selection(indices) else: image_data = root["images"][slice(*rng)] n_unique_images = image_data.shape[0] if timestep_dimension: n_timesteps = int(np.ceil(n_unique_images / n_cameras)) out_shape = (n_timesteps, n_cameras, n_channels) + image_shape else: out_shape = (n_unique_images, n_channels) + image_shape unique_images = np.zeros(out_shape, dtype=np.uint8) threads = [] # distribute image loading and processing over multiple threads for i in range(n_threads): image_loader = ImageLoader( loader_id=i, n_loaders=n_threads, image_data=image_data, unique_images=unique_images, n_unique_images=n_unique_images, n_cameras=n_cameras, reorder_pixels=reorder_pixels, timestep_dimension=timestep_dimension, ) threads.append(image_loader) image_loader.start() for thread in threads: thread.join() store.close() return unique_images
[docs] def convert_timestep_to_image_index( self, timesteps: np.ndarray, zarr_path: Union[str, os.PathLike] = None, ) -> np.ndarray: """Convert camera timesteps to image indices. Args: timesteps: Array of camera timesteps. Returns: Array of image indices. """ if zarr_path is None: zarr_path = self._download_dataset() store = zarr.LMDBStore(zarr_path, readonly=True) root = zarr.open(store=store) # mapping from observation index to image index # (necessary since the camera frequency < control frequency) image_indices = root["obs_to_image_index"].get_coordinate_selection(timesteps) store.close() return image_indices
[docs] def get_dataset( self, zarr_path: Union[str, os.PathLike] = None, clip: bool = True, rng: Optional[Tuple[int, int]] = None, indices: Optional[np.ndarray] = None, n_threads: Optional[int] = None, ) -> Dict[str, Any]: """Get the dataset. When called for the first time, the dataset is automatically downloaded and saved to ``~/.trifinger_rl_datasets``. Args: zarr_path: Optional path to a Zarr directory containing the dataset, which will be used instead of the default. clip: If True, observations are clipped to be within the environment's observation space. rng: Optional range to return. rng=(m,n) means that observations, actions and rewards m to n-1 are returned. If not specified, the entire dataset is returned. indices: Optional array of timestep indices for which to load data. rng and indices are mutually exclusive, only one of them can be set. n_threads: Number of threads to use for processing the images. If None, the number of threads is set to the number of CPUs available to the process. Returns: A dictionary containing the following keys - observations: Either an array or a list of dictionaries containing the observations depending on whether `flatten_obs` is True or False. - actions: Array containing the actions. - rewards: Array containing the rewards. - timeouts: Array containing the timeouts (True only at the end of an episode by default. Always False if `set_terminals` is True). - terminals: Array containing the terminals (Always False by default. If `set_terminals` is True, only True at the last timestep of an episode). - images (only if present in dataset): Array of the shape (n_control_timesteps, n_cameras, n_channels, height, width) containing the image data. The cannels are ordered as RGB. """ # The offline RL dataset is loaded from a Zarr directory which contains # the following Zarr arrays (this is an implementation detail and # not necessary to understand for users of the class): # - observations: Two-dimensional array of shape # `(n_control_timesteps, n_obs)` containing the observations as # flat vectors of length `n_obs` (except for the camera images # which are stored in image_data if present in the dataset). # - actions: Two-dimensional array of shape `(n_control_timesteps, # n_actions)` containing the actions. # - rewards: One-dimensional array of length `n_control_timesteps` # containing the rewards. # - episode_ends: One-dimensional array of length `n_episodes` # containing the indices of the last control timestep of each # episode. # - timeouts: One-dimensional array of length `n_control_timesteps` # with values of type bool. Only True at timesteps where the # episode ends, False otherwise. # - image_data: Ragged array of type bytes, which contains the # compressed image data. The images obtained from all cameras # at each camera time step are written one after another to this # array. After decompression the color information is contained # in a Bayer pattern. The images should therefore be debayerd # before use. Also note the information on the reorder_pixels # attribute below. The dataset has the following attributes: # - n_cameras: Number of cameras. # - n_channels: Number of channels per camera image. # - compression: Type of compression used. Only "image" is # supported by this class. # - image_codec: Codec used to compress the image data. Only # "jpeg" and "png" are supported by this class. # - image_shape: Tuple of length 2 containing the height and width # of the images. # - reorder_pixels: If true, the pixels of the Bayer pattern have # been reordered, such that all pixels of a specific colour are # next to each other in one big block (i.e. one block with all # red pixels, one with all blue pixels and one with all green # pixels). This leads to more continuity of the data (compared # to the original Bayer pattern) and thus tends to improve the # performance of standard image compression algorithms (e.g. # PNG). To restore the original image, the pixels need to be # reordered back before debayering. # - obs_to_image_index: One-dimensional array of length # `n_control_timesteps` containing the index of the camera # image corresponding to each control timestep. This mapping # is necessary because the camera frequency is lower than the # control frequency. if rng is not None and indices is not None: raise ValueError("rng and indices cannot be specified at the same time.") if zarr_path is None: zarr_path = self._download_dataset() store = zarr.LMDBStore(zarr_path, readonly=True) root = zarr.open(store=store) data_dict = {} if indices is None: # turn range into slice n_avail_transitions = root["observations"].shape[0] if rng is None: rng = (None, None) rng = ( 0 if rng[0] is None else rng[0], n_avail_transitions if rng[1] is None else rng[1], ) range_slice = slice(*rng) for k in self._PRELOAD_VECTOR_KEYS + self._PRELOAD_SCALAR_KEYS: data_dict[k] = root[k][range_slice] else: for k in self._PRELOAD_VECTOR_KEYS: data_dict[k] = root[k].get_orthogonal_selection((indices, slice(None))) for k in self._PRELOAD_SCALAR_KEYS: data_dict[k] = root[k].get_coordinate_selection(indices) n_control_timesteps = data_dict["observations"].shape[0] # clip to make sure that there are no outliers in the data if clip: data_dict["observations"] = data_dict["observations"].clip( min=self._flat_unflattening_space.low, max=self._flat_unflattening_space.high, dtype=self._flat_unflattening_space.dtype, ) if not (self.flatten_obs and self.obs_to_keep is None): # unflatten observations, i.e., turn them into dicts again unflattened_obs = [] obs = data_dict["observations"] for i in range(obs.shape[0]): unflattened_obs.append( spaces.unflatten(self._unflattening_space, obs[i, ...]) ) data_dict["observations"] = unflattened_obs # timeouts, terminals and info if self.set_terminals: data_dict["terminals"] = data_dict["timeouts"] data_dict["timeouts"] = np.zeros(n_control_timesteps, dtype=bool) data_dict["infos"] = [{} for _ in range(n_control_timesteps)] # process obs (filtering, flattening, scaling) for i in range(n_control_timesteps): data_dict["observations"][i] = self._process_obs( obs=data_dict["observations"][i] ) # turn observations into array if obs are flattened if self.flatten_obs: data_dict["observations"] = np.array( data_dict["observations"], dtype=self.observation_space.dtype ) if "images" in root.keys(): n_cameras = root["images"].attrs["n_cameras"] if indices is None: # mapping from observation index to image index # (necessary since the camera frequency < control frequency) obs_to_image_index = root["obs_to_image_index"][range_slice] image_index_range = ( obs_to_image_index[0], # add n_cameras to include last images as well obs_to_image_index[-1] + n_cameras, ) # load images unique_images = self.get_image_data( rng=image_index_range, zarr_path=zarr_path, n_threads=n_threads ) else: obs_to_image_index = root[ "obs_to_image_index" ].get_coordinate_selection(indices) # load images from all cameras, not only first one all_cam_indices = np.zeros( obs_to_image_index.shape[0] * n_cameras, dtype=np.int64 ) for i in range(n_cameras): all_cam_indices[i::n_cameras] = obs_to_image_index + i # remove duplicates and sort image_indices, unique_to_original = np.unique( all_cam_indices, return_inverse=True ) # load images unique_images = self.get_image_data( indices=image_indices, zarr_path=zarr_path, n_threads=n_threads ) # repeat images to account for control frequency > camera frequency images = np.zeros( (n_control_timesteps,) + unique_images.shape[1:], dtype=np.uint8 ) for i in range(n_control_timesteps): if indices is None: index = (obs_to_image_index[i] - obs_to_image_index[0]) // n_cameras else: # map from original image index to unique image index index = unique_to_original[i * n_cameras] // n_cameras images[i] = unique_images[index] data_dict["images"] = images store.close() return data_dict
def get_dataset_chunk(self, chunk_id, zarr_path=None): raise NotImplementedError()
[docs] def compute_reward( self, achieved_goal: dict, desired_goal: dict, info: dict ) -> float: """Compute the reward for the given achieved and desired goal. Args: achieved_goal: Current pose of the object. desired_goal: Goal pose of the object. info: An info dictionary containing a field "time_index" which contains the time index of the achieved_goal. Returns: The reward that corresponds to the provided achieved goal w.r.t. to the desired goal. """ return self.sim_env.compute_reward(achieved_goal, desired_goal, info)
[docs] def step( self, action: np.ndarray, **kwargs ) -> Tuple[Union[Dict, np.ndarray], float, bool, bool, Dict]: """Execute one step. Args: action: Array of 9 torque commands, one for each robot joint. Returns: A tuple with - observation (dict or tuple): agent's observation of the current environment. If `self.flatten_obs` is False then as a dictionary. If `self.flatten_obs` is True then either as a 1D NumPy array (if no images are to be included) or as a tuple (if images are to be included) consisting of * a 1D NumPy array containing all observations except the camera images, and * a NumPy array of shape (n_cameras, n_channels, height, width) containing the camera images. - reward (float): amount of reward returned after previous action. - terminated (bool): whether the MDP has reached a terminal state. If true, the user needs to call `reset()`. - truncated (bool): Whether the truncation condition outside the scope of the MDP is satisfied. For this environment this corresponds to a timeout. If true, the user needs to call `reset()`. - info (dict): info dictionary containing the current time index. """ if self.real_robot: raise NotImplementedError( "The step method is not available for real-robot data." ) obs, rew, terminated, truncated, info = self.sim_env.step(action, **kwargs) # process obs processed_obs = self._process_obs(obs) return processed_obs, rew, terminated, truncated, info
[docs] def reset(self) -> Tuple[Union[Dict, np.ndarray], Dict]: """Reset the environment. Returns: Tuple of observation and info dictionary. """ if self.real_robot: raise NotImplementedError( "The reset method is not available for real-robot data." ) obs, info = self.sim_env.reset() # process obs processed_obs = self._process_obs(obs) return processed_obs, info
[docs] def seed(self, seed: Optional[int] = None) -> List[int]: """Set random seed of the environment.""" return self.sim_env.seed(seed)
[docs] def render(self, mode: str = "human"): """Does not do anything for this environment.""" if self.real_robot: raise NotImplementedError( "The render method is not available for real-robot data." ) self.sim_env.render(mode)
[docs] def reset_fingers(self, reset_wait_time: int = 3000): """Moves the fingers to initial position. This resets neither the frontend nor the cube. This method is supposed to be used for 'soft resets' between episodes in one job. """ if self.real_robot: raise NotImplementedError( "The reset_fingers method is not available for real-robot data." ) obs, info = self.sim_env.reset_fingers(reset_wait_time) processed_obs = self._process_obs(obs) return processed_obs, info