diff --git a/lingvo/core/attention.py b/lingvo/core/attention.py index 817c394bc..638edcdfc 100644 --- a/lingvo/core/attention.py +++ b/lingvo/core/attention.py @@ -1697,6 +1697,18 @@ def PackSource( # [time_steps, batch_size, source_dim] source_vecs = py_utils.HasRank(source_vecs, 3) time_steps, batch_size = py_utils.GetShape(source_vecs, 2) + + # calculate time_steps and batch_size * num_heads, to be used as time and + # batch size for the internal attention, but avoid dynamic size calculations + # whenever possible. + if source_vecs.shape.ndims is None or source_vecs.shape[1] is None: + # batch_size is dynamic; avoid multiplication with num_heads + time_steps_for_internal_attn = time_steps + batch_size_for_internal_attn = -1 + else: + time_steps_for_internal_attn = -1 + batch_size_for_internal_attn = batch_size * num_heads + # [time_steps, batch_size, context_dim] source_contexts = py_utils.HasShape( source_contexts, [time_steps, batch_size, -1] @@ -1739,8 +1751,8 @@ def PackSource( source_vecs = tf.reshape( source_vecs, [ - -1, - batch_size * num_heads, + time_steps_for_internal_attn, + batch_size_for_internal_attn, symbolic.ToStatic(p.hidden_dim // num_heads), ], ) @@ -1781,7 +1793,11 @@ def PackSource( # => [time_steps, batch_size * num_heads, context_dim / num_heads] source_contexts = tf.reshape( source_contexts, - [-1, batch_size * num_heads, context_dim // num_heads], + [ + time_steps_for_internal_attn, + batch_size_for_internal_attn, + context_dim // num_heads, + ], ) source_contexts = gshard_utils.MeshSplit( source_contexts, p.device_mesh, p.activation_split_dims_mapping @@ -1798,7 +1814,8 @@ def PackSource( source_padding = tf.tile(source_padding, [1, 1, num_heads]) # => [time_steps, batch_size * num_heads] source_padding = tf.reshape( - source_padding, [-1, batch_size * num_heads] + source_padding, + [time_steps_for_internal_attn, batch_size_for_internal_attn], ) with tf.name_scope('segment_id'): @@ -1809,7 +1826,8 @@ def PackSource( source_segment_id = tf.tile(source_segment_id, [1, 1, num_heads]) # => [time_steps, batch_size * num_heads] source_segment_id = tf.reshape( - source_segment_id, [-1, batch_size * num_heads] + source_segment_id, + [time_steps_for_internal_attn, batch_size_for_internal_attn], ) return self.atten.PackSource(