Skip to content

Commit

Permalink
Merge pull request octo-models#127 from rail-berkeley/dibya_add_pad_m…
Browse files Browse the repository at this point in the history
…asks

Paying attention to Pad Masks
  • Loading branch information
kvablack authored Dec 7, 2023
2 parents 06a8545 + c0dff44 commit 6df0a40
Show file tree
Hide file tree
Showing 12 changed files with 574 additions and 352 deletions.
5 changes: 3 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@ def get_config(
entity=placeholder(str),
),
wandb_resume_id=placeholder(str),
eval_datasets=[
eval_datasets=(
"bridge_dataset",
"taco_play",
"berkeley_cable_routing",
"berkeley_autolab_ur5",
],
),
)
)

Expand Down Expand Up @@ -230,6 +230,7 @@ def get_model_config(transformer_size):

return {
**get_transformer_kwargs(transformer_size),
"proper_pad_mask": True,
"max_horizon": 10,
"readouts": dict(),
"heads": dict(
Expand Down
33 changes: 23 additions & 10 deletions orca/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,36 @@ def _subsample(traj, subsample_length):
return traj


def _add_pad_mask_dict(traj):
"""Adds a dictionary indicating which elements of the observation are padding.
traj["observation"]["pad_mask_dict"] = {k: traj["observation"][k] is not padding}
"""
traj_len = tf.shape(traj["action"])[0]
pad_masks = {}
for key in traj["observation"]:
if traj["observation"][key].dtype == tf.string:
pad_masks[key] = tf.strings.length(traj["observation"][key]) != 0
else:
pad_masks[key] = tf.ones([traj_len], dtype=tf.bool)
traj["observation"]["pad_mask_dict"] = pad_masks
return traj


def _decode_images(obs: dict) -> dict:
"""Decodes images and depth images, marking empty strings as padding."""
pad_mask_dict = {} # indicates which keys in the dict are padding
"""Decodes images and depth images."""
for key in obs:
if "image" in key:
if obs[key].dtype == tf.string:
if tf.strings.length(obs[key]) == 0:
# this is a padding image
obs[key] = tf.zeros((1, 1, 3), dtype=tf.uint8)
pad_mask_dict[key] = False
else:
obs[key] = tf.io.decode_image(
obs[key], expand_animations=False, dtype=tf.uint8
)
pad_mask_dict[key] = True
elif obs[key].dtype == tf.uint8:
pad_mask_dict[key] = True
pass
else:
raise ValueError(
f"Unsupported image dtype: found {key} with dtype {obs[key].dtype}"
Expand All @@ -112,20 +125,16 @@ def _decode_images(obs: dict) -> dict:
if tf.strings.length(obs[key]) == 0:
# this is a padding image
obs[key] = tf.zeros((1, 1), dtype=tf.float32)
pad_mask_dict[key] = False
else:
obs[key] = tf.io.decode_image(
obs[key], expand_animations=False, dtype=tf.float32
)[..., 0]
pad_mask_dict[key] = True
elif obs[key].dtype == tf.float32:
pad_mask_dict[key] = True
pass
else:
raise ValueError(
f"Unsupported depth dtype: found {key} with dtype {obs[key].dtype}"
)

obs["pad_mask_dict"] = pad_mask_dict
return obs


Expand Down Expand Up @@ -204,6 +213,7 @@ def apply_trajectory_transforms(
tf.math.abs(x["observation"]["proprio"]) <= max_proprio
)
)
dataset = dataset.map(_add_pad_mask_dict, num_parallel_calls)

# adds the "tasks" key
if goal_relabeling_strategy is not None:
Expand All @@ -217,6 +227,9 @@ def apply_trajectory_transforms(

def move_language_instruction_to_tasks(traj):
traj["tasks"]["language_instruction"] = traj.pop("language_instruction")
traj["tasks"]["pad_mask_dict"]["language_instruction"] = (
tf.strings.length(traj["tasks"]["language_instruction"]) != 0
)
return traj

dataset = dataset.map(move_language_instruction_to_tasks, num_parallel_calls)
Expand Down
61 changes: 5 additions & 56 deletions orca/data/utils/task_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,62 +8,6 @@
import tensorflow as tf


def drop_keys_independent(
traj: Dict[str, Any],
drop_key_groups_probs: List[Tuple[List[str], float]],
allow_drop_all: bool = False,
) -> Dict[str, Any]:
"""
Independently drop keys in the tasks dictionary.
:param traj: A dictionary containing trajectory data. should have a "tasks" key.
:param drop_key_groups_probs: A list of tuples, where each tuple contains a list of keys and a dropout probability.
:param allow_drop_all: If True, allow dropping all keys. Otherwise, if all keys are dropped, return the original
:return: A dictionary with keys dropped out according to the specified probabilities.
"""

# don't drop keys if there is no language instruction
if tf.math.reduce_all(traj["tasks"]["language_instruction"] == ""):
return traj

tasks = traj["tasks"]
new_tasks = tasks.copy()
dropped_all = True
image_keys = [key for key in tasks.keys() if "image" in key]

for key_group, prob in drop_key_groups_probs:
if not all(key in tasks for key in key_group):
raise KeyError(
f"keys {key_group} are not all present in tasks dictionary. tasks keys: {tasks.keys()}"
)

drop_group = tf.random.uniform([]) < prob
dropped_all = dropped_all and drop_group

# When no goal images are present, the goal timestep becomes the final timestep
if all([image_key in key_group for image_key in image_keys]):
new_tasks["goal_timestep"] = tf.where(
drop_group,
tasks["end_timestep"],
tasks["goal_timestep"],
)

for key in key_group:
new_tasks[key] = tf.where(
drop_group,
tf.zeros_like(tasks[key])
if tf.debugging.is_numeric_tensor(tasks[key])
else "",
tasks[key],
)

if not allow_drop_all and dropped_all:
return traj

traj["tasks"] = new_tasks
return traj


def delete_task_conditioning(
traj: Dict[str, Any],
delete_key_groups_probs: List[Tuple[List[str], float]],
Expand Down Expand Up @@ -109,6 +53,11 @@ def delete_task_conditioning(
else "",
tasks[key],
)
new_tasks["pad_mask_dict"][key] = tf.where(
i == delete_group_idx,
tf.zeros_like(tasks["pad_mask_dict"][key]),
new_tasks["pad_mask_dict"][key],
)

traj["tasks"] = new_tasks
return traj
12 changes: 9 additions & 3 deletions orca/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def create_model_def(
max_horizon,
transformer_kwargs,
heads,
proper_pad_mask=False,
**kwargs,
):
"""
Expand All @@ -37,13 +38,18 @@ def create_model_def(
"""
if len(kwargs) > 0:
logging.warn(f"Extra kwargs passed into create_model_def: {kwargs}")

if proper_pad_mask:
logging.warn("Using proper_pad_mask=True")
observation_tokenizer_defs = {
k: TOKENIZERS.get(info["cls_name"])(**info["kwargs"])
k: TOKENIZERS.get(info["cls_name"])(
**info["kwargs"], proper_pad_mask=proper_pad_mask
)
for k, info in observation_tokenizers.items()
}
task_tokenizer_defs = {
k: TOKENIZERS.get(info["cls_name"])(**info["kwargs"])
k: TOKENIZERS.get(info["cls_name"])(
**info["kwargs"], proper_pad_mask=proper_pad_mask
)
for k, info in task_tokenizers.items()
}

Expand Down
Loading

0 comments on commit 6df0a40

Please sign in to comment.