Skip to content

Commit

Permalink
Support stopping inference early when there is an EOS token
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 579277158
  • Loading branch information
tensorflower-gardener committed Nov 3, 2023
1 parent bf77e8d commit 799897c
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 12 deletions.
1 change: 1 addition & 0 deletions official/projects/pix2seq/configs/pix2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class Pix2Seq(hyperparams.Config):
temperature: float = 1.0
top_k: int = 0
top_p: float = 0.4
eos_token: int | None = None


@dataclasses.dataclass
Expand Down
49 changes: 37 additions & 12 deletions official/projects/pix2seq/modeling/pix2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def __init__(
temperature=1.0,
top_k=0,
top_p=0.4,
eos_token: int | None = None,
**kwargs
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -280,6 +281,7 @@ def __init__(
self._temperature = temperature
self._top_k = top_k
self._top_p = top_p
self._eos_token = eos_token

@property
def backbone(self) -> tf_keras.Model:
Expand All @@ -304,6 +306,7 @@ def get_config(self):
"temperature": self._temperature,
"top_k": self._top_k,
"top_p": self._top_p,
"eos_token": self._eos_token,
"num_heads": self._num_heads,
}

Expand Down Expand Up @@ -369,11 +372,36 @@ def call(
temperature=self._temperature,
top_k=self._top_k,
top_p=self._top_p,
eos_token=self._eos_token,
)

return [tokens, logits]


def _create_cond_fn(seq_len: int, eos_token: int | None, prompt_len: int):
"""Returns a loop condition for decoder.
Args:
seq_len: the maximum sequence length.
eos_token: if not None, enable early termination based on end-of-sequence
token.
prompt_len: the length of prompt sequence.
"""

def cond(step, caches, tokens, logits):
del caches
del logits
within_seq_len = (seq_len > prompt_len) & (step < seq_len - 1)
if eos_token is None:
return within_seq_len
else:
tokens = tokens[prompt_len:step]
reached_eos = tf.reduce_all(tf.reduce_any(tokens == eos_token, axis=0))
return within_seq_len & tf.logical_not(reached_eos)

return cond


class Pix2SeqTransformer(tf_keras.layers.Layer):
"""Encoder and Decoder of Pix2Seq."""

Expand Down Expand Up @@ -521,6 +549,7 @@ def infer(
top_k=0,
top_p=0.4,
sampling_callback=None,
eos_token: int | None = None,
):
"""Autoregressive (without teacher-forcing) prediction.
Expand All @@ -542,6 +571,10 @@ def infer(
sampling_callback: a callbak `function` that take `next_logits`, and
return `next_token`. This is used when users need a specific logic for
sampling. Default to `None` with standard free-form sampling.
eos_token: if not None, stop inference early based on this end-of-sequence
(EOS) token. This won't change sequence length. However, for each
sequence, the tokens and logit values after the EOS token will have
undefined behavior based on implementation detail.
Returns:
sampled tokens with shape of (bsz, max_seq_len-prompt_len).
Expand Down Expand Up @@ -627,15 +660,6 @@ def loop_body(step, caches, tokens, logits, is_prompt=False):
logits = tf.tensor_scatter_nd_update(logits, [[next_step]], [next_logits])
return (next_step, caches, tokens, logits)

def cond(step, caches, tokens, logits):
del caches
del tokens
del logits
return tf.logical_and(
tf.greater(seq_len, prompt_len),
tf.less(step, seq_len - 1)
)

caches_var = tf.zeros(
[seq_len-1, self._num_decoder_layers, bsz, self._hidden_size])
tokens_var = tf.zeros([seq_len, bsz], dtype=tf.int64)
Expand All @@ -649,11 +673,12 @@ def cond(step, caches, tokens, logits):
step, caches_var, tokens_var, logits_var = loop_body(
step, caches_var, tokens_var, logits_var, is_prompt=True
)

_, _, tokens_var, logits_var = tf.while_loop(
cond=cond,
cond=_create_cond_fn(
seq_len=seq_len, eos_token=eos_token, prompt_len=prompt_len
),
body=loop_body,
loop_vars=[step, caches_var, tokens_var, logits_var]
loop_vars=[step, caches_var, tokens_var, logits_var],
)

sampled_tokens = tf.transpose(tokens_var[prompt_len:], [1, 0])
Expand Down
113 changes: 113 additions & 0 deletions official/projects/pix2seq/modeling/pix2seq_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,119 @@ def test_forward_infer(self):

self.assertLen(tokens, 2) # intermediate decoded outputs.

def test_forward_infer_with_eos(self):
hidden_size = 256
num_heads = 8
max_seq_len = 50
vocab_size = 600
image_size = 640
batch_size = 2
backbone = resnet.ResNet(50, bn_trainable=False)
backbone_endpoint_name = '5'
model = pix2seq_model.Pix2Seq(
backbone,
backbone_endpoint_name,
max_seq_len,
vocab_size,
hidden_size,
num_heads=num_heads,
eos_token=0,
)
tokens, _ = model(
tf.ones((batch_size, image_size, image_size, 3)),
tf.ones((batch_size, 1), tf.int64) * 10,
False,
)

self.assertLen(tokens, 2) # intermediate decoded outputs.

def test_cond_fn_without_early_stopping(self):
tokens = tf.constant(
# pyformat: disable
[
[0, 0, 0],
[0, 0, 0],
[0, 1, 0],
[1, 0, 0],
[0, 0, 1], # Should not stop early.
[0, 0, 0],
[0, 0, 0], # Should stop inference here.
],
# pyformat: enable
dtype=tf.int64
)
cond = pix2seq_model._create_cond_fn(
seq_len=tokens.shape[0],
eos_token=None,
prompt_len=1,
)
expected_results = [True, True, True, True, True, True, False]

self.assertLen(expected_results, tokens.shape[0])
for step, expected_result in enumerate(expected_results):
self.assertEqual(
expected_result,
cond(step, None, tokens, None),
msg=f'step={step}',
)

def test_cond_fn_with_early_stopping(self):
tokens = tf.constant(
# pyformat: disable
[
[0, 0, 0],
[0, 0, 0],
[0, 1, 0],
[1, 0, 0],
[0, 0, 1], # Should stop inference here.
[0, 0, 0],
[0, 0, 0],
],
# pyformat: enable
dtype=tf.int64
)
cond = pix2seq_model._create_cond_fn(
seq_len=tokens.shape[0],
eos_token=1,
prompt_len=1,
)
expected_results = [True, True, True, True, True, False, False]

self.assertLen(expected_results, tokens.shape[0])
for step, expected_result in enumerate(expected_results):
self.assertEqual(
expected_result,
cond(step, None, tokens, None),
msg=f'step={step}',
)

def test_cond_fn_with_early_stopping_keep_inference_to_end(self):
tokens = tf.constant(
# pyformat: disable
[
[1, 1, 1], # EOS within prompt should be ignored.
[0, 0, 0],
[0, 1, 0],
[1, 0, 0], # Should keep inferencing until the end.
],
# pyformat: enable
dtype=tf.int64
)
cond = pix2seq_model._create_cond_fn(
seq_len=tokens.shape[0],
eos_token=1,
prompt_len=1,
)
expected_results = [True, True, True, False]

self.assertLen(expected_results, tokens.shape[0])
for step, expected_result in enumerate(expected_results):
self.assertEqual(
expected_result,
cond(step, None, tokens, None),
msg=f'step={step}',
)


if __name__ == '__main__':
tf.test.main()
1 change: 1 addition & 0 deletions official/projects/pix2seq/tasks/pix2seq_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def build_model(self):
temperature=config.temperature,
top_p=config.top_p,
top_k=config.top_k,
eos_token=config.eos_token,
)
return model

Expand Down

0 comments on commit 799897c

Please sign in to comment.