diff --git a/modules/dvector.py b/modules/dvector.py index 56260b5..9b4bfa9 100644 --- a/modules/dvector.py +++ b/modules/dvector.py @@ -1,6 +1,7 @@ """Build a model for d-vector speaker embedding.""" import abc +import math from typing import List import torch @@ -48,10 +49,18 @@ def embed_utterance(self, utterance: Tensor) -> Tensor: if utterance.ndim == 3: utterance = utterance.squeeze(0) - if utterance.size(1) <= self.seg_len: + if utterance.size(0) <= self.seg_len: embed = self.forward(utterance.unsqueeze(0)).squeeze(0) else: - segments = utterance.unfold(0, self.seg_len, self.seg_len // 2) + # Pad to multiple of hop length + hop_len = self.seg_len // 2 + tgt_len = math.ceil(utterance.size(0) / hop_len) * hop_len + padded = torch.cat( + [utterance, torch.zeros(tgt_len - utterance.size(0), utterance.size(1))] + ) + + segments = padded.unfold(0, self.seg_len, self.seg_len // 2) + segments = segments.transpose(1, 2) # (batch, seg_len, mel_dim) embeds = self.forward(segments) embed = embeds.mean(dim=0) embed = embed.div(embed.norm(p=2, dim=-1, keepdim=True))