From 799897c261a2e0a3b0898cdfdb4eef3b668e41f0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 3 Nov 2023 12:54:50 -0700 Subject: [PATCH] Support stopping inference early when there is an EOS token PiperOrigin-RevId: 579277158 --- official/projects/pix2seq/configs/pix2seq.py | 1 + .../pix2seq/modeling/pix2seq_model.py | 49 ++++++-- .../pix2seq/modeling/pix2seq_model_test.py | 113 ++++++++++++++++++ .../projects/pix2seq/tasks/pix2seq_task.py | 1 + 4 files changed, 152 insertions(+), 12 deletions(-) diff --git a/official/projects/pix2seq/configs/pix2seq.py b/official/projects/pix2seq/configs/pix2seq.py index 79f96857bba..410dd370e7e 100644 --- a/official/projects/pix2seq/configs/pix2seq.py +++ b/official/projects/pix2seq/configs/pix2seq.py @@ -132,6 +132,7 @@ class Pix2Seq(hyperparams.Config): temperature: float = 1.0 top_k: int = 0 top_p: float = 0.4 + eos_token: int | None = None @dataclasses.dataclass diff --git a/official/projects/pix2seq/modeling/pix2seq_model.py b/official/projects/pix2seq/modeling/pix2seq_model.py index 9eaf105901c..b4bf59d3e66 100644 --- a/official/projects/pix2seq/modeling/pix2seq_model.py +++ b/official/projects/pix2seq/modeling/pix2seq_model.py @@ -240,6 +240,7 @@ def __init__( temperature=1.0, top_k=0, top_p=0.4, + eos_token: int | None = None, **kwargs ): super().__init__(**kwargs) @@ -280,6 +281,7 @@ def __init__( self._temperature = temperature self._top_k = top_k self._top_p = top_p + self._eos_token = eos_token @property def backbone(self) -> tf_keras.Model: @@ -304,6 +306,7 @@ def get_config(self): "temperature": self._temperature, "top_k": self._top_k, "top_p": self._top_p, + "eos_token": self._eos_token, "num_heads": self._num_heads, } @@ -369,11 +372,36 @@ def call( temperature=self._temperature, top_k=self._top_k, top_p=self._top_p, + eos_token=self._eos_token, ) return [tokens, logits] +def _create_cond_fn(seq_len: int, eos_token: int | None, prompt_len: int): + """Returns a loop condition for decoder. + + Args: + seq_len: the maximum sequence length. + eos_token: if not None, enable early termination based on end-of-sequence + token. + prompt_len: the length of prompt sequence. + """ + + def cond(step, caches, tokens, logits): + del caches + del logits + within_seq_len = (seq_len > prompt_len) & (step < seq_len - 1) + if eos_token is None: + return within_seq_len + else: + tokens = tokens[prompt_len:step] + reached_eos = tf.reduce_all(tf.reduce_any(tokens == eos_token, axis=0)) + return within_seq_len & tf.logical_not(reached_eos) + + return cond + + class Pix2SeqTransformer(tf_keras.layers.Layer): """Encoder and Decoder of Pix2Seq.""" @@ -521,6 +549,7 @@ def infer( top_k=0, top_p=0.4, sampling_callback=None, + eos_token: int | None = None, ): """Autoregressive (without teacher-forcing) prediction. @@ -542,6 +571,10 @@ def infer( sampling_callback: a callbak `function` that take `next_logits`, and return `next_token`. This is used when users need a specific logic for sampling. Default to `None` with standard free-form sampling. + eos_token: if not None, stop inference early based on this end-of-sequence + (EOS) token. This won't change sequence length. However, for each + sequence, the tokens and logit values after the EOS token will have + undefined behavior based on implementation detail. Returns: sampled tokens with shape of (bsz, max_seq_len-prompt_len). @@ -627,15 +660,6 @@ def loop_body(step, caches, tokens, logits, is_prompt=False): logits = tf.tensor_scatter_nd_update(logits, [[next_step]], [next_logits]) return (next_step, caches, tokens, logits) - def cond(step, caches, tokens, logits): - del caches - del tokens - del logits - return tf.logical_and( - tf.greater(seq_len, prompt_len), - tf.less(step, seq_len - 1) - ) - caches_var = tf.zeros( [seq_len-1, self._num_decoder_layers, bsz, self._hidden_size]) tokens_var = tf.zeros([seq_len, bsz], dtype=tf.int64) @@ -649,11 +673,12 @@ def cond(step, caches, tokens, logits): step, caches_var, tokens_var, logits_var = loop_body( step, caches_var, tokens_var, logits_var, is_prompt=True ) - _, _, tokens_var, logits_var = tf.while_loop( - cond=cond, + cond=_create_cond_fn( + seq_len=seq_len, eos_token=eos_token, prompt_len=prompt_len + ), body=loop_body, - loop_vars=[step, caches_var, tokens_var, logits_var] + loop_vars=[step, caches_var, tokens_var, logits_var], ) sampled_tokens = tf.transpose(tokens_var[prompt_len:], [1, 0]) diff --git a/official/projects/pix2seq/modeling/pix2seq_model_test.py b/official/projects/pix2seq/modeling/pix2seq_model_test.py index 245d12539d7..7fe7a340665 100644 --- a/official/projects/pix2seq/modeling/pix2seq_model_test.py +++ b/official/projects/pix2seq/modeling/pix2seq_model_test.py @@ -96,6 +96,119 @@ def test_forward_infer(self): self.assertLen(tokens, 2) # intermediate decoded outputs. + def test_forward_infer_with_eos(self): + hidden_size = 256 + num_heads = 8 + max_seq_len = 50 + vocab_size = 600 + image_size = 640 + batch_size = 2 + backbone = resnet.ResNet(50, bn_trainable=False) + backbone_endpoint_name = '5' + model = pix2seq_model.Pix2Seq( + backbone, + backbone_endpoint_name, + max_seq_len, + vocab_size, + hidden_size, + num_heads=num_heads, + eos_token=0, + ) + tokens, _ = model( + tf.ones((batch_size, image_size, image_size, 3)), + tf.ones((batch_size, 1), tf.int64) * 10, + False, + ) + + self.assertLen(tokens, 2) # intermediate decoded outputs. + + def test_cond_fn_without_early_stopping(self): + tokens = tf.constant( + # pyformat: disable + [ + [0, 0, 0], + [0, 0, 0], + [0, 1, 0], + [1, 0, 0], + [0, 0, 1], # Should not stop early. + [0, 0, 0], + [0, 0, 0], # Should stop inference here. + ], + # pyformat: enable + dtype=tf.int64 + ) + cond = pix2seq_model._create_cond_fn( + seq_len=tokens.shape[0], + eos_token=None, + prompt_len=1, + ) + expected_results = [True, True, True, True, True, True, False] + + self.assertLen(expected_results, tokens.shape[0]) + for step, expected_result in enumerate(expected_results): + self.assertEqual( + expected_result, + cond(step, None, tokens, None), + msg=f'step={step}', + ) + + def test_cond_fn_with_early_stopping(self): + tokens = tf.constant( + # pyformat: disable + [ + [0, 0, 0], + [0, 0, 0], + [0, 1, 0], + [1, 0, 0], + [0, 0, 1], # Should stop inference here. + [0, 0, 0], + [0, 0, 0], + ], + # pyformat: enable + dtype=tf.int64 + ) + cond = pix2seq_model._create_cond_fn( + seq_len=tokens.shape[0], + eos_token=1, + prompt_len=1, + ) + expected_results = [True, True, True, True, True, False, False] + + self.assertLen(expected_results, tokens.shape[0]) + for step, expected_result in enumerate(expected_results): + self.assertEqual( + expected_result, + cond(step, None, tokens, None), + msg=f'step={step}', + ) + + def test_cond_fn_with_early_stopping_keep_inference_to_end(self): + tokens = tf.constant( + # pyformat: disable + [ + [1, 1, 1], # EOS within prompt should be ignored. + [0, 0, 0], + [0, 1, 0], + [1, 0, 0], # Should keep inferencing until the end. + ], + # pyformat: enable + dtype=tf.int64 + ) + cond = pix2seq_model._create_cond_fn( + seq_len=tokens.shape[0], + eos_token=1, + prompt_len=1, + ) + expected_results = [True, True, True, False] + + self.assertLen(expected_results, tokens.shape[0]) + for step, expected_result in enumerate(expected_results): + self.assertEqual( + expected_result, + cond(step, None, tokens, None), + msg=f'step={step}', + ) + if __name__ == '__main__': tf.test.main() diff --git a/official/projects/pix2seq/tasks/pix2seq_task.py b/official/projects/pix2seq/tasks/pix2seq_task.py index d8b492b4416..cb8bf746ecc 100644 --- a/official/projects/pix2seq/tasks/pix2seq_task.py +++ b/official/projects/pix2seq/tasks/pix2seq_task.py @@ -73,6 +73,7 @@ def build_model(self): temperature=config.temperature, top_p=config.top_p, top_k=config.top_k, + eos_token=config.eos_token, ) return model