diff --git a/lingvo/core/conformer_layer.py b/lingvo/core/conformer_layer.py index b76cf1b8c..acd783cbe 100644 --- a/lingvo/core/conformer_layer.py +++ b/lingvo/core/conformer_layer.py @@ -377,6 +377,8 @@ def _NormalizeStep(self, theta, inputs, paddings, state0, state1): inputs, paddings, norm_state1 = self.norm.StreamStep( theta.norm, inputs, paddings, state0.norm_state) state1.norm_state = norm_state1 + elif isinstance(self.norm, bn_layers.BatchNormLayer): + inputs = self.norm.FProp(theta.norm, inputs) elif isinstance(self.norm, layers.LayerNorm): inputs = self.norm.FProp(theta.norm, inputs) else: diff --git a/lingvo/core/conformer_layer_test.py b/lingvo/core/conformer_layer_test.py index c8fd2cb9a..085e7a516 100644 --- a/lingvo/core/conformer_layer_test.py +++ b/lingvo/core/conformer_layer_test.py @@ -72,6 +72,8 @@ def _GetParams(self, **kwargs): input_dim=input_dim, is_causal=True, kernel_size=kernel) if norm_type == 'ln': p.conv_norm_layer_tpl = lingvo_layers.LayerNorm.Params() + elif norm_type == 'bn': + p.conv_norm_layer_tpl = bn_layers.BatchNormLayer.Params() else: p.conv_norm_layer_tpl = bn_layers.GroupNormLayer.Params().Set( num_groups=2, cumulative=True) @@ -90,12 +92,13 @@ def _GetFPropOutput(self, fprop_out): @parameterized.named_parameters( ('Basic',), ('BasicGN', False, 'gn'), + ('BasicBN', False, 'bn'), ('SkipNorm', True), ) def testLeftContext(self, testonly_skip_norm_layers=False, norm_type='ln'): with flagsaver.flagsaver(testonly_skip_norm_layers=testonly_skip_norm_layers ), cluster_factory.SetEval(True): - assert norm_type in ('ln', 'gn') + assert norm_type in ('ln', 'gn', 'bn') input_dim, kernel = 2, 3 self._TestStreamStepHelper( num_heads=2, input_dim=input_dim, kernel=kernel, norm_type=norm_type)