forked from octo-models/octo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathocto_module.py
427 lines (363 loc) · 17 KB
/
octo_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
# Written by Dibya
import logging
from typing import Dict, Optional
import flax.linen as nn
import jax
import jax.numpy as jnp
from octo.model.components.base import TokenGroup
from octo.model.components.block_transformer import (
AttentionRule,
BlockTransformer,
PrefixGroup,
TimestepGroup,
)
from octo.utils.spec import ModuleSpec
from octo.utils.typing import Data, Sequence
class OctoTransformer(nn.Module):
"""
This module forms the base of the Octo architecture.
The core idea is to run a causal transformer on the following sequence,
[task, observation 0, observation 1, observation 2, ...]
The task is tokenized using a set of *task tokenizers* (for example, a tokenizer that processes the
language instruction into tokens, or one that processes the goal images into tokens).
The observation at each timestep is tokenized using a set of *observation tokenizers*
(for example, a tokenizer that processes the primary image into tokens, or one that processes
the wrist image into tokens).
We introduce additional tokens ("readouts") that "read out" the information in the transformer for
downstream action or value prediction. For example, we may have an "action" readout that provides
embeddings that are useful for predicting actions, and a "value" readout with embeddings that are useful
for predicting values.
The transformer is a blockwise-causal transformer, where each timestep only attends to the same or
previous timesteps. The easiest way to understand how the model works is to run:
```
>>> model(observations, tasks, timestep_pad_mask, verbose=True)
```
Generally, the model runs the transformer on something like the following sequence:
[
<task language tokens>,
<t=0 "image_primary" tokens>, <t=0 "image_wrist" tokens>, <t=0 readout_action tokens>, ...
<t=1 "image_primary" tokens>, <t=1 "image_wrist" tokens>, <t=1 readout_action tokens>, ...
<t=2 "image_primary" tokens>, <t=2 "image_wrist" tokens>, <t=2 readout_action tokens>, ...
...
]
The observation tokens attend to the task prefix, and to all observation tokens in the same or previous
timesteps. So, "image_wrist" can attend to "image_primary" and vice versa.
Readouts provide a mechanism for "reading out" the information in the transformer. They are designed to
only *read* from the sequence before it, without the ability to influence (i.e. write) the computation for
any of the non-readout tokens. By design, different readouts (e.g. "action" vs "value") are completely
independent of each other, meaning they can be run separately without affecting each other.
Args:
observations_tokenizers (Dict[str, nn.Module]): Dictionary of flax modules for tokenizing the observations.
The output of each tokenizer is concatenated to form the observation tokens.
task_tokenizers (Dict[str, nn.Module]): Dictionary of flax modules for tokenizing the task.
The output of each tokenizer is concatenated to form the task token prefix.
readouts (Dict[str, int]): Dictionary of {readout_name: n_tokens_for_readout}.
transformer_kwargs (Dict): Dictionary of kwargs to forward to the Transformer.
token_embedding_size (int): Dimension of the token embeddings
max_horizon (int): The maximum number of timesteps that the transformer can be run with. Note that while the
transformer can be run with any horizon <= max_horizon, the model will only generate sane outputs for
horizon lengths smaller or equal to the pre-training horizon.
repeat_task_tokens: If true, repeats the task tokens at each observation timesetep.
"""
observation_tokenizers: Dict[str, nn.Module]
task_tokenizers: Dict[str, nn.Module]
readouts: Dict[str, int]
transformer_kwargs: Dict
token_embedding_size: int
max_horizon: int
repeat_task_tokens: bool
use_correct_attention: bool = False
@nn.compact
def __call__(
self,
observations: Data,
tasks: Data,
timestep_pad_mask: jax.Array,
readouts: Optional[Sequence[str]] = None,
train: bool = False,
verbose: bool = False,
) -> Dict[str, TokenGroup]:
"""
Args:
observations: A dictionary containing observation data for a batch of trajectory windows.
Each entry has shape (batch, horizon, *).
tasks: A dictionary containing task data for the trajectory windows.
Each entry has shape (batch, *).
timestep_pad_mask: A boolean mask of shape (batch, horizon) where False indicates a padded timestep.
readouts: A list of readouts to compute. If None, defaults to all readouts. Must be a subset of the readouts specified in the model config.
train: Whether model is being trained.
verbose: If True, prints out the transformer structure.
Returns:
transformer_outputs: A dictionary {token_group_name: token_group},
which contain the transformer embeddings for all observation tokens, task tokens, and readout tokens.
The special keys "task" and "obs" contain the concatenated embeddings for all task tokens and observation tokens, respectively.
Note: Horizon can be anything <= max_horizon.
"""
if readouts is None:
readouts = list(self.readouts.keys())
#
# Check that all inputs are valid
#
assert set(readouts).issubset(
set(self.readouts.keys())
), "readouts must be specified in the model config"
batch_size, horizon = jax.tree_util.tree_leaves(observations)[0].shape[:2]
assert horizon <= self.max_horizon, "horizon must be <= max_horizon"
assert jax.tree_util.tree_all(
jax.tree_map(lambda x: x.shape[1] == horizon, observations)
), "observations must have the same horizon"
#
# Attention rules for the transformer
#
# Tasks attend to all other tasks, but not to observations or readouts
task_attention_rules = {"task_*": AttentionRule.CAUSAL}
# Observations attend to all tasks and all other observations tokens causally,
# e.g. at same timestep or before, but do not attend to readouts
observation_attention_rules = {
"task_*": AttentionRule.CAUSAL,
"obs_*": AttentionRule.CAUSAL,
}
#
# Create inputs for the transformer
#
all_prefix_groups = []
all_timestep_groups = []
#
# First, add the task tokens
#
for name, tok in self.task_tokenizers.items():
group_name = f"task_{name}"
# Receive inputs from tokenizer and cast to embedding size
tokenizer_output: TokenGroup = tok(observations, tasks, train=train)
if tokenizer_output is None:
logging.warning(f"Skipping task tokenizer: {group_name}")
continue
task_tokens = nn.Dense(
self.token_embedding_size, name=f"{group_name}_projection"
)(tokenizer_output.tokens)
# task_tokens shape is (batch, n_tokens, token_embedding_size)
# Add positional embedding
task_tokens += self._create_positional_embedding(group_name, task_tokens)
all_prefix_groups.append(
PrefixGroup(
tokens=task_tokens,
mask=tokenizer_output.mask,
name=group_name,
attention_rules=task_attention_rules,
)
)
#
# Next, add the observation tokens
#
for name, tok in self.observation_tokenizers.items():
group_name = f"obs_{name}"
# Receive inputs from tokenizer and cast to embedding size
tokenizer_output: TokenGroup = tok(observations, tasks, train=train)
if tokenizer_output is None:
logging.warning(f"Skipping observation tokenizer: {group_name}")
continue
obs_tokens = nn.Dense(
self.token_embedding_size, name=f"{group_name}_projection"
)(tokenizer_output.tokens)
# obs_tokens shape is (batch, horizon, n_tokens, token_embedding_size)
# Add positional embedding
obs_tokens += self._create_positional_embedding(group_name, obs_tokens)
# Update mask to account for which timesteps are padding
obs_pad_mask = jnp.logical_and(
timestep_pad_mask[:, :, None], tokenizer_output.mask
)
all_timestep_groups.append(
TimestepGroup(
tokens=obs_tokens,
mask=obs_pad_mask,
name=group_name,
attention_rules=observation_attention_rules,
)
)
if self.repeat_task_tokens:
logging.info(
"repeating task tokens at each timestep to perform cross-modal attention"
)
# get task tokens
for tasks in all_prefix_groups:
# lang (batch, n_tokens, token_embedding_size)
task_tokens = tasks.tokens[:, jnp.newaxis, :, :]
ws = all_timestep_groups[0].tokens.shape[1]
task_tokens = jnp.tile(task_tokens, [1, ws, 1, 1])
task_pad_mask = tasks.mask[:, jnp.newaxis, :]
task_pad_mask = jnp.tile(task_pad_mask, [1, ws, 1])
group_name = f"obs_{tasks.name}"
all_timestep_groups.append(
TimestepGroup(
tokens=task_tokens,
mask=task_pad_mask,
name=group_name,
attention_rules=observation_attention_rules,
)
)
#
# Finally, add the readout tokens
#
for readout_name in readouts:
group_name = f"readout_{readout_name}"
# Readouts do not correspond to any inputs, just positional embeddings
n_tokens_for_readout = self.readouts[readout_name]
readout_tokens = jnp.zeros(
(batch_size, horizon, n_tokens_for_readout, self.token_embedding_size)
)
# Add positional embedding
readout_tokens += self._create_positional_embedding(
group_name, readout_tokens
)
readout_mask = jnp.ones((batch_size, horizon, n_tokens_for_readout))
readout_attention_rules = {
"task_*": AttentionRule.CAUSAL,
"obs_*": AttentionRule.CAUSAL,
group_name: AttentionRule.CAUSAL,
} # Attend to tasks, all previous observations, and *only it's own own readout*
all_timestep_groups.append(
TimestepGroup(
tokens=readout_tokens,
mask=readout_mask,
name=group_name,
attention_rules=readout_attention_rules,
)
)
# Run the transformer!
assert (
self.transformer_kwargs.get("add_position_embedding", False) is False
), "Already added positional embeddings to the tokens"
prefix_outputs, timestep_outputs = BlockTransformer(
self.transformer_kwargs, use_correct_attention=self.use_correct_attention
)(
all_prefix_groups,
all_timestep_groups,
train=train,
verbose=verbose,
)
outputs = {}
outputs.update(
{
group.name: TokenGroup(group.tokens, group.mask)
for group in prefix_outputs
}
)
outputs.update(
{
group.name: TokenGroup(group.tokens, group.mask)
for group in timestep_outputs
}
)
if len(prefix_outputs) > 0:
outputs["task"] = TokenGroup.concatenate(
[TokenGroup(group.tokens, group.mask) for group in prefix_outputs]
)
outputs["obs"] = TokenGroup.concatenate(
[
TokenGroup(group.tokens, group.mask)
for group in timestep_outputs
if group.name.startswith("obs_")
],
axis=-2,
)
return outputs
def _create_positional_embedding(self, name: str, tokens: jax.Array):
if tokens.ndim == 3: # for prefixes
shape = (1, *tokens.shape[-2:])
elif (
tokens.ndim == 4
): # for timesteps, create embedding for max_horizon, then truncate
shape = (1, self.max_horizon, *tokens.shape[-2:])
else:
raise ValueError(f"Invalid tokens shape: {tokens.shape}")
embedding = self.param(
f"{name}_pos_embedding",
nn.initializers.normal(stddev=0.02),
shape,
)
if tokens.ndim == 4:
# Use only the timesteps we receive as input
embedding = embedding[:, : tokens.shape[1]]
return jnp.broadcast_to(embedding, tokens.shape)
class OctoModule(nn.Module):
"""
Bundles OctoTransformer with various heads (useful for keeping all parameters in one place).
"""
octo_transformer: OctoTransformer
heads: Dict[str, nn.Module]
def __call__(
self, observations, tasks, timestep_pad_mask, train=True, verbose=False
):
"""Run transformer and the main method for all heads. Useful for init.
Args:
observations: A dictionary containing observation data
where each element has shape (batch, horizon, *).
tasks: A dictionary containing task data
where each element has shape (batch, *).
timestep_pad_mask: A boolean mask of shape (batch, horizon) where False indicates a padded timestep.
train: Run in training mode
verbose: If True, prints out the structure of the OctoTransformer (useful for debugging!)
Returns:
transformer_outputs: See OctoTransformer.__call__
head_outputs: dictionary of outputs from heads {head_name: output}
"""
transformer_outputs = self.octo_transformer(
observations, tasks, timestep_pad_mask, train=train, verbose=verbose
)
head_outputs = {}
for head_name, head in self.heads.items():
head_outputs[head_name] = head(transformer_outputs, train=train)
return transformer_outputs, head_outputs
@classmethod
def create(
cls,
observation_tokenizers: Dict[str, ModuleSpec],
task_tokenizers: Dict[str, ModuleSpec],
heads: Dict[str, ModuleSpec],
readouts: Dict[str, int],
transformer_kwargs: Dict,
token_embedding_size: int,
max_horizon: int,
repeat_task_tokens: bool = False,
use_correct_attention: bool = False,
) -> "OctoModule":
"""
Canonical way to create an OctoModule from configuration.
Args:
observation_tokenizers: dict of {tokenizer_name: tokenizer_spec} (see tokenizers.py)
task_tokenizers: dict of {tokenizer_name: tokenizer_spec} (see tokenizers.py)
heads: dict of {head_name: head_spec} (see heads.py)
readouts: dict of {readout_name (str): n_tokens_for_readout (int)}
token_embedding_size (int): The latent dimension of the token embeddings
max_horizon (int): Sets the size of positional embeddings, and provides an upper limit on the
maximum horizon of the model
repeat_task_tokens (bool): If true, repeats the task tokens at each observation timestep.
transformer_kwargs: additional kwargs to forward to the transformer, which include:
num_layers (int): number of layers
mlp_dim (int): hidden dimension of the MLPs
num_heads (int): Number of heads in nn.MultiHeadDotProductAttention
dropout_rate (float): dropout rate.
attention_dropout_rate (float): dropout rate in self attention.
"""
observation_tokenizer_defs = {
k: ModuleSpec.instantiate(spec)()
for k, spec in observation_tokenizers.items()
}
task_tokenizer_defs = {
k: ModuleSpec.instantiate(spec)() for k, spec in task_tokenizers.items()
}
head_defs = {k: ModuleSpec.instantiate(spec)() for k, spec in heads.items()}
model_def = OctoTransformer(
observation_tokenizers=observation_tokenizer_defs,
task_tokenizers=task_tokenizer_defs,
readouts=readouts,
token_embedding_size=token_embedding_size,
max_horizon=max_horizon,
repeat_task_tokens=repeat_task_tokens,
transformer_kwargs=transformer_kwargs,
use_correct_attention=use_correct_attention,
)
return cls(
octo_transformer=model_def,
heads=head_defs,
)