From ec687535cccf820f7725ef4ca80bcb24a9992b3b Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 5 Apr 2021 16:11:34 -0700 Subject: [PATCH] more cleanup --- src/dalle_mtf/sample.py | 26 -------------------------- src/model_fns.py | 6 +----- test.py | 5 +---- 3 files changed, 2 insertions(+), 35 deletions(-) diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py index 486fedd..68128de 100644 --- a/src/dalle_mtf/sample.py +++ b/src/dalle_mtf/sample.py @@ -5,14 +5,10 @@ def sample_autoregressive(inputs, model, - stop_at_token=50256, max_steps=None, temperature=0.9, - padding_id = 0, - min_start_pos = None, variable_dtype=mtf.VariableDType(tf.float32), has_partial_sequences=True, - remove_partial_sequences=False, sampling_keep_top_k=-1, ): """Sample randomly one token at a time. @@ -87,25 +83,10 @@ def sample_autoregressive(inputs, if not has_partial_sequences: partial_sequences_eos_count = 0 - if stop_at_token is not None: - partial_sequences_eos_count = mtf.reduce_sum( - mtf.to_int32(mtf.equal(inputs, stop_at_token)), - reduced_dim=length_dim) - def cond_fn(position, ids, *unused_states): """Should we run another loop iteration?""" past_end = mtf.greater_equal(position, length_dim.size) - if max_steps: - past_end = mtf.logical_or( - past_end, mtf.greater_equal(position - initial_position, max_steps)) - is_done = past_end - if stop_at_token is not None: - eos_count = mtf.reduce_sum( - mtf.to_int32(mtf.equal(ids, stop_at_token)), - reduced_dim=length_dim) - has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) - is_done = mtf.logical_or(is_done, has_additional_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done) @@ -169,11 +150,4 @@ def body_fn(position, ids, *states): final_position, outputs = mtf.while_loop( cond_fn, body_fn, while_loop_inputs)[:2] del final_position - if has_partial_sequences and remove_partial_sequences: - # Remove partial sequences from outputs - partial_length = mtf.reduce_sum( - mtf.to_int32(mtf.not_equal(inputs, padding_id)), - reduced_dim=length_dim) - outputs = mtf.dynamic_shift( - outputs, -partial_length, length_dim, wrap=False) return outputs diff --git a/src/model_fns.py b/src/model_fns.py index 9ba11d8..3f1e243 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -152,14 +152,10 @@ def dalle_model_fn(features, labels, mode, params): mtf_samples = sample_autoregressive(inputs, model, - max_steps=model.total_seq_dim, # will always run until the full image is produced - stop_at_token=None, temperature=0.9, - padding_id = 0, variable_dtype=model.variable_dtype, has_partial_sequences=True, - remove_partial_sequences=True, - sampling_keep_top_k=-1, + sampling_keep_top_k=-2, ) mtf_samples = mtf.anonymize(mtf_samples) diff --git a/test.py b/test.py index e5d929d..d954d76 100644 --- a/test.py +++ b/test.py @@ -73,10 +73,7 @@ def test_sampling(): inputs, model, variable_dtype=mtf.VariableDType(), - max_steps = sequence_dim.size, - remove_partial_sequences=False, - stop_at_token=None, - min_start_pos=model.text_seq_len + max_steps = sequence_dim.size ) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])