Skip to content

Commit

Permalink
Add BatchNorm to supported normalizations.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 561187464
  • Loading branch information
lingvo-bot authored and copybara-github committed Aug 30, 2023
1 parent 74ac32f commit c00a74b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
2 changes: 2 additions & 0 deletions lingvo/core/conformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ def _NormalizeStep(self, theta, inputs, paddings, state0, state1):
inputs, paddings, norm_state1 = self.norm.StreamStep(
theta.norm, inputs, paddings, state0.norm_state)
state1.norm_state = norm_state1
elif isinstance(self.norm, bn_layers.BatchNormLayer):
inputs = self.norm.FProp(theta.norm, inputs)
elif isinstance(self.norm, layers.LayerNorm):
inputs = self.norm.FProp(theta.norm, inputs)
else:
Expand Down
5 changes: 4 additions & 1 deletion lingvo/core/conformer_layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def _GetParams(self, **kwargs):
input_dim=input_dim, is_causal=True, kernel_size=kernel)
if norm_type == 'ln':
p.conv_norm_layer_tpl = lingvo_layers.LayerNorm.Params()
elif norm_type == 'bn':
p.conv_norm_layer_tpl = bn_layers.BatchNormLayer.Params()
else:
p.conv_norm_layer_tpl = bn_layers.GroupNormLayer.Params().Set(
num_groups=2, cumulative=True)
Expand All @@ -90,12 +92,13 @@ def _GetFPropOutput(self, fprop_out):
@parameterized.named_parameters(
('Basic',),
('BasicGN', False, 'gn'),
('BasicBN', False, 'bn'),
('SkipNorm', True),
)
def testLeftContext(self, testonly_skip_norm_layers=False, norm_type='ln'):
with flagsaver.flagsaver(testonly_skip_norm_layers=testonly_skip_norm_layers
), cluster_factory.SetEval(True):
assert norm_type in ('ln', 'gn')
assert norm_type in ('ln', 'gn', 'bn')
input_dim, kernel = 2, 3
self._TestStreamStepHelper(
num_heads=2, input_dim=input_dim, kernel=kernel, norm_type=norm_type)
Expand Down

0 comments on commit c00a74b

Please sign in to comment.