diff --git a/alphafold/model/modules.py b/alphafold/model/modules.py index 554c078c0..353edb207 100644 --- a/alphafold/model/modules.py +++ b/alphafold/model/modules.py @@ -558,8 +558,8 @@ def __call__(self, q_data, m_data, mask, nonbatched_bias=None): q_data: A tensor of queries, shape [batch_size, N_queries, q_channels]. m_data: A tensor of memories from which the keys and values are projected, shape [batch_size, N_keys, m_channels]. - mask: A mask for the attention, shape [batch_size, N_queries, N_keys]. - nonbatched_bias: Shared bias, shape [N_queries, N_keys]. + mask: A mask for the attention, shape [batch_size, N_heads, N_queries, N_keys]. + nonbatched_bias: Shared bias, shape [N_heads, N_queries, N_keys]. Returns: A float32 tensor of shape [batch_size, N_queries, output_dim].