-
-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add testing configuration and diarization tests * Add aggregation tests * Add end-to-end test for a sample wav file and several latencies * Fix rounding error in min latency unit test * Improve CI workflows and add pytest. Fix matplotlib colormap error * Install missing dependencies in CI * Add onnxruntime as a test dependency * Update expected timestamp tolerance to up to 50ms
- Loading branch information
1 parent
9e6c2e9
commit 467997d
Showing
18 changed files
with
513 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
name: Pytest | ||
|
||
on: | ||
pull_request: | ||
branches: | ||
- main | ||
- develop | ||
|
||
jobs: | ||
test: | ||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- name: Checkout code | ||
uses: actions/checkout@v3 | ||
|
||
- name: Set up Python | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: '3.10' | ||
|
||
- name: Install apt dependencies | ||
run: | | ||
sudo add-apt-repository ppa:savoury1/ffmpeg4 | ||
sudo apt-get update | ||
sudo apt-get -y install ffmpeg libportaudio2=19.6.0-1.1 | ||
- name: Install pip dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install .[tests] | ||
- name: Run tests | ||
run: | | ||
pytest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
numpy>=1.20.2 | ||
matplotlib>=3.3.3 | ||
matplotlib>=3.3.3,<3.6.0 | ||
rx>=3.2.0 | ||
scipy>=1.6.0 | ||
sounddevice>=0.4.2 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import random | ||
|
||
import pytest | ||
import torch | ||
|
||
from diart.models import SegmentationModel, EmbeddingModel | ||
|
||
|
||
class DummySegmentationModel: | ||
def to(self, device): | ||
pass | ||
|
||
def __call__(self, waveform: torch.Tensor) -> torch.Tensor: | ||
assert waveform.ndim == 3 | ||
|
||
batch_size, num_channels, num_samples = waveform.shape | ||
num_frames = random.randint(250, 500) | ||
num_speakers = random.randint(3, 5) | ||
|
||
return torch.rand(batch_size, num_frames, num_speakers) | ||
|
||
|
||
class DummyEmbeddingModel: | ||
def to(self, device): | ||
pass | ||
|
||
def __call__(self, waveform: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: | ||
assert waveform.ndim == 3 | ||
assert weights.ndim == 2 | ||
|
||
batch_size, num_channels, num_samples = waveform.shape | ||
batch_size_weights, num_frames = weights.shape | ||
|
||
assert batch_size == batch_size_weights | ||
|
||
embedding_dim = random.randint(128, 512) | ||
|
||
return torch.randn(batch_size, embedding_dim) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def segmentation_model() -> SegmentationModel: | ||
return SegmentationModel(DummySegmentationModel) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def embedding_model() -> EmbeddingModel: | ||
return EmbeddingModel(DummyEmbeddingModel) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
SPEAKER sample 1 6.675 0.533 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 7.625 1.883 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 9.508 1.000 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 10.508 0.567 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 10.625 4.133 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 14.325 3.733 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 18.058 3.450 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 18.325 0.183 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 21.508 0.017 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 21.775 0.233 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 22.008 6.633 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 28.508 1.500 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 29.958 0.050 <NA> <NA> speaker0 <NA> <NA> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
SPEAKER sample 1 6.708 0.450 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 7.625 1.383 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 9.008 1.500 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 10.008 1.067 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 10.592 4.200 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 14.308 3.700 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 18.042 3.250 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 18.508 0.033 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 21.108 0.383 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 21.508 0.033 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 21.775 6.817 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 28.008 2.000 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 29.975 0.033 <NA> <NA> speaker0 <NA> <NA> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
SPEAKER sample 1 6.725 0.433 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 7.592 0.817 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 8.475 1.617 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 9.892 1.150 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 10.625 4.133 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 14.292 3.667 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 18.008 3.533 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 18.225 0.283 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 21.758 6.867 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 27.875 2.133 <NA> <NA> speaker1 <NA> <NA> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
SPEAKER sample 1 6.725 0.433 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 7.625 0.467 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 8.008 2.050 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 9.875 1.167 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 10.592 4.167 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 14.292 3.667 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 17.992 3.550 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 18.192 0.367 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 21.758 6.833 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 27.825 2.183 <NA> <NA> speaker1 <NA> <NA> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
SPEAKER sample 1 6.742 0.400 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 7.625 0.650 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 8.092 1.950 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 9.875 1.167 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 10.575 4.183 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 14.308 3.667 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 17.992 3.550 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 18.208 0.333 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 21.758 6.817 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 27.808 2.200 <NA> <NA> speaker1 <NA> <NA> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
SPEAKER sample 1 6.742 0.383 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 7.625 0.667 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 8.092 1.967 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 9.875 1.167 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 10.558 4.200 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 14.308 3.667 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 17.992 3.550 <NA> <NA> speaker1 <NA> <NA> | ||
SPEAKER sample 1 18.208 0.317 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 21.758 6.817 <NA> <NA> speaker0 <NA> <NA> | ||
SPEAKER sample 1 27.808 2.200 <NA> <NA> speaker1 <NA> <NA> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import numpy as np | ||
import pytest | ||
from pyannote.core import SlidingWindow, SlidingWindowFeature | ||
|
||
from diart.blocks.aggregation import ( | ||
AggregationStrategy, | ||
HammingWeightedAverageStrategy, | ||
FirstOnlyStrategy, | ||
AverageStrategy, | ||
DelayedAggregation, | ||
) | ||
|
||
|
||
def test_strategy_build(): | ||
strategy = AggregationStrategy.build("mean") | ||
assert isinstance(strategy, AverageStrategy) | ||
|
||
strategy = AggregationStrategy.build("hamming") | ||
assert isinstance(strategy, HammingWeightedAverageStrategy) | ||
|
||
strategy = AggregationStrategy.build("first") | ||
assert isinstance(strategy, FirstOnlyStrategy) | ||
|
||
with pytest.raises(Exception): | ||
AggregationStrategy.build("invalid") | ||
|
||
|
||
def test_aggregation(): | ||
duration = 5 | ||
frames = 500 | ||
step = 0.5 | ||
speakers = 2 | ||
start_time = 10 | ||
resolution = duration / frames | ||
|
||
dagg1 = DelayedAggregation(step=step, latency=2, strategy="mean") | ||
dagg2 = DelayedAggregation(step=step, latency=2, strategy="hamming") | ||
dagg3 = DelayedAggregation(step=step, latency=2, strategy="first") | ||
|
||
for dagg in [dagg1, dagg2, dagg3]: | ||
assert dagg.num_overlapping_windows == 4 | ||
|
||
buffers = [ | ||
SlidingWindowFeature( | ||
np.random.rand(frames, speakers), | ||
SlidingWindow( | ||
start=(i + start_time) * step, duration=resolution, step=resolution | ||
), | ||
) | ||
for i in range(dagg1.num_overlapping_windows) | ||
] | ||
|
||
for dagg in [dagg1, dagg2, dagg3]: | ||
assert dagg(buffers).data.shape == (51, 2) |
Oops, something went wrong.