From 10dc0246637753ec572fcf471ca7ddaa09527b5e Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 5 Apr 2021 14:02:38 -0700 Subject: [PATCH] fix initial positions at text_seq_len --- src/dalle_mtf/sample.py | 10 ++-------- test.py | 3 ++- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index fd9b430..486fedd 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -51,14 +51,8 @@ def sample_autoregressive(inputs, batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] - initial_position = mtf.reduce_sum( - mtf.to_int32(mtf.not_equal(inputs, padding_id)), - reduced_dim=length_dim) # Gets position where zero padding starts - - if min_start_pos is not None: - # force the sampling to never start below a minimum starting position, say the text length. - # this will also be useful for image completion, where you can start sampling from half the image tokens - initial_position = mtf.maximum(initial_position, min_start_pos) + # Gets position (in image inputs) where zero padding starts + initial_position = mtf.zeros(inputs.mesh, batch_dims, dtype=tf.int32) + model.text_seq_len length_range = mtf.range(inputs.mesh, length_dim, tf.int32) diff --git a/test.py b/test.py index 6196995..e5d929d 100644 --- a/test.py +++ b/test.py @@ -56,7 +56,7 @@ def test_sampling(): model = DALLE( batch_size = 1, - text_seq_len = 1, + text_seq_len = 3, image_seq_len = 4, n_embd = 16, n_heads = 2, @@ -82,3 +82,4 @@ def test_sampling(): mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) samples = lowering.export_to_tf_tensor(samples) + print(samples) \ No newline at end of file