diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml new file mode 100644 index 00000000..067b7907 --- /dev/null +++ b/.github/workflows/pytest.yml @@ -0,0 +1,35 @@ +name: Pytest + +on: + pull_request: + branches: + - main + - develop + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: '3.10' + + - name: Install apt dependencies + run: | + sudo add-apt-repository ppa:savoury1/ffmpeg4 + sudo apt-get update + sudo apt-get -y install ffmpeg libportaudio2=19.6.0-1.1 + + - name: Install pip dependencies + run: | + python -m pip install --upgrade pip + pip install .[tests] + + - name: Run tests + run: | + pytest diff --git a/.github/workflows/quick-runs.yml b/.github/workflows/quick-runs.yml index 0540897a..24b7f387 100644 --- a/.github/workflows/quick-runs.yml +++ b/.github/workflows/quick-runs.yml @@ -38,6 +38,7 @@ jobs: run: | python -m pip install --upgrade pip pip install . + pip install onnxruntime==1.18.0 - name: Crop audio and rttm run: | sox audio/ES2002a_long.wav audio/ES2002a.wav trim 00:40 00:30 @@ -50,10 +51,10 @@ jobs: rm rttms/ES2002b_long.rttm - name: Run stream run: | - diart.stream audio/ES2002a.wav --output trash --no-plot --hf-token ${{ secrets.HUGGINGFACE }} + diart.stream audio/ES2002a.wav --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx --output trash --no-plot - name: Run benchmark run: | - diart.benchmark audio --reference rttms --batch-size 4 --hf-token ${{ secrets.HUGGINGFACE }} + diart.benchmark audio --reference rttms --batch-size 4 --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx - name: Run tuning run: | - diart.tune audio --reference rttms --batch-size 4 --num-iter 2 --output trash --hf-token ${{ secrets.HUGGINGFACE }} + diart.tune audio --reference rttms --batch-size 4 --num-iter 2 --output trash --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx diff --git a/assets/models/embedding_uint8.onnx b/assets/models/embedding_uint8.onnx new file mode 100644 index 00000000..ac5ab44d Binary files /dev/null and b/assets/models/embedding_uint8.onnx differ diff --git a/assets/models/segmentation_uint8.onnx b/assets/models/segmentation_uint8.onnx new file mode 100644 index 00000000..8daa3751 Binary files /dev/null and b/assets/models/segmentation_uint8.onnx differ diff --git a/requirements.txt b/requirements.txt index 2d3e4611..d6a7aee6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ numpy>=1.20.2 -matplotlib>=3.3.3 +matplotlib>=3.3.3,<3.6.0 rx>=3.2.0 scipy>=1.6.0 sounddevice>=0.4.2 diff --git a/setup.cfg b/setup.cfg index 66b255f2..315c44ef 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,7 +21,7 @@ package_dir= packages=find: install_requires= numpy>=1.20.2 - matplotlib>=3.3.3 + matplotlib>=3.3.3,<3.6.0 rx>=3.2.0 scipy>=1.6.0 sounddevice>=0.4.2 @@ -41,6 +41,11 @@ install_requires= websocket-client>=0.58.0 rich>=12.5.1 +[options.extras_require] +tests= + pytest>=7.4.0,<8.0.0 + onnxruntime==1.18.0 + [options.packages.find] where=src diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..3c5a2915 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,48 @@ +import random + +import pytest +import torch + +from diart.models import SegmentationModel, EmbeddingModel + + +class DummySegmentationModel: + def to(self, device): + pass + + def __call__(self, waveform: torch.Tensor) -> torch.Tensor: + assert waveform.ndim == 3 + + batch_size, num_channels, num_samples = waveform.shape + num_frames = random.randint(250, 500) + num_speakers = random.randint(3, 5) + + return torch.rand(batch_size, num_frames, num_speakers) + + +class DummyEmbeddingModel: + def to(self, device): + pass + + def __call__(self, waveform: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + assert waveform.ndim == 3 + assert weights.ndim == 2 + + batch_size, num_channels, num_samples = waveform.shape + batch_size_weights, num_frames = weights.shape + + assert batch_size == batch_size_weights + + embedding_dim = random.randint(128, 512) + + return torch.randn(batch_size, embedding_dim) + + +@pytest.fixture(scope="session") +def segmentation_model() -> SegmentationModel: + return SegmentationModel(DummySegmentationModel) + + +@pytest.fixture(scope="session") +def embedding_model() -> EmbeddingModel: + return EmbeddingModel(DummyEmbeddingModel) diff --git a/tests/data/audio/sample.wav b/tests/data/audio/sample.wav new file mode 100644 index 00000000..150d49a6 Binary files /dev/null and b/tests/data/audio/sample.wav differ diff --git a/tests/data/rttm/latency_0.5.rttm b/tests/data/rttm/latency_0.5.rttm new file mode 100644 index 00000000..058ed2e2 --- /dev/null +++ b/tests/data/rttm/latency_0.5.rttm @@ -0,0 +1,13 @@ +SPEAKER sample 1 6.675 0.533 speaker0 +SPEAKER sample 1 7.625 1.883 speaker0 +SPEAKER sample 1 9.508 1.000 speaker1 +SPEAKER sample 1 10.508 0.567 speaker0 +SPEAKER sample 1 10.625 4.133 speaker1 +SPEAKER sample 1 14.325 3.733 speaker0 +SPEAKER sample 1 18.058 3.450 speaker1 +SPEAKER sample 1 18.325 0.183 speaker0 +SPEAKER sample 1 21.508 0.017 speaker0 +SPEAKER sample 1 21.775 0.233 speaker1 +SPEAKER sample 1 22.008 6.633 speaker0 +SPEAKER sample 1 28.508 1.500 speaker1 +SPEAKER sample 1 29.958 0.050 speaker0 diff --git a/tests/data/rttm/latency_1.rttm b/tests/data/rttm/latency_1.rttm new file mode 100644 index 00000000..40c591e8 --- /dev/null +++ b/tests/data/rttm/latency_1.rttm @@ -0,0 +1,13 @@ +SPEAKER sample 1 6.708 0.450 speaker0 +SPEAKER sample 1 7.625 1.383 speaker0 +SPEAKER sample 1 9.008 1.500 speaker1 +SPEAKER sample 1 10.008 1.067 speaker0 +SPEAKER sample 1 10.592 4.200 speaker1 +SPEAKER sample 1 14.308 3.700 speaker0 +SPEAKER sample 1 18.042 3.250 speaker1 +SPEAKER sample 1 18.508 0.033 speaker0 +SPEAKER sample 1 21.108 0.383 speaker0 +SPEAKER sample 1 21.508 0.033 speaker1 +SPEAKER sample 1 21.775 6.817 speaker0 +SPEAKER sample 1 28.008 2.000 speaker1 +SPEAKER sample 1 29.975 0.033 speaker0 diff --git a/tests/data/rttm/latency_2.rttm b/tests/data/rttm/latency_2.rttm new file mode 100644 index 00000000..dacd8453 --- /dev/null +++ b/tests/data/rttm/latency_2.rttm @@ -0,0 +1,10 @@ +SPEAKER sample 1 6.725 0.433 speaker0 +SPEAKER sample 1 7.592 0.817 speaker0 +SPEAKER sample 1 8.475 1.617 speaker1 +SPEAKER sample 1 9.892 1.150 speaker0 +SPEAKER sample 1 10.625 4.133 speaker1 +SPEAKER sample 1 14.292 3.667 speaker0 +SPEAKER sample 1 18.008 3.533 speaker1 +SPEAKER sample 1 18.225 0.283 speaker0 +SPEAKER sample 1 21.758 6.867 speaker0 +SPEAKER sample 1 27.875 2.133 speaker1 diff --git a/tests/data/rttm/latency_3.rttm b/tests/data/rttm/latency_3.rttm new file mode 100644 index 00000000..95d432dc --- /dev/null +++ b/tests/data/rttm/latency_3.rttm @@ -0,0 +1,10 @@ +SPEAKER sample 1 6.725 0.433 speaker0 +SPEAKER sample 1 7.625 0.467 speaker0 +SPEAKER sample 1 8.008 2.050 speaker1 +SPEAKER sample 1 9.875 1.167 speaker0 +SPEAKER sample 1 10.592 4.167 speaker1 +SPEAKER sample 1 14.292 3.667 speaker0 +SPEAKER sample 1 17.992 3.550 speaker1 +SPEAKER sample 1 18.192 0.367 speaker0 +SPEAKER sample 1 21.758 6.833 speaker0 +SPEAKER sample 1 27.825 2.183 speaker1 diff --git a/tests/data/rttm/latency_4.rttm b/tests/data/rttm/latency_4.rttm new file mode 100644 index 00000000..2a73c427 --- /dev/null +++ b/tests/data/rttm/latency_4.rttm @@ -0,0 +1,10 @@ +SPEAKER sample 1 6.742 0.400 speaker0 +SPEAKER sample 1 7.625 0.650 speaker0 +SPEAKER sample 1 8.092 1.950 speaker1 +SPEAKER sample 1 9.875 1.167 speaker0 +SPEAKER sample 1 10.575 4.183 speaker1 +SPEAKER sample 1 14.308 3.667 speaker0 +SPEAKER sample 1 17.992 3.550 speaker1 +SPEAKER sample 1 18.208 0.333 speaker0 +SPEAKER sample 1 21.758 6.817 speaker0 +SPEAKER sample 1 27.808 2.200 speaker1 diff --git a/tests/data/rttm/latency_5.rttm b/tests/data/rttm/latency_5.rttm new file mode 100644 index 00000000..78b1f1e1 --- /dev/null +++ b/tests/data/rttm/latency_5.rttm @@ -0,0 +1,10 @@ +SPEAKER sample 1 6.742 0.383 speaker0 +SPEAKER sample 1 7.625 0.667 speaker0 +SPEAKER sample 1 8.092 1.967 speaker1 +SPEAKER sample 1 9.875 1.167 speaker0 +SPEAKER sample 1 10.558 4.200 speaker1 +SPEAKER sample 1 14.308 3.667 speaker0 +SPEAKER sample 1 17.992 3.550 speaker1 +SPEAKER sample 1 18.208 0.317 speaker0 +SPEAKER sample 1 21.758 6.817 speaker0 +SPEAKER sample 1 27.808 2.200 speaker1 diff --git a/tests/test_aggregation.py b/tests/test_aggregation.py new file mode 100644 index 00000000..21d40322 --- /dev/null +++ b/tests/test_aggregation.py @@ -0,0 +1,54 @@ +import numpy as np +import pytest +from pyannote.core import SlidingWindow, SlidingWindowFeature + +from diart.blocks.aggregation import ( + AggregationStrategy, + HammingWeightedAverageStrategy, + FirstOnlyStrategy, + AverageStrategy, + DelayedAggregation, +) + + +def test_strategy_build(): + strategy = AggregationStrategy.build("mean") + assert isinstance(strategy, AverageStrategy) + + strategy = AggregationStrategy.build("hamming") + assert isinstance(strategy, HammingWeightedAverageStrategy) + + strategy = AggregationStrategy.build("first") + assert isinstance(strategy, FirstOnlyStrategy) + + with pytest.raises(Exception): + AggregationStrategy.build("invalid") + + +def test_aggregation(): + duration = 5 + frames = 500 + step = 0.5 + speakers = 2 + start_time = 10 + resolution = duration / frames + + dagg1 = DelayedAggregation(step=step, latency=2, strategy="mean") + dagg2 = DelayedAggregation(step=step, latency=2, strategy="hamming") + dagg3 = DelayedAggregation(step=step, latency=2, strategy="first") + + for dagg in [dagg1, dagg2, dagg3]: + assert dagg.num_overlapping_windows == 4 + + buffers = [ + SlidingWindowFeature( + np.random.rand(frames, speakers), + SlidingWindow( + start=(i + start_time) * step, duration=resolution, step=resolution + ), + ) + for i in range(dagg1.num_overlapping_windows) + ] + + for dagg in [dagg1, dagg2, dagg3]: + assert dagg(buffers).data.shape == (51, 2) diff --git a/tests/test_diarization.py b/tests/test_diarization.py new file mode 100644 index 00000000..1895c26a --- /dev/null +++ b/tests/test_diarization.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +import random + +import pytest + +from diart import SpeakerDiarizationConfig, SpeakerDiarization +from utils import build_waveform_swf + + +@pytest.fixture +def random_diarization_config( + segmentation_model, embedding_model +) -> SpeakerDiarizationConfig: + duration = round(random.uniform(1, 10), 1) + step = round(random.uniform(0.1, duration), 1) + latency = round(random.uniform(step, duration), 1) + return SpeakerDiarizationConfig( + segmentation=segmentation_model, + embedding=embedding_model, + duration=duration, + step=step, + latency=latency, + ) + + +@pytest.fixture(scope="session") +def min_latency_config(segmentation_model, embedding_model) -> SpeakerDiarizationConfig: + return SpeakerDiarizationConfig( + segmentation=segmentation_model, + embedding=embedding_model, + duration=5, + step=0.5, + latency="min", + ) + + +@pytest.fixture(scope="session") +def max_latency_config(segmentation_model, embedding_model) -> SpeakerDiarizationConfig: + return SpeakerDiarizationConfig( + segmentation=segmentation_model, + embedding=embedding_model, + duration=5, + step=0.5, + latency="max", + ) + + +def test_config( + segmentation_model, embedding_model, min_latency_config, max_latency_config +): + duration = round(random.uniform(1, 10), 1) + step = round(random.uniform(0.1, duration), 1) + latency = round(random.uniform(step, duration), 1) + config = SpeakerDiarizationConfig( + segmentation=segmentation_model, + embedding=embedding_model, + duration=duration, + step=step, + latency=latency, + ) + + assert config.duration == duration + assert config.step == step + assert config.latency == latency + assert min_latency_config.latency == min_latency_config.step + assert max_latency_config.latency == max_latency_config.duration + + +def test_bad_latency(segmentation_model, embedding_model): + duration = round(random.uniform(1, 10), 1) + step = round(random.uniform(0.5, duration - 0.2), 1) + latency_too_low = round(random.uniform(0, step - 0.1), 1) + latency_too_high = round(random.uniform(duration + 0.1, 100), 1) + + config1 = SpeakerDiarizationConfig( + segmentation=segmentation_model, + embedding=embedding_model, + duration=duration, + step=step, + latency=latency_too_low, + ) + config2 = SpeakerDiarizationConfig( + segmentation=segmentation_model, + embedding=embedding_model, + duration=duration, + step=step, + latency=latency_too_high, + ) + + with pytest.raises(AssertionError): + SpeakerDiarization(config1) + + with pytest.raises(AssertionError): + SpeakerDiarization(config2) + + +def test_pipeline_build(random_diarization_config): + pipeline = SpeakerDiarization(random_diarization_config) + + assert pipeline.get_config_class() == SpeakerDiarizationConfig + + hparams = pipeline.hyper_parameters() + hp_names = [hp.name for hp in hparams] + assert len(set(hp_names)) == 3 + + for hparam in hparams: + assert hparam.low == 0 + if hparam.name in ["tau_active", "rho_update"]: + assert hparam.high == 1 + elif hparam.name == "delta_new": + assert hparam.high == 2 + else: + assert False + + assert pipeline.config == random_diarization_config + + +def test_timestamp_shift(random_diarization_config): + pipeline = SpeakerDiarization(random_diarization_config) + + assert pipeline.timestamp_shift == 0 + + new_shift = round(random.uniform(-10, 10), 1) + pipeline.set_timestamp_shift(new_shift) + assert pipeline.timestamp_shift == new_shift + + waveform = build_waveform_swf( + random_diarization_config.duration, + random_diarization_config.sample_rate, + ) + prediction, _ = pipeline([waveform])[0] + + for segment, _, label in prediction.itertracks(yield_label=True): + assert segment.start >= new_shift + assert segment.end >= new_shift + + pipeline.reset() + assert pipeline.timestamp_shift == 0 + + +def test_call_min_latency(min_latency_config): + pipeline = SpeakerDiarization(min_latency_config) + waveform1 = build_waveform_swf( + min_latency_config.duration, + min_latency_config.sample_rate, + start_time=0, + ) + waveform2 = build_waveform_swf( + min_latency_config.duration, + min_latency_config.sample_rate, + min_latency_config.step, + ) + + batch = [waveform1, waveform2] + output = pipeline(batch) + + pred1, wave1 = output[0] + pred2, wave2 = output[1] + + assert waveform1.data.shape[0] == wave1.data.shape[0] + assert wave1.data.shape[0] > wave2.data.shape[0] + + pred1_timeline = pred1.get_timeline() + pred2_timeline = pred2.get_timeline() + pred1_duration = round(pred1_timeline[-1].end - pred1_timeline[0].start, 3) + pred2_duration = round(pred2_timeline[-1].end - pred2_timeline[0].start, 3) + + expected_duration = round(min_latency_config.duration, 3) + expected_step = round(min_latency_config.step, 3) + assert not pred1_timeline or pred1_duration <= expected_duration + assert not pred2_timeline or pred2_duration <= expected_step + + +def test_call_max_latency(max_latency_config): + pipeline = SpeakerDiarization(max_latency_config) + waveform1 = build_waveform_swf( + max_latency_config.duration, + max_latency_config.sample_rate, + start_time=0, + ) + waveform2 = build_waveform_swf( + max_latency_config.duration, + max_latency_config.sample_rate, + max_latency_config.step, + ) + + batch = [waveform1, waveform2] + output = pipeline(batch) + + pred1, wave1 = output[0] + pred2, wave2 = output[1] + + assert waveform1.data.shape[0] > wave1.data.shape[0] + assert wave1.data.shape[0] == wave2.data.shape[0] + + pred1_timeline = pred1.get_timeline() + pred2_timeline = pred2.get_timeline() + pred1_duration = pred1_timeline[-1].end - pred1_timeline[0].start + pred2_duration = pred2_timeline[-1].end - pred2_timeline[0].start + + expected_step = round(max_latency_config.step, 3) + assert not pred1_timeline or round(pred1_duration, 3) <= expected_step + assert not pred2_timeline or round(pred2_duration, 3) <= expected_step diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py new file mode 100644 index 00000000..de5ee8b5 --- /dev/null +++ b/tests/test_end_to_end.py @@ -0,0 +1,78 @@ +import math +from pathlib import Path + +import pytest +from pyannote.database.util import load_rttm + +from diart import SpeakerDiarization, SpeakerDiarizationConfig +from diart.inference import StreamingInference +from diart.models import SegmentationModel, EmbeddingModel +from diart.sources import FileAudioSource + +MODEL_DIR = Path(__file__).parent.parent / "assets" / "models" +DATA_DIR = Path(__file__).parent / "data" + + +@pytest.fixture(scope="session") +def segmentation(): + model_path = MODEL_DIR / "segmentation_uint8.onnx" + return SegmentationModel.from_pretrained(model_path) + + +@pytest.fixture(scope="session") +def embedding(): + model_path = MODEL_DIR / "embedding_uint8.onnx" + return EmbeddingModel.from_pretrained(model_path) + + +@pytest.fixture(scope="session") +def make_config(segmentation, embedding): + def _config(latency): + return SpeakerDiarizationConfig( + segmentation=segmentation, + embedding=embedding, + step=0.5, + latency=latency, + tau_active=0.507, + rho_update=0.006, + delta_new=1.057 + ) + return _config + + +@pytest.mark.parametrize("source_file", [DATA_DIR / "audio" / "sample.wav"]) +@pytest.mark.parametrize("latency", [0.5, 1, 2, 3, 4, 5]) +def test_benchmark(make_config, source_file, latency): + config = make_config(latency) + pipeline = SpeakerDiarization(config) + + padding = pipeline.config.get_file_padding(source_file) + source = FileAudioSource( + source_file, + pipeline.config.sample_rate, + padding, + pipeline.config.step, + ) + + pipeline.set_timestamp_shift(-padding[0]) + inference = StreamingInference( + pipeline, + source, + do_profile=False, + do_plot=False, + show_progress=False + ) + + pred = inference() + + expected_file = (DATA_DIR / "rttm" / f"latency_{latency}.rttm") + expected = load_rttm(expected_file).popitem()[1] + + assert len(pred) == len(expected) + for track1, track2 in zip(pred.itertracks(yield_label=True), expected.itertracks(yield_label=True)): + pred_segment, _, pred_spk = track1 + expected_segment, _, expected_spk = track2 + # We can tolerate a difference of up to 50ms + assert math.isclose(pred_segment.start, expected_segment.start, abs_tol=0.05) + assert math.isclose(pred_segment.end, expected_segment.end, abs_tol=0.05) + assert pred_spk == expected_spk diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..e8ae41c2 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,17 @@ +from __future__ import annotations +import random +import numpy as np +from pyannote.core import SlidingWindowFeature, SlidingWindow + + +def build_waveform_swf( + duration: float, sample_rate: int, start_time: float | None = None +) -> SlidingWindowFeature: + start_time = round(random.uniform(0, 600), 1) if start_time is None else start_time + chunk_size = int(duration * sample_rate) + resolution = duration / chunk_size + samples = np.random.randn(chunk_size, 1) + sliding_window = SlidingWindow( + start=start_time, step=resolution, duration=resolution + ) + return SlidingWindowFeature(samples, sliding_window)