Skip to content

Commit

Permalink
Merge pull request octo-models#134 from rail-berkeley/kevin-aug-per-i…
Browse files Browse the repository at this point in the history
…mage

Allow separate augmentation/resizing for different images
  • Loading branch information
dibyaghosh authored Dec 7, 2023
2 parents 6df0a40 + 28840cf commit aefd71e
Show file tree
Hide file tree
Showing 9 changed files with 299 additions and 88 deletions.
10 changes: 5 additions & 5 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def get_config(
dataset_kwargs=get_dataset_config(modality, window_size),
optimizer=dict(
learning_rate=dict(
name="rsqrt",
init_value=0.0,
peak_value=3e-4,
warmup_steps=2000,
decay_steps=num_steps,
end_value=0.0,
timescale=10000,
),
weight_decay=0.1,
clip_gradient=1.0,
Expand Down Expand Up @@ -107,8 +107,8 @@ def get_dataset_config(modality="multimodal", window_size=1):
normalization_type = "normal"
if modality == "multimodal":
task_augmentation = dict(
task_augmentation_strategy="delete_task_conditioning",
task_augmentation_kwargs=dict(
task_augment_strategy="delete_task_conditioning",
task_augment_kwargs=dict(
delete_key_groups_probs=[
(["image_*"], 0.5),
(["language_instruction"], 0.5),
Expand Down Expand Up @@ -137,7 +137,6 @@ def get_dataset_config(modality="multimodal", window_size=1):
additional_action_window_size=0,
goal_relabeling_strategy="uniform",
subsample_length=100,
**task_augmentation,
),
"frame_transform_kwargs": dict(
resize_size=(256, 256),
Expand All @@ -155,6 +154,7 @@ def get_dataset_config(modality="multimodal", window_size=1):
"random_hue",
],
),
**task_augmentation,
),
"traj_transform_threads": 48, # shared between all datasets
"traj_read_threads": 48, # shared between all datasets
Expand Down
58 changes: 39 additions & 19 deletions experiments/dibya/finetune_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from config import update_config, wrap
from ml_collections import ConfigDict
from ml_collections.config_dict import FieldReference, placeholder

from config import update_config, wrap
from orca.data.utils.data_utils import ActionEncoding, StateEncoding


Expand Down Expand Up @@ -71,6 +71,7 @@ def get_config(
window_size=window_size,
optimizer=dict(
learning_rate=dict(
name="cosine",
init_value=0.0,
peak_value=3e-4,
warmup_steps=2000,
Expand Down Expand Up @@ -116,28 +117,47 @@ def get_config(
window_size=window_size,
additional_action_window_size=0,
goal_relabeling_strategy=goal_relabeling_strategy,
task_augmentation_strategy="delete_task_conditioning",
task_augmentation_kwargs=dict(
delete_key_groups_probs=delete_key_groups_probs,
),
# If the default data loading speed is too slow, try these:
# num_parallel_calls=16, # for less CPU-intensive ops
)
workspace_augment_kwargs = dict(
random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),
random_brightness=[0.1],
random_contrast=[0.9, 1.1],
random_saturation=[0.9, 1.1],
random_hue=[0.05],
augment_order=[
"random_resized_crop",
"random_brightness",
"random_contrast",
"random_saturation",
"random_hue",
],
)
wrist_augment_kwargs = dict(
random_brightness=[0.1],
random_contrast=[0.9, 1.1],
random_saturation=[0.9, 1.1],
random_hue=[0.05],
augment_order=[
"random_brightness",
"random_contrast",
"random_saturation",
"random_hue",
],
)
frame_transform_kwargs = dict(
resize_size=(256, 256),
image_augment_kwargs=dict(
random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),
random_brightness=[0.2],
random_contrast=[0.8, 1.2],
random_saturation=[0.8, 1.2],
random_hue=[0.1],
augment_order=[
"random_resized_crop",
"random_brightness",
"random_contrast",
"random_saturation",
"random_hue",
],
resize_size=[
(256, 256), # workspace (3rd person) camera is at 256x256
(128, 128), # wrist camera is at 128x128
],
image_augment_kwargs=[
workspace_augment_kwargs,
wrist_augment_kwargs,
],
task_augment_strategy="delete_task_conditioning",
task_augment_kwargs=dict(
delete_key_groups_probs=delete_key_groups_probs,
),
)
# If the default data loading speed is too slow, try these:
Expand Down
118 changes: 118 additions & 0 deletions experiments/kevin/golden_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import copy

from config import get_config as get_base_config
from config import update_config, wrap


def get_config(config_string=None):
base_config = get_base_config(config_string)

# Can't delete with update_config
del base_config["model"]["observation_tokenizers"]
# Field reference can't be updated with update_config
base_config["window_size"] = 2
base_config["num_steps"] = 300000

# different augmentations for wrist and workspace
workspace_augment_kwargs = dict(
random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),
random_brightness=[0.1],
random_contrast=[0.9, 1.1],
random_saturation=[0.9, 1.1],
random_hue=[0.05],
augment_order=[
"random_resized_crop",
"random_brightness",
"random_contrast",
"random_saturation",
"random_hue",
],
)
wrist_augment_kwargs = dict(
random_brightness=[0.1],
random_contrast=[0.9, 1.1],
random_saturation=[0.9, 1.1],
random_hue=[0.05],
augment_order=[
"random_brightness",
"random_contrast",
"random_saturation",
"random_hue",
],
)

del base_config["dataset_kwargs"]["frame_transform_kwargs"]["resize_size"]
del base_config["dataset_kwargs"]["frame_transform_kwargs"]["image_augment_kwargs"]

base_config["dataset_kwargs"]["frame_transform_kwargs"]["resize_size"] = [
(256, 256), # workspace (3rd person) camera is at 256x256
(128, 128), # wrist camera is at 128x128
]
base_config["dataset_kwargs"]["frame_transform_kwargs"]["image_augment_kwargs"] = [
workspace_augment_kwargs,
wrist_augment_kwargs,
]

config = update_config(
base_config,
optimizer=dict(
frozen_keys=("*hf_model*",),
),
dataset_kwargs=dict(
oxe_kwargs=dict(
data_mix="oxe_magic_soup",
data_dir="gs://rail-orca-central2/resize_256_256",
n_wrist_cameras=1,
),
batch_size=256,
shuffle_buffer_size=500000,
balance_weights=True,
),
model={
"observation_tokenizers": {
"workspace": {
"cls_name": "image_tokenizer",
"kwargs": dict(
obs_stack_keys=["image_0"],
task_stack_keys=["image_0"],
task_film_keys=[],
encoder="small-stem-16",
),
},
"wrist": {
"cls_name": "image_tokenizer",
"kwargs": dict(
obs_stack_keys=["image_1"],
task_stack_keys=["image_1"],
task_film_keys=[],
encoder="small-stem-16",
),
},
},
"task_tokenizers": {
"language": {
"cls_name": "language_tokenizer",
"kwargs": dict(
encoder="t5-base",
finetune_encoder=False,
),
},
},
},
text_processor="hf_tokenizer",
text_processor_kwargs=dict(
tokenizer_name="t5-base",
encode_with_model=False,
tokenizer_kwargs={
"max_length": 16,
"padding": "max_length",
"truncation": True,
"return_tensors": "np",
},
),
pretrained_loaders=["from_huggingface"],
pretrained_loader_kwargs=[dict(hf_model="t5-base")],
eval_datasets=["bridge_dataset"],
)

return config
2 changes: 1 addition & 1 deletion finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def process_batch(batch):

tx, lr_callable, param_norm_callable = create_optimizer(
params,
FLAGS.config.optimizer.to_dict(),
**FLAGS.config.optimizer.to_dict(),
)
train_state = TrainState.create(
apply_fn=model.model_def.apply,
Expand Down
Loading

0 comments on commit aefd71e

Please sign in to comment.