Skip to content

Commit

Permalink
fix initial positions at text_seq_len
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 5, 2021
1 parent 4dbf727 commit 10dc024
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 9 deletions.
10 changes: 2 additions & 8 deletions src/dalle_mtf/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,8 @@ def sample_autoregressive(inputs,
batch_dims = inputs.shape.dims[:-1]
length_dim = inputs.shape.dims[-1]

initial_position = mtf.reduce_sum(
mtf.to_int32(mtf.not_equal(inputs, padding_id)),
reduced_dim=length_dim) # Gets position where zero padding starts

if min_start_pos is not None:
# force the sampling to never start below a minimum starting position, say the text length.
# this will also be useful for image completion, where you can start sampling from half the image tokens
initial_position = mtf.maximum(initial_position, min_start_pos)
# Gets position (in image inputs) where zero padding starts
initial_position = mtf.zeros(inputs.mesh, batch_dims, dtype=tf.int32) + model.text_seq_len

length_range = mtf.range(inputs.mesh, length_dim, tf.int32)

Expand Down
3 changes: 2 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_sampling():

model = DALLE(
batch_size = 1,
text_seq_len = 1,
text_seq_len = 3,
image_seq_len = 4,
n_embd = 16,
n_heads = 2,
Expand All @@ -82,3 +82,4 @@ def test_sampling():
mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
samples = lowering.export_to_tf_tensor(samples)
print(samples)

0 comments on commit 10dc024

Please sign in to comment.