diff --git a/lingvo/core/batch_major_attention.py b/lingvo/core/batch_major_attention.py index 334e98cda..5e9d73c2f 100644 --- a/lingvo/core/batch_major_attention.py +++ b/lingvo/core/batch_major_attention.py @@ -3110,6 +3110,16 @@ def ExtendStep(self, msg = 'Not implemented yet' raise NotImplementedError(msg) + def zero_state(self, batch_size): + """Returns the initial state given the batch size.""" + del batch_size + p = self.params + if p.left_context != 1 or p.right_context != 0: + msg = ('Streaming implementation of chunkwise attention with left context' + 'or right context is not supported yet') + raise NotImplementedError(msg) + return py_utils.NestedMap() + class ChunkwiseSelfAttentionXL(ChunkwiseSelfAttention): """Chunkwise Self Attention with relative position embedding."""