From 7434bcd343cbe3c6599776a29a201ab97e3c96ed Mon Sep 17 00:00:00 2001 From: Jason Date: Wed, 22 Jan 2025 13:02:32 -0500 Subject: [PATCH] Add New Transformer Backbone for TTS Models (#11911) * add core functions and classes of Transformer blocks. * completed tests for all classes, and made bugfixes. Now all tests are passed after running, `pytest -s -vvv tests/collections/tts/modules/test_transformer_2501.py` --------- Signed-off-by: Jason Signed-off-by: blisc Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: XuesongYang Co-authored-by: blisc Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: XuesongYang --- .../tts/modules/transformer_2501.py | 710 ++++++++++++++++ .../tts/modules/test_transformer_2501.py | 797 ++++++++++++++++++ 2 files changed, 1507 insertions(+) create mode 100644 nemo/collections/tts/modules/transformer_2501.py create mode 100644 tests/collections/tts/modules/test_transformer_2501.py diff --git a/nemo/collections/tts/modules/transformer_2501.py b/nemo/collections/tts/modules/transformer_2501.py new file mode 100644 index 000000000000..dc5debc04f39 --- /dev/null +++ b/nemo/collections/tts/modules/transformer_2501.py @@ -0,0 +1,710 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from abc import abstractmethod +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from nemo.utils import logging + +# TODO: Move the cache implementation out of the Module class, and pass it as part of the forward so we can reset +# as needed in the inference pipeline. + + +class ConvolutionLayer(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 1, + stride: int = 1, + padding: Optional[int] = None, + dilation: int = 1, + bias: bool = True, + is_causal: bool = False, + ): + """ + A convolutional layer that supports causal convolutions with padding. Replaces the standard MLP layer used in + the original transformer. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (int): Size of the convolving kernel. + stride (int): Stride of the convolution. + padding (Optional[int]): Padding added to both sides of the input. If None, it's calculated automatically. + dilation (int): Spacing between kernel elements. + bias (bool): If True, adds a learnable bias to the output. + is_causal (bool): If True, uses causal convolution. + """ + super().__init__() + + # Setup up padding; should be 0 if set to causal + # If not causal and padding is None, set an appropriate value for padding + self.causal_padding = None + if is_causal: + self.causal_padding = ((kernel_size - 1) * dilation, 0) + if padding is not None: + logging.warning( + f'{self} was initialized with is_causal set to True, and padding set to {padding}. ' + f'The provided padding value will be ignored and set to {self.causal_padding}.' + ) + padding = 0 + elif padding is None: + if kernel_size % 2 == 0: + raise ValueError("`kernel_size` must be odd when `padding` is None.") + else: + padding = int(dilation * (kernel_size - 1) / 2) + + self.is_causal = is_causal + self.kernel_size = kernel_size + self.dilation = dilation + + self.conv = torch.nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + + def forward(self, signal): + if self.is_causal: # TODO: maybe replace with identify rather than keep conditional if in forward + signal = F.pad(signal, self.causal_padding) + + conv_signal = self.conv(signal) + + return conv_signal + + +class PositionwiseConvFF(torch.nn.Module): + def __init__( + self, + d_model: int, + d_ffn: int, + p_dropout: float, + kernel_size: int = 1, + bias: bool = False, + is_causal: bool = True, + non_linearity: Callable = torch.nn.GELU(approximate="tanh"), + ): + """ + Positionwise Convolutional Feed-Forward layer to replace the MLP layer in transformers. + + Module will take the input with d_model hidden state, project it to d_ffn hidden dimension, perform nonlinear + transformation, and project the state back into d_model hidden dimension. Finally, it applied dropout. + + Args: + d_model (int): Input and output dimension of the model. + d_ffn (int): Hidden dimension of the feed-forward network (usually 4 * d_model). + p_dropout (float): Dropout probability. + kernel_size (int): Size of the convolving kernel. + bias (bool): If True, adds a learnable bias to the convolution layers. + is_causal (bool): If True, uses causal convolution. + non_linearity (Callable): Activation function to use (default: GELU). + """ + super().__init__() + # d_ffn is usually 4*d_model + self.d_model = d_model + self.non_linearity = non_linearity + + self.proj = ConvolutionLayer(d_model, d_ffn, bias=bias, kernel_size=kernel_size, is_causal=is_causal) + self.o_net = ConvolutionLayer(d_ffn, d_model, bias=bias, kernel_size=kernel_size, is_causal=is_causal) + self.dropout = torch.nn.Dropout(p_dropout) + + def forward(self, x): + """ + x (B, T, C) + """ + x = self.non_linearity(self.proj(x.transpose(1, 2))) + x = self.dropout(self.o_net(x).transpose(1, 2)) + return x + + +class Attention(torch.nn.Module): + def __init__( + self, + n_heads: int, + d_model: int, + p_dropout: float, + is_causal: bool = True, + ): + """ + Base Attention parent class. Users should not be instantiating this class, but rather use SelfAttention or + CrossAttention classes as appropriate. + Does DotProductionAttention and additionally dropout inside the module. The class does not currently support + RoPE nor ALiBi. + + Args: + n_heads (int): Number of attention heads. + d_model (int): Dimension of the model. + p_dropout (float): Dropout probability. + is_causal (bool): Whether to use causal attention. Only supported when used in SelfAttention. + """ + super().__init__() + assert d_model % n_heads == 0, "d_model % n_head != 0" + self.d_head = d_model // n_heads + self.n_heads = n_heads + self.d_model = d_model + self.scale = self.d_head**-0.5 + self.is_causal = is_causal + self.o_net = torch.nn.Linear(n_heads * self.d_head, d_model, bias=False) + self.dropout = torch.nn.Dropout(p_dropout) + self.use_cache = False + self.cache = self._init_cache() + + @abstractmethod + def compute_qkv_and_mask( + self, + query: torch.Tensor, + query_mask: Optional[torch.Tensor] = None, + memory: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + ): + pass + + @staticmethod + def _init_cache() -> Dict[str, Optional[Union[bool, torch.Tensor]]]: + return { + 'is_initialized': False, + 'self_k': None, + 'self_v': None, + 'cross_kv': None, + 'cross_k': None, + 'cross_v': None, + } + + def reset_cache(self, use_cache: bool = False): + self.use_cache = use_cache + self.cache = self._init_cache() + + def attn_naive( + self, + query: torch.Tensor, + query_mask: Optional[torch.Tensor] = None, + memory: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + attn_prior: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + if self.use_cache: + if self.cache['is_initialized']: + query = query[:, -1:, :] + query_mask = query_mask[:, -1:] if query_mask is not None else None + else: + self.cache['is_initialized'] = True + + # Calls into children classes to compute qkv tensors and mask tensor + q, k, v, mask = self.compute_qkv_and_mask( + query=query, query_mask=query_mask, memory=memory, memory_mask=memory_mask + ) + + # (B, T, nh, dh) -> (B, nh, T, dh) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + B, T, _ = query.shape + attn_score = torch.matmul(q, k.transpose(2, 3)) * self.scale + if mask is not None: + # assumes there's at least one mask + attn_score.masked_fill_(mask == 0, float('-inf')) + if self.is_causal: + attn_score.masked_fill_(self.causal_mask[..., :T, :T] == 0, float('-inf')) + + # attn_prior or square mask or vanilla attention + if attn_prior is not None: + eps = 1e-8 + attn_prior = attn_prior[:, :T] # trim for inference + attn_prior = torch.log(attn_prior + eps) + attn_prior = attn_prior[:, None].repeat(1, self.n_heads, 1, 1) + attn_score_log = F.log_softmax(attn_score, dim=-1) + attn_prior + attn_prob = F.softmax(attn_score_log, dim=-1) + else: + attn_prob = F.softmax(attn_score, dim=-1) + + if mask is not None: + attn_prob = attn_prob.masked_fill(mask == 0, 0.0) + attn_prob = self.dropout(attn_prob) + + y = torch.matmul(attn_prob, v) + y = y.transpose(1, 2).contiguous().view(B, T, -1) + + return y, [attn_prob, attn_score] + + def forward( + self, + query: torch.Tensor, + query_mask: Optional[torch.Tensor] = None, + memory: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + attn_prior: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Forward pass of the Attention module. + + Args: + query (torch.Tensor): Input tensor of shape (B, T1, C). + query_mask (Optional[torch.Tensor]): Mask for query tensor of shape (B, T1). + memory (Optional[torch.Tensor]): Memory tensor for cross-attention of shape (B, T2, C). + memory_mask (Optional[torch.Tensor]): Mask for memory tensor of shape (B, T2). + attn_prior (Optional[torch.Tensor]): Prior attention weights of shape (B, T1, T2). + + Returns: + Tuple[torch.Tensor, List[torch.Tensor]]: + - y: Attention module tensor output of shape (B, T1, C). + - attn_prob: List containing attention probabilities and scores. returned only in attn_naive. + [0]: Attention probabilities used for logging during validation. + [1]: Attention scores used for CTC loss (only in naive attention). + """ + + y, attn_prob = self.attn_naive(query, query_mask, memory, memory_mask, attn_prior) + y = self.dropout(self.o_net(y)) + + return y, attn_prob + + +class SelfAttention(Attention): + def __init__( + self, + n_heads: int, + d_model: int, + p_dropout: float, + is_causal: bool = True, + max_length_causal_mask: int = 4096, + ): + """ + Implements SelfAttention. See parent class for forward implementation. + + Args: + n_heads (int): Number of attention heads. + d_model (int): Dimension of the model. + p_dropout (float): Dropout probability. + is_causal (bool): Whether to use causal attention. Only supported when used in SelfAttention. + max_length_causal_mask (int): Maximum sequence length for Attention module. + """ + super().__init__( + n_heads=n_heads, + d_model=d_model, + p_dropout=p_dropout, + is_causal=is_causal, + ) + if is_causal: + if max_length_causal_mask is None or max_length_causal_mask < 0: + raise ValueError( + "Self Attention was called with is_causal True, but received an inappropriate value" + f"of {max_length_causal_mask} for max_length_causal_mask" + ) + self.register_buffer( + "causal_mask", + torch.tril(torch.ones(max_length_causal_mask, max_length_causal_mask)).view( + 1, 1, max_length_causal_mask, max_length_causal_mask + ), + ) + self.qkv_net = torch.nn.Linear(d_model, 3 * n_heads * self.d_head, bias=False) + + def compute_qkv_and_mask( + self, + query: torch.Tensor, + query_mask: Optional[torch.Tensor] = None, + memory: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + ): + B, T, _ = query.shape + qkv = self.qkv_net(query).reshape(B, T, 3, self.n_heads, self.d_head) + q, k, v = qkv.chunk(3, dim=2) + q, k, v = q.squeeze(2), k.squeeze(2), v.squeeze(2) + if self.use_cache: + if self.cache['self_k'] is not None: + k = torch.cat([self.cache['self_k'], k], dim=1) + v = torch.cat([self.cache['self_v'], v], dim=1) + self.cache['self_k'] = k + self.cache['self_v'] = v + mask = query_mask[:, None, :, None] if query_mask is not None else None + return q, k, v, mask + + +class CrossAttention(Attention): + def __init__( + self, + n_heads: int, + d_model: int, + d_memory: int, + p_dropout: float, + ): + """ + Implements CrossAttention. See parent class for forward implementation. Must be non-causal. + + Args: + n_heads (int): Number of attention heads. + d_model (int): Dimension of the model. + d_memory (int): Dimension of the conditioning / cross-attention input. + p_dropout (float): Dropout probability. + """ + super().__init__( + n_heads=n_heads, + d_model=d_model, + p_dropout=p_dropout, + is_causal=False, + ) + if d_memory is None: + raise ValueError("d_memory must be provided for cross-attention") + self.q_net = torch.nn.Linear(d_model, n_heads * self.d_head, bias=False) + self.kv_net = torch.nn.Linear(d_memory, 2 * n_heads * self.d_head, bias=False) + + def compute_qkv_and_mask( + self, + query: torch.Tensor, + query_mask: Optional[torch.Tensor] = None, + memory: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + ): + Bq, Tq, _ = query.shape + Bkv, Tkv, _ = memory.shape + q = self.q_net(query).reshape(Bq, Tq, self.n_heads, self.d_head) + if self.use_cache and self.cache['cross_kv'] is not None: + kv = self.cache['cross_kv'] + else: + kv = self.kv_net(memory).reshape(Bkv, Tkv, 2, self.n_heads, self.d_head) + + if self.use_cache and self.cache['cross_k'] is not None: + k = self.cache['cross_k'] + v = self.cache['cross_v'] + else: + k, v = kv.chunk(2, dim=2) + k, v = k.squeeze(2), v.squeeze(2) + if self.use_cache: + self.cache['cross_kv'] = kv + self.cache['cross_k'] = k + self.cache['cross_v'] = v + + mask = memory_mask[:, None, None] if memory_mask is not None else None + return q, k, v, mask + + +class TransformerLayer(torch.nn.Module): + def __init__( + self, + d_model: int, + d_ffn: int, + sa_n_heads: int, + kernel_size: int, + p_dropout: float, + has_xattn: bool, + xa_d_memory: Optional[int] = None, + xa_n_heads: Optional[int] = None, + is_causal: bool = True, + apply_norm_to_cond: bool = True, + max_length_causal_mask: int = 4096, + conv_non_linearity: Callable = torch.nn.GELU(approximate="tanh"), + ): + """ + One layer of the Transformer. + Args: + d_model : Model dimension + d_ffn : Feed forward dimension (usually 4*d_model) + sa_n_heads : Number of attention heads used in self-attention + kernel_size : Convolution kernel size for FFN + p_dropout : Dropout probability + has_xattn : Whether to use cross attention + xa_d_memory : Hidden dimension for cross attention + xa_n_heads : Number of attention heads used in cross attention + is_causal : Whether to use causal attention + apply_norm_to_cond : Whether to apply normalization to conditioning tensor + max_length_causal_mask : Maximum length of causal mask + conv_non_linearity : Convolution non-linearity + """ + super().__init__() + self.has_xattn = has_xattn + + self.norm_self = torch.nn.LayerNorm(d_model, bias=False) + self.self_attention = SelfAttention( + n_heads=sa_n_heads, + d_model=d_model, + p_dropout=p_dropout, + max_length_causal_mask=max_length_causal_mask, + is_causal=is_causal, + ) + + if self.has_xattn: + self.apply_norm_to_cond = apply_norm_to_cond + self.norm_xattn_query = torch.nn.LayerNorm(d_model, bias=False) + self.cross_attention = CrossAttention( + n_heads=xa_n_heads, + d_model=d_model, + d_memory=xa_d_memory, + p_dropout=p_dropout, + ) + + if self.apply_norm_to_cond: + self.norm_xattn_memory = torch.nn.LayerNorm(xa_d_memory, bias=False) + + self.norm_pos_ff = torch.nn.LayerNorm(d_model, bias=False) + self.pos_ff = PositionwiseConvFF( + d_model, d_ffn, p_dropout, kernel_size=kernel_size, is_causal=is_causal, non_linearity=conv_non_linearity + ) + + self.use_cache = False + self.cache = self._init_cache() + + @staticmethod + def _init_cache() -> Dict: + return { + 'self_attn_output': None, + 'cross_attn_output': None, + 'memory': None, + } + + def reset_cache(self, use_cache=False): + self.use_cache = use_cache + self.cache = self._init_cache() + self.self_attention.reset_cache(use_cache) + if self.has_xattn: + self.cross_attention.reset_cache(use_cache) + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + cond: Optional[torch.Tensor] = None, + cond_mask: Optional[torch.Tensor] = None, + attn_prior: Optional[torch.Tensor] = None, + ) -> Dict: + """ + Args: + x (B, T1, C): Input tensor + x_mask (B, T1): Multiplicative mask where True means we keep the input, False we zero it out. + Mask for self attention input. + cond (B, T2, C): Conditioning tensor + cond_mask (B, T2): Multiplicative mask where True means we keep the input, False we zero + it out. Mask for cross attention input if it exists. + + Returns dict with keys + output (B, T1, C): Output tensor + attn_probabilities : Attention probabilities + """ + x = x * x_mask.unsqueeze(-1) + x_, s_attn_prob = self.self_attention(query=self.norm_self(x), query_mask=x_mask) + if self.use_cache: + if self.cache['self_attn_output'] is not None: + x_ = torch.cat([self.cache['self_attn_output'], x_], dim=1) + self.cache['self_attn_output'] = x_ + x = x + x_ + + x_attn_prob = None + if self.has_xattn and cond is not None: + x_normed = self.norm_xattn_query(x) + if self.use_cache and self.cache['memory'] is not None: + memory = self.cache['memory'] + else: + memory = self.norm_xattn_memory(cond) if self.apply_norm_to_cond else cond + if self.use_cache: + self.cache['memory'] = memory + + x_res, x_attn_prob = self.cross_attention( + query=x_normed, query_mask=x_mask, memory=memory, memory_mask=cond_mask, attn_prior=attn_prior + ) + if self.use_cache: + if self.cache['cross_attn_output'] is not None: + x_res = torch.cat([self.cache['cross_attn_output'], x_res], dim=1) + self.cache['cross_attn_output'] = x_res + x = x + x_res + + # mlp final projection + x = x + self.pos_ff(self.norm_pos_ff(x)) + x = x * x_mask.unsqueeze(-1) + + return { + 'output': x, + 'attn_probabilities': {'self_attn_probabilities': s_attn_prob, 'cross_attn_probabilities': x_attn_prob}, + } + + +class Transformer(torch.nn.Module): + def __init__( + self, + n_layers: int, + d_model: int, + d_ffn: int, + sa_n_heads: int, + kernel_size: int, + p_dropout: float = 0.0, + p_dropout_out: float = 0.0, + has_xattn: bool = False, + xa_d_memory: Optional[int] = None, + xa_n_heads: Optional[int] = None, + is_causal: bool = True, + apply_norm_to_cond: bool = True, + apply_norm_out: bool = False, + max_length_causal_mask: int = 4096, + use_learnable_pos_emb: bool = False, + conv_non_linearity: Callable = torch.nn.GELU(approximate="tanh"), + ): + """ + Initializes a stack of transformer layers. Can be used for both encoder and decoder. + Set is_causal is True for autoregressive models. Equivalent to TransformerBlock from Megatron-LM + Args: + n_layers : Number of transformer layers + d_model : Model dimension + d_ffn : Feed forward dimension (usually 4*d_model) + sa_n_heads : Number of attention heads used in self-attention + kernel_size : Convolution kernel size for FFN + p_dropout : Dropout probability + p_dropout_out : Dropout probability for output + has_xattn : Whether to use cross attention + xa_d_memory : Hidden dimension for cross attention; required if has_xattn is True + xa_n_heads : Number of attention heads used in cross attention; required if has_xattn is True + is_causal : Whether to make attention and the convolution feedforward networks causal. + apply_norm_to_cond : Whether to apply normalization to conditioning tensor; conditioning tensor being + the input to the memory part of cross-attention. + apply_norm_out : Whether to apply normalization to output + max_length_causal_mask : Maximum length of causal mask + use_learnable_pos_emb : Whether to add a learnable positionable embedding inside the class + conv_non_linearity : Convolution non-linearity + """ + if has_xattn and (xa_d_memory is None or xa_n_heads is None): + raise ValueError("It requires that `xa_d_memory` and `xa_n_heads` are specified when `has_xattn` is True!") + + super().__init__() + self.dropout = torch.nn.Dropout(p_dropout) + self.p_dropout_out = p_dropout_out + + if self.p_dropout_out > 0.0: + self.dropout_out = torch.nn.Dropout(self.p_dropout_out) + else: + self.dropout_out = None + + self.apply_norm_out = apply_norm_out + if self.apply_norm_out: + self.norm_out = torch.nn.LayerNorm(d_model, bias=False) + else: + self.norm_out = None + + self.layers = torch.nn.ModuleList() + for _ in range(n_layers): + self.layers.append( + TransformerLayer( + d_model=d_model, + d_ffn=d_ffn, + sa_n_heads=sa_n_heads, + kernel_size=kernel_size, + p_dropout=p_dropout, + has_xattn=has_xattn, + xa_d_memory=xa_d_memory, + xa_n_heads=xa_n_heads, + is_causal=is_causal, + apply_norm_to_cond=apply_norm_to_cond, + max_length_causal_mask=max_length_causal_mask, + conv_non_linearity=conv_non_linearity, + ) + ) + + self.use_learnable_pos_emb = use_learnable_pos_emb + self.position_embeddings = None + if self.use_learnable_pos_emb: + self.position_embeddings = torch.nn.Embedding(max_length_causal_mask, d_model) + # Apply random uniform init for all layers, except for output layers: The second of the two layers in the MLP + # and the last linear projection in dot product attention. The output layers are scaled depending on the + # number of layers + self.apply(self._init_weights_gpt2) + for name, param in self.named_parameters(): + if 'o_net' in name and name.endswith('weight'): + torch.nn.init.normal_(param, mean=0.0, std=0.02 / math.sqrt(2 * n_layers)) + + def reset_cache(self, use_cache=False): + for layer in self.layers: + layer.reset_cache(use_cache) + + @staticmethod + def _init_weights_gpt2(module): + if isinstance(module, (torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv1d)): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + if isinstance(module, torch.nn.Linear) and module.bias is not None: + torch.nn.init.zeros_(module.bias) + + @staticmethod + def _get_layer_inputs( + idx: int, + cond: Optional[Union[torch.Tensor, List[torch.Tensor]]], + cond_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]], + attn_prior: Optional[Union[torch.Tensor, List[torch.Tensor]]], + multi_encoder_mapping: Optional[List[Optional[int]]], + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + if multi_encoder_mapping is not None: + if multi_encoder_mapping[idx] is None: + return None, None, None + else: + return ( + cond[multi_encoder_mapping[idx]], + cond_mask[multi_encoder_mapping[idx]] if cond_mask is not None else None, + attn_prior[multi_encoder_mapping[idx]] if attn_prior is not None else None, + ) + else: + return cond, cond_mask, attn_prior + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + cond: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + cond_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + attn_prior: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + multi_encoder_mapping: Optional[List[Optional[int]]] = None, + ) -> Dict[str, Union[torch.Tensor, List]]: + """ + Args: + x (B, T1, C): + x_mask (B, T1): Multiplicative mask where True means we keep the input, False we zero it out. + Mostly used in non-causal self-attention to zero out padding values. In causal self-attention, the + causal mask will be used in place of this. + cond (B, T2, C) or list of such tensors (from different encoders) + cond_mask (B, T2): Multiplicative mask where True means we keep the input, False we zero it + out or list of such tensors (from different encoders) output (B, T1, C) + multi_encoder_mapping : None or Same size as n_layers, value indicates which cond input to use + for this layer + + Returns dict with keys: + output (B, T1, C): Output tensor + attn_probabilities : Attention probabilities of each layer + """ + if isinstance(cond, list) and len(self.layers) < len(cond): + raise ValueError( + f"Insufficient Transformer layers for multiple conditionals. Each layer must cross-attend one conditional." + f"Found {len(self.layers)} layers for {len(cond)} conditionals." + ) + + if self.use_learnable_pos_emb: + positions = torch.arange(x.size(1), device=x.device).unsqueeze(0) + x = x + self.position_embeddings(positions) + + attn_probabilities = [] + x = self.dropout(x) + for idx, layer in enumerate(self.layers): + _cond, _cond_mask, _attn_prior = self._get_layer_inputs( + idx, cond, cond_mask, attn_prior, multi_encoder_mapping + ) + out_dict = layer(x, x_mask, _cond, _cond_mask, attn_prior=_attn_prior) + x = out_dict['output'] + attn_probabilities.append(out_dict['attn_probabilities']) + + if self.norm_out is not None: + x = self.norm_out(x) + + if self.dropout_out is not None: + x = self.dropout_out(x) + + return {'output': x, 'attn_probabilities': attn_probabilities} diff --git a/tests/collections/tts/modules/test_transformer_2501.py b/tests/collections/tts/modules/test_transformer_2501.py new file mode 100644 index 000000000000..b7f486028aea --- /dev/null +++ b/tests/collections/tts/modules/test_transformer_2501.py @@ -0,0 +1,797 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import random + +import numpy as np +import pytest +import torch + +from nemo.collections.tts.modules.transformer_2501 import ( + ConvolutionLayer, + CrossAttention, + PositionwiseConvFF, + SelfAttention, + Transformer, + TransformerLayer, +) +from nemo.collections.tts.parts.utils.tts_dataset_utils import beta_binomial_prior_distribution + + +def set_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +@pytest.mark.unit +class TestConvolutionLayer: + @classmethod + def setup_class(cls): + cls.in_channels = 3 + cls.out_channels = 6 + cls.kernel_size = 3 + cls.stride = 1 + cls.dilation = 1 + cls.bias = True + # fmt:off + cls.input_tensor = torch.Tensor( + [[[-1.0542, 0.2675, 0.6963, 0.4738, 0.3910, -0.1505, 0.9171, -0.1528, 3.7269, 0.1779], + [-1.0317, 1.6818, 1.4257, -0.5003, -1.7254, 0.8830, -0.4541, -0.4631, -0.0986, 0.5083], + [-0.3231, -1.0899, 0.5774, 0.1661, 0.9620, -2.3307, -0.6158, -0.3663, 1.2469, -1.0208]]] + ) + # fmt:on + + def test_non_causal_forward(self): + set_seed(0) + layer = ConvolutionLayer( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + dilation=self.dilation, + bias=self.bias, + is_causal=False, + ) + + with torch.no_grad(): + output_tensor = layer(self.input_tensor) + + # fmt:off + expected_output_tensor = torch.Tensor( + [[[ 0.1912, -0.0555, -0.2681, -0.2289, 1.0788, -0.3908, 0.0936, -0.7962, 1.3754, -0.0731], + [-0.3715, -0.6326, -0.9596, -0.0933, -0.1024, -0.2082, -0.5924, 0.1097, -0.5418, -0.0854], + [ 0.3974, 0.4537, 0.3299, 0.1471, -0.5983, -0.8645, 0.0975, 0.6063, -0.6619, -0.9711], + [-0.3048, 0.3862, -0.2462, -0.9903, -0.6189, 0.7389, 0.0785, -1.0870, -1.0018, -1.2426], + [-0.4357, -0.0446, 0.0879, 0.0930, -0.2242, 0.5285, 0.4006, -0.1846, 0.5668, -0.5242], + [-0.0625, 0.4123, -0.6289, -0.4317, 0.1595, 0.0386, -1.0774, 0.2218, 0.8483, -0.4886]]] + ) + # fmt:on + + assert torch.allclose(output_tensor, expected_output_tensor, atol=1e-4) + + def test_causal_forward(self): + set_seed(0) + layer = ConvolutionLayer( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + dilation=self.dilation, + bias=self.bias, + is_causal=True, + ) + + with torch.no_grad(): + output_tensor = layer(self.input_tensor) + + # fmt:off + expected_output_tensor = torch.Tensor( + [[[ 0.4301, 0.1912, -0.0555, -0.2681, -0.2289, 1.0788, -0.3908, 0.0936, -0.7962, 1.3754], + [-0.0501, -0.3715, -0.6326, -0.9596, -0.0933, -0.1024, -0.2082, -0.5924, 0.1097, -0.5418], + [-0.4204, 0.3974, 0.4537, 0.3299, 0.1471, -0.5983, -0.8645, 0.0975, 0.6063, -0.6619], + [ 0.1543, -0.3048, 0.3862, -0.2462, -0.9903, -0.6189, 0.7389, 0.0785, -1.0870, -1.0018], + [-0.1337, -0.4357, -0.0446, 0.0879, 0.0930, -0.2242, 0.5285, 0.4006, -0.1846, 0.5668], + [-0.6127, -0.0625, 0.4123, -0.6289, -0.4317, 0.1595, 0.0386, -1.0774, 0.2218, 0.8483]]] + ) + # fmt:on + + assert torch.allclose(output_tensor, expected_output_tensor, atol=1e-4) + + +class TestPositionwiseConvFF: + @classmethod + def setup_class(cls): + cls.d_model = 3 + cls.d_ffn = 12 + cls.p_dropout = 0.0 + cls.kernel_size = 3 + cls.bias = True + # fmt:off + cls.input_tensor = torch.Tensor( + [[[-1.6682, -0.6069, 0.1321], + [-1.5489, 0.3279, -0.9159], + [-0.7490, 1.8984, 0.5030], + [-0.8130, 0.0058, -1.9979], + [-1.4994, -0.3270, 1.4961], + [-1.6613, -1.7827, 0.8932], + [-0.6276, -1.0770, -0.9971], + [ 1.5424, 1.3590, 1.2287], + [-0.1543, 0.3365, 1.7475], + [-0.1753, 0.4115, 0.0772]]] + ) + # fmt:on + + def test_causal_forward(self): + set_seed(0) + layer = PositionwiseConvFF( + self.d_model, self.d_ffn, self.p_dropout, self.kernel_size, bias=self.bias, is_causal=True + ) + + with torch.no_grad(): + output_tensor = layer(self.input_tensor) + + # fmt:off + expected_output_tensor = torch.Tensor( + [[[-0.1242, -0.0114, 0.0212], + [-0.0441, -0.0555, -0.0795], + [-0.0282, 0.0366, -0.2033], + [-0.0421, 0.0305, -0.2573], + [-0.1877, -0.2492, -0.1638], + [-0.4300, -0.1160, 0.2177], + [-0.1652, -0.3130, -0.3329], + [-0.1737, 0.1133, -0.1802], + [-0.2599, -0.0381, 0.1362], + [-0.0584, -0.2936, 0.2719]]] + ) + # fmt:on + + assert torch.allclose(output_tensor, expected_output_tensor, atol=1e-4) + + def test_non_causal_forward(self): + set_seed(0) + layer = PositionwiseConvFF( + self.d_model, self.d_ffn, self.p_dropout, self.kernel_size, bias=self.bias, is_causal=False + ) + + with torch.no_grad(): + output_tensor = layer(self.input_tensor) + + # fmt:off + expected_output_tensor = torch.Tensor( + [[[-0.0617, -0.0321, -0.1646], + [-0.0421, 0.0305, -0.2573], + [-0.1877, -0.2492, -0.1638], + [-0.4300, -0.1160, 0.2177], + [-0.1652, -0.3130, -0.3329], + [-0.1737, 0.1133, -0.1802], + [-0.2599, -0.0381, 0.1362], + [-0.0584, -0.2936, 0.2719], + [ 0.0361, 0.1110, 0.0441], + [-0.0244, 0.0682, 0.0340]]] + ) + # fmt:on + + assert torch.allclose(output_tensor, expected_output_tensor, atol=1e-4) + + +class TestSelfAttention: + @classmethod + def setup_class(cls): + cls.n_heads = 2 + cls.d_model = 4 + cls.p_dropout = 0.0 + cls.max_length_causal_mask = 6 + # fmt:off + cls.query_tensor = torch.Tensor( + [[[ 0.7239, -0.2362, -0.6610, -1.3759], + [ 1.7381, 0.0793, -1.1241, 0.9529], + [-1.9809, 0.2217, 0.0795, 0.0307], + [ 0.3208, 0.4485, 0.3046, -0.0704], + [-1.4412, 0.8981, 0.1219, 0.0481], + [ 1.7811, -0.1358, 0.6073, 0.8275]]] + ) + # fmt:on + + def test_causal_forward(self): + set_seed(0) + layer = SelfAttention( + self.n_heads, + self.d_model, + self.p_dropout, + is_causal=True, + max_length_causal_mask=self.max_length_causal_mask, + ) + query_mask = torch.ones(1, self.max_length_causal_mask).bool() + with torch.no_grad(): + output_tensor, attn_output = layer(self.query_tensor, query_mask) + + # fmt:off + expected_output_tensor = torch.Tensor( + [[[-0.2569, 0.2782, -0.0348, 0.2480], + [-0.3949, 0.4054, -0.0876, 0.2574], + [-0.1033, 0.0659, -0.0259, 0.1738], + [-0.1485, 0.0995, -0.0415, 0.0684], + [-0.0123, 0.0185, -0.0027, 0.0708], + [-0.0672, 0.0566, -0.0214, -0.0021]]] + ) + expected_attn_prob = torch.Tensor( + [[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.5828, 0.4172, 0.0000, 0.0000, 0.0000, 0.0000], + [0.3216, 0.4260, 0.2525, 0.0000, 0.0000, 0.0000], + [0.2385, 0.2238, 0.2872, 0.2504, 0.0000, 0.0000], + [0.1807, 0.1973, 0.2057, 0.2045, 0.2118, 0.0000], + [0.1159, 0.1388, 0.2010, 0.1721, 0.2161, 0.1562]], + [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.2866, 0.7134, 0.0000, 0.0000, 0.0000, 0.0000], + [0.2799, 0.2472, 0.4729, 0.0000, 0.0000, 0.0000], + [0.2964, 0.2535, 0.2075, 0.2427, 0.0000, 0.0000], + [0.1864, 0.1616, 0.2394, 0.1974, 0.2152, 0.0000], + [0.1666, 0.2030, 0.1391, 0.1649, 0.1546, 0.1719]]]] + ) + expected_attn_score = torch.Tensor( + [[[[ 0.5248, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')], + [-0.1948, -0.5291, float('-inf'), float('-inf'), float('-inf'), float('-inf')], + [-0.0279, 0.2533, -0.2698, float('-inf'), float('-inf'), float('-inf')], + [-0.0508, -0.1145, 0.1350, -0.0020, float('-inf'), float('-inf')], + [-0.0985, -0.0105, 0.0315, 0.0257, 0.0604, float('-inf')], + [-0.3253, -0.1457, 0.2250, 0.0694, 0.2971, -0.0275]], + [[ 0.5075, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')], + [-0.5541, 0.3578, float('-inf'), float('-inf'), float('-inf'), float('-inf')], + [-0.2499, -0.3738, 0.2746, float('-inf'), float('-inf'), float('-inf')], + [ 0.2215, 0.0654, -0.1351, 0.0216, float('-inf'), float('-inf')], + [-0.1011, -0.2439, 0.1488, -0.0438, 0.0425, float('-inf')], + [ 0.0526, 0.2502, -0.1277, 0.0424, -0.0221, 0.0840]]]] + ) + # fmt:on + + assert torch.allclose(output_tensor, expected_output_tensor, atol=1e-4) + assert torch.allclose(attn_output[0], expected_attn_prob, atol=1e-4) + assert torch.allclose(attn_output[1], expected_attn_score, atol=1e-4) + + def test_non_causal_forward(self): + set_seed(0) + layer = SelfAttention( + self.n_heads, + self.d_model, + self.p_dropout, + is_causal=False, + max_length_causal_mask=self.max_length_causal_mask, + ) + query_mask = torch.ones(1, self.max_length_causal_mask).bool() + with torch.no_grad(): + output_tensor, attn_output = layer(self.query_tensor, query_mask) + + # fmt:off + expected_output_tensor = torch.Tensor( + [[[-0.0954, 0.1131, -0.0195, 0.0704], + [-0.0401, 0.0364, -0.0156, 0.0088], + [ 0.0324, 0.0368, 0.0174, 0.0501], + [-0.0633, 0.0610, -0.0176, 0.0180], + [-0.0017, 0.0361, 0.0030, 0.0319], + [-0.0672, 0.0566, -0.0214, -0.0021]]] + ) + expected_attn_prob = torch.Tensor( + [[[[0.2835, 0.1482, 0.1723, 0.1426, 0.1422, 0.1111], + [0.1254, 0.0898, 0.2819, 0.1493, 0.2682, 0.0854], + [0.1549, 0.2051, 0.1216, 0.1663, 0.1294, 0.2228], + [0.1583, 0.1485, 0.1906, 0.1662, 0.1890, 0.1475], + [0.1498, 0.1635, 0.1705, 0.1696, 0.1755, 0.1711], + [0.1159, 0.1388, 0.2010, 0.1721, 0.2161, 0.1562]], + [[0.2536, 0.1945, 0.1079, 0.1628, 0.1233, 0.1579], + [0.0866, 0.2156, 0.1709, 0.1551, 0.1903, 0.1815], + [0.1361, 0.1202, 0.2300, 0.1626, 0.1941, 0.1569], + [0.2038, 0.1744, 0.1427, 0.1669, 0.1488, 0.1634], + [0.1565, 0.1357, 0.2010, 0.1657, 0.1807, 0.1604], + [0.1666, 0.2030, 0.1391, 0.1649, 0.1546, 0.1719]]]] + ) + expected_attn_score = torch.Tensor( + [[[[ 5.2482e-01, -1.2346e-01, 2.7022e-02, -1.6210e-01, -1.6488e-01, -4.1190e-01], + [-1.9484e-01, -5.2910e-01, 6.1538e-01, -2.0263e-02, 5.6540e-01, -5.7875e-01], + [-2.7873e-02, 2.5326e-01, -2.6980e-01, 4.3127e-02, -2.0759e-01, 3.3584e-01], + [-5.0756e-02, -1.1455e-01, 1.3498e-01, -2.0208e-03, 1.2687e-01, -1.2113e-01], + [-9.8482e-02, -1.0463e-02, 3.1513e-02, 2.5712e-02, 6.0431e-02, 3.4494e-02], + [-3.2530e-01, -1.4574e-01, 2.2503e-01, 6.9375e-02, 2.9711e-01, -2.7542e-02]], + [[ 5.0748e-01, 2.4198e-01, -3.4704e-01, 6.4137e-02, -2.1347e-01, 3.3762e-02], + [-5.5410e-01, 3.5778e-01, 1.2559e-01, 2.8689e-02, 2.3316e-01, 1.8558e-01], + [-2.4989e-01, -3.7383e-01, 2.7462e-01, -7.2002e-02, 1.0508e-01, -1.0771e-01], + [ 2.2154e-01, 6.5375e-02, -1.3510e-01, 2.1609e-02, -9.3194e-02, 3.4042e-04], + [-1.0105e-01, -2.4395e-01, 1.4884e-01, -4.3842e-02, 4.2481e-02, -7.6735e-02], + [ 5.2595e-02, 2.5018e-01, -1.2765e-01, 4.2375e-02, -2.2093e-02, 8.4005e-02]]]] + ) + # fmt:on + + assert torch.allclose(output_tensor, expected_output_tensor, atol=1e-4) + assert torch.allclose(attn_output[0], expected_attn_prob, atol=1e-4) + assert torch.allclose(attn_output[1], expected_attn_score, atol=1e-5) + + +class TestCrossAttention: + @classmethod + def setup_class(cls): + cls.n_heads = 2 + cls.d_model = 4 + cls.d_memory = 3 + cls.p_dropout = 0.0 + cls.max_length = 6 + # fmt:off + # shape = (1, cls.max_length, cls.d_model) + cls.query_tensor = torch.Tensor( + [[[0.7352, -0.5871, -0.1204, -2.0200], + [-0.4618, -0.3604, 1.2287, -0.3434], + [0.7838, -0.7646, -1.3349, -0.1538], + [-0.9749, -1.0789, -0.0126, -0.7225], + [2.6929, -0.2091, 2.1242, -1.0123], + [0.5094, -2.0566, 1.3922, -0.2156]]] + ) + # fmt:on + cls.query_mask = torch.ones(1, cls.query_tensor.shape[1]).bool() + # fmt:off + # shape = (1, 5, cls.d_memory) + cls.memory_tensor = torch.Tensor( + [[[ 2.0132e-01, -5.6582e-01, 1.1191e+00], + [-6.2371e-01, -9.3398e-02, -1.3744e+00], + [-9.8265e-01, -8.1742e-01, 4.5611e-01], + [-5.4802e-01, -1.1218e+00, 7.6138e-01], + [-1.9899e+00, -1.7910e-03, 9.0718e-01]]] + ) + # fmt:on + cls.memory_mask = torch.ones(1, cls.memory_tensor.shape[1]).bool() + # shape = (1, cls.query_tensor.shape[1], cls.memory_tensor.shape[1]) + cls.attn_prior = torch.from_numpy( + beta_binomial_prior_distribution( + phoneme_count=cls.memory_tensor.shape[1], mel_count=cls.query_tensor.shape[1] + ) + ).unsqueeze(0) + + def test_forward(self): + set_seed(0) + layer = CrossAttention(self.n_heads, self.d_model, self.d_memory, self.p_dropout) + + with torch.no_grad(): + output_tensor, attn_output = layer( + self.query_tensor, self.query_mask, self.memory_tensor, self.memory_mask, self.attn_prior + ) + + # fmt:off + expected_output_tensor = torch.Tensor( + [[[ 0.2267, -0.2271, 0.0573, -0.0681], + [ 0.2672, -0.1823, 0.0722, -0.0859], + [ 0.3212, -0.2218, 0.0835, -0.0715], + [ 0.3568, -0.2573, 0.0918, -0.0789], + [ 0.3962, -0.4112, 0.0816, -0.1972], + [ 0.3457, -0.4253, 0.0568, -0.2216]]] + ) + expected_attn_prob = torch.Tensor( + [[[[0.4220, 0.4859, 0.0709, 0.0188, 0.0025], + [0.3944, 0.3475, 0.1642, 0.0784, 0.0155], + [0.1335, 0.3448, 0.2794, 0.1752, 0.0671], + [0.0914, 0.3300, 0.2343, 0.2437, 0.1006], + [0.0256, 0.1138, 0.2145, 0.3343, 0.3119], + [0.0117, 0.0617, 0.1112, 0.3354, 0.4800]], + [[0.8045, 0.1024, 0.0661, 0.0242, 0.0028], + [0.4020, 0.2953, 0.1914, 0.0907, 0.0207], + [0.1446, 0.2798, 0.3026, 0.1965, 0.0766], + [0.0718, 0.2151, 0.2778, 0.2719, 0.1634], + [0.0673, 0.0341, 0.1929, 0.4534, 0.2522], + [0.0064, 0.0264, 0.0999, 0.2872, 0.5802]]]] + ) + expected_attn_score = torch.Tensor( + [[[[-0.5044, 0.4476, -0.4961, -0.5728, -0.8010], + [-0.0761, -0.2027, -0.5103, -0.4385, -0.6724], + [-0.1525, 0.2576, 0.0471, -0.0138, 0.0075], + [-0.3084, -0.0058, -0.7538, -0.7144, -1.0604], + [-0.0799, 0.0260, -0.1511, -0.1493, -0.2187], + [-0.1935, -0.3225, -0.9867, -0.8637, -1.3162]], + [[ 0.4704, -0.7801, -0.2374, 0.0126, -0.3430], + [ 0.0623, -0.2461, -0.2380, -0.1743, -0.2658], + [ 0.0104, 0.1313, 0.2096, 0.1834, 0.2217], + [-0.0859, 0.0300, -0.1194, -0.1410, -0.1111], + [ 0.7876, -1.2782, -0.3570, 0.0556, -0.5312], + [ 0.1046, -0.2652, -0.1856, -0.1104, -0.2180]]]] + ) + # fmt:on + + assert torch.allclose(output_tensor, expected_output_tensor, atol=1e-4) + assert torch.allclose(attn_output[0], expected_attn_prob, atol=1e-4) + assert torch.allclose(attn_output[1], expected_attn_score, atol=1e-4) + + +class TestTransformerLayer: + @classmethod + def setup_class(cls): + cls.d_model = 2 + cls.d_ffn = 8 + cls.sa_n_heads = 2 + cls.kernel_size = 3 + cls.p_dropout = 0.0 + cls.max_length_causal_mask = 5 + # fmt:off + # shape = (1, cls.max_length_causal_mask, cls.d_model) + cls.x = torch.Tensor( + [[[ 0.5115, 0.0889], + [-0.8568, -2.9632], + [-1.3728, 0.7325], + [-2.4593, -0.9018], + [ 0.9621, 0.4212]]] + ) + # fmt:on + cls.x_mask = torch.ones(1, cls.max_length_causal_mask).bool() + # fmt:off + # shape = (1, 3, cls.d_model) + cls.cond = torch.Tensor( + [[[ 1.4441, 0.1393], + [ 0.2828, -0.2456], + [-0.3075, 0.6581]]] + ) + # fmt:on + cls.cond_mask = torch.ones(1, cls.cond.shape[1]).bool() + # shape = (1, cls.x.shape[1], cls.cond.shape[1]) + cls.attn_prior = torch.from_numpy( + beta_binomial_prior_distribution(phoneme_count=cls.cond.shape[1], mel_count=cls.x.shape[1]) + ).unsqueeze(0) + + def test_forward_causal_self_attn_and_has_xattn(self): + set_seed(0) + layer = TransformerLayer( + self.d_model, + self.d_ffn, + self.sa_n_heads, + self.kernel_size, + self.p_dropout, + has_xattn=True, + xa_n_heads=2, + xa_d_memory=2, + is_causal=True, + max_length_causal_mask=self.max_length_causal_mask, + ) + + with torch.no_grad(): + output_dict = layer(self.x, self.x_mask, self.cond, self.cond_mask, self.attn_prior) + + # fmt:off + expected_output = { + 'output': torch.Tensor( + [[[ 0.1936, 0.5387], + [-1.0270, -2.5452], + [-1.6884, 0.8765], + [-2.7496, -0.7887], + [ 1.2837, 0.1172]]] + ), + 'attn_probabilities': { + 'self_attn_probabilities': [ + torch.Tensor( + [[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.5000, 0.5000, 0.0000, 0.0000, 0.0000], + [0.3068, 0.3068, 0.3864, 0.0000, 0.0000], + [0.2213, 0.2213, 0.2787, 0.2787, 0.0000], + [0.2180, 0.2180, 0.1730, 0.1730, 0.2180]], + [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.5000, 0.5000, 0.0000, 0.0000, 0.0000], + [0.3237, 0.3237, 0.3527, 0.0000, 0.0000], + [0.2393, 0.2393, 0.2607, 0.2607, 0.0000], + [0.2068, 0.2068, 0.1898, 0.1898, 0.2068]]]] + ), + torch.Tensor( + [[[[0.1154, float('-inf'), float('-inf'), float('-inf'), float('-inf')], + [0.1154, 0.1154, float('-inf'), float('-inf'), float('-inf')], + [-0.1154, -0.1154, 0.1154, float('-inf'), float('-inf')], + [-0.1154, -0.1154, 0.1154, 0.1154, float('-inf')], + [0.1154, 0.1154, -0.1154, -0.1154, 0.1154]], + [[0.0429, float('-inf'), float('-inf'), float('-inf'), float('-inf')], + [0.0429, 0.0429, float('-inf'), float('-inf'), float('-inf')], + [-0.0429, -0.0429, 0.0429, float('-inf'), float('-inf')], + [-0.0429, -0.0429, 0.0429, 0.0429, float('-inf')], + [0.0429, 0.0429, -0.0429, -0.0429, 0.0429]]]] + ) + ], + 'cross_attn_probabilities': [ + torch.Tensor( + [[[[0.7181, 0.2394, 0.0426], + [0.4843, 0.3874, 0.1283], + [0.2753, 0.4129, 0.3118], + [0.1344, 0.3583, 0.5074], + [0.0520, 0.2599, 0.6882]], + [[0.5959, 0.1987, 0.2054], + [0.2837, 0.2270, 0.4893], + [0.3740, 0.5610, 0.0651], + [0.2355, 0.6280, 0.1365], + [0.0108, 0.0542, 0.9349]]]] + ), + torch.Tensor( + [[[[0.0586, 0.0586, -0.0586], + [0.0624, 0.0624, -0.0624], + [-0.0624, -0.0624, 0.0624], + [-0.0624, -0.0624, 0.0624], + [0.0624, 0.0624, -0.0624]], + [[-0.8214, -0.8214, 0.8214], + [-0.8745, -0.8744, 0.8745], + [0.8745, 0.8744, -0.8745], + [0.8745, 0.8744, -0.8745], + [-0.8744, -0.8744, 0.8744]]]] + ) + ] + } + } + # fmt:on + + assert torch.allclose(output_dict["output"], expected_output["output"], atol=1e-4) + for i in range(2): + assert torch.allclose( + output_dict["attn_probabilities"]["self_attn_probabilities"][i], + expected_output["attn_probabilities"]["self_attn_probabilities"][i], + atol=1e-4, + ) + assert torch.allclose( + output_dict["attn_probabilities"]["cross_attn_probabilities"][i], + expected_output["attn_probabilities"]["cross_attn_probabilities"][i], + atol=1e-4, + ) + + +@pytest.mark.unit +class TestTransformer: + @classmethod + def setup_class(cls): + cls.n_layers = 1 + cls.d_model = 4 + cls.d_ffn = 16 + cls.sa_n_heads = 2 + cls.kernel_size = 3 + cls.p_dropout = 0.0 + cls.p_dropout_out = 0.0 + cls.is_causal = True + cls.max_length_causal_mask = 6 + + # fmt:off + cls.input_tensor = torch.Tensor( + [[[ 0.7049, 0.0305, -0.8542, 0.5388], + [-0.5265, -1.3320, 1.5451, 0.4086], + [-2.0546, 0.5259, 0.5995, -0.4078], + [ 0.4530, -0.3918, 2.1403, -0.2062], + [-0.0984, 0.4855, 0.7076, 0.0431], + [-0.4394, -0.6761, 1.7389, -0.9423]]] + ) + # fmt:on + + def test_forward_causal_self_attn_and_no_xattn(self): + set_seed(0) + model = Transformer( + n_layers=self.n_layers, + d_model=self.d_model, + d_ffn=self.d_ffn, + sa_n_heads=self.sa_n_heads, + kernel_size=self.kernel_size, + p_dropout=self.p_dropout, + p_dropout_out=self.p_dropout_out, + has_xattn=False, + is_causal=self.is_causal, + max_length_causal_mask=self.max_length_causal_mask, + ) + + # Check model init + assert torch.isclose(torch.mean(model.layers[0].pos_ff.proj.conv.weight), torch.tensor(0.0), atol=1e-2) + assert torch.isclose(torch.std(model.layers[0].pos_ff.proj.conv.weight), torch.tensor(0.02), atol=1e-2) + assert torch.isclose(torch.mean(model.layers[0].pos_ff.o_net.conv.weight), torch.tensor(0.0), atol=1e-2) + assert torch.isclose( + torch.std(model.layers[0].pos_ff.o_net.conv.weight), torch.tensor(0.02 / math.sqrt(2.0)), atol=1e-3 + ) + + mask_tensor = torch.ones(1, self.max_length_causal_mask).bool() + with torch.no_grad(): + output_dict = model(x=self.input_tensor, x_mask=mask_tensor) + + # fmt:off + expected_output_tensor = { + 'output': torch.Tensor( + [[[0.7047, 0.0305, -0.8555, 0.5402], + [-0.5192, -1.3324, 1.5455, 0.4148], + [-2.0593, 0.5290, 0.5969, -0.4101], + [0.4517, -0.3968, 2.1392, -0.2041], + [-0.1019, 0.4854, 0.7077, 0.0431], + [-0.4458, -0.6789, 1.7447, -0.9447]]] + ), + 'attn_probabilities': [ + { + 'self_attn_probabilities': [ + torch.Tensor( + [[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.4998, 0.5002, 0.0000, 0.0000, 0.0000, 0.0000], + [0.3337, 0.3333, 0.3330, 0.0000, 0.0000, 0.0000], + [0.2498, 0.2500, 0.2501, 0.2501, 0.0000, 0.0000], + [0.2002, 0.2001, 0.2000, 0.1999, 0.1999, 0.0000], + [0.1666, 0.1666, 0.1667, 0.1667, 0.1667, 0.1667]], + [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.5005, 0.4995, 0.0000, 0.0000, 0.0000, 0.0000], + [0.3332, 0.3331, 0.3336, 0.0000, 0.0000, 0.0000], + [0.2507, 0.2494, 0.2510, 0.2489, 0.0000, 0.0000], + [0.2002, 0.1995, 0.2006, 0.1994, 0.2003, 0.0000], + [0.1671, 0.1663, 0.1674, 0.1660, 0.1670, 0.1662]]]] + ), + torch.Tensor( + [[[[-3.4823e-04, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')], + [-7.0210e-04, -5.6984e-05, float('-inf'), float('-inf'), float('-inf'), float('-inf')], + [1.2551e-03, 1.1431e-04, -5.9334e-04, float('-inf'), float('-inf'), float('-inf')], + [-8.1514e-04, 2.6650e-05, 5.2952e-04, 5.6903e-04, float('-inf'), float('-inf')], + [8.0150e-04, 1.4366e-04, -2.7793e-04, -7.0636e-04, -8.2140e-04, float('-inf')], + [-4.6137e-04, 6.2648e-05, 3.6768e-04, 2.8095e-04, 4.9188e-04, 3.6888e-04]], + [[-7.5861e-04, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')], + [8.0373e-04, -1.2745e-03, float('-inf'), float('-inf'), float('-inf'), float('-inf')], + [-4.5038e-04, -6.7648e-04, 7.5978e-04, float('-inf'), float('-inf'), float('-inf')], + [1.5376e-03, -3.6984e-03, 3.0423e-03, -5.4870e-03, float('-inf'), float('-inf')], + [3.7014e-04, -2.7010e-03, 2.4310e-03, -3.2604e-03, 9.3840e-04, float('-inf')], + [1.3868e-03, -3.8372e-03, 3.2144e-03, -5.4860e-03, 5.0273e-04, -4.4343e-03]]]] + ), + ], + 'cross_attn_probabilities': None, + } + ], + } + # fmt:on + + assert output_dict["output"].shape == expected_output_tensor["output"].shape + assert torch.allclose(output_dict["output"], expected_output_tensor["output"], atol=1e-4) + for i in range(2): + assert torch.allclose( + output_dict["attn_probabilities"][0]["self_attn_probabilities"][i], + expected_output_tensor["attn_probabilities"][0]["self_attn_probabilities"][i], + atol=1e-4, + ) + assert output_dict["attn_probabilities"][0]["cross_attn_probabilities"] is None + + def test_forward_causal_self_attn_and_has_xattn(self): + set_seed(0) + model = Transformer( + n_layers=2, + d_model=self.d_model, + d_ffn=self.d_ffn, + sa_n_heads=self.sa_n_heads, + kernel_size=self.kernel_size, + p_dropout=self.p_dropout, + p_dropout_out=self.p_dropout_out, + has_xattn=True, + xa_d_memory=4, + xa_n_heads=2, + is_causal=self.is_causal, + max_length_causal_mask=self.max_length_causal_mask, + ) + + # fmt:off + cond = [ + # shape (1, 3, 4) + torch.Tensor( + [[[-0.7475, 1.1461, 0.7300, 1.4471], + [ 1.8744, -0.1654, 1.2418, -1.6983], + [-0.3123, 0.2320, 0.7457, 1.9868]]] + ), + # shape (1, 5, 4) + torch.Tensor( + [[[-0.6683, -1.2178, 1.3696, 0.9941], + [ 0.0297, -0.1616, 0.1891, 0.0580], + [-1.0771, 0.2547, -1.4023, 0.0971], + [ 1.1132, 0.6311, -0.1449, 0.2351], + [ 0.8920, 2.3663, 0.2248, -0.7298]]] + ) + ] + # fmt:on + + cond_mask = [torch.ones(1, cond[0].shape[1]).bool(), torch.ones(1, cond[1].shape[1]).bool()] + mask_tensor = torch.ones(1, self.max_length_causal_mask).bool() + multi_encoder_mapping = [0, 1] + with torch.no_grad(): + output_dict = model( + x=self.input_tensor, + x_mask=mask_tensor, + cond=cond, + cond_mask=cond_mask, + multi_encoder_mapping=multi_encoder_mapping, + ) + + # fmt:off + expected_output = { + 'output': torch.Tensor( + [[[0.7043, 0.0288, -0.8547, 0.5384], + [-0.5283, -1.3311, 1.5429, 0.4083], + [-2.0560, 0.5259, 0.6020, -0.4099], + [0.4554, -0.3829, 2.1433, -0.2036], + [-0.0986, 0.4794, 0.7067, 0.0432], + [-0.4392, -0.6772, 1.7428, -0.9393]]] + ), + 'attn_probabilities': [ + { + 'self_attn_probabilities': [ + torch.Tensor( + [[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.4989, 0.5011, 0.0000, 0.0000, 0.0000, 0.0000], + [0.3331, 0.3332, 0.3336, 0.0000, 0.0000, 0.0000], + [0.2495, 0.2496, 0.2504, 0.2505, 0.0000, 0.0000], + [0.1998, 0.1994, 0.2002, 0.2001, 0.2005, 0.0000], + [0.1662, 0.1662, 0.1668, 0.1668, 0.1671, 0.1669]], + [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.5008, 0.4992, 0.0000, 0.0000, 0.0000, 0.0000], + [0.3334, 0.3336, 0.3331, 0.0000, 0.0000, 0.0000], + [0.2504, 0.2500, 0.2496, 0.2500, 0.0000, 0.0000], + [0.2001, 0.2002, 0.2000, 0.1999, 0.1998, 0.0000], + [0.1670, 0.1667, 0.1665, 0.1667, 0.1665, 0.1666]]]] + ), + ], + 'cross_attn_probabilities': [ + torch.Tensor( + [[[[0.3331, 0.3336, 0.3334], + [0.3335, 0.3331, 0.3334], + [0.3336, 0.3331, 0.3332], + [0.3335, 0.3332, 0.3334], + [0.3336, 0.3331, 0.3333], + [0.3335, 0.3332, 0.3333]], + [[0.3333, 0.3335, 0.3332], + [0.3334, 0.3335, 0.3331], + [0.3333, 0.3330, 0.3337], + [0.3334, 0.3335, 0.3332], + [0.3333, 0.3331, 0.3336], + [0.3334, 0.3334, 0.3333]]]] + ) + ] + }, + { + 'self_attn_probabilities': [ + torch.Tensor( + [[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.5005, 0.4995, 0.0000, 0.0000, 0.0000, 0.0000], + [0.3336, 0.3330, 0.3334, 0.0000, 0.0000, 0.0000], + [0.2503, 0.2499, 0.2498, 0.2500, 0.0000, 0.0000], + [0.2002, 0.1999, 0.2000, 0.2000, 0.2000, 0.0000], + [0.1669, 0.1666, 0.1666, 0.1667, 0.1666, 0.1666]], + [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.5001, 0.4999, 0.0000, 0.0000, 0.0000, 0.0000], + [0.3329, 0.3330, 0.3340, 0.0000, 0.0000, 0.0000], + [0.2499, 0.2498, 0.2505, 0.2499, 0.0000, 0.0000], + [0.1997, 0.1997, 0.2003, 0.2000, 0.2004, 0.0000], + [0.1665, 0.1664, 0.1669, 0.1666, 0.1669, 0.1667]]]] + ), + ], + 'cross_attn_probabilities': [ + torch.Tensor( + [[[[0.1999, 0.1998, 0.2002, 0.1999, 0.2001], + [0.2000, 0.1997, 0.2004, 0.1998, 0.2002], + [0.2001, 0.2000, 0.2001, 0.1999, 0.2000], + [0.2000, 0.2001, 0.1998, 0.2001, 0.1999], + [0.2001, 0.2002, 0.1998, 0.2001, 0.1998], + [0.2000, 0.2002, 0.1998, 0.2001, 0.1999]], + [[0.1998, 0.1998, 0.2001, 0.2004, 0.2000], + [0.2003, 0.2003, 0.1998, 0.1995, 0.2001], + [0.2003, 0.2003, 0.1998, 0.1995, 0.2001], + [0.2001, 0.2001, 0.2000, 0.1998, 0.2000], + [0.2002, 0.2001, 0.1999, 0.1997, 0.2000], + [0.2002, 0.2001, 0.2000, 0.1997, 0.2000]]]] + ), + ], + } + ], + } + # fmt:on + + assert torch.allclose(output_dict["output"], expected_output["output"], atol=1e-4) + for i in range(2): + assert torch.allclose( + output_dict["attn_probabilities"][i]["self_attn_probabilities"][0], + expected_output["attn_probabilities"][i]["self_attn_probabilities"][0], + atol=1e-4, + ) + assert torch.allclose( + output_dict["attn_probabilities"][i]["cross_attn_probabilities"][0], + expected_output["attn_probabilities"][i]["cross_attn_probabilities"][0], + atol=1e-4, + )