diff --git a/config.py b/config.py index 27bc4121..83b59f86 100644 --- a/config.py +++ b/config.py @@ -56,11 +56,11 @@ def get_config( dataset_kwargs=get_dataset_config(modality, window_size), optimizer=dict( learning_rate=dict( + name="rsqrt", init_value=0.0, peak_value=3e-4, warmup_steps=2000, - decay_steps=num_steps, - end_value=0.0, + timescale=10000, ), weight_decay=0.1, clip_gradient=1.0, @@ -107,8 +107,8 @@ def get_dataset_config(modality="multimodal", window_size=1): normalization_type = "normal" if modality == "multimodal": task_augmentation = dict( - task_augmentation_strategy="delete_task_conditioning", - task_augmentation_kwargs=dict( + task_augment_strategy="delete_task_conditioning", + task_augment_kwargs=dict( delete_key_groups_probs=[ (["image_*"], 0.5), (["language_instruction"], 0.5), @@ -137,7 +137,6 @@ def get_dataset_config(modality="multimodal", window_size=1): additional_action_window_size=0, goal_relabeling_strategy="uniform", subsample_length=100, - **task_augmentation, ), "frame_transform_kwargs": dict( resize_size=(256, 256), @@ -155,6 +154,7 @@ def get_dataset_config(modality="multimodal", window_size=1): "random_hue", ], ), + **task_augmentation, ), "traj_transform_threads": 48, # shared between all datasets "traj_read_threads": 48, # shared between all datasets diff --git a/experiments/dibya/finetune_config.py b/experiments/dibya/finetune_config.py index 78a98dd3..56c1d62a 100644 --- a/experiments/dibya/finetune_config.py +++ b/experiments/dibya/finetune_config.py @@ -1,7 +1,7 @@ +from config import update_config, wrap from ml_collections import ConfigDict from ml_collections.config_dict import FieldReference, placeholder -from config import update_config, wrap from orca.data.utils.data_utils import ActionEncoding, StateEncoding @@ -71,6 +71,7 @@ def get_config( window_size=window_size, optimizer=dict( learning_rate=dict( + name="cosine", init_value=0.0, peak_value=3e-4, warmup_steps=2000, @@ -116,28 +117,47 @@ def get_config( window_size=window_size, additional_action_window_size=0, goal_relabeling_strategy=goal_relabeling_strategy, - task_augmentation_strategy="delete_task_conditioning", - task_augmentation_kwargs=dict( - delete_key_groups_probs=delete_key_groups_probs, - ), # If the default data loading speed is too slow, try these: # num_parallel_calls=16, # for less CPU-intensive ops ) + workspace_augment_kwargs = dict( + random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]), + random_brightness=[0.1], + random_contrast=[0.9, 1.1], + random_saturation=[0.9, 1.1], + random_hue=[0.05], + augment_order=[ + "random_resized_crop", + "random_brightness", + "random_contrast", + "random_saturation", + "random_hue", + ], + ) + wrist_augment_kwargs = dict( + random_brightness=[0.1], + random_contrast=[0.9, 1.1], + random_saturation=[0.9, 1.1], + random_hue=[0.05], + augment_order=[ + "random_brightness", + "random_contrast", + "random_saturation", + "random_hue", + ], + ) frame_transform_kwargs = dict( - resize_size=(256, 256), - image_augment_kwargs=dict( - random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]), - random_brightness=[0.2], - random_contrast=[0.8, 1.2], - random_saturation=[0.8, 1.2], - random_hue=[0.1], - augment_order=[ - "random_resized_crop", - "random_brightness", - "random_contrast", - "random_saturation", - "random_hue", - ], + resize_size=[ + (256, 256), # workspace (3rd person) camera is at 256x256 + (128, 128), # wrist camera is at 128x128 + ], + image_augment_kwargs=[ + workspace_augment_kwargs, + wrist_augment_kwargs, + ], + task_augment_strategy="delete_task_conditioning", + task_augment_kwargs=dict( + delete_key_groups_probs=delete_key_groups_probs, ), ) # If the default data loading speed is too slow, try these: diff --git a/experiments/kevin/golden_config.py b/experiments/kevin/golden_config.py new file mode 100644 index 00000000..91aa3a0c --- /dev/null +++ b/experiments/kevin/golden_config.py @@ -0,0 +1,118 @@ +import copy + +from config import get_config as get_base_config +from config import update_config, wrap + + +def get_config(config_string=None): + base_config = get_base_config(config_string) + + # Can't delete with update_config + del base_config["model"]["observation_tokenizers"] + # Field reference can't be updated with update_config + base_config["window_size"] = 2 + base_config["num_steps"] = 300000 + + # different augmentations for wrist and workspace + workspace_augment_kwargs = dict( + random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]), + random_brightness=[0.1], + random_contrast=[0.9, 1.1], + random_saturation=[0.9, 1.1], + random_hue=[0.05], + augment_order=[ + "random_resized_crop", + "random_brightness", + "random_contrast", + "random_saturation", + "random_hue", + ], + ) + wrist_augment_kwargs = dict( + random_brightness=[0.1], + random_contrast=[0.9, 1.1], + random_saturation=[0.9, 1.1], + random_hue=[0.05], + augment_order=[ + "random_brightness", + "random_contrast", + "random_saturation", + "random_hue", + ], + ) + + del base_config["dataset_kwargs"]["frame_transform_kwargs"]["resize_size"] + del base_config["dataset_kwargs"]["frame_transform_kwargs"]["image_augment_kwargs"] + + base_config["dataset_kwargs"]["frame_transform_kwargs"]["resize_size"] = [ + (256, 256), # workspace (3rd person) camera is at 256x256 + (128, 128), # wrist camera is at 128x128 + ] + base_config["dataset_kwargs"]["frame_transform_kwargs"]["image_augment_kwargs"] = [ + workspace_augment_kwargs, + wrist_augment_kwargs, + ] + + config = update_config( + base_config, + optimizer=dict( + frozen_keys=("*hf_model*",), + ), + dataset_kwargs=dict( + oxe_kwargs=dict( + data_mix="oxe_magic_soup", + data_dir="gs://rail-orca-central2/resize_256_256", + n_wrist_cameras=1, + ), + batch_size=256, + shuffle_buffer_size=500000, + balance_weights=True, + ), + model={ + "observation_tokenizers": { + "workspace": { + "cls_name": "image_tokenizer", + "kwargs": dict( + obs_stack_keys=["image_0"], + task_stack_keys=["image_0"], + task_film_keys=[], + encoder="small-stem-16", + ), + }, + "wrist": { + "cls_name": "image_tokenizer", + "kwargs": dict( + obs_stack_keys=["image_1"], + task_stack_keys=["image_1"], + task_film_keys=[], + encoder="small-stem-16", + ), + }, + }, + "task_tokenizers": { + "language": { + "cls_name": "language_tokenizer", + "kwargs": dict( + encoder="t5-base", + finetune_encoder=False, + ), + }, + }, + }, + text_processor="hf_tokenizer", + text_processor_kwargs=dict( + tokenizer_name="t5-base", + encode_with_model=False, + tokenizer_kwargs={ + "max_length": 16, + "padding": "max_length", + "truncation": True, + "return_tensors": "np", + }, + ), + pretrained_loaders=["from_huggingface"], + pretrained_loader_kwargs=[dict(hf_model="t5-base")], + eval_datasets=["bridge_dataset"], + ) + + return config diff --git a/finetune.py b/finetune.py index e1d277fc..403f80e1 100644 --- a/finetune.py +++ b/finetune.py @@ -198,7 +198,7 @@ def process_batch(batch): tx, lr_callable, param_norm_callable = create_optimizer( params, - FLAGS.config.optimizer.to_dict(), + **FLAGS.config.optimizer.to_dict(), ) train_state = TrainState.create( apply_fn=model.model_def.apply, diff --git a/orca/data/dataset.py b/orca/data/dataset.py index 921dca08..b1d73369 100644 --- a/orca/data/dataset.py +++ b/orca/data/dataset.py @@ -140,23 +140,51 @@ def _decode_images(obs: dict) -> dict: def _augment(obs: dict, seed, augment_kwargs) -> dict: """Augments images, skipping padding images.""" - for key in obs: - if "image" in key: + num_image_keys = sum(["image" in key for key in obs]) + + if not isinstance(augment_kwargs, Sequence): + augment_kwargs = [copy.deepcopy(augment_kwargs)] * num_image_keys + + for i in range(num_image_keys): + if augment_kwargs[i] is not None: + key = f"image_{i}" if obs["pad_mask_dict"][key]: obs[key] = dl.transforms.augment_image( - obs[key], **augment_kwargs, seed=seed + obs[key], **augment_kwargs[i], seed=seed + i ) return obs +def _resize(obs: dict, resize_size, depth_resize_size) -> dict: + """Resizes images and depth images.""" + num_image_keys = sum(["image" in key for key in obs]) + num_depth_keys = sum(["depth" in key for key in obs]) + + if resize_size is None or isinstance(resize_size[0], int): + resize_size = [resize_size] * num_image_keys + if depth_resize_size is None or isinstance(depth_resize_size[0], int): + depth_resize_size = [depth_resize_size] * num_depth_keys + + for i in range(num_image_keys): + if resize_size[i] is not None: + key = f"image_{i}" + obs[key] = dl.transforms.resize_image(obs[key], size=resize_size[i]) + + for i in range(num_depth_keys): + if depth_resize_size[i] is not None: + key = f"depth_{i}" + obs[key] = dl.transforms.resize_depth_image( + obs[key], size=depth_resize_size[i] + ) + return obs + + def apply_trajectory_transforms( dataset: dl.DLataset, *, train: bool, goal_relabeling_strategy: Optional[str] = None, goal_relabeling_kwargs: dict = {}, - task_augmentation_strategy: Optional[str] = None, - task_augmentation_kwargs: dict = {}, window_size: int = 1, additional_action_window_size: int = 0, action_encoding: ActionEncoding = ActionEncoding.EEF_POS, @@ -181,9 +209,6 @@ def apply_trajectory_transforms( goal_relabeling_strategy (str, optional): The goal relabeling strategy to use, or None for no goal relabeling. See `bc_goal_relabeling.py`. goal_relabeling_kwargs (dict, optional): Additional keyword arguments to pass to the goal relabeling function. - task_augmentation_strategy (Optional[str], optional): The task augmentation strategy to use, or None for no task - augmentation. See `task_augmentation.py`. - task_augmentation_kwargs (dict, optional): Additional keyword arguments to pass to the task augmentation function. window_size (int, optional): The length of the snippets that trajectories are chunked into. additional_action_window_size (int, optional): The number of additional actions beyond window_size to include in the chunked actions. @@ -239,15 +264,6 @@ def move_language_instruction_to_tasks(traj): partial(_subsample, subsample_length=subsample_length), num_parallel_calls ) - if train and task_augmentation_strategy is not None: - dataset = dataset.map( - partial( - getattr(task_augmentation, task_augmentation_strategy), - **task_augmentation_kwargs, - ), - num_parallel_calls, - ) - dataset = dataset.map( partial( _chunk_act_obs, @@ -263,8 +279,15 @@ def move_language_instruction_to_tasks(traj): def get_frame_transforms( train: bool, - image_augment_kwargs: Optional[dict] = None, - resize_size: Optional[Tuple[int, int]] = None, + image_augment_kwargs: Union[Optional[dict], Sequence[Optional[dict]]] = None, + resize_size: Union[ + Optional[Tuple[int, int]], Sequence[Optional[Tuple[int, int]]] + ] = None, + depth_resize_size: Union[ + Optional[Tuple[int, int]], Sequence[Optional[Tuple[int, int]]] + ] = None, + task_augment_strategy: Optional[str] = None, + task_augment_kwargs: dict = {}, ) -> List[Callable[[dict], dict]]: """ Returns a list of functions to be applied to each frame. These transforms are usually @@ -272,9 +295,21 @@ def get_frame_transforms( Args: train (bool): Whether the dataset is for training (affects image augmentation). - image_augment_kwargs (dict): Keyword arguments to pass to the image augmentation function. See - `dlimp.transforms.augment_image` for documentation. - resize_size (Tuple[int, int], optional): If provided, images will be resized to this size. + image_augment_kwargs (dict or Sequence[dict]): Keyword arguments to pass to the image + augmentation function. See `dlimp.transforms.augment_image` for documentation. If a list + of dicts is provided, then the ith entry will be used for "image_i" (order determined by + "image_obs_keys"). A value of None or a None list entry will skip image augmentation for + the corresponding image(s). + resize_size (Tuple[int, int] or Sequence[Tuple[int, int]]): If provided, images will be + resized to this size. If a list of tuples is provided, then the ith entry will be used + for "image_i" and "depth_i" (order determined by "image_obs_keys" and "depth_obs_keys", + respectively). A value of None or a None list entry will skip resizing for the + corresponding image(s). + depth_resize_size (Tuple[int, int] or Sequence[Tuple[int, int]]): Same as resize_size, but + for depth images. + task_augmentation_strategy (Optional[str], optional): The task augmentation strategy to use, or None for no task + augmentation. See `task_augmentation.py`. + task_augmentation_kwargs (dict, optional): Additional keyword arguments to pass to the task augmentation function. """ # convenience wrapper that takes a function that operates on a non-chunked "observation" dict @@ -288,25 +323,29 @@ def apply_obs_transform(fn: Callable[[dict], dict], frame): transforms = [] - # decode images (and depth images), marking empty strings as padding - transforms.append(partial(apply_obs_transform, _decode_images)) - - # resize images, if requested - if resize_size is not None: + if train and task_augment_strategy is not None: + # perform task augmentation (e.g., dropping keys) transforms.append( partial( - apply_obs_transform, - partial(dl.transforms.resize_images, size=resize_size), - ) + getattr(task_augmentation, task_augment_strategy), + **task_augment_kwargs, + ), ) - transforms.append( + + # decode images (and depth images) + transforms.append(partial(apply_obs_transform, _decode_images)) + + # resize images (and depth images) + transforms.append( + partial( + apply_obs_transform, partial( - apply_obs_transform, - partial(dl.transforms.resize_depth_images, size=resize_size), - ) + _resize, resize_size=resize_size, depth_resize_size=depth_resize_size + ), ) + ) - if train and image_augment_kwargs is not None: + if train: # augment all images with the same seed, skipping padding images def aug(frame): seed = tf.random.uniform([2], maxval=tf.dtypes.int32.max, dtype=tf.int32) @@ -323,9 +362,9 @@ def make_dataset_from_rlds( data_dir: str, train: bool, shuffle: bool = True, - image_obs_keys: Union[str, List[str]] = [], - depth_obs_keys: Union[str, List[str]] = [], - state_obs_keys: Union[str, List[str]] = [], + image_obs_keys: Union[str, Sequence[str]] = (), + depth_obs_keys: Union[str, Sequence[str]] = (), + state_obs_keys: Union[str, Sequence[str]] = (), state_encoding: StateEncoding = StateEncoding.NONE, action_encoding: ActionEncoding = ActionEncoding.EEF_POS, action_proprio_normalization_type: Optional[str] = None, diff --git a/orca/model/components/tokenizers.py b/orca/model/components/tokenizers.py index 37f0d831..1b2f52c5 100644 --- a/orca/model/components/tokenizers.py +++ b/orca/model/components/tokenizers.py @@ -1,4 +1,3 @@ -import functools as ft import logging import re from typing import Dict, Optional, Sequence @@ -13,8 +12,7 @@ from orca.model.components.transformer import MAPHead EPS = 1e-6 -from dataclasses import dataclass -import os +from dataclasses import field def generate_proper_pad_mask( @@ -84,7 +82,7 @@ class ImageTokenizer(nn.Module): """ encoder: str - encoder_kwargs: dict = None + encoder_kwargs: dict = field(default_factory=dict) use_token_learner: bool = False num_tokens: int = 8 conditioning_type: str = "none" diff --git a/orca/utils/train_callbacks.py b/orca/utils/train_callbacks.py index f77d53e3..6d55420f 100644 --- a/orca/utils/train_callbacks.py +++ b/orca/utils/train_callbacks.py @@ -62,9 +62,10 @@ class SaveCallback(Callback): save_dir: Optional[str] def __post_init__(self): - if self.save_dir is not None and jax.process_index() == 0: - tf.io.gfile.makedirs(self.save_dir) - logging.info(f"Created {self.save_dir}") + if self.save_dir is not None: + if jax.process_index() == 0: + tf.io.gfile.makedirs(self.save_dir) + logging.info(f"Created {self.save_dir}") # make checkpointers # only keep latest full TrainState self.state_checkpointer = orbax.checkpoint.CheckpointManager( @@ -81,7 +82,7 @@ def __post_init__(self): ) def __call__(self, train_state: TrainState, step: int): - if self.save_dir is not None and jax.process_index() == 0: + if self.save_dir is not None: self.params_checkpointer.save( step, train_state.params, @@ -219,7 +220,10 @@ def __post_init__(self): val_iterator = map(self.process_batch_fn, val_iterator) self.val_iterators[single_dataset_kwargs["name"]] = val_iterator - @jax.jit + @partial( + jax.jit, + out_shardings=jax.sharding.PositionalSharding(jax.devices()).replicate(), + ) def eval_step(state, batch): loss_fn_partial = partial( self.loss_fn, diff --git a/orca/utils/train_utils.py b/orca/utils/train_utils.py index ed6365d5..ba4b6756 100644 --- a/orca/utils/train_utils.py +++ b/orca/utils/train_utils.py @@ -189,36 +189,68 @@ def filter_eval_datasets(dataset_kwargs_list, sample_weights, eval_datasets=None ) -def create_optimizer( - params_or_params_shape: Params, optimizer_kwargs: dict -) -> optax.GradientTransformation: +def create_lr_schedule(name: str, **kwargs): + """Creates a learning rate callable. + + Currently supported schedules: + cosine: cosine decay with warmup. + kwargs: init_value, peak_value, warmup_steps, decay_steps + rsqrt: inverse square root decay with warmup, from the "Scaling Vision Transformers" paper. + kwargs: init_value, peak_value, warmup_steps, timescale (optional, default 10000) + + Args: + name: name of the schedule + **kwargs: additional kwargs, which vary by schedule + """ + if name == "cosine": + return optax.warmup_cosine_decay_schedule(**kwargs) + elif name == "rsqrt": + timescale = kwargs.get("timescale", 10000) + return optax.join_schedules( + [ + optax.linear_schedule( + init_value=kwargs["init_value"], + end_value=kwargs["peak_value"], + transition_steps=kwargs["warmup_steps"], + ), + lambda step: kwargs["peak_value"] + / jnp.sqrt((step + timescale) / timescale), + ], + [kwargs["warmup_steps"]], + ) + else: + raise ValueError(f"Unsupported lr schedule: {name}") + + +def create_optimizer(params_or_params_shape: Params, **kwargs: dict): """Creates optimizer for ORCA. - Optimizer_kwargs are the kwargs for optax.adamw; if the learning rate is a dict, - it is interpreted as the kwargs for optax.warmup_cosine_decay_schedule. If clip_gradient - is specified, then gradient clipping is applied. + kwargs are the kwargs for optax.adamw; if the "learning_rate" key is a dict, it is interpreted + as the kwargs for create_lr_schedule (see above), otherwise it is interpreted as a constant + learning rate. + + If clip_gradient is specified, then gradient clipping is applied. If frozen_keys is specified, + then those parameters are frozen (i.e. not updated) during training. Returns: tx: an Optax optimizer lr_callable: Function that takes the current step and returns the learning rate """ - if isinstance(optimizer_kwargs["learning_rate"], dict): - optimizer_kwargs["learning_rate"] = optax.warmup_cosine_decay_schedule( - **optimizer_kwargs["learning_rate"] - ) - lr_callable = optimizer_kwargs["learning_rate"] + if isinstance(kwargs["learning_rate"], dict): + lr_callable = create_lr_schedule(**kwargs["learning_rate"]) else: - lr_callable = lambda _: optimizer_kwargs["learning_rate"] + lr_callable = lambda _: kwargs["learning_rate"] + kwargs["learning_rate"] = lr_callable # Following ViT, timm, MAE: this mask skips weight decay on biases and LayerNorm parameters wd_mask = jax.tree_util.tree_map_with_path( lambda path, x: "kernel" in jax.tree_util.keystr(path), params_or_params_shape ) - clip_gradient = optimizer_kwargs.pop("clip_gradient", None) - frozen_keys = optimizer_kwargs.pop("frozen_keys", None) + clip_gradient = kwargs.pop("clip_gradient", None) + frozen_keys = kwargs.pop("frozen_keys", None) - tx = optax.adamw(mu_dtype=jnp.bfloat16, **optimizer_kwargs, mask=wd_mask) + tx = optax.adamw(mu_dtype=jnp.bfloat16, **kwargs, mask=wd_mask) if clip_gradient is not None: tx = optax.chain( optax.clip_by_global_norm(clip_gradient), diff --git a/train.py b/train.py index a91c61b1..a5d6335c 100644 --- a/train.py +++ b/train.py @@ -240,7 +240,7 @@ def process_batch(batch): )["params"] tx, lr_callable, param_norm_callable = create_optimizer( params_shape, - FLAGS.config.optimizer.to_dict(), + **FLAGS.config.optimizer.to_dict(), ) train_state = create_train_state( construct_rng,