Skip to content

Commit

Permalink
Add Test Suite (#237)
Browse files Browse the repository at this point in the history
* Add testing configuration and diarization tests

* Add aggregation tests

* Add end-to-end test for a sample wav file and several latencies

* Fix rounding error in min latency unit test

* Improve CI workflows and add pytest. Fix matplotlib colormap error

* Install missing dependencies in CI

* Add onnxruntime as a test dependency

* Update expected timestamp tolerance to up to 50ms
  • Loading branch information
juanmc2005 committed May 25, 2024
1 parent 9e6c2e9 commit 467997d
Show file tree
Hide file tree
Showing 18 changed files with 513 additions and 5 deletions.
35 changes: 35 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: Pytest

on:
pull_request:
branches:
- main
- develop

jobs:
test:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.10'

- name: Install apt dependencies
run: |
sudo add-apt-repository ppa:savoury1/ffmpeg4
sudo apt-get update
sudo apt-get -y install ffmpeg libportaudio2=19.6.0-1.1
- name: Install pip dependencies
run: |
python -m pip install --upgrade pip
pip install .[tests]
- name: Run tests
run: |
pytest
7 changes: 4 additions & 3 deletions .github/workflows/quick-runs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install .
pip install onnxruntime==1.18.0
- name: Crop audio and rttm
run: |
sox audio/ES2002a_long.wav audio/ES2002a.wav trim 00:40 00:30
Expand All @@ -50,10 +51,10 @@ jobs:
rm rttms/ES2002b_long.rttm
- name: Run stream
run: |
diart.stream audio/ES2002a.wav --output trash --no-plot --hf-token ${{ secrets.HUGGINGFACE }}
diart.stream audio/ES2002a.wav --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx --output trash --no-plot
- name: Run benchmark
run: |
diart.benchmark audio --reference rttms --batch-size 4 --hf-token ${{ secrets.HUGGINGFACE }}
diart.benchmark audio --reference rttms --batch-size 4 --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx
- name: Run tuning
run: |
diart.tune audio --reference rttms --batch-size 4 --num-iter 2 --output trash --hf-token ${{ secrets.HUGGINGFACE }}
diart.tune audio --reference rttms --batch-size 4 --num-iter 2 --output trash --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx
Binary file added assets/models/embedding_uint8.onnx
Binary file not shown.
Binary file added assets/models/segmentation_uint8.onnx
Binary file not shown.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy>=1.20.2
matplotlib>=3.3.3
matplotlib>=3.3.3,<3.6.0
rx>=3.2.0
scipy>=1.6.0
sounddevice>=0.4.2
Expand Down
7 changes: 6 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package_dir=
packages=find:
install_requires=
numpy>=1.20.2
matplotlib>=3.3.3
matplotlib>=3.3.3,<3.6.0
rx>=3.2.0
scipy>=1.6.0
sounddevice>=0.4.2
Expand All @@ -41,6 +41,11 @@ install_requires=
websocket-client>=0.58.0
rich>=12.5.1

[options.extras_require]
tests=
pytest>=7.4.0,<8.0.0
onnxruntime==1.18.0

[options.packages.find]
where=src

Expand Down
48 changes: 48 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import random

import pytest
import torch

from diart.models import SegmentationModel, EmbeddingModel


class DummySegmentationModel:
def to(self, device):
pass

def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
assert waveform.ndim == 3

batch_size, num_channels, num_samples = waveform.shape
num_frames = random.randint(250, 500)
num_speakers = random.randint(3, 5)

return torch.rand(batch_size, num_frames, num_speakers)


class DummyEmbeddingModel:
def to(self, device):
pass

def __call__(self, waveform: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
assert waveform.ndim == 3
assert weights.ndim == 2

batch_size, num_channels, num_samples = waveform.shape
batch_size_weights, num_frames = weights.shape

assert batch_size == batch_size_weights

embedding_dim = random.randint(128, 512)

return torch.randn(batch_size, embedding_dim)


@pytest.fixture(scope="session")
def segmentation_model() -> SegmentationModel:
return SegmentationModel(DummySegmentationModel)


@pytest.fixture(scope="session")
def embedding_model() -> EmbeddingModel:
return EmbeddingModel(DummyEmbeddingModel)
Binary file added tests/data/audio/sample.wav
Binary file not shown.
13 changes: 13 additions & 0 deletions tests/data/rttm/latency_0.5.rttm
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
SPEAKER sample 1 6.675 0.533 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 7.625 1.883 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 9.508 1.000 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 10.508 0.567 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 10.625 4.133 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 14.325 3.733 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 18.058 3.450 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 18.325 0.183 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 21.508 0.017 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 21.775 0.233 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 22.008 6.633 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 28.508 1.500 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 29.958 0.050 <NA> <NA> speaker0 <NA> <NA>
13 changes: 13 additions & 0 deletions tests/data/rttm/latency_1.rttm
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
SPEAKER sample 1 6.708 0.450 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 7.625 1.383 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 9.008 1.500 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 10.008 1.067 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 10.592 4.200 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 14.308 3.700 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 18.042 3.250 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 18.508 0.033 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 21.108 0.383 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 21.508 0.033 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 21.775 6.817 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 28.008 2.000 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 29.975 0.033 <NA> <NA> speaker0 <NA> <NA>
10 changes: 10 additions & 0 deletions tests/data/rttm/latency_2.rttm
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
SPEAKER sample 1 6.725 0.433 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 7.592 0.817 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 8.475 1.617 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 9.892 1.150 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 10.625 4.133 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 14.292 3.667 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 18.008 3.533 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 18.225 0.283 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 21.758 6.867 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 27.875 2.133 <NA> <NA> speaker1 <NA> <NA>
10 changes: 10 additions & 0 deletions tests/data/rttm/latency_3.rttm
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
SPEAKER sample 1 6.725 0.433 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 7.625 0.467 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 8.008 2.050 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 9.875 1.167 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 10.592 4.167 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 14.292 3.667 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 17.992 3.550 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 18.192 0.367 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 21.758 6.833 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 27.825 2.183 <NA> <NA> speaker1 <NA> <NA>
10 changes: 10 additions & 0 deletions tests/data/rttm/latency_4.rttm
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
SPEAKER sample 1 6.742 0.400 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 7.625 0.650 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 8.092 1.950 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 9.875 1.167 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 10.575 4.183 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 14.308 3.667 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 17.992 3.550 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 18.208 0.333 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 21.758 6.817 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 27.808 2.200 <NA> <NA> speaker1 <NA> <NA>
10 changes: 10 additions & 0 deletions tests/data/rttm/latency_5.rttm
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
SPEAKER sample 1 6.742 0.383 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 7.625 0.667 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 8.092 1.967 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 9.875 1.167 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 10.558 4.200 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 14.308 3.667 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 17.992 3.550 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 18.208 0.317 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 21.758 6.817 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 27.808 2.200 <NA> <NA> speaker1 <NA> <NA>
54 changes: 54 additions & 0 deletions tests/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np
import pytest
from pyannote.core import SlidingWindow, SlidingWindowFeature

from diart.blocks.aggregation import (
AggregationStrategy,
HammingWeightedAverageStrategy,
FirstOnlyStrategy,
AverageStrategy,
DelayedAggregation,
)


def test_strategy_build():
strategy = AggregationStrategy.build("mean")
assert isinstance(strategy, AverageStrategy)

strategy = AggregationStrategy.build("hamming")
assert isinstance(strategy, HammingWeightedAverageStrategy)

strategy = AggregationStrategy.build("first")
assert isinstance(strategy, FirstOnlyStrategy)

with pytest.raises(Exception):
AggregationStrategy.build("invalid")


def test_aggregation():
duration = 5
frames = 500
step = 0.5
speakers = 2
start_time = 10
resolution = duration / frames

dagg1 = DelayedAggregation(step=step, latency=2, strategy="mean")
dagg2 = DelayedAggregation(step=step, latency=2, strategy="hamming")
dagg3 = DelayedAggregation(step=step, latency=2, strategy="first")

for dagg in [dagg1, dagg2, dagg3]:
assert dagg.num_overlapping_windows == 4

buffers = [
SlidingWindowFeature(
np.random.rand(frames, speakers),
SlidingWindow(
start=(i + start_time) * step, duration=resolution, step=resolution
),
)
for i in range(dagg1.num_overlapping_windows)
]

for dagg in [dagg1, dagg2, dagg3]:
assert dagg(buffers).data.shape == (51, 2)
Loading

0 comments on commit 467997d

Please sign in to comment.