Skip to content

Commit

Permalink
Merge pull request #87 from juanmc2005/develop
Browse files Browse the repository at this point in the history
Version 0.5
  • Loading branch information
juanmc2005 authored Aug 31, 2022
2 parents d4ff0ee + b75dc9f commit 2734c04
Show file tree
Hide file tree
Showing 17 changed files with 451 additions and 169 deletions.
106 changes: 59 additions & 47 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
Stream audio
</a>
<span> | </span>
<a href="#add-your-model">
Add your model
<a href="#custom-models">
Custom models
</a>
<span> | </span>
<a href="#tune-hyper-parameters">
Expand All @@ -34,6 +34,10 @@
Build pipelines
</a>
<br/>
<a href="#websockets">
WebSockets
</a>
<span> | </span>
<a href="#powered-by-research">
Research
</a>
Expand Down Expand Up @@ -72,10 +76,10 @@ conda install pysoundfile -c conda-forge

3) [Install PyTorch](https://pytorch.org/get-started/locally/#start-locally)

4) Install pyannote.audio 2.0 (currently no official release)
4) Install pyannote.audio

```shell
pip install git+https://github.com/pyannote/pyannote-audio.git@2.0.1#egg=pyannote-audio
pip install pyannote.audio==2.0.1
```

**Note:** starting from version 0.4, installing pyannote.audio is mandatory to run the default system or to use pyannote-based models. In any other case, this step can be ignored.
Expand Down Expand Up @@ -105,25 +109,26 @@ See `diart.stream -h` for more options.

### From python

Run a real-time speaker diarization pipeline over an audio stream with `RealTimeInference`:
Use `RealTimeInference` to easily run a pipeline on an audio source and write the results to disk:

```python
from diart.sources import MicrophoneAudioSource
from diart.inference import RealTimeInference
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig

config = PipelineConfig() # Default parameters
pipeline = OnlineSpeakerDiarization(config)
audio_source = MicrophoneAudioSource(config.sample_rate)
inference = RealTimeInference("/output/path", do_plot=True)
inference(pipeline, audio_source)
from diart.pipelines import OnlineSpeakerDiarization
from diart.sinks import RTTMWriter

pipeline = OnlineSpeakerDiarization()
mic = MicrophoneAudioSource(pipeline.config.sample_rate)
inference = RealTimeInference(pipeline, mic, do_plot=True)
inference.attach_observers(RTTMWriter("/output/file.rttm"))
inference()
```

For faster inference and evaluation on a dataset we recommend to use `Benchmark` instead (see our notes on [reproducibility](#reproducibility)).
For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#reproducibility)).

## Add your model
## Custom models

Third-party segmentation and embedding models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel`:
Third-party models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel`:

```python
import torch
Expand All @@ -148,8 +153,8 @@ class MyEmbeddingModel(EmbeddingModel):
config = PipelineConfig(embedding=MyEmbeddingModel())
pipeline = OnlineSpeakerDiarization(config)
mic = MicrophoneAudioSource(config.sample_rate)
inference = RealTimeInference("/out/dir")
inference(pipeline, mic)
inference = RealTimeInference(pipeline, mic)
inference()
```

## Tune hyper-parameters
Expand All @@ -159,31 +164,21 @@ Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.re
### From the command line

```shell
diart.tune /wav/dir --reference /rttm/dir --output /out/dir
diart.tune /wav/dir --reference /rttm/dir --output /output/dir
```

See `diart.tune -h` for more options.

### From python

```python
from diart.optim import Optimizer, TauActive, RhoUpdate, DeltaNew
from diart.pipelines import PipelineConfig
from diart.inference import Benchmark
from diart.optim import Optimizer

# Benchmark runs and evaluates the pipeline on a dataset
benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir/tmp", show_report=False)
# Base configuration for the pipeline we're going to tune
base_config = PipelineConfig()
# Hyper-parameters to optimize
hparams = [TauActive, RhoUpdate, DeltaNew]
# Optimizer implements the optimization loop
optimizer = Optimizer(benchmark, base_config, hparams, "/out/dir")
# Run optimization
optimizer.optimize(num_iter=100, show_progress=True)
optimizer = Optimizer("/wav/dir", "/rttm/dir", "/output/dir")
optimizer(num_iter=100)
```

This will use `/out/dir/tmp` as a working directory and write results to an sqlite database in `/out/dir`.
This will write results to an sqlite database in `/output/dir`.

### Distributed optimization

Expand All @@ -195,26 +190,23 @@ mysql -u root -e "CREATE DATABASE IF NOT EXISTS example"
optuna create-study --study-name "example" --storage "mysql://root@localhost/example"
```

Then you can run multiple identical optimizers pointing to the database:
You can now run multiple identical optimizers pointing to this database:

```shell
diart.tune /wav/dir --reference /rttm/dir --output /out/dir --storage mysql://root@localhost/example
diart.tune /wav/dir --reference /rttm/dir --storage mysql://root@localhost/example
```

If you are using the python API, make sure that worker directories are different to avoid concurrency issues:
or in python:

```python
from diart.optim import Optimizer
from diart.inference import Benchmark
from optuna.samplers import TPESampler
import optuna

ID = 0 # Worker identifier
base_config, hparams = ...
benchmark = Benchmark("/wav/dir", "/rttm/dir", f"/out/dir/worker-{ID}", show_report=False)
study = optuna.load_study("example", "mysql://root@localhost/example", TPESampler())
optimizer = Optimizer(benchmark, base_config, hparams, study)
optimizer.optimize(num_iter=100, show_progress=True)
db = "mysql://root@localhost/example"
study = optuna.load_study("example", db, TPESampler())
optimizer = Optimizer("/wav/dir", "/rttm/dir", study)
optimizer(num_iter=100)
```

## Build pipelines
Expand Down Expand Up @@ -256,6 +248,24 @@ torch.Size([4, 512])
...
```

## WebSockets

Diart is also compatible with the WebSocket protocol to serve pipelines on the web.

In the following example we build a minimal server that receives audio chunks and sends back predictions in RTTM format:

```python
from diart.pipelines import OnlineSpeakerDiarization
from diart.sources import WebSocketAudioSource
from diart.inference import RealTimeInference

pipeline = OnlineSpeakerDiarization()
source = WebSocketAudioSource(pipeline.config.sample_rate, "localhost", 7007)
inference = RealTimeInference(pipeline, source, do_plot=True)
inference.attach_hooks(lambda ann_wav: source.send(ann_wav[0].to_rttm()))
inference()
```

## Powered by research

Diart is the official implementation of the paper *[Overlap-aware low-latency online speaker diarization based on end-to-end local segmentation](/paper.pdf)* by [Juan Manuel Coria](https://juanmc2005.github.io/), [Hervé Bredin](https://herve.niderb.fr), [Sahar Ghannay](https://saharghannay.github.io/) and [Sophie Rosset](https://perso.limsi.fr/rosset/).
Expand Down Expand Up @@ -299,32 +309,34 @@ To obtain the best results, make sure to use the following hyper-parameters:
| DIHARD II | 1s | 0.619 | 0.326 | 0.997 |
| DIHARD II | 5s | 0.555 | 0.422 | 1.517 |

`diart.benchmark` and `diart.inference.Benchmark` can quickly run and evaluate the pipeline, and even measure its real-time latency. For instance, for a DIHARD III configuration:
`diart.benchmark` and `diart.inference.Benchmark` can run, evaluate and measure the real-time latency of the pipeline. For instance, for a DIHARD III configuration:

```shell
diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --output /out/dir
diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --segmentation pyannote/segmentation@Interspeech2021
```

or using the inference API:

```python
from diart.inference import Benchmark
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig
from diart.models import SegmentationModel

config = PipelineConfig(
# Set the model used in the paper
segmentation=SegmentationModel.from_pyannote("pyannote/segmentation@Interspeech2021"),
step=0.5,
latency=0.5,
tau_active=0.555,
rho_update=0.422,
delta_new=1.517
)
pipeline = OnlineSpeakerDiarization(config)
benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir")

benchmark = Benchmark("/wav/dir", "/rttm/dir")
benchmark(pipeline)
```

This runs a faster inference by pre-calculating model outputs in batches.
This pre-calculates model outputs in batches, so it runs a lot faster.
See `diart.benchmark -h` for more options.

For convenience and to facilitate future comparisons, we also provide the [expected outputs](/expected_outputs) of the paper implementation in RTTM format for every entry of Table 1 and Figure 5. This includes the VBx offline topline as well as our proposed online approach with latencies 500ms, 1s, 2s, 3s, 4s, and 5s.
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ einops>=0.3.0
tqdm>=4.64.0
pandas>=1.4.2
torchaudio>=0.10,<1.0
pyannote.core>=4.4
pyannote.core>=4.5
pyannote.database>=4.1.1
pyannote.metrics>=3.2
optuna>=2.10
websockets>=10.3
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name=diart
version=0.4.0
version=0.5.0
author=Juan Manuel Coria
description=Speaker diarization in real time
long_description=file: README.md
Expand Down Expand Up @@ -29,10 +29,11 @@ install_requires=
tqdm>=4.64.0
pandas>=1.4.2
torchaudio>=0.10,<1.0
pyannote.core>=4.4
pyannote.core>=4.5
pyannote.database>=4.1.1
pyannote.metrics>=3.2
optuna>=2.10
websockets>=10.3

[options.packages.find]
where=src
Expand Down
2 changes: 2 additions & 0 deletions src/diart/argdoc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
SEGMENTATION = "Segmentation model name from pyannote"
EMBEDDING = "Embedding model name from pyannote"
STEP = "Sliding window step (in seconds)"
LATENCY = "System latency (in seconds). STEP <= LATENCY <= CHUNK_DURATION"
TAU = "Probability threshold to consider a speaker as active. 0 <= TAU <= 1"
Expand Down
21 changes: 16 additions & 5 deletions src/diart/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import argparse
from pathlib import Path

import torch

import diart.argdoc as argdoc
from diart.inference import Benchmark
from diart.models import SegmentationModel, EmbeddingModel
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig


def run():
parser = argparse.ArgumentParser()
parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
parser.add_argument("--reference", type=str,
parser.add_argument("root", type=Path, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
parser.add_argument("--embedding", default="pyannote/embedding", type=str,
help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding")
parser.add_argument("--reference", type=Path,
help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files")
parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
Expand All @@ -23,20 +29,25 @@ def run():
parser.add_argument("--batch-size", default=32, type=int, help=f"{argdoc.BATCH_SIZE}. Defaults to 32")
parser.add_argument("--cpu", dest="cpu", action="store_true",
help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to `root`")
parser.add_argument("--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing")
args = parser.parse_args()
args.device = torch.device("cpu") if args.cpu else None
args.segmentation = SegmentationModel.from_pyannote(args.segmentation)
args.embedding = EmbeddingModel.from_pyannote(args.embedding)

benchmark = Benchmark(
args.root,
args.reference,
args.output,
show_progress=True,
show_report=True,
batch_size=args.batch_size
batch_size=args.batch_size,
)

benchmark(OnlineSpeakerDiarization(PipelineConfig.from_namespace(args), profile=True))
pipeline = OnlineSpeakerDiarization(PipelineConfig.from_namespace(args), profile=True)
report = benchmark(pipeline)
if args.output is not None and report is not None:
report.to_csv(args.output / "benchmark_report.csv")


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion src/diart/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
OverlapAwareSpeakerEmbedding,
)
from .segmentation import SpeakerSegmentation
from .utils import Binarize
from .utils import Binarize, Resample, AdjustVolume
2 changes: 1 addition & 1 deletion src/diart/blocks/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures:
with torch.no_grad():
wave = rearrange(self.formatter.cast(waveform), "batch sample channel -> batch channel sample")
output = self.model(wave.to(self.device)).cpu()
return self.formatter.restore_type(output)
return self.formatter.restore_type(output)
Loading

0 comments on commit 2734c04

Please sign in to comment.