Skip to content

Commit

Permalink
Allow a scalar tasks tensor in MultiTaskProjection.
Browse files Browse the repository at this point in the history
If `tasks` is unbatched, the same task is used for all batches of the `inputs` tensor.

PiperOrigin-RevId: 595498478
  • Loading branch information
lingvo-bot authored and copybara-github committed Jan 3, 2024
1 parent ab71210 commit 17d1b70
Showing 1 changed file with 32 additions and 16 deletions.
48 changes: 32 additions & 16 deletions lingvo/core/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6946,8 +6946,10 @@ def MultiTaskProjection(
inputs: input tensor, size [batch, input_dim] or [batch_dim, time_dim,
input_dim]
tasks: An int32 tensor containing the task ID for each input. Tensor size is
[batch_dim] or [batch_dim, time_dim] (allowed only when inputs also has a
time dimension), no elements are larger than num_tasks.
[] (scalar) or [batch_dim] or [batch_dim, time_dim], no elements are
larger than num_tasks. Time dimension is only allowed if `inputs` also has
a time dimension. If a dimension is missing, the same task is used for the
different batch and/or time "slots" of the `inputs` tensor.
einsum_order: the algorithm to use, either 'select_and_multiply' or
'multiply_and_select'.
quant_layer: QuantizableLayer used for AQT (pass `self`)
Expand Down Expand Up @@ -6975,20 +6977,25 @@ def MultiTaskProjection(
inputs = HasShape(inputs, [-1, -1, input_dim])
batch_size, time_size = GetShape(inputs, 2)
t_input = 't'
if GetRank(tasks) == 1:
if GetRank(tasks) == 0:
b_task = ''
t_task = ''
elif GetRank(tasks) == 1:
tasks = HasShape(tasks, [batch_size])
b_task = 'b'
t_task = ''
else:
assert time_size is not None
tasks = HasShape(tasks, [batch_size, time_size])
b_task = 'b'
t_task = 't'

# [batch, num_tasks] or [batch, time, num_tasks]
# [num_tasks] or [batch, num_tasks] or [batch, time, num_tasks]
tasks_onehot = tf.one_hot(tasks, num_tasks, axis=-1, dtype=inputs.dtype)

# Einsum axis names:
# b - batch
# t - time (t_input and t_task, if corresponding tensor has a time dimension)
# b - batch (b_task, if the corresponding tensor has batch dimension)
# t - time (t_input and t_task, if corresponding tensors have time dimension)
# k - task
# i - input_dim
# o - output_dim
Expand All @@ -6998,15 +7005,15 @@ def MultiTaskProjection(
weights = quant_layer.QWeight(weights, domain=w_q_domain)
weights = quant_layer.ToAqtWeight(w_q_name, weights, feature_axis=-1)
# select..
# [batch, {time,} input_dim, output_dim]
# [{batch,} {time,} input_dim, output_dim]
selected_weights = tf.einsum(
f'b{t_task}k,kio->b{t_task}io', tasks_onehot, weights
f'{b_task}{t_task}k,kio->{b_task}{t_task}io', tasks_onehot, weights
)
selected_weights = quant_layer.FromAqtWeight(w_q_name, selected_weights)
# .. and multiply
# [batch, {time,} output_dim]
out = tf.einsum(
f'b{t_input}i,b{t_task}io->b{t_input}o', inputs, selected_weights
f'b{t_input}i,{b_task}{t_task}io->b{t_input}o', inputs, selected_weights
)
elif einsum_order == 'multiply_and_select':
# multiply..
Expand All @@ -7015,17 +7022,26 @@ def MultiTaskProjection(
# .. and select
# [batch, {time,} output_dim]
out = tf.einsum(
f'b{t_input}ko,b{t_task}k->b{t_input}o', all_projected, tasks_onehot
f'b{t_input}ko,{b_task}{t_task}k->b{t_input}o',
all_projected,
tasks_onehot,
)
else:
raise ValueError(
'einsum_order must be select_and_multiply or multiply_and_select.'
)
if biases is not None:
# [batch, {time,} output_dim]
bias = tf.einsum(f'b{t_task}k,ko->b{t_task}o', tasks_onehot, biases)
if GetRank(out) == GetRank(bias):
out += bias
else:
out += tf.expand_dims(bias, 1)
# [{batch,} {time,} output_dim]
bias = tf.einsum(
f'{b_task}{t_task}k,ko->{b_task}{t_task}o', tasks_onehot, biases
)

# If `out` has time dimension (`bto`), and `tasks` has batch but no time
# (`bo`), then we need to expand the bias on the second dimension for
# broadcasting. All other combinations are already broadcastable. (It is not
# valid for bias to have time dimension without inputs also having time,
# checked above.)
if t_input and b_task and not t_task:
bias = tf.expand_dims(bias, 1)
out += bias
return out

0 comments on commit 17d1b70

Please sign in to comment.