diff --git a/config.py b/config.py index 837a9664..27bc4121 100644 --- a/config.py +++ b/config.py @@ -93,12 +93,12 @@ def get_config( entity=placeholder(str), ), wandb_resume_id=placeholder(str), - eval_datasets=[ + eval_datasets=( "bridge_dataset", "taco_play", "berkeley_cable_routing", "berkeley_autolab_ur5", - ], + ), ) ) @@ -230,6 +230,7 @@ def get_model_config(transformer_size): return { **get_transformer_kwargs(transformer_size), + "proper_pad_mask": True, "max_horizon": 10, "readouts": dict(), "heads": dict( diff --git a/orca/data/dataset.py b/orca/data/dataset.py index 85bb7908..921dca08 100644 --- a/orca/data/dataset.py +++ b/orca/data/dataset.py @@ -86,23 +86,36 @@ def _subsample(traj, subsample_length): return traj +def _add_pad_mask_dict(traj): + """Adds a dictionary indicating which elements of the observation are padding. + + traj["observation"]["pad_mask_dict"] = {k: traj["observation"][k] is not padding} + """ + traj_len = tf.shape(traj["action"])[0] + pad_masks = {} + for key in traj["observation"]: + if traj["observation"][key].dtype == tf.string: + pad_masks[key] = tf.strings.length(traj["observation"][key]) != 0 + else: + pad_masks[key] = tf.ones([traj_len], dtype=tf.bool) + traj["observation"]["pad_mask_dict"] = pad_masks + return traj + + def _decode_images(obs: dict) -> dict: - """Decodes images and depth images, marking empty strings as padding.""" - pad_mask_dict = {} # indicates which keys in the dict are padding + """Decodes images and depth images.""" for key in obs: if "image" in key: if obs[key].dtype == tf.string: if tf.strings.length(obs[key]) == 0: # this is a padding image obs[key] = tf.zeros((1, 1, 3), dtype=tf.uint8) - pad_mask_dict[key] = False else: obs[key] = tf.io.decode_image( obs[key], expand_animations=False, dtype=tf.uint8 ) - pad_mask_dict[key] = True elif obs[key].dtype == tf.uint8: - pad_mask_dict[key] = True + pass else: raise ValueError( f"Unsupported image dtype: found {key} with dtype {obs[key].dtype}" @@ -112,20 +125,16 @@ def _decode_images(obs: dict) -> dict: if tf.strings.length(obs[key]) == 0: # this is a padding image obs[key] = tf.zeros((1, 1), dtype=tf.float32) - pad_mask_dict[key] = False else: obs[key] = tf.io.decode_image( obs[key], expand_animations=False, dtype=tf.float32 )[..., 0] - pad_mask_dict[key] = True elif obs[key].dtype == tf.float32: - pad_mask_dict[key] = True + pass else: raise ValueError( f"Unsupported depth dtype: found {key} with dtype {obs[key].dtype}" ) - - obs["pad_mask_dict"] = pad_mask_dict return obs @@ -204,6 +213,7 @@ def apply_trajectory_transforms( tf.math.abs(x["observation"]["proprio"]) <= max_proprio ) ) + dataset = dataset.map(_add_pad_mask_dict, num_parallel_calls) # adds the "tasks" key if goal_relabeling_strategy is not None: @@ -217,6 +227,9 @@ def apply_trajectory_transforms( def move_language_instruction_to_tasks(traj): traj["tasks"]["language_instruction"] = traj.pop("language_instruction") + traj["tasks"]["pad_mask_dict"]["language_instruction"] = ( + tf.strings.length(traj["tasks"]["language_instruction"]) != 0 + ) return traj dataset = dataset.map(move_language_instruction_to_tasks, num_parallel_calls) diff --git a/orca/data/utils/task_augmentation.py b/orca/data/utils/task_augmentation.py index 90c405f0..b63c81d9 100644 --- a/orca/data/utils/task_augmentation.py +++ b/orca/data/utils/task_augmentation.py @@ -8,62 +8,6 @@ import tensorflow as tf -def drop_keys_independent( - traj: Dict[str, Any], - drop_key_groups_probs: List[Tuple[List[str], float]], - allow_drop_all: bool = False, -) -> Dict[str, Any]: - """ - Independently drop keys in the tasks dictionary. - - :param traj: A dictionary containing trajectory data. should have a "tasks" key. - :param drop_key_groups_probs: A list of tuples, where each tuple contains a list of keys and a dropout probability. - :param allow_drop_all: If True, allow dropping all keys. Otherwise, if all keys are dropped, return the original - :return: A dictionary with keys dropped out according to the specified probabilities. - """ - - # don't drop keys if there is no language instruction - if tf.math.reduce_all(traj["tasks"]["language_instruction"] == ""): - return traj - - tasks = traj["tasks"] - new_tasks = tasks.copy() - dropped_all = True - image_keys = [key for key in tasks.keys() if "image" in key] - - for key_group, prob in drop_key_groups_probs: - if not all(key in tasks for key in key_group): - raise KeyError( - f"keys {key_group} are not all present in tasks dictionary. tasks keys: {tasks.keys()}" - ) - - drop_group = tf.random.uniform([]) < prob - dropped_all = dropped_all and drop_group - - # When no goal images are present, the goal timestep becomes the final timestep - if all([image_key in key_group for image_key in image_keys]): - new_tasks["goal_timestep"] = tf.where( - drop_group, - tasks["end_timestep"], - tasks["goal_timestep"], - ) - - for key in key_group: - new_tasks[key] = tf.where( - drop_group, - tf.zeros_like(tasks[key]) - if tf.debugging.is_numeric_tensor(tasks[key]) - else "", - tasks[key], - ) - - if not allow_drop_all and dropped_all: - return traj - - traj["tasks"] = new_tasks - return traj - - def delete_task_conditioning( traj: Dict[str, Any], delete_key_groups_probs: List[Tuple[List[str], float]], @@ -109,6 +53,11 @@ def delete_task_conditioning( else "", tasks[key], ) + new_tasks["pad_mask_dict"][key] = tf.where( + i == delete_group_idx, + tf.zeros_like(tasks["pad_mask_dict"][key]), + new_tasks["pad_mask_dict"][key], + ) traj["tasks"] = new_tasks return traj diff --git a/orca/model/__init__.py b/orca/model/__init__.py index 1dd63c64..55c7c71c 100644 --- a/orca/model/__init__.py +++ b/orca/model/__init__.py @@ -13,6 +13,7 @@ def create_model_def( max_horizon, transformer_kwargs, heads, + proper_pad_mask=False, **kwargs, ): """ @@ -37,13 +38,18 @@ def create_model_def( """ if len(kwargs) > 0: logging.warn(f"Extra kwargs passed into create_model_def: {kwargs}") - + if proper_pad_mask: + logging.warn("Using proper_pad_mask=True") observation_tokenizer_defs = { - k: TOKENIZERS.get(info["cls_name"])(**info["kwargs"]) + k: TOKENIZERS.get(info["cls_name"])( + **info["kwargs"], proper_pad_mask=proper_pad_mask + ) for k, info in observation_tokenizers.items() } task_tokenizer_defs = { - k: TOKENIZERS.get(info["cls_name"])(**info["kwargs"]) + k: TOKENIZERS.get(info["cls_name"])( + **info["kwargs"], proper_pad_mask=proper_pad_mask + ) for k, info in task_tokenizers.items() } diff --git a/orca/model/components/action_heads.py b/orca/model/components/action_heads.py index c2548e85..fa28a8d6 100644 --- a/orca/model/components/action_heads.py +++ b/orca/model/components/action_heads.py @@ -1,5 +1,5 @@ # adapted from https://github.com/google-research/robotics_transformer/blob/master/transformer_network.py -from typing import Optional +from typing import Dict, Optional import distrax from einops import rearrange @@ -7,6 +7,7 @@ import jax import jax.numpy as jnp +from orca.model.components.base import TokenGroup from orca.model.components.tokenizers import BinTokenizer from orca.model.components.transformer import MAPHead from orca.utils.typing import PRNGKey @@ -53,39 +54,39 @@ def setup(self): else self.normalization_type, ) - def __call__(self, embeddings, train=True) -> jax.Array: + def __call__( + self, transformer_outputs: Dict[str, TokenGroup], train=True + ) -> jax.Array: """ Args: - embeddings: jnp.ndarray w/ shape (batch_size, horizon, n_tokens, embedding_size) + transformer_outputs: Dict[str, TokenGroup] the output of an OrcaTransformer """ - if isinstance(embeddings, dict): - assert ( - self.readout_key is not None - ), "Must specify readout_key if passing in a dictionary of OrcaTransformer embeddings" - embeddings = embeddings[self.readout_key] - batch_size, horizon, n_tokens, embedding_size = embeddings.shape + assert self.readout_key is not None + token_group = transformer_outputs[self.readout_key] + assert token_group.tokens.ndim == 4, ( + f"Expected token_group.tokens to have shape (batch_size, horizon, num_tokens, embedding_size), " + f"but got shape {token_group.tokens.shape}" + ) if self.use_map: - embeddings = self.map_head(embeddings, train=train)[:, :, 0] + embeddings = self.map_head(token_group, train=train)[:, :, 0] else: - embeddings = embeddings.mean(axis=-2) + embeddings = token_group.tokens.mean(axis=-2) # Now, embeddings is (batch_size, horizon, embedding_size) - logits = self.vocab_proj( - embeddings - ) # (batch_size, horizon, vocab_size * pred_horizon * action_dim) - logits = jnp.reshape( + # (batch_size, horizon, pred_horizon * action_dim * vocab_size) + logits = self.vocab_proj(embeddings) + logits = rearrange( logits, - ( - batch_size, - horizon, - self.pred_horizon, - self.action_dim, - self.vocab_size, - ), + "b h (p a d) -> b h p a d", + p=self.pred_horizon, + a=self.action_dim, + d=self.vocab_size, ) return logits - def loss(self, embeddings, actions, pad_mask, train=True): + def loss( + self, transformer_outputs: Dict[str, TokenGroup], actions, pad_mask, train=True + ): """ Args: embeddings: jnp.ndarray w/ shape (batch_size, horizon, num_tokens, embedding_size) @@ -101,7 +102,7 @@ def loss(self, embeddings, actions, pad_mask, train=True): # unfolding the pred_horizon dim, and projecting to the vocab size # (batch, horizon, pred_horizon, action_dim, token_embedding_size) - action_logits = self.__call__(embeddings, train=train) + action_logits = self.__call__(transformer_outputs, train=train) horizon = action_logits.shape[1] assert ( @@ -155,7 +156,7 @@ def loss(self, embeddings, actions, pad_mask, train=True): def predict_action( self, - embeddings, + transformer_outputs: Dict[str, TokenGroup], train: bool = True, argmax: bool = False, sample_shape: tuple = (), @@ -166,7 +167,7 @@ def predict_action( # unfolding the pred_horizon dim, and projecting to the vocab size # (batch, tokens_per_action, token_embedding_size) - action_logits = self.__call__(embeddings, train=train) + action_logits = self.__call__(transformer_outputs, train=train) action_logits = action_logits[:, -1] if argmax: @@ -220,24 +221,26 @@ def setup(self): else self.normalization_type, ) - def __call__(self, embeddings, train=True) -> jax.Array: + def __call__( + self, transformer_outputs: Dict[str, TokenGroup], train=True + ) -> jax.Array: """ Args: embeddings: jnp.ndarray w/ shape (batch_size, horizon, n_tokens, embedding_size) """ - if isinstance(embeddings, dict): - assert ( - self.readout_key is not None - ), "Must specify readout_key if passing in a dictionary of OrcaTransformer embeddings" - embeddings = embeddings[self.readout_key] - - batch_size, horizon, n_tokens, embedding_size = embeddings.shape + assert self.readout_key is not None + token_group = transformer_outputs[self.readout_key] + assert token_group.tokens.ndim == 4, ( + f"Expected token_group.tokens to have shape (batch_size, horizon, num_tokens, embedding_size), " + f"but got shape {token_group.tokens.shape}" + ) if self.use_map: - embeddings = self.map_head(embeddings, train=train) + embeddings = self.map_head(token_group, train=train) else: - assert n_tokens == self.pred_horizon * self.action_dim + embeddings = token_group.tokens + assert embeddings.shape[-2] == self.pred_horizon * self.action_dim - # embeddings is now (batch_size, horizon, pred_horizon * action_dim, vocab_size) + # Now, embeddings is (batch_size, horizon, pred_horizon * action_dim, embedding_size) logits = self.vocab_proj(embeddings) logits = rearrange( logits, "b h (p a) d -> b h p a d", p=self.pred_horizon, a=self.action_dim @@ -264,20 +267,24 @@ def setup(self): self.map_head = MAPHead() self.mean_proj = nn.Dense(self.pred_horizon * self.action_dim) - def __call__(self, embeddings, train=True) -> jax.Array: + def __call__( + self, transformer_outputs: Dict[str, TokenGroup], train=True + ) -> jax.Array: """ Args: embeddings: jnp.ndarray w/ shape (batch_size, horizon, n_tokens, embedding_size) """ - if isinstance(embeddings, dict): - assert ( - self.readout_key is not None - ), "Must specify readout_key if passing in a dictionary of OrcaTransformer embeddings" - embeddings = embeddings[self.readout_key] + assert self.readout_key is not None + token_group = transformer_outputs[self.readout_key] + assert token_group.tokens.ndim == 4, ( + f"Expected token_group.tokens to have shape (batch_size, horizon, num_tokens, embedding_size), " + f"but got shape {token_group.tokens.shape}" + ) if self.use_map: - embeddings = self.map_head(embeddings, train=train)[:, :, 0] + embeddings = self.map_head(token_group, train=train)[:, :, 0] else: - embeddings = embeddings.mean(axis=-2) + embeddings = token_group.tokens.mean(axis=-2) + # Now, embeddings is (batch_size, horizon, embedding_size) mean = self.mean_proj(embeddings) mean = rearrange( mean, "b h (p a) -> b h p a", p=self.pred_horizon, a=self.action_dim @@ -285,7 +292,9 @@ def __call__(self, embeddings, train=True) -> jax.Array: mean = jnp.tanh(mean / self.max_action) * self.max_action return mean - def loss(self, embeddings, actions, pad_mask, train=True): + def loss( + self, transformer_outputs: Dict[str, TokenGroup], actions, pad_mask, train=True + ): """ Trains the mean head with MSE and the logstd head with KL divergence. @@ -299,7 +308,7 @@ def loss(self, embeddings, actions, pad_mask, train=True): metrics: dict """ # (batch, horizon, pred_horizon, action_dim) - mean = self.__call__(embeddings, train=train) + mean = self.__call__(transformer_outputs, train=train) horizon = mean.shape[1] assert ( @@ -331,7 +340,7 @@ def loss(self, embeddings, actions, pad_mask, train=True): def predict_action( self, - embeddings, + transformer_outputs: Dict[str, TokenGroup], train: bool = True, argmax: bool = False, sample_shape: tuple = (), @@ -341,7 +350,7 @@ def predict_action( # get the logits for the last action by taking the action tokens of the last timestep, # unfolding the pred_horizon dim, and projecting to the vocab size # (batch, tokens_per_action, token_embedding_size) - mean = self.__call__(embeddings, train=train) + mean = self.__call__(transformer_outputs, train=train) mean = mean[:, -1] logstd = jnp.full_like(mean, -10.0) diff --git a/orca/model/components/base.py b/orca/model/components/base.py new file mode 100644 index 00000000..1a58c6d1 --- /dev/null +++ b/orca/model/components/base.py @@ -0,0 +1,33 @@ +import flax +import jax +import jax.numpy as jnp + +from orca.utils.typing import Sequence + + +@flax.struct.dataclass +class TokenGroup: + """A group of tokens that have semantic meaning together (e.g. the tokens for a single observation) + + Attributes: + tokens: jax.Array of shape (..., n_tokens, token_dim) + mask: jax.Array of shape (..., n_tokens) indicating which tokens are valid (1) vs padding (0) + """ + + tokens: jax.typing.ArrayLike + mask: jax.typing.ArrayLike + + @classmethod + def create( + cls, tokens: jax.typing.ArrayLike, mask: jax.typing.ArrayLike = None, **kwargs + ): + if mask is None: + mask = jnp.ones(tokens.shape[:-1]) + assert mask.ndim == tokens.ndim - 1 + return cls(tokens, mask, **kwargs) + + @classmethod + def concatenate(cls, group_list: Sequence["TokenGroup"], axis=-2): + data = jnp.concatenate([t.tokens for t in group_list], axis=axis) + mask = jnp.concatenate([t.mask for t in group_list], axis=axis + 1) + return cls(data, mask) diff --git a/orca/model/components/block_transformer.py b/orca/model/components/block_transformer.py index a9bf6f2a..2b673167 100644 --- a/orca/model/components/block_transformer.py +++ b/orca/model/components/block_transformer.py @@ -1,8 +1,8 @@ # Written by Dibya -from dataclasses import asdict, dataclass, replace from enum import Enum +from fnmatch import fnmatch import logging -from typing import Mapping, Tuple +from typing import Any, Dict, Mapping, Sequence, Tuple, Union import einops import flax @@ -11,8 +11,8 @@ import jax.numpy as jnp import numpy as np +from orca.model.components.base import TokenGroup from orca.model.components.transformer import Transformer -from orca.utils.typing import Sequence, Union class AttentionRule(Enum): @@ -27,35 +27,57 @@ class AttentionRule(Enum): ALL = "all" # Breaks causal structure! Be careful -@dataclass -class PrefixGroup: - """A group of tokens that will be at the beginning of the token sequence. (e.g. task tokens)""" +@flax.struct.dataclass +class PrefixGroup(TokenGroup): + """A group of tokens that will be at the beginning of the token sequence. (e.g. task tokens) + + Adds a name identifying the group, and a dictionary indicating what groups it should attend to. + + name (str): Name of the group, which other groups will look at when deciding whether to attend to this group. + attention_rules (Dict[str, AttentionRule]): A dictionary of {pattern: AttentionRule} where the attention rule + is recovered by fnmatch-ing the name of the other group until a match is found (or the end). + """ name: str - tokens: jax.typing.ArrayLike # with shape (batch, n_tokens, token_embedding_size) attention_rules: Mapping[str, AttentionRule] def __post_init__(self): - assert self.tokens.ndim == 3, "PrefixGroup tokens must be (batch, n_tokens, d)" + assert ( + len(self.tokens.shape) == 3 + ), "PrefixGroup tokens must be (batch, n_tokens, d)" + assert len(self.mask.shape) == 2, "PrefixGroup mask must be (batch, n_tokens)" -@dataclass -class TimestepGroup: - """A group of tokens that is repeated for each timestep. (e.g. observation tokens)""" +@flax.struct.dataclass +class TimestepGroup(TokenGroup): + """A group of tokens that is repeated for each timestep. (e.g. observation tokens) - name: str - tokens: jax.typing.ArrayLike # with shape (batch, horizon, n_tokens, token_embedding_size) - attention_rules: Mapping[str, AttentionRule] + See PrefixGroup for details on the name and attention_rules fields. + """ + + name: str = flax.struct.field(pytree_node=False) + attention_rules: Mapping[str, AttentionRule] = flax.struct.field(pytree_node=False) def __post_init__(self): assert ( - self.tokens.ndim == 4 - ), "TimestepGroup tokens must be (batch, horizon, n_tokens, d))" + len(self.tokens.shape) == 4 + ), "TimestepGroup tokens must be (batch, horizon, n_tokens, d)" + assert ( + len(self.mask.shape) == 3 + ), "TimestepGroup mask must be (batch, horizon, n_tokens)" + + +def find_match(pattern_dict: Dict[str, Any], name: str, default: Any) -> Any: + """Find the first matching pattern in the dictionary, or return the default value.""" + for pattern, value in pattern_dict.items(): + if fnmatch(name, pattern): + return value + return default -@dataclass +@flax.struct.dataclass class TokenMetadata: - """Useful metadata for computing attention masks. Note that all tokens within the + """Attention mask logic supported by AttentionRule. Note that all tokens within the same group at the same timestep always attend to each other unless you explicitly have attention_rules[self.name] = AttentionRule.NEVER """ @@ -66,16 +88,15 @@ class TokenMetadata: @classmethod def create(cls, group: Union[PrefixGroup, TimestepGroup], timestep: int): - group_dict = asdict(group) - group_dict.pop("tokens") return cls( timestep=timestep, - **group_dict, + name=group.name, + attention_rules=group.attention_rules, ) def should_attend_to(self, other_metadata: "TokenMetadata") -> bool: - attention_rule = self.attention_rules.get( - other_metadata.name, AttentionRule.NEVER + attention_rule = find_match( + self.attention_rules, other_metadata.name, AttentionRule.NEVER ) if attention_rule == AttentionRule.CAUSAL: @@ -92,20 +113,16 @@ def should_attend_to(self, other_metadata: "TokenMetadata") -> bool: raise ValueError(f"Invalid attention rule: {attention_rule}") -def split_tokens(ary, n_tokens_per_group, axis): +def split_tokens(ary: jax.Array, n_tokens_per_group: Sequence[int], axis: int): cumsum = np.cumsum(n_tokens_per_group) return jnp.split(ary, cumsum, axis=axis) class BlockTransformer(nn.Module): - # Forwarded to Transformer - num_layers: int = 4 - mlp_dim: int = 1024 - num_attention_heads: int = 8 - dropout_rate: float = 0.1 - attention_dropout_rate: float = 0.1 - add_position_embedding: bool = False + """A transformer that acts on multiple groups of tokens, which may attend to each other (in complex patterns).""" + # Forwarded to Transformer + transformer_kwargs: Dict # Enforce that timestep causal structure is not broken (future timesteps can't attend to past timesteps) enforce_causal: bool = True @@ -114,19 +131,23 @@ def __call__( self, prefix_groups: Sequence[PrefixGroup], timestep_groups: Sequence[TimestepGroup], - timestep_pad_mask: jax.typing.ArrayLike, train: bool, verbose: bool = False, ) -> Tuple[Sequence[PrefixGroup], Sequence[TimestepGroup]]: """ Args: prefix_groups: A list of PrefixGroup objects. - Each group has tokens with shape (batch, n_tokens, token_embedding_size) - Each group also dictates which other groups it will attend to. + Each group has + - tokens with shape (batch, n_tokens, token_embedding_size) + - mask with shape (batch, n_tokens) indicating which tokens are padding. + - name identifying the group + - dictionary of attention patterns dictating which other groups it will attend to. timestep_groups: A list of TimestepGroup objects. - Each group has tokens with shape (batch, horizon, n_tokens, token_embedding_size) - Each group also dictates which other groups it will attend to. - timestep_pad_mask: A boolean mask of shape (batch, horizon) indicating which timesteps are padding. + Each group has + - tokens with shape (batch, horizon, n_tokens, token_embedding_size) + - mask with shape (batch, horizon, n_tokens) indicating which tokens are padding. + - name identifying the group + - dictionary of attention patterns dictating which other groups it will attend to. train: Whether to use dropout. Returns: @@ -143,62 +164,26 @@ def __call__( assert all([group.tokens.shape[-1] == token_dim for group in prefix_groups]) assert all([group.tokens.shape[-1] == token_dim for group in timestep_groups]) - # Creates correct attention mask for transformer using group attention rules - attention_mask = self.generate_attention_mask( - prefix_groups, timestep_groups, timestep_pad_mask - ) - # Assemble input tokens (batch, total_tokens, token_embedding_size) input_tokens = self.assemble_input_tokens(prefix_groups, timestep_groups) - # Run transformer - transformer = Transformer( - num_layers=self.num_layers, - mlp_dim=self.mlp_dim, - num_heads=self.num_attention_heads, - dropout_rate=self.dropout_rate, - attention_dropout_rate=self.attention_dropout_rate, - add_position_embedding=self.add_position_embedding, - ) - output = transformer(input_tokens, attention_mask, train=train) - - # Split output into prefix and timestep groups - - tokens_per_prefix_group = [group.tokens.shape[1] for group in prefix_groups] - n_prefix_tokens = sum(tokens_per_prefix_group) + # Creates correct attention mask for transformer using group attention rules and masks + # Shape: (batch, 1, total_tokens, total_tokens) + attention_mask = self.generate_attention_mask(prefix_groups, timestep_groups) - prefix_embeddings, timestep_embeddings = jnp.split( - output, [n_prefix_tokens], axis=1 - ) + # Sows attention mask for ease of retrieval when debugging + # https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.sow + self.sow("intermediates", "attention_mask", attention_mask) - # Process prefix group outputs - if len(prefix_groups) > 0: - prefix_embeddings_split = split_tokens( - prefix_embeddings, tokens_per_prefix_group, axis=1 - ) - all_prefix_outputs = [ - replace(group, tokens=embeddings) - for group, embeddings in zip(prefix_groups, prefix_embeddings_split) - ] - else: - all_prefix_outputs = [] - - # Process timestep group outputs - timestep_embeddings = einops.rearrange( - timestep_embeddings, - "batch (horizon n_tokens) d -> batch horizon n_tokens d", - horizon=horizon, + # Run transformer + output = Transformer(**self.transformer_kwargs)( + input_tokens, attention_mask, train=train ) - tokens_per_timestep_group = [group.tokens.shape[2] for group in timestep_groups] - timestep_embeddings_split = split_tokens( - timestep_embeddings, tokens_per_timestep_group, axis=2 + # Split output into prefix and timestep groups + all_prefix_outputs, all_timestep_outputs = self.split_output_tokens( + output, prefix_groups, timestep_groups ) - - all_timestep_outputs = [ - replace(group, tokens=embeddings) - for group, embeddings in zip(timestep_groups, timestep_embeddings_split) - ] return all_prefix_outputs, all_timestep_outputs def assemble_input_tokens( @@ -238,46 +223,76 @@ def assemble_input_tokens( tokens = jnp.concatenate([all_prefix_tokens, all_timestep_tokens], axis=1) return tokens + def split_output_tokens( + self, + output_tokens: jax.Array, + prefix_groups: Sequence[PrefixGroup], + timestep_groups: Sequence[TimestepGroup], + ): + """Reverses the process of assemble_input_tokens.""" + + horizon = timestep_groups[0].tokens.shape[1] + tokens_per_prefix_group = [group.tokens.shape[1] for group in prefix_groups] + n_prefix_tokens = sum(tokens_per_prefix_group) + + prefix_embeddings, timestep_embeddings = jnp.split( + output_tokens, [n_prefix_tokens], axis=1 + ) + + # Process prefix group outputs + if len(prefix_groups) > 0: + prefix_embeddings_split = split_tokens( + prefix_embeddings, tokens_per_prefix_group, axis=1 + ) + all_prefix_outputs = [ + group.replace(tokens=embeddings) + for group, embeddings in zip(prefix_groups, prefix_embeddings_split) + ] + else: + all_prefix_outputs = [] + + # Process timestep group outputs + timestep_embeddings = einops.rearrange( + timestep_embeddings, + "batch (horizon n_tokens) d -> batch horizon n_tokens d", + horizon=horizon, + ) + + tokens_per_timestep_group = [group.tokens.shape[2] for group in timestep_groups] + timestep_embeddings_split = split_tokens( + timestep_embeddings, tokens_per_timestep_group, axis=2 + ) + + all_timestep_outputs = [ + group.replace(tokens=embeddings) + for group, embeddings in zip(timestep_groups, timestep_embeddings_split) + ] + return all_prefix_outputs, all_timestep_outputs + def generate_attention_mask( self, prefix_groups: Sequence[PrefixGroup], timestep_groups: Sequence[TimestepGroup], - timestep_pad_mask: jax.typing.ArrayLike, ): """ Args: prefix_groups: A list of PrefixGroup objects. timestep_groups: A list of TimestepGroup objects. - pad_mask: A boolean mask of shape (batch, horizon) indicating which timesteps are padding. Returns: - attention_mask: A boolean mask of shape (batch, num_heads, total_tokens, total_tokens) + attention_mask: A boolean mask of shape (batch, 1, total_tokens, total_tokens) - We use the attention rules within each group to determine the transformer attention mask. + We use the attention rules specified by each group to determine the transformer attention mask. + We then combine this with the padding mask to ensure that padding tokens are not attended to. """ if self.enforce_causal: - # First verify that prefix group isn't attending to any timestep group - for prefix_group in prefix_groups: - for ts_group in timestep_groups: - assert ( - prefix_group.attention_rules.get( - ts_group.name, AttentionRule.NEVER - ) - == AttentionRule.NEVER - ), f"Causality broken! Prefix group {prefix_group.name} is attending to timestep group {ts_group.name}" - - # Next, make sure that timestep groups aren't attending to future timesteps - for group in prefix_groups + timestep_groups: - for rule in group.attention_rules.values(): - assert ( - rule != AttentionRule.ALL - ), "Causality broken! WhenToAttend.ALL attends to future timesteps too." + self.verify_causality(prefix_groups, timestep_groups) def _get_position(i, tokens_per_elem): return np.searchsorted(np.cumsum(tokens_per_elem), i) - horizon = timestep_pad_mask.shape[1] + horizon = timestep_groups[0].tokens.shape[1] tokens_per_prefix_group = [group.tokens.shape[1] for group in prefix_groups] tokens_per_timestep_group = [group.tokens.shape[2] for group in timestep_groups] @@ -305,46 +320,83 @@ def get_token_metadata(i): attention_mask[i, j] = mask pad_attention_mask = self.generate_pad_attention_mask( - timestep_pad_mask, tokens_per_time_step, tokens_for_prefix + prefix_groups, timestep_groups ) attention_mask = jnp.logical_and(attention_mask, pad_attention_mask) return attention_mask def generate_pad_attention_mask( self, - timestep_pad_mask: jax.typing.ArrayLike, - tokens_per_time_step: int, - tokens_for_prefix: int, + prefix_groups: Sequence[PrefixGroup], + timestep_groups: Sequence[TimestepGroup], ): """ - Generate attention mask that ignores padding. `timestep_pad_mask` has shape (batch, horizon) and - records which time steps are padding. We first expand the mask to shape (batch, horizon * tokens_per_time_step) - and then prepend a mask for the task prefix to get shape (batch, total_tokens). - We broadcast to (batch, num_heads, total_tokens, total_tokens). + Generate a nn.MultiHeadDotProductAttention mask that ignores padding by masks from all timestep groups, + unfold the horizon dim, and concatenate with all the prefix group masks. + We broadcast this (batch, total_tokens) mask to the requisite (batch, 1, total_tokens, total_tokens). """ - batch_size, horizon = timestep_pad_mask.shape - - total_tokens = tokens_for_prefix + tokens_per_time_step * horizon - sequence_mask = jnp.repeat(timestep_pad_mask, tokens_per_time_step, axis=1) - task_mask = jnp.ones((batch_size, tokens_for_prefix), dtype=int) - full_mask = jnp.concatenate([task_mask, sequence_mask], axis=1) - - full_mask = jnp.broadcast_to( - full_mask[:, None, None, :], + batch_size, horizon = timestep_groups[0].tokens.shape[:2] + if len(prefix_groups) > 0: + prefix_pad_mask = jnp.concatenate( + [group.mask for group in prefix_groups], axis=1 + ) + else: + prefix_pad_mask = jnp.zeros((batch_size, 0), dtype=jnp.bool_) + timestep_pad_mask = jnp.concatenate( + [group.mask for group in timestep_groups], axis=2 + ) + timestep_pad_mask = einops.rearrange( + timestep_pad_mask, + "batch horizon n_tokens -> batch (horizon n_tokens)", + ) + pad_mask = jnp.concatenate([prefix_pad_mask, timestep_pad_mask], axis=1) + # pad_mask has shape (batch, total_tokens) + pad_mask = jnp.broadcast_to( + pad_mask[:, None, None, :], ( - full_mask.shape[0], - self.num_attention_heads, - total_tokens, - total_tokens, + batch_size, + 1, + pad_mask.shape[1], + pad_mask.shape[1], ), ) - return full_mask + return pad_mask + + def verify_causality( + self, + prefix_groups: Sequence[PrefixGroup], + timestep_groups: Sequence[TimestepGroup], + ): + """Ensures that no token can attend to another token in a future timestep.""" + # First verify that prefix group isn't attending to any timestep group + for prefix_group in prefix_groups: + for ts_group in timestep_groups: + rule = find_match( + prefix_group.attention_rules, ts_group.name, AttentionRule.NEVER + ) + assert ( + prefix_group.attention_rules.get(ts_group.name, AttentionRule.NEVER) + == AttentionRule.NEVER + ), f"Causality broken! Prefix group {prefix_group.name} is attending to timestep group {ts_group.name}" + + # Next, make sure that nothing is attending to future timesteps + for group in prefix_groups + timestep_groups: + for other_group in prefix_groups + timestep_groups: + rule = find_match( + group.attention_rules, other_group.name, AttentionRule.NEVER + ) + assert ( + rule != AttentionRule.ALL + ), "Causality broken! WhenToAttend.ALL attends to future timesteps too." def pretty_print_attention_mask( self, prefix_groups: Sequence[PrefixGroup], timestep_groups: Sequence[TimestepGroup], ): + """ + Visualizes the attention patterns for each token group for debugging purposes. + """ logging.warning("Prefix groups:") for prefix_group in prefix_groups: logging.warning( diff --git a/orca/model/components/tokenizers.py b/orca/model/components/tokenizers.py index 32da6065..37f0d831 100644 --- a/orca/model/components/tokenizers.py +++ b/orca/model/components/tokenizers.py @@ -1,7 +1,7 @@ import functools as ft import logging import re -from typing import Sequence +from typing import Dict, Optional, Sequence import flax.linen as nn import jax @@ -9,9 +9,33 @@ from jax.scipy.stats import norm from orca.model.components import encoders +from orca.model.components.base import TokenGroup from orca.model.components.transformer import MAPHead EPS = 1e-6 +from dataclasses import dataclass +import os + + +def generate_proper_pad_mask( + tokens: jax.Array, + pad_mask_dict: Optional[Dict[str, jax.Array]], + keys: Sequence[str], +) -> jax.Array: + if pad_mask_dict is None: + logging.warning("No pad_mask_dict found. Nothing will be masked.") + return jnp.ones(tokens.shape[:-1]) + if not all([key in pad_mask_dict for key in keys]): + logging.warning( + f"pad_mask_dict missing keys {set(keys) - set(pad_mask_dict.keys())}." + "Nothing will be masked." + ) + return jnp.ones(tokens.shape[:-1]) + + pad_mask = jnp.stack([pad_mask_dict[key] for key in keys], axis=-1) + pad_mask = jnp.any(pad_mask, axis=-1) + pad_mask = jnp.broadcast_to(pad_mask[..., None], tokens.shape[:-1]) + return pad_mask class TokenLearner(nn.Module): @@ -38,6 +62,14 @@ def __call__(self, inputs, train: bool = True): return MAPHead(num_readouts=self.num_tokens)(x, train=train) +def regex_match(regex_keys, x): + return any([re.match(r_key, x) for r_key in regex_keys]) + + +def regex_filter(regex_keys, xs): + return list(filter(lambda x: regex_match(regex_keys, x), xs)) + + class ImageTokenizer(nn.Module): """Image tokenizer that encodes image stack into tokens with optional FiLM conditioning. @@ -59,6 +91,7 @@ class ImageTokenizer(nn.Module): obs_stack_keys: Sequence[str] = ("image_.*", "depth_.*") task_stack_keys: Sequence[str] = tuple() task_film_keys: Sequence[str] = tuple() + proper_pad_mask: bool = False @nn.compact def __call__( @@ -67,24 +100,30 @@ def __call__( tasks=None, train: bool = True, ): - def extract_inputs(regex_keys, inputs, check_spatial=False): + def extract_inputs(keys, inputs, check_spatial=False): extracted_outputs = [] - for r_key in regex_keys: - for key in filter(re.compile(r_key).match, sorted(inputs.keys())): - if check_spatial: - assert len(inputs[key].shape) >= 4 - extracted_outputs.append(inputs[key]) + for key in keys: + if check_spatial: + assert len(inputs[key].shape) >= 4 + extracted_outputs.append(inputs[key]) return jnp.concatenate(extracted_outputs, axis=-1) + obs_stack_keys = regex_filter(self.obs_stack_keys, sorted(observations.keys())) + if len(obs_stack_keys) == 0: + logging.info( + f"No image inputs matching {self.obs_stack_keys} were found." + "Skipping tokenizer entirely." + ) + assert self.proper_pad_mask, "Cannot skip unless using proper_pad_mask." + return None + # stack all spatial observation and task inputs - enc_inputs = extract_inputs( - self.obs_stack_keys, observations, check_spatial=True - ) + enc_inputs = extract_inputs(obs_stack_keys, observations, check_spatial=True) if tasks and self.task_stack_keys: - task_inputs = extract_inputs( - self.task_stack_keys, tasks, check_spatial=True - ) + task_stack_keys = regex_filter(self.task_stack_keys, sorted(tasks.keys())) + task_inputs = extract_inputs(task_stack_keys, tasks, check_spatial=True) task_inputs = task_inputs[:, None].repeat(enc_inputs.shape[1], axis=1) + # TODO: allow somehow for task inputs to be not provided... enc_inputs = jnp.concatenate([enc_inputs, task_inputs], axis=-1) b, t, h, w, c = enc_inputs.shape enc_inputs = jnp.reshape(enc_inputs, (b * t, h, w, c)) @@ -108,7 +147,16 @@ def extract_inputs(regex_keys, inputs, check_spatial=False): image_tokens = TokenLearner(num_tokens=self.num_tokens)( image_tokens, train=train ) - return image_tokens + + if self.proper_pad_mask: + pad_mask = generate_proper_pad_mask( + image_tokens, + observations.get("pad_mask_dict", None), + obs_stack_keys, + ) + else: + pad_mask = jnp.ones(image_tokens.shape[:-1]) + return TokenGroup(image_tokens, pad_mask) class LanguageTokenizer(nn.Module): @@ -123,6 +171,7 @@ class LanguageTokenizer(nn.Module): encoder: str = None finetune_encoder: bool = False + proper_pad_mask: bool = False def setup(self): if self.encoder is not None: @@ -140,10 +189,18 @@ def __call__( tasks=None, train: bool = True, ): - if self.encoder is not None: + if "language_instruction" not in tasks: + logging.warning("No language inputs found. Skipping tokenizer entirely.") + assert self.proper_pad_mask, "Cannot skip unless using proper pad mask." + return None + + if not isinstance(tasks["language_instruction"], jax.Array): + assert ( + self.encoder is not None + ), "Received language tokens but no encoder specified." tokens = self.hf_model(**tasks["language_instruction"]).last_hidden_state else: - # add a time dimension to language + # add a # tokens dimension to language if tasks["language_instruction"].ndim == 2: tokens = tasks["language_instruction"][:, None, :] else: @@ -152,7 +209,17 @@ def __call__( if not self.finetune_encoder: tokens = jax.lax.stop_gradient(tokens) - return tokens + # TODO: incorporate padding info from language tokens here too + if self.proper_pad_mask: + pad_mask = generate_proper_pad_mask( + tokens, + tasks.get("pad_mask_dict", None), + ("language_instruction",), + ) + else: + pad_mask = jnp.ones(tokens.shape[:-1]) + + return TokenGroup(tokens, pad_mask) class BinTokenizer(nn.Module): @@ -209,9 +276,18 @@ class LowdimObsTokenizer(BinTokenizer): obs_keys: Sequence[str] = tuple() discretize: bool = False + proper_pad_mask: bool = False def __call__(self, observations, *unused_args, **unused_kwargs): assert self.obs_keys, "Need to specify observation keys to tokenize." + if len(regex_filter(self.obs_keys, sorted(observations.keys()))) == 0: + logging.warning( + f"No observation inputs matching {self.obs_keys} were found." + "Skipping tokenizer entirely." + ) + assert self.proper_pad_mask, "Cannot skip unless using proper pad mask." + return None + tokenizer_inputs = [] for o_key in self.obs_keys: for key in filter(re.compile(o_key).match, sorted(observations.keys())): @@ -222,9 +298,11 @@ def __call__(self, observations, *unused_args, **unused_kwargs): tokenizer_inputs = jnp.concatenate(tokenizer_inputs, axis=-1) if self.discretize: tokenized_inputs = super().__call__(tokenizer_inputs) - return jax.nn.one_hot(tokenized_inputs, self.n_bins) + tokens = jax.nn.one_hot(tokenized_inputs, self.n_bins) else: - return tokenizer_inputs[..., None] + tokens = tokenizer_inputs[..., None] + mask = jnp.ones(tokens.shape[:-1]) + return TokenGroup(tokens, mask) TOKENIZERS = { diff --git a/orca/model/components/transformer.py b/orca/model/components/transformer.py index 91dd9a3c..aed0f7bb 100644 --- a/orca/model/components/transformer.py +++ b/orca/model/components/transformer.py @@ -5,7 +5,8 @@ import jax import jax.numpy as jnp -from orca.utils.typing import Dtype, PRNGKey, Shape +from orca.model.components.base import TokenGroup +from orca.utils.typing import Dtype, PRNGKey, Shape, Union class AddPositionEmbs(nn.Module): @@ -83,7 +84,12 @@ class MAPHead(nn.Module): num_readouts: int = 1 @nn.compact - def __call__(self, x, train=True): + def __call__(self, x: Union[jax.Array, TokenGroup], train=True): + if isinstance(x, TokenGroup): + x, mask = x.tokens, x.mask + else: + mask = None + *batch_dims, l, d = x.shape x = x.reshape(-1, l, d) batch_size = x.shape[0] @@ -95,9 +101,16 @@ def __call__(self, x, train=True): x.dtype, ) probe = jnp.tile(probe, [batch_size, 1, 1]) + + if mask is not None: + mask = mask.reshape(-1, l) + mask = jnp.broadcast_to( + mask[:, None, None, :], (batch_size, 1, self.num_readouts, l) + ) + out = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform() - )(probe, x) + )(probe, x, mask=mask) # TODO: dropout on head? y = nn.LayerNorm()(out) @@ -176,7 +189,7 @@ class Transformer(nn.Module): num_layers: int mlp_dim: int - num_heads: int + num_attention_heads: int dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1 add_position_embedding: bool = False @@ -208,7 +221,7 @@ def __call__(self, x, attention_mask, *, train): dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, name=f"encoderblock_{lyr}", - num_heads=self.num_heads, + num_heads=self.num_attention_heads, )(x, attention_mask, deterministic=not train) encoded = nn.LayerNorm(name="encoder_norm")(x) diff --git a/orca/model/orca_model.py b/orca/model/orca_model.py index f044fee3..0c35a7d8 100644 --- a/orca/model/orca_model.py +++ b/orca/model/orca_model.py @@ -1,10 +1,12 @@ # Written by Dibya +import logging from typing import Dict, Optional import flax.linen as nn import jax import jax.numpy as jnp +from orca.model.components.base import TokenGroup from orca.model.components.block_transformer import ( AttentionRule, BlockTransformer, @@ -13,8 +15,6 @@ ) from orca.utils.typing import Data, Sequence -posemb_init = nn.initializers.normal(stddev=0.02) - class OrcaTransformer(nn.Module): """ @@ -34,9 +34,9 @@ class OrcaTransformer(nn.Module): The transformer is a blockwise-causal transformer, where each timestep only attends to the same or previous timesteps. - When called, the module requests a set of readouts, and performs a forward pass of the transformer on the following sequence: + The model performs a forward pass of the transformer on the following sequence: - [ + [ , , , , ... , , , ... @@ -44,23 +44,26 @@ class OrcaTransformer(nn.Module): ] The observation tokens attend to the task prefix, and to all observation tokens in the same or previous timesteps. - Readouts attend to everything observation tokens do, but are not attended to by observation or task tokens. All - tokens within the same group and same timestep (e.g. "observation ts0 tokens") fully attend to each other. + Readouts provide a mechanism for "reading out" the information in the transformer. They can attend to the task prefix, + and causally to all previous observation tokens, but + - Readouts cannot attend to other readouts + - Observation and task tokens cannot attend to readouts By this design, each readout does not influence the computation happening in the task or observation tokens, - and each readout is **independent of one another**. This allows us to hot-swap in different + and each readout is **independent of one another**. This flexible design allows us to hot-swap in different readouts at any time (e.g. we can run with the action readout or the value readout or both at the same time). - Args: - observations_tokenizers (Sequence[nn.Module]): List of flax modules for tokenizing the observations. + observations_tokenizers (Dict[str, nn.Module]): Dictionary of flax modules for tokenizing the observations. The output of each tokenizer is concatenated to form the observation tokens. - task_tokenizers (Sequence[nn.Module]): List of flax modules for tokenizing the task. + task_tokenizers (Dict[str, nn.Module]): Dictionary of flax modules for tokenizing the task. The output of each tokenizer is concatenated to form the task token prefix. - readouts (Dict[str, int]): Dictionary of {readout_name: n_tokens_for_readout} - transformer_kwargs (Dict): Dictionary of kwargs to forward to BlockTransformer. - token_embedding_size (int): Dimension of the token embeddings (default: 512) - max_horizon (int): The maximum number of timesteps that the transformer can be run with. + readouts (Dict[str, int]): Dictionary of {readout_name: n_tokens_for_readout}. + transformer_kwargs (Dict): Dictionary of kwargs to forward to the Transformer. + token_embedding_size (int): Dimension of the token embeddings + max_horizon (int): The maximum number of timesteps that the transformer can be run with. Note that while the + transformer can be run with any horizon <= max_horizon, the model will only generate sane outputs for + horizon lengths smaller or equal to the pre-training horizon. """ observation_tokenizers: Dict[str, nn.Module] @@ -79,7 +82,7 @@ def __call__( readouts: Optional[Sequence[str]] = None, train: bool = False, verbose: bool = False, - ): + ) -> Dict[str, TokenGroup]: """ Args: observations: A dictionary containing observation data for a batch of trajectory windows. @@ -92,20 +95,22 @@ def __call__( verbose: If True, prints out the transformer structure. Returns: - embedding_dict: A dictionary { - **{readout_name: embedding of shape (batch, horizon, n_tokens_for_readout, token_embedding_size) for k in readouts}, - also includes the outputs corresponding to the task and observation tokens (although this probably isn't as useful) - } + transformer_outputs: A dictionary {token_group_name: token_group}, + which contain the transformer embeddings for all observation tokens, task tokens, and readout tokens. + The special keys "task" and "obs" contain the concatenated embeddings for all task tokens and observation tokens, respectively. Note: Horizon can be anything <= max_horizon. """ if readouts is None: readouts = list(self.readouts.keys()) + # # Check that all inputs are valid + # + assert set(readouts).issubset( set(self.readouts.keys()) - ), "readouts must be a subset of those specified in the model config" + ), "readouts must be specified in the model config" batch_size, horizon = jax.tree_util.tree_leaves(observations)[0].shape[:2] assert horizon <= self.max_horizon, "horizon must be <= max_horizon" @@ -113,83 +118,111 @@ def __call__( jax.tree_map(lambda x: x.shape[1] == horizon, observations) ), "observations must have the same horizon" + # # Create inputs for the transformer + # + all_prefix_groups = [] all_timestep_groups = [] - all_task_names = [f"task_{name}" for name in self.task_tokenizers] - all_obs_names = [f"obs_{name}" for name in self.observation_tokenizers] - - task_attention_rules = { - task_name: AttentionRule.CAUSAL for task_name in all_task_names - } # Tasks attend to all other tasks + # Tasks attend to all other tasks, but not to observations or readouts + task_attention_rules = {"task_*": AttentionRule.CAUSAL} + # Observations attend to all tasks and previous observations causally, but not to readouts observation_attention_rules = { - name: AttentionRule.CAUSAL for name in all_task_names + all_obs_names - } # Observations attend to all tasks and previous observations causally + "task_*": AttentionRule.CAUSAL, + "obs_*": AttentionRule.CAUSAL, + } + # # First, add the task tokens + # + for name, tok in self.task_tokenizers.items(): + group_name = f"task_{name}" # Receive inputs from tokenizer and cast to embedding size - task_tokens = tok(observations, tasks, train=train) - task_tokens = nn.Dense(self.token_embedding_size)(task_tokens) - + tokenizer_output: TokenGroup = tok(observations, tasks, train=train) + if tokenizer_output is None: + logging.warning(f"Skipping task tokenizer: {group_name}") + continue + + task_tokens = nn.Dense( + self.token_embedding_size, name=f"{group_name}_projection" + )(tokenizer_output.tokens) # task_tokens shape is (batch, n_tokens, token_embedding_size) # Add positional embedding - task_pos_embedding = self._create_positional_embedding( - f"task_{name}", task_tokens.shape[1], prefix=True - ) - task_tokens += task_pos_embedding + task_tokens += self._create_positional_embedding(group_name, task_tokens) all_prefix_groups.append( - PrefixGroup(f"task_{name}", task_tokens, task_attention_rules) + PrefixGroup( + tokens=task_tokens, + mask=tokenizer_output.mask, + name=group_name, + attention_rules=task_attention_rules, + ) ) + # # Next, add the observation tokens + # + for name, tok in self.observation_tokenizers.items(): + group_name = f"obs_{name}" # Receive inputs from tokenizer and cast to embedding size - obs_tokens = tok(observations, tasks, train=train) - obs_tokens = nn.Dense(self.token_embedding_size)(obs_tokens) + tokenizer_output: TokenGroup = tok(observations, tasks, train=train) + if tokenizer_output is None: + logging.warning(f"Skipping observation tokenizer: {group_name}") + continue + + obs_tokens = nn.Dense( + self.token_embedding_size, name=f"{group_name}_projection" + )(tokenizer_output.tokens) # obs_tokens shape is (batch, horizon, n_tokens, token_embedding_size) # Add positional embedding - obs_pos_embedding = self._create_positional_embedding( - f"obs_{name}", obs_tokens.shape[2], prefix=False - ) - obs_tokens += obs_pos_embedding[:, :horizon, :, :] + obs_tokens += self._create_positional_embedding(group_name, obs_tokens) + + # Update mask to account for which timesteps are padding + obs_pad_mask = jnp.logical_and(pad_mask[:, :, None], tokenizer_output.mask) all_timestep_groups.append( - TimestepGroup(f"obs_{name}", obs_tokens, observation_attention_rules) + TimestepGroup( + tokens=obs_tokens, + mask=obs_pad_mask, + name=group_name, + attention_rules=observation_attention_rules, + ) ) - + # # Finally, add the readout tokens + # + for readout_name in readouts: - # Readouts do not correspond to any inputs, so we just create a bunch of zeros + group_name = f"readout_{readout_name}" + # Readouts do not correspond to any inputs, just positional embeddings n_tokens_for_readout = self.readouts[readout_name] readout_tokens = jnp.zeros( (batch_size, horizon, n_tokens_for_readout, self.token_embedding_size) ) # Add positional embedding - readout_pos_embedding = self._create_positional_embedding( - f"readout_{readout_name}", n_tokens_for_readout, prefix=False + readout_tokens += self._create_positional_embedding( + group_name, readout_tokens ) - readout_tokens += readout_pos_embedding[:, :horizon, :, :] - - attention_rules = { - **{ - name: AttentionRule.CAUSAL - for name in all_task_names + all_obs_names - }, - f"readout_{readout_name}": AttentionRule.CAUSAL, - } # Attend to tasks, all previous observations, and your own previous readout tokens + readout_mask = jnp.ones((batch_size, horizon, n_tokens_for_readout)) + readout_attention_rules = { + "task_*": AttentionRule.CAUSAL, + "obs_*": AttentionRule.CAUSAL, + group_name: AttentionRule.CAUSAL, + } # Attend to tasks, all previous observations, and *only it's own own readout* all_timestep_groups.append( TimestepGroup( - f"readout_{readout_name}", - readout_tokens, - attention_rules, + tokens=readout_tokens, + mask=readout_mask, + name=group_name, + attention_rules=readout_attention_rules, ) ) @@ -198,30 +231,34 @@ def __call__( self.transformer_kwargs.get("add_position_embedding", False) is False ), "Already added positional embeddings to the tokens" - prefix_outputs, timestep_outputs = BlockTransformer(**self.transformer_kwargs)( + prefix_outputs, timestep_outputs = BlockTransformer(self.transformer_kwargs)( all_prefix_groups, all_timestep_groups, - pad_mask, train=train, verbose=verbose, ) - - outputs = dict() - outputs.update({group.name: group.tokens for group in prefix_outputs}) + outputs = {} outputs.update( { - group.name.removeprefix("readout_"): group.tokens + group.name: TokenGroup(group.tokens, group.mask) + for group in prefix_outputs + } + ) + outputs.update( + { + group.name: TokenGroup(group.tokens, group.mask) for group in timestep_outputs } ) if len(prefix_outputs) > 0: - outputs["task"] = jnp.concatenate( - [group.tokens for group in prefix_outputs], axis=-2 + outputs["task"] = TokenGroup.concatenate( + [TokenGroup(group.tokens, group.mask) for group in prefix_outputs] ) - outputs["obs"] = jnp.concatenate( + + outputs["obs"] = TokenGroup.concatenate( [ - group.tokens + TokenGroup(group.tokens, group.mask) for group in timestep_outputs if group.name.startswith("obs_") ], @@ -230,16 +267,25 @@ def __call__( return outputs - def _create_positional_embedding(self, name, n_tokens, prefix=False): - if prefix: - shape = (1, n_tokens, self.token_embedding_size) + def _create_positional_embedding(self, name: str, tokens: jax.Array): + if tokens.ndim == 3: # for prefixes + shape = (1, *tokens.shape[-2:]) + elif ( + tokens.ndim == 4 + ): # for timesteps, create embedding for max_horizon, then truncate + shape = (1, self.max_horizon, *tokens.shape[-2:]) else: - shape = (1, self.max_horizon, n_tokens, self.token_embedding_size) - return self.param( + raise ValueError(f"Invalid tokens shape: {tokens.shape}") + + embedding = self.param( f"{name}_pos_embedding", - posemb_init, + nn.initializers.normal(stddev=0.02), shape, ) + if tokens.ndim == 4: + # Use only the timesteps we receive as input + embedding = embedding[:, : tokens.shape[1]] + return jnp.broadcast_to(embedding, tokens.shape) class OrcaModel(nn.Module): diff --git a/orca/utils/train_callbacks.py b/orca/utils/train_callbacks.py index 3c37c2e8..f77d53e3 100644 --- a/orca/utils/train_callbacks.py +++ b/orca/utils/train_callbacks.py @@ -112,14 +112,32 @@ def remove_text(tasks: Data, zero_text_encoding: Data): tasks["language_instruction"], zero_text_encoding, ) - tasks = flax.core.copy(tasks, {"language_instruction": new_language}) + new_pad_dict = flax.core.copy( + tasks["pad_mask_dict"], + { + "language_instruction": jnp.zeros_like( + tasks["pad_mask_dict"]["language_instruction"] + ) + }, + ) + tasks = flax.core.copy( + tasks, {"language_instruction": new_language, "pad_mask_dict": new_pad_dict} + ) return tasks def remove_images(tasks: Data): """Replaces images inside task dict with zero (black) images.""" - new_images = {k: jnp.zeros_like(v) for k, v in tasks.items() if "image" in k} - return flax.core.copy(tasks, new_images) + updates = {k: jnp.zeros_like(v) for k, v in tasks.items() if "image" in k} + updates["pad_mask_dict"] = flax.core.copy( + tasks["pad_mask_dict"], + { + k: jnp.zeros_like(v) + for k, v in tasks["pad_mask_dict"].items() + if "image" in k + }, + ) + return flax.core.copy(tasks, updates) @partial(jax.jit, static_argnames=("samples_per_state", "policy_mode")) @@ -215,11 +233,12 @@ def eval_step(state, batch): if "base" in self.modes_to_evaluate: all_tasks["base"] = batch["tasks"] if "image_conditioned" in self.modes_to_evaluate: - all_tasks["text_conditioned"] = remove_images(batch["tasks"]) - if "text_conditioned" in self.modes_to_evaluate: all_tasks["image_conditioned"] = remove_text( batch["tasks"], self.zero_text ) + if "text_conditioned" in self.modes_to_evaluate: + all_tasks["text_conditioned"] = remove_images(batch["tasks"]) + if "unconditioned" in self.modes_to_evaluate: all_tasks["unconditioned"] = remove_text( remove_images(batch["tasks"]), self.zero_text diff --git a/orca/utils/visualization_lib.py b/orca/utils/visualization_lib.py index c5e57b12..f896995c 100644 --- a/orca/utils/visualization_lib.py +++ b/orca/utils/visualization_lib.py @@ -97,6 +97,9 @@ def run_policy_on_trajectory(policy_fn, traj, *, text_processor=None): tasks["language_instruction"] = text_processor.encode( [s.decode("utf-8") for s in traj["tasks"]["language_instruction"]] ) + tasks["pad_mask_dict"]["language_instruction"] = np.array( + [len(s.decode("utf-8")) > 0 for s in traj["tasks"]["language_instruction"]] + ) actions = policy_fn(traj["observation"], tasks)