Skip to content

Commit

Permalink
Merge pull request octo-models#138 from rail-berkeley/final-dataset-r…
Browse files Browse the repository at this point in the history
…efactor

Dataset refactor pt. ???
  • Loading branch information
kvablack authored Dec 11, 2023
2 parents cd7d81e + e0c1d5f commit cf170b6
Show file tree
Hide file tree
Showing 37 changed files with 1,774 additions and 1,501 deletions.
3 changes: 0 additions & 3 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@ ignore=W503,
E203,
E731,
E722,
F401,
F841,
E402,
E741,
E501,
C406,
per-file-ignores =
orca/data/dataset_transforms.py: F405
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ repos:
- id: check-yaml
- id: check-ast
- id: check-added-large-files
exclude: ^examples/
- id: check-case-conflict
- id: check-merge-conflict
- id: end-of-file-fixer
Expand Down
22 changes: 8 additions & 14 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from ml_collections import ConfigDict
from ml_collections.config_dict import FieldReference, placeholder

from orca.data.utils.data_utils import NormalizationType


def update_config(config, **kwargs):
updates = ConfigDict(kwargs)
Expand Down Expand Up @@ -104,15 +106,12 @@ def get_config(


def get_dataset_config(modality="multimodal", window_size=1):
normalization_type = "normal"
normalization_type = NormalizationType.NORMAL
if modality == "multimodal":
task_augmentation = dict(
task_augment_strategy="delete_task_conditioning",
task_augment_kwargs=dict(
delete_key_groups_probs=[
(["image_*"], 0.5),
(["language_instruction"], 0.5),
],
keep_image_prob=0.5,
),
)
else:
Expand All @@ -124,19 +123,15 @@ def get_dataset_config(modality="multimodal", window_size=1):
data_mix=placeholder(str),
# for v4 TPUs: "gs://rail-orca-central2/resize_336_336"
data_dir=placeholder(str),
n_third_person_cameras=1,
n_wrist_cameras=0,
load_camera_views=("primary", "wrist"),
load_depth=False,
),
# common_dataset_kwargs override specific kwargs from dataset_kwargs_list
"common_dataset_kwargs": dict(
action_proprio_normalization_type=normalization_type,
),
"traj_transform_kwargs": dict(
window_size=window_size,
additional_action_window_size=0,
future_action_window_size=0,
goal_relabeling_strategy="uniform",
subsample_length=100,
**task_augmentation,
),
"frame_transform_kwargs": dict(
resize_size=(256, 256),
Expand All @@ -154,11 +149,10 @@ def get_dataset_config(modality="multimodal", window_size=1):
"random_hue",
],
),
**task_augmentation,
num_parallel_calls=200,
),
"traj_transform_threads": 48, # shared between all datasets
"traj_read_threads": 48, # shared between all datasets
"frame_transform_threads": 200, # not shared between datasets
"shuffle_buffer_size": 100000, # shared between all datasets
"batch_size": 1024,
"balance_weights": True,
Expand Down
528 changes: 528 additions & 0 deletions examples/dataloading.ipynb

Large diffs are not rendered by default.

4 changes: 0 additions & 4 deletions examples/envs/aloha_sim_env.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import copy
import os
from typing import List

import cv2
import dlimp as dl
from einops import rearrange
import gym
import jax
import jax.numpy as jnp
import numpy as np

Expand Down
1 change: 0 additions & 1 deletion examples/eval_finetuned.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import wandb

sys.path.append("/nfs/nfs2/users/karl/code/act")
from envs.aloha_sim_env import AlohaGymEnv

from orca.utils.gym_wrappers import HistoryWrapper, RHCWrapper, UnnormalizeActionProprio
from orca.utils.pretrained_utils import ORCAModel
Expand Down
9 changes: 2 additions & 7 deletions examples/eval_finetuned_on_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,11 @@

from orca.utils.eval_utils import (
download_checkpoint_from_gcs,
load_jaxrlm_checkpoint,
sample_actions,
supply_rng,
)
from orca.utils.gym_wrappers import (
HistoryWrapper,
RHCWrapper,
TemporalEnsembleWrapper,
UnnormalizeActionProprio,
)
from orca.utils.gym_wrappers import HistoryWrapper, RHCWrapper, UnnormalizeActionProprio
from orca.utils.gym_wrappers import TemporalEnsembleWrapper # noqa: F401
from orca.utils.pretrained_utils import ORCAModel

np.set_printoptions(suppress=True)
Expand Down
3 changes: 1 addition & 2 deletions examples/finetune_new_observation_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
import wandb

from orca.data.dataset import make_single_dataset
from orca.data.utils.data_utils import ActionEncoding, StateEncoding
from orca.data.utils.text_processing import text_processors
from orca.data.oxe.oxe_dataset_configs import ActionEncoding, StateEncoding
from orca.utils.jax_utils import initialize_compilation_cache
from orca.utils.pretrained_utils import ORCAModel
from orca.utils.train_callbacks import SaveCallback
Expand Down
32 changes: 32 additions & 0 deletions experiments/kevin/custom_standardization_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Episode transforms for custom (non-OXE) RLDS datasets to canonical dataset definition."""
from typing import Any, Dict

import tensorflow as tf


def r2_d2_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# every input feature is batched, ie has leading batch dimension
trajectory["action"] = tf.concat(
(
trajectory["action_dict"]["cartesian_velocity"],
trajectory["action_dict"]["gripper_velocity"],
),
axis=-1,
)
return trajectory


def fmb_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# every input feature is batched, ie has leading batch dimension
trajectory["observation"]["state"] = tf.concat(
(
trajectory["observation"]["state"],
trajectory["observation"]["gripper_state"][..., None],
),
axis=-1,
)
return trajectory


def aloha_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
return trajectory
27 changes: 12 additions & 15 deletions experiments/kevin/golden_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import copy

from config import get_config as get_base_config
from config import update_config, wrap

Expand Down Expand Up @@ -44,14 +42,14 @@ def get_config(config_string=None):
del base_config["dataset_kwargs"]["frame_transform_kwargs"]["resize_size"]
del base_config["dataset_kwargs"]["frame_transform_kwargs"]["image_augment_kwargs"]

base_config["dataset_kwargs"]["frame_transform_kwargs"]["resize_size"] = [
(256, 256), # workspace (3rd person) camera is at 256x256
(128, 128), # wrist camera is at 128x128
]
base_config["dataset_kwargs"]["frame_transform_kwargs"]["image_augment_kwargs"] = [
workspace_augment_kwargs,
wrist_augment_kwargs,
]
base_config["dataset_kwargs"]["frame_transform_kwargs"]["resize_size"] = {
"primary": (256, 256), # workspace camera is at 256x256
"wrist": (128, 128), # wrist camera is at 128x128
}
base_config["dataset_kwargs"]["frame_transform_kwargs"]["image_augment_kwargs"] = {
"primary": workspace_augment_kwargs,
"wrist": wrist_augment_kwargs,
}

config = update_config(
base_config,
Expand All @@ -62,7 +60,6 @@ def get_config(config_string=None):
oxe_kwargs=dict(
data_mix="oxe_magic_soup",
data_dir="gs://rail-orca-central2/resize_256_256",
n_wrist_cameras=1,
),
batch_size=256,
shuffle_buffer_size=500000,
Expand All @@ -73,17 +70,17 @@ def get_config(config_string=None):
"workspace": {
"cls_name": "image_tokenizer",
"kwargs": dict(
obs_stack_keys=["image_0"],
task_stack_keys=["image_0"],
obs_stack_keys=["image_primary"],
task_stack_keys=["image_primary"],
task_film_keys=[],
encoder="small-stem-16",
),
},
"wrist": {
"cls_name": "image_tokenizer",
"kwargs": dict(
obs_stack_keys=["image_1"],
task_stack_keys=["image_1"],
obs_stack_keys=["image_wrist"],
task_stack_keys=["image_wrist"],
task_film_keys=[],
encoder="small-stem-16",
),
Expand Down
14 changes: 12 additions & 2 deletions finetune.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
from functools import partial
import imp
import json
import os

Expand Down Expand Up @@ -154,12 +155,21 @@ def process_batch(batch):
del batch["dataset_name"]
return batch

# load standardize_fn from `path/to/file.py:fn_name` format
if (
standardize_fn := FLAGS.config["dataset_kwargs"].get("standardize_fn", None)
) is not None:
path, name = standardize_fn.split(":")
# imp is deprecated, but it's also what ml_collections uses
standardize_fn = getattr(imp.load_source("standardize_fn", path), name)
del FLAGS.config["dataset_kwargs"]["standardize_fn"]
FLAGS.config["dataset_kwargs"]["standardize_fn"] = standardize_fn

dataset = make_single_dataset(
FLAGS.config.dataset_kwargs,
FLAGS.config.traj_transform_kwargs,
FLAGS.config.frame_transform_kwargs,
train=True,
frame_transform_threads=FLAGS.config.frame_transform_threads,
)
train_data_iter = (
dataset.repeat()
Expand Down Expand Up @@ -267,7 +277,7 @@ def loss_fn(params, state, batch, rng, train=True):
model = model_def.bind({"params": params}, rngs={"dropout": rng})
transformer_embeddings = model.orca_transformer(
batch["observation"],
batch["tasks"],
batch["task"],
batch["observation"]["pad_mask"],
train=train,
)
Expand Down
9 changes: 4 additions & 5 deletions finetune_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from config import update_config, wrap
from config import wrap
from ml_collections import ConfigDict
from ml_collections.config_dict import FieldReference, placeholder

from orca.data.utils.data_utils import ActionEncoding, StateEncoding


@wrap
def get_config(
Expand All @@ -25,9 +23,10 @@ def get_config(
"data_dir": "./tests/debug_dataset",
"image_obs_keys": ["image_0", None],
"state_obs_keys": ["state", None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"action_proprio_normalization_type": "normal",
# standardize_fn is dynamically loaded from a file
# for example: "experiments/kevin/custom_standardization_transforms.py:aloha_dataset_transform"
"standardize_fn": "orca/data/oxe/oxe_standardization_transforms.py:bridge_dataset_transform",
# If the default data loading speed is too slow, try these:
# "num_parallel_reads": 8, # for reading from disk / GCS
# "num_parallel_calls": 16, # for initial dataset construction
Expand Down
Loading

0 comments on commit cf170b6

Please sign in to comment.