Skip to content

Commit

Permalink
Move things around. Simplify code and wording
Browse files Browse the repository at this point in the history
  • Loading branch information
juanmc2005 committed Nov 15, 2023
1 parent 2108280 commit 4bb955a
Showing 1 changed file with 21 additions and 23 deletions.
44 changes: 21 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,24 @@

<div align="center">
<h4>
<a href="#-installation">
💾 Installation
</a>
<span> | </span>
<a href="#%EF%B8%8F-stream-audio">
🎙️ Stream audio
</a>
<span> | </span>
<a href="#-available-models">
🧠 Available models
<a href="#-installation">
💾 Installation
</a>
<span> | </span>
<a href="#-add-your-model">
🤖 Add your model
<a href="#-models">
🧠 Available models
</a>
<br />
<a href="#-tune-hyper-parameters">
📈 Tune hyper-parameters
📈 Tuning
</a>
<span> | </span>
<a href="#-build-pipelines">
🧠🔗 Build pipelines
🧠🔗 Pipelines
</a>
<span> | </span>
<a href="#-websockets">
Expand Down Expand Up @@ -77,14 +73,22 @@ create your own AI pipeline, benchmark it, tune its hyper-parameters, and even s

## 💾 Installation

1) Create environment:
**1) Make sure your system has the following dependencies:**

```
ffmpeg < 4.4
portaudio == 19.6.X
libsndfile >= 1.2.2
```

Alternatively, we provide an `environment.yml` file for a pre-configured conda environment:

```shell
conda env create -f diart/environment.yml
conda activate diart
```

2) Install the package:
**2) Install the package:**
```shell
pip install diart
```
Expand Down Expand Up @@ -138,10 +142,9 @@ prediction = inference()

For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#-reproducibility)).

## 🧠 Available models

You can use a different segmentation or embedding model with `--segmentation` and `--embedding`.
## 🧠 Models

You can use other models with the `--segmentation` and `--embedding` arguments.
Or in python:

```python
Expand All @@ -151,6 +154,8 @@ segmentation = m.SegmentationModel.from_pretrained("model_name")
embedding = m.EmbeddingModel.from_pretrained("model_name")
```

### Available pre-trained models

Below is a list of all the models currently supported by diart:

| Model Name | Model Type | CPU Time* | GPU Time* |
Expand All @@ -171,16 +176,13 @@ The latency of embedding models is measured in a diarization pipeline using `pya

\* CPU: AMD Ryzen 9 - GPU: RTX 4060 Max-Q

## 🤖 Add your model
### Custom models

Third-party models can be integrated by providing a loader function:

```python
from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.models import EmbeddingModel, SegmentationModel
from diart.sources import MicrophoneAudioSource
from diart.inference import StreamingInference


def segmentation_loader():
# It should take a waveform and return a segmentation tensor
Expand All @@ -190,17 +192,13 @@ def embedding_loader():
# It should take (waveform, weights) and return per-speaker embeddings
return load_pretrained_model("my_other_model.ckpt")


segmentation = SegmentationModel(segmentation_loader)
embedding = EmbeddingModel(embedding_loader)
config = SpeakerDiarizationConfig(
segmentation=segmentation,
embedding=embedding,
)
pipeline = SpeakerDiarization(config)
mic = MicrophoneAudioSource()
inference = StreamingInference(pipeline, mic)
prediction = inference()
```

If you have an ONNX model, you can use `from_onnx()`:
Expand Down

0 comments on commit 4bb955a

Please sign in to comment.