Skip to content

Commit

Permalink
feat: add "soft" option to Powerset.to_multilabel conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin committed Oct 22, 2023
1 parent 03f8265 commit 49b89e2
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions pyannote/audio/utils/powerset.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,26 +84,32 @@ def build_cardinality(self) -> torch.Tensor:
powerset_k += 1
return cardinality

def to_multilabel(self, powerset: torch.Tensor) -> torch.Tensor:
"""Convert predictions from (soft) powerset to (hard) multi-label
def to_multilabel(self, powerset: torch.Tensor, soft: bool = False) -> torch.Tensor:
"""Convert predictions from powerset to multi-label
Parameter
---------
powerset : (batch_size, num_frames, num_powerset_classes) torch.Tensor
Soft predictions in "powerset" space.
soft : bool, optional
Return soft multi-label predictions. Defaults to False (i.e. hard predictions)
Assumes that `powerset` are "logits" (not "probabilities").
Returns
-------
multi_label : (batch_size, num_frames, num_classes) torch.Tensor
Hard predictions in "multi-label" space.
Predictions in "multi-label" space.
"""

hard_powerset = torch.nn.functional.one_hot(
torch.argmax(powerset, dim=-1),
self.num_powerset_classes,
).float()
if soft:
powerset_probs = torch.exp(powerset)
else:
powerset_probs = torch.nn.functional.one_hot(
torch.argmax(powerset, dim=-1),
self.num_powerset_classes,
).float()

return torch.matmul(hard_powerset, self.mapping)
return torch.matmul(powerset_probs, self.mapping)

def forward(self, powerset: torch.Tensor) -> torch.Tensor:
"""Alias for `to_multilabel`"""
Expand Down

0 comments on commit 49b89e2

Please sign in to comment.