From a828661f0e9efa4776051c1842c191f3e00efced Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Tue, 26 Jul 2022 16:57:00 +0200 Subject: [PATCH 01/23] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7cf7030b..fdc5f57c 100644 --- a/README.md +++ b/README.md @@ -72,10 +72,10 @@ conda install pysoundfile -c conda-forge 3) [Install PyTorch](https://pytorch.org/get-started/locally/#start-locally) -4) Install pyannote.audio 2.0 (currently in development) +4) Install pyannote.audio 2.0 from a compatible commit (currently in active development) ```shell -pip install git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio +pip install git+https://github.com/pyannote/pyannote-audio.git@3147e2bfe9a7af388d0c01f3bba3d0578ba60c67#egg=pyannote-audio ``` **Note:** starting from version 0.4, installing pyannote.audio is mandatory to run the default system or to use pyannote-based models. In any other case, this step can be ignored. From ea609faeb550d4abf6dda2f0c7457f9dde19bfbb Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Tue, 26 Jul 2022 17:40:22 +0200 Subject: [PATCH 02/23] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index fdc5f57c..a0f45b85 100644 --- a/README.md +++ b/README.md @@ -72,10 +72,10 @@ conda install pysoundfile -c conda-forge 3) [Install PyTorch](https://pytorch.org/get-started/locally/#start-locally) -4) Install pyannote.audio 2.0 from a compatible commit (currently in active development) +4) Install pyannote.audio 2.0 (currently no official release) ```shell -pip install git+https://github.com/pyannote/pyannote-audio.git@3147e2bfe9a7af388d0c01f3bba3d0578ba60c67#egg=pyannote-audio +pip install git+https://github.com/pyannote/pyannote-audio.git@2.0.1#egg=pyannote-audio ``` **Note:** starting from version 0.4, installing pyannote.audio is mandatory to run the default system or to use pyannote-based models. In any other case, this step can be ignored. From a7befd775c3e90e83b3bd8e208bcdb08679c0d44 Mon Sep 17 00:00:00 2001 From: Amit Kesari Date: Wed, 27 Jul 2022 00:26:06 +0530 Subject: [PATCH 03/23] add `study_or_path` as a Path for conversion from string updated according to suggestion for converting str into Path --- src/diart/optim.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diart/optim.py b/src/diart/optim.py index db616db5..2f7dd10c 100644 --- a/src/diart/optim.py +++ b/src/diart/optim.py @@ -51,6 +51,7 @@ def __init__( if isinstance(study_or_path, Study): self.study = study_or_path elif isinstance(study_or_path, str) or isinstance(study_or_path, Path): + study_or_path = Path(study_or_path) self.study = create_study( storage="sqlite:///" + str(study_or_path / f"{study_or_path.stem}.db"), sampler=TPESampler(), From b6e048cb9d41ccf77c3830797d20ae2d40d0062d Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Wed, 27 Jul 2022 17:06:14 +0200 Subject: [PATCH 04/23] Add WebSocketAudioSource --- requirements.txt | 1 + setup.cfg | 1 + src/diart/sources.py | 73 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 74 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index d2259fb5..a72491e7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ pyannote.core>=4.4 pyannote.database>=4.1.1 pyannote.metrics>=3.2 optuna>=2.10 +websockets>=10.3 diff --git a/setup.cfg b/setup.cfg index 5ba6a171..c22a7aff 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,6 +33,7 @@ install_requires= pyannote.database>=4.1.1 pyannote.metrics>=3.2 optuna>=2.10 + websockets>=10.3 [options.packages.find] where=src diff --git a/src/diart/sources.py b/src/diart/sources.py index 89fe2634..f8fe9850 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -1,11 +1,14 @@ +import asyncio +import base64 import math from pathlib import Path from queue import SimpleQueue -from typing import Text, Optional, Callable +from typing import Text, Optional, Callable, AnyStr import numpy as np import sounddevice as sd import torch +import websockets from einops import rearrange from pyannote.core import SlidingWindowFeature, SlidingWindow from rx.subject import Subject @@ -222,3 +225,71 @@ def read(self): self.stream.on_error(e) break self.stream.on_completed() + + +class WebSocketAudioSource(AudioSource): + """Represents a source of audio coming from the network using the WebSocket protocol. + + Parameters + ---------- + sample_rate: int + Sample rate of the chunks emitted. + host: Text + The host to run the websocket server. Defaults to ``None`` (all interfaces). + port: int + The port to run the websocket server. Defaults to 7007. + """ + def __init__(self, sample_rate: int, host: Text = None, port: int = 7007): + name = host if host is not None and host else "localhost" + uri = f"ws://{name}:{port}" + super().__init__(uri, sample_rate) + self.host = host + self.port = port + self.websocket = None + + async def _ws_handler(self, websocket): + self.websocket = websocket + try: + async for message in websocket: + # Decode chunk encoded in base64 + byte_samples = base64.decodebytes(message.encode("utf-8")) + # Recover array from bytes + samples = np.frombuffer(byte_samples, dtype=np.float32) + # Reshape and send through + self.stream.on_next(samples.reshape(1, -1)) + self.stream.on_completed() + except websockets.ConnectionClosedError as e: + self.stream.on_error(e) + + async def _async_read(self): + async with websockets.serve(self._ws_handler, self.host, self.port): + await asyncio.Future() + + async def _async_send(self, message: AnyStr): + await self.websocket.send(message) + + def read(self): + """Starts running the websocket server and listening for audio chunks""" + asyncio.run(self._async_read()) + + def send(self, message: AnyStr): + """Send a message through the current websocket. + + Parameters + ---------- + message: AnyStr + Bytes or string to send. + """ + # A running loop must exist in order to send back a message + ws_closed = "Websocket isn't open, try calling `read()` first" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + raise RuntimeError(ws_closed) + + if not loop.is_running(): + raise RuntimeError(ws_closed) + + # TODO support broadcasting to many clients + # Schedule a coroutine to send back the message + asyncio.run_coroutine_threadsafe(self._async_send(message), loop=loop) From 7a4114f0a5f176deed8d64633fe53bf2752e6a9d Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Wed, 27 Jul 2022 17:26:43 +0200 Subject: [PATCH 05/23] Add WebSocket section to README.md --- README.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/README.md b/README.md index a0f45b85..76e34f44 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,10 @@ Build pipelines
+ + WebSockets + + | Research @@ -256,6 +260,31 @@ torch.Size([4, 512]) ... ``` +## WebSockets + +Diart is also compatible with the WebSocket protocol so you can serve your pipeline on the web. + +In the following example we build a minimal server so a client can send audio to the remote pipeline and then receive a prediction in RTTM format: + +```python +import rx.operators as ops +import diart.operators as dops +from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig +from diart.sources import WebSocketAudioSource + +config = PipelineConfig() +source = WebSocketAudioSource(config.sample_rate, "localhost", 7007) +pipeline = OnlineSpeakerDiarization(config) + +pipeline.from_audio_source(source).pipe( + dops.progress(f"Streaming from {source.uri}", unit="chunk"), + ops.starmap(lambda ann, _: ann.to_rttm()), + ops.filter(lambda rttm: bool(rttm)), # Ignore non-speech +).subscribe(source.send) + +source.read() +``` + ## Powered by research Diart is the official implementation of the paper *[Overlap-aware low-latency online speaker diarization based on end-to-end local segmentation](/paper.pdf)* by [Juan Manuel Coria](https://juanmc2005.github.io/), [Hervé Bredin](https://herve.niderb.fr), [Sahar Ghannay](https://saharghannay.github.io/) and [Sophie Rosset](https://perso.limsi.fr/rosset/). From 2262b04c0eea5ac3d65960c4b2cee6bfb26ad806 Mon Sep 17 00:00:00 2001 From: ckliao-nccu <71756659+ckliao-nccu@users.noreply.github.com> Date: Thu, 28 Jul 2022 14:08:40 +0800 Subject: [PATCH 06/23] Replace uri to avoid path error --- src/diart/sources.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diart/sources.py b/src/diart/sources.py index f8fe9850..241576a7 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -241,7 +241,7 @@ class WebSocketAudioSource(AudioSource): """ def __init__(self, sample_rate: int, host: Text = None, port: int = 7007): name = host if host is not None and host else "localhost" - uri = f"ws://{name}:{port}" + uri = f"{name}:{port}" super().__init__(uri, sample_rate) self.host = host self.port = port From 7223b54f22d5e924058e9161fd7078dc8c746129 Mon Sep 17 00:00:00 2001 From: ckliao-nccu <71756659+ckliao-nccu@users.noreply.github.com> Date: Thu, 28 Jul 2022 14:14:43 +0800 Subject: [PATCH 07/23] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 76e34f44..20fb45db 100644 --- a/README.md +++ b/README.md @@ -76,10 +76,10 @@ conda install pysoundfile -c conda-forge 3) [Install PyTorch](https://pytorch.org/get-started/locally/#start-locally) -4) Install pyannote.audio 2.0 (currently no official release) +4) Install pyannote.audio 2.0 ```shell -pip install git+https://github.com/pyannote/pyannote-audio.git@2.0.1#egg=pyannote-audio +pip install pyannote.audio ``` **Note:** starting from version 0.4, installing pyannote.audio is mandatory to run the default system or to use pyannote-based models. In any other case, this step can be ignored. From a5948dc34e1ea48d5ce1c1d9588128e65b2dc8f3 Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Fri, 29 Jul 2022 14:47:49 +0200 Subject: [PATCH 08/23] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 20fb45db..cc0055db 100644 --- a/README.md +++ b/README.md @@ -76,7 +76,7 @@ conda install pysoundfile -c conda-forge 3) [Install PyTorch](https://pytorch.org/get-started/locally/#start-locally) -4) Install pyannote.audio 2.0 +4) Install pyannote.audio ```shell pip install pyannote.audio From 9294a06bf28f38765a78388cfdc79cebf86676b1 Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Fri, 29 Jul 2022 14:47:58 +0200 Subject: [PATCH 09/23] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index cc0055db..6f59e533 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,7 @@ conda install pysoundfile -c conda-forge 4) Install pyannote.audio ```shell -pip install pyannote.audio +pip install pyannote.audio==2.0.1 ``` **Note:** starting from version 0.4, installing pyannote.audio is mandatory to run the default system or to use pyannote-based models. In any other case, this step can be ignored. From e0ebb960692895d7c03ee9d4a7524d554bb0bee0 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Fri, 29 Jul 2022 17:47:46 +0200 Subject: [PATCH 10/23] Make RealTimeInference compatible with websockets. RealTimeInference and Benchmark can now run without writing to disk --- README.md | 84 ++++++++++++------------- src/diart/benchmark.py | 4 +- src/diart/inference.py | 137 +++++++++++++++++++++++++++-------------- src/diart/optim.py | 9 +-- src/diart/pipelines.py | 4 +- src/diart/sinks.py | 42 +++++++++---- src/diart/sources.py | 3 +- src/diart/stream.py | 5 +- src/diart/tune.py | 13 +--- 9 files changed, 175 insertions(+), 126 deletions(-) diff --git a/README.md b/README.md index 6f59e533..37e20eaa 100644 --- a/README.md +++ b/README.md @@ -22,8 +22,8 @@ Stream audio | - - Add your model + + Custom models | @@ -109,25 +109,28 @@ See `diart.stream -h` for more options. ### From python -Run a real-time speaker diarization pipeline over an audio stream with `RealTimeInference`: +Use `RealTimeInference` to easily run a pipeline on an audio source and write the results to disk: ```python from diart.sources import MicrophoneAudioSource from diart.inference import RealTimeInference -from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig +from diart.pipelines import OnlineSpeakerDiarization +from diart.sinks import RTTMWriter -config = PipelineConfig() # Default parameters -pipeline = OnlineSpeakerDiarization(config) -audio_source = MicrophoneAudioSource(config.sample_rate) -inference = RealTimeInference("/output/path", do_plot=True) -inference(pipeline, audio_source) +pipeline = OnlineSpeakerDiarization() +mic = MicrophoneAudioSource(pipeline.config.sample_rate) +inference = RealTimeInference(pipeline, mic, do_plot=True) +# Optionally stream predictions to an RTTM file +inference.attach_observers(RTTMWriter("/output/file.rttm")) + +inference() ``` -For faster inference and evaluation on a dataset we recommend to use `Benchmark` instead (see our notes on [reproducibility](#reproducibility)). +For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#reproducibility)). -## Add your model +## Custom models -Third-party segmentation and embedding models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel`: +Third-party models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel`: ```python import torch @@ -152,8 +155,9 @@ class MyEmbeddingModel(EmbeddingModel): config = PipelineConfig(embedding=MyEmbeddingModel()) pipeline = OnlineSpeakerDiarization(config) mic = MicrophoneAudioSource(config.sample_rate) -inference = RealTimeInference("/out/dir") -inference(pipeline, mic) +inference = RealTimeInference(pipeline, mic) + +inference() ``` ## Tune hyper-parameters @@ -172,22 +176,19 @@ See `diart.tune -h` for more options. ```python from diart.optim import Optimizer, TauActive, RhoUpdate, DeltaNew -from diart.pipelines import PipelineConfig from diart.inference import Benchmark # Benchmark runs and evaluates the pipeline on a dataset -benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir/tmp", show_report=False) -# Base configuration for the pipeline we're going to tune -base_config = PipelineConfig() +benchmark = Benchmark("/wav/dir", "/rttm/dir", show_report=False) # Hyper-parameters to optimize hparams = [TauActive, RhoUpdate, DeltaNew] # Optimizer implements the optimization loop -optimizer = Optimizer(benchmark, base_config, hparams, "/out/dir") -# Run optimization +optimizer = Optimizer(benchmark, hparams, "/out/dir") + optimizer.optimize(num_iter=100, show_progress=True) ``` -This will use `/out/dir/tmp` as a working directory and write results to an sqlite database in `/out/dir`. +This will write results to an sqlite database in `/out/dir`. ### Distributed optimization @@ -199,13 +200,13 @@ mysql -u root -e "CREATE DATABASE IF NOT EXISTS example" optuna create-study --study-name "example" --storage "mysql://root@localhost/example" ``` -Then you can run multiple identical optimizers pointing to the database: +Then you can run multiple identical optimizers pointing to this database: ```shell diart.tune /wav/dir --reference /rttm/dir --output /out/dir --storage mysql://root@localhost/example ``` -If you are using the python API, make sure that worker directories are different to avoid concurrency issues: +or in python: ```python from diart.optim import Optimizer @@ -213,11 +214,11 @@ from diart.inference import Benchmark from optuna.samplers import TPESampler import optuna -ID = 0 # Worker identifier -base_config, hparams = ... -benchmark = Benchmark("/wav/dir", "/rttm/dir", f"/out/dir/worker-{ID}", show_report=False) +hparams = ... +benchmark = Benchmark("/wav/dir", "/rttm/dir", show_report=False) study = optuna.load_study("example", "mysql://root@localhost/example", TPESampler()) -optimizer = Optimizer(benchmark, base_config, hparams, study) +optimizer = Optimizer(benchmark, hparams, study) + optimizer.optimize(num_iter=100, show_progress=True) ``` @@ -262,27 +263,21 @@ torch.Size([4, 512]) ## WebSockets -Diart is also compatible with the WebSocket protocol so you can serve your pipeline on the web. +Diart is also compatible with the WebSocket protocol to serve your pipeline on the web. -In the following example we build a minimal server so a client can send audio to the remote pipeline and then receive a prediction in RTTM format: +In the following example we build a minimal server for a client to send audio and receive a prediction in RTTM format: ```python -import rx.operators as ops -import diart.operators as dops -from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig +from diart.pipelines import OnlineSpeakerDiarization from diart.sources import WebSocketAudioSource +from diart.inference import RealTimeInference -config = PipelineConfig() -source = WebSocketAudioSource(config.sample_rate, "localhost", 7007) -pipeline = OnlineSpeakerDiarization(config) - -pipeline.from_audio_source(source).pipe( - dops.progress(f"Streaming from {source.uri}", unit="chunk"), - ops.starmap(lambda ann, _: ann.to_rttm()), - ops.filter(lambda rttm: bool(rttm)), # Ignore non-speech -).subscribe(source.send) +pipeline = OnlineSpeakerDiarization() +source = WebSocketAudioSource(pipeline.config.sample_rate, "localhost", 7007) +inference = RealTimeInference(pipeline, source, do_plot=True) +inference.attach_hooks(lambda ann_wav: source.send(ann_wav[0].to_rttm())) -source.read() +inference() ``` ## Powered by research @@ -331,7 +326,7 @@ To obtain the best results, make sure to use the following hyper-parameters: `diart.benchmark` and `diart.inference.Benchmark` can quickly run and evaluate the pipeline, and even measure its real-time latency. For instance, for a DIHARD III configuration: ```shell -diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --output /out/dir +diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 ``` or using the inference API: @@ -348,8 +343,7 @@ config = PipelineConfig( delta_new=1.517 ) pipeline = OnlineSpeakerDiarization(config) -benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir") - +benchmark = Benchmark("/wav/dir", "/rttm/dir") benchmark(pipeline) ``` diff --git a/src/diart/benchmark.py b/src/diart/benchmark.py index ad0969fb..ab60be0d 100644 --- a/src/diart/benchmark.py +++ b/src/diart/benchmark.py @@ -23,7 +23,7 @@ def run(): parser.add_argument("--batch-size", default=32, type=int, help=f"{argdoc.BATCH_SIZE}. Defaults to 32") parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") - parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to `root`") + parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to no writing") args = parser.parse_args() args.device = torch.device("cpu") if args.cpu else None @@ -33,7 +33,7 @@ def run(): args.output, show_progress=True, show_report=True, - batch_size=args.batch_size + batch_size=args.batch_size, ) benchmark(OnlineSpeakerDiarization(PipelineConfig.from_namespace(args), profile=True)) diff --git a/src/diart/inference.py b/src/diart/inference.py index e58240e9..2f770084 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -1,74 +1,97 @@ from pathlib import Path -from typing import Union, Text, Optional +from typing import Union, Text, Optional, Callable, Tuple import pandas as pd import rx.operators as ops -from pyannote.core import Annotation +from pyannote.core import Annotation, SlidingWindowFeature from pyannote.database.util import load_rttm from pyannote.metrics.diarization import DiarizationErrorRate +from rx.core import Observer import diart.operators as dops import diart.sources as src from diart.pipelines import OnlineSpeakerDiarization -from diart.sinks import RTTMWriter, RealTimePlot +from diart.sinks import RTTMAccumulator, RTTMWriter, RealTimePlot class RealTimeInference: """ + Simplifies inference in real time for users that do not want to play with the reactivex interface. Streams an audio source to an online speaker diarization pipeline. - It writes predictions to an output directory in RTTM format and plots them in real time. + It allows users to attach a chain of operations in the form of hooks. Parameters ---------- - output_path: Text or Path - Output directory to store predictions in RTTM format. + pipeline: OnlineSpeakerDiarization + Configured speaker diarization pipeline. + source: AudioSource + Audio source to be read and streamed. do_plot: bool Whether to draw predictions in a moving plot. Defaults to True. """ - def __init__(self, output_path: Union[Text, Path], do_plot: bool = True): - self.output_path = Path(output_path).expanduser() - self.output_path.mkdir(parents=True, exist_ok=True) + def __init__( + self, + pipeline: OnlineSpeakerDiarization, + source: src.AudioSource, + do_plot: bool = True + ): + self.pipeline = pipeline + self.source = source self.do_plot = do_plot + self.accumulator = RTTMAccumulator() + self.stream = self.pipeline.from_audio_source(source).pipe( + dops.progress(f"Streaming {source.uri}", total=source.length, leave=True), + ops.do(self.accumulator), + ) - def __call__(self, pipeline: OnlineSpeakerDiarization, source: src.AudioSource) -> Annotation: + def attach_hooks(self, *hooks: Callable[[Tuple[Annotation, SlidingWindowFeature]], None]): """ - Stream audio chunks from `source` to `pipeline` and write predictions to disk. + Attach hooks to the pipeline. Parameters ---------- - pipeline: OnlineSpeakerDiarization - Configured speaker diarization pipeline. - source: AudioSource - Audio source to be read and streamed. + *hooks: (Tuple[Annotation, SlidingWindowFeature]) -> None + Hook functions to consume emitted annotations and audio. + """ + self.stream = self.stream.pipe(*[ops.do_action(hook) for hook in hooks]) + + def attach_observers(self, *observers: Observer): + """ + Attach rx observers to the pipeline. + + Parameters + ---------- + *observers: Observer + Observers to consume emitted annotations and audio. + """ + self.stream = self.stream.pipe(*[ops.do(sink) for sink in observers]) + + def __call__(self) -> Annotation: + """ + Stream audio chunks from `source` to `pipeline` + writing predictions to disk. Returns ------- predictions: Annotation Speaker diarization pipeline predictions """ - rttm_path = self.output_path / f"{source.uri}.rttm" - rttm_writer = RTTMWriter(path=rttm_path) - observable = pipeline.from_audio_source(source).pipe( - dops.progress(f"Streaming {source.uri}", total=source.length, leave=True) - ) - if not self.do_plot: - # Write RTTM file only - observable.subscribe(rttm_writer) - else: - # Write RTTM file + buffering and real-time plot - observable.pipe( - ops.do(rttm_writer), + config = self.pipeline.config + observable = self.stream + if self.do_plot: + # Buffering is needed for the real-time plot, so we do this at the very end + observable = self.stream.pipe( dops.buffer_output( - duration=pipeline.config.duration, - step=pipeline.config.step, - latency=pipeline.config.latency, - sample_rate=pipeline.config.sample_rate + duration=config.duration, + step=config.step, + latency=config.latency, + sample_rate=config.sample_rate, ), - ).subscribe(RealTimePlot(pipeline.config.duration, pipeline.config.latency)) - # Stream audio through the pipeline - source.read() - - return load_rttm(rttm_path)[source.uri] + ops.do(RealTimePlot(config.duration, config.latency)), + ) + observable.subscribe() + self.source.read() + return self.accumulator.annotation class Benchmark: @@ -82,10 +105,27 @@ class Benchmark: ---------- speech_path: Text or Path Directory with audio files. - reference_path: Text or Path + reference_path: Text, Path or None Directory with reference RTTM files (same names as audio files). - output_path: Text or Path + If None, performance will not be calculated. + Defaults to None. + output_path: Text, Path or None Output directory to store predictions in RTTM format. + If None, predictions will not be written to disk. + Defaults to None. + show_progress: bool + Whether to show progress bars. + Defaults to True. + show_report: bool + Whether to print a performance report to stdout. + Defaults to True. + batch_size: int + Inference batch size. + If < 2, then it will run in real time. + If >= 2, then it will pre-calculate segmentation and + embeddings, running the rest in real time. + The performance between this two modes does not differ. + Defaults to 32. """ def __init__( self, @@ -104,9 +144,8 @@ def __init__( self.reference_path = Path(self.reference_path).expanduser() assert self.reference_path.is_dir(), "Reference path must be a directory" - if output_path is None: - self.output_path = self.speech_path - else: + self.output_path = output_path + if self.output_path is not None: self.output_path = Path(output_path).expanduser() self.output_path.mkdir(parents=True, exist_ok=True) @@ -133,6 +172,7 @@ def __call__(self, pipeline: OnlineSpeakerDiarization) -> Optional[pd.DataFrame] loader = src.AudioLoader(pipeline.config.sample_rate, mono=True) audio_file_paths = list(self.speech_path.iterdir()) num_audio_files = len(audio_file_paths) + predictions = [] for i, filepath in enumerate(audio_file_paths): num_chunks = loader.get_num_sliding_chunks( filepath, pipeline.config.duration, pipeline.config.step @@ -164,22 +204,27 @@ def __call__(self, pipeline: OnlineSpeakerDiarization) -> Optional[pd.DataFrame] if self.show_progress: observable = observable.pipe( dops.progress( - desc=f"Streaming {filepath.stem} ({i + 1}/{num_audio_files})", + desc=f"Streaming {source.uri} ({i + 1}/{num_audio_files})", total=num_chunks, leave=False, ) ) - observable.subscribe(RTTMWriter(path=self.output_path / f"{filepath.stem}.rttm")) + if self.output_path is not None: + observable = observable.pipe( + ops.do(RTTMWriter(self.output_path / f"{source.uri}.rttm")) + ) + accumulator = RTTMAccumulator() + observable.subscribe(accumulator) source.read() + predictions.append(accumulator.annotation) # Run evaluation if self.reference_path is not None: metric = DiarizationErrorRate(collar=0, skip_overlap=False) - for ref_path in self.reference_path.iterdir(): - ref = load_rttm(ref_path).popitem()[1] - hyp = load_rttm(self.output_path / ref_path.name).popitem()[1] + for hyp in predictions: + ref = load_rttm(self.reference_path / f"{hyp.uri}.rttm").popitem()[1] metric(ref, hyp) return metric.report(display=self.show_report) diff --git a/src/diart/optim.py b/src/diart/optim.py index 2f7dd10c..576cfd6c 100644 --- a/src/diart/optim.py +++ b/src/diart/optim.py @@ -39,12 +39,12 @@ class Optimizer: def __init__( self, benchmark: Benchmark, - base_config: PipelineConfig, hparams: Iterable[HyperParameter], study_or_path: Union[FilePath, Study], + base_config: Optional[PipelineConfig] = None, ): self.benchmark = benchmark - self.base_config = base_config + self.base_config = PipelineConfig() if base_config is None else base_config self.hparams = hparams self._progress: Optional[tqdm] = None @@ -99,11 +99,6 @@ def objective(self, trial: Trial) -> float: # Run pipeline over the dataset report = self.benchmark(pipeline) - # Clean RTTM files - for tmp_file in self.benchmark.output_path.iterdir(): - if tmp_file.name.endswith(".rttm"): - tmp_file.unlink() - # Extract DER from report return report.loc["TOTAL", "diarization error rate"]["%"] diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py index 391d60b3..fab217c5 100644 --- a/src/diart/pipelines.py +++ b/src/diart/pipelines.py @@ -130,8 +130,8 @@ def get_operators(self, source: src.AudioSource) -> List[dops.Operator]: class OnlineSpeakerDiarization: - def __init__(self, config: PipelineConfig, profile: bool = False): - self.config = config + def __init__(self, config: Optional[PipelineConfig] = None, profile: bool = False): + self.config = PipelineConfig() if config is None else config self.profile = profile self.segmentation = blocks.SpeakerSegmentation(config.segmentation, config.device) self.embedding = blocks.OverlapAwareSpeakerEmbedding( diff --git a/src/diart/sinks.py b/src/diart/sinks.py index e859aa77..3fcde11e 100644 --- a/src/diart/sinks.py +++ b/src/diart/sinks.py @@ -1,5 +1,4 @@ from pathlib import Path -from traceback import print_exc from typing import Union, Text, Optional, Tuple import matplotlib.pyplot as plt @@ -18,11 +17,11 @@ class RTTMWriter(Observer): def __init__(self, path: Union[Path, Text], patch_collar: float = 0.05): super().__init__() self.patch_collar = patch_collar - self.path = Path(path) + self.path = Path(path).expanduser() if self.path.exists(): self.path.unlink() - def patch_rttm(self): + def patch(self): """Stitch same-speaker turns that are close to each other""" annotation = list(load_rttm(self.path).values())[0] with open(self.path, 'w') as file: @@ -33,15 +32,36 @@ def on_next(self, value: Tuple[Annotation, SlidingWindowFeature]): value[0].write_rttm(file) def on_error(self, error: Exception): - try: - self.patch_rttm() - except Exception: - print("Error while patching RTTM file:") - print_exc() - exit(1) + self.patch() + raise error def on_completed(self): - self.patch_rttm() + self.patch() + + +class RTTMAccumulator(Observer): + def __init__(self, patch_collar: float = 0.05): + super().__init__() + self.patch_collar = patch_collar + self.annotation = None + + def patch(self): + """Stitch same-speaker turns that are close to each other""" + self.annotation.support(self.patch_collar) + + def on_next(self, value: Tuple[Annotation, SlidingWindowFeature]): + annotation, waveform = value + if self.annotation is None: + self.annotation = annotation + else: + self.annotation.update(annotation) + + def on_error(self, error: Exception): + self.patch() + raise error + + def on_completed(self): + self.patch() class RealTimePlot(Observer): @@ -124,4 +144,4 @@ def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]): def on_error(self, error: Exception): if not isinstance(error, WindowClosedException): - print_exc() + raise error diff --git a/src/diart/sources.py b/src/diart/sources.py index 241576a7..aeca5cb7 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -292,4 +292,5 @@ def send(self, message: AnyStr): # TODO support broadcasting to many clients # Schedule a coroutine to send back the message - asyncio.run_coroutine_threadsafe(self._async_send(message), loop=loop) + if message: + asyncio.run_coroutine_threadsafe(self._async_send(message), loop=loop) diff --git a/src/diart/stream.py b/src/diart/stream.py index 63232e43..271164b9 100644 --- a/src/diart/stream.py +++ b/src/diart/stream.py @@ -7,6 +7,7 @@ import diart.sources as src from diart.inference import RealTimeInference from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig +from diart.sinks import RTTMWriter def run(): @@ -44,7 +45,9 @@ def run(): audio_source = src.MicrophoneAudioSource(config.sample_rate) # Run online inference - RealTimeInference(args.output, do_plot=not args.no_plot)(pipeline, audio_source) + inference = RealTimeInference(pipeline, audio_source, do_plot=not args.no_plot) + inference.attach_observers(RTTMWriter(args.output / f"{audio_source.uri}.rttm")) + inference() if __name__ == "__main__": diff --git a/src/diart/tune.py b/src/diart/tune.py index 000e3da8..d048b83c 100644 --- a/src/diart/tune.py +++ b/src/diart/tune.py @@ -1,6 +1,5 @@ import argparse from pathlib import Path -from uuid import uuid4 import optuna import torch @@ -39,18 +38,13 @@ def run(): args.output.mkdir(parents=True, exist_ok=True) args.device = torch.device("cpu") if args.cpu else None - # Assign unique worker ID - idx = uuid4() - # Create benchmark object to run the pipeline on a set of files - work_path = args.output / f"worker-{idx}" benchmark = Benchmark( args.root, args.reference, - work_path, show_progress=True, show_report=False, - batch_size=args.batch_size + batch_size=args.batch_size, ) # Create the base configuration for each trial @@ -66,12 +60,9 @@ def run(): study_or_path = optuna.load_study(db_name, args.storage, TPESampler()) # Run optimization - optimizer = Optimizer(benchmark, base_config, hparams, study_or_path) + optimizer = Optimizer(benchmark, hparams, study_or_path, base_config) optimizer.optimize(num_iter=args.num_iter, show_progress=True) - # Clean temporary directory - work_path.rmdir() - if __name__ == "__main__": run() From 1b2b2899d9dbcf2a84f12e2baf3d647f4ec42ed1 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Fri, 29 Jul 2022 18:22:47 +0200 Subject: [PATCH 11/23] Greatly simplify the optim API by setting sensible defaults --- README.md | 40 +++++++++++++--------------------------- src/diart/optim.py | 21 ++++++++++++++++----- src/diart/tune.py | 36 +++++++++++++++++++----------------- 3 files changed, 48 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 37e20eaa..0684b04d 100644 --- a/README.md +++ b/README.md @@ -120,9 +120,7 @@ from diart.sinks import RTTMWriter pipeline = OnlineSpeakerDiarization() mic = MicrophoneAudioSource(pipeline.config.sample_rate) inference = RealTimeInference(pipeline, mic, do_plot=True) -# Optionally stream predictions to an RTTM file inference.attach_observers(RTTMWriter("/output/file.rttm")) - inference() ``` @@ -156,7 +154,6 @@ config = PipelineConfig(embedding=MyEmbeddingModel()) pipeline = OnlineSpeakerDiarization(config) mic = MicrophoneAudioSource(config.sample_rate) inference = RealTimeInference(pipeline, mic) - inference() ``` @@ -167,7 +164,7 @@ Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.re ### From the command line ```shell -diart.tune /wav/dir --reference /rttm/dir --output /out/dir +diart.tune /wav/dir --reference /rttm/dir --output /output/dir ``` See `diart.tune -h` for more options. @@ -175,20 +172,13 @@ See `diart.tune -h` for more options. ### From python ```python -from diart.optim import Optimizer, TauActive, RhoUpdate, DeltaNew -from diart.inference import Benchmark - -# Benchmark runs and evaluates the pipeline on a dataset -benchmark = Benchmark("/wav/dir", "/rttm/dir", show_report=False) -# Hyper-parameters to optimize -hparams = [TauActive, RhoUpdate, DeltaNew] -# Optimizer implements the optimization loop -optimizer = Optimizer(benchmark, hparams, "/out/dir") +from diart.optim import Optimizer -optimizer.optimize(num_iter=100, show_progress=True) +optimizer = Optimizer("/wav/dir", "/rttm/dir", "/output/dir") +optimizer(num_iter=100) ``` -This will write results to an sqlite database in `/out/dir`. +This will write results to an sqlite database in `/output/dir`. ### Distributed optimization @@ -200,26 +190,23 @@ mysql -u root -e "CREATE DATABASE IF NOT EXISTS example" optuna create-study --study-name "example" --storage "mysql://root@localhost/example" ``` -Then you can run multiple identical optimizers pointing to this database: +You can now run multiple identical optimizers pointing to this database: ```shell -diart.tune /wav/dir --reference /rttm/dir --output /out/dir --storage mysql://root@localhost/example +diart.tune /wav/dir --reference /rttm/dir --storage mysql://root@localhost/example ``` or in python: ```python from diart.optim import Optimizer -from diart.inference import Benchmark from optuna.samplers import TPESampler import optuna -hparams = ... -benchmark = Benchmark("/wav/dir", "/rttm/dir", show_report=False) -study = optuna.load_study("example", "mysql://root@localhost/example", TPESampler()) -optimizer = Optimizer(benchmark, hparams, study) - -optimizer.optimize(num_iter=100, show_progress=True) +db = "mysql://root@localhost/example" +study = optuna.load_study("example", db, TPESampler()) +optimizer = Optimizer("/wav/dir", "/rttm/dir", study) +optimizer(num_iter=100) ``` ## Build pipelines @@ -276,7 +263,6 @@ pipeline = OnlineSpeakerDiarization() source = WebSocketAudioSource(pipeline.config.sample_rate, "localhost", 7007) inference = RealTimeInference(pipeline, source, do_plot=True) inference.attach_hooks(lambda ann_wav: source.send(ann_wav[0].to_rttm())) - inference() ``` @@ -323,7 +309,7 @@ To obtain the best results, make sure to use the following hyper-parameters: | DIHARD II | 1s | 0.619 | 0.326 | 0.997 | | DIHARD II | 5s | 0.555 | 0.422 | 1.517 | -`diart.benchmark` and `diart.inference.Benchmark` can quickly run and evaluate the pipeline, and even measure its real-time latency. For instance, for a DIHARD III configuration: +`diart.benchmark` and `diart.inference.Benchmark` can run, evaluate and measure the real-time latency of the pipeline. For instance, for a DIHARD III configuration: ```shell diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 @@ -347,7 +333,7 @@ benchmark = Benchmark("/wav/dir", "/rttm/dir") benchmark(pipeline) ``` -This runs a faster inference by pre-calculating model outputs in batches. +This pre-calculates model outputs in batches, so it runs a lot faster. See `diart.benchmark -h` for more options. For convenience and to facilitate future comparisons, we also provide the [expected outputs](/expected_outputs) of the paper implementation in RTTM format for every entry of Table 1 and Figure 5. This includes the VBx offline topline as well as our proposed online approach with latencies 500ms, 1s, 2s, 3s, 4s, and 5s. diff --git a/src/diart/optim.py b/src/diart/optim.py index 576cfd6c..5351e16f 100644 --- a/src/diart/optim.py +++ b/src/diart/optim.py @@ -1,7 +1,7 @@ from collections import OrderedDict from dataclasses import dataclass from pathlib import Path -from typing import Iterable, Text, Optional, Union +from typing import Sequence, Text, Optional, Union from optuna import TrialPruned, Study, create_study from optuna.samplers import TPESampler @@ -38,14 +38,25 @@ def from_name(name: Text) -> 'HyperParameter': class Optimizer: def __init__( self, - benchmark: Benchmark, - hparams: Iterable[HyperParameter], + speech_path: Union[Text, Path], + reference_path: Optional[Union[Text, Path]], study_or_path: Union[FilePath, Study], + batch_size: int = 32, + hparams: Optional[Sequence[HyperParameter]] = None, base_config: Optional[PipelineConfig] = None, ): - self.benchmark = benchmark + self.benchmark = Benchmark( + speech_path, + reference_path, + show_progress=True, + show_report=False, + batch_size=batch_size, + ) self.base_config = PipelineConfig() if base_config is None else base_config self.hparams = hparams + if self.hparams is None: + self.hparams = [TauActive, RhoUpdate, DeltaNew] + self._progress: Optional[tqdm] = None if isinstance(study_or_path, Study): @@ -102,7 +113,7 @@ def objective(self, trial: Trial) -> float: # Extract DER from report return report.loc["TOTAL", "diarization error rate"]["%"] - def optimize(self, num_iter: int, show_progress: bool = True): + def __call__(self, num_iter: int, show_progress: bool = True): self._progress = None if show_progress: self._progress = trange(num_iter) diff --git a/src/diart/tune.py b/src/diart/tune.py index d048b83c..5dac1f79 100644 --- a/src/diart/tune.py +++ b/src/diart/tune.py @@ -6,7 +6,6 @@ from optuna.samplers import TPESampler import diart.argdoc as argdoc -from diart.inference import Benchmark from diart.optim import Optimizer, HyperParameter from diart.pipelines import PipelineConfig @@ -32,21 +31,10 @@ def run(): parser.add_argument("--num-iter", default=100, type=int, help="Number of optimization trials") parser.add_argument("--storage", type=str, help="Optuna storage string. If provided, continue a previous study instead of creating one. The database name must match the study name") - parser.add_argument("--output", required=True, type=str, help="Working directory") + parser.add_argument("--output", type=str, help="Working directory") args = parser.parse_args() - args.output = Path(args.output) - args.output.mkdir(parents=True, exist_ok=True) args.device = torch.device("cpu") if args.cpu else None - # Create benchmark object to run the pipeline on a set of files - benchmark = Benchmark( - args.root, - args.reference, - show_progress=True, - show_report=False, - batch_size=args.batch_size, - ) - # Create the base configuration for each trial base_config = PipelineConfig.from_namespace(args) @@ -54,14 +42,28 @@ def run(): hparams = [HyperParameter.from_name(name) for name in args.hparams] # Use a custom storage if given - study_or_path = args.output - if args.storage is not None: + if args.output is not None: + msg = "Both `output` and `storage` were set, but only one was expected" + assert args.storage is None, msg + args.output = Path(args.output) + args.output.mkdir(parents=True, exist_ok=True) + study_or_path = args.output + elif args.storage is not None: db_name = Path(args.storage).stem study_or_path = optuna.load_study(db_name, args.storage, TPESampler()) + else: + msg = "Please provide either `output` or `storage`" + raise ValueError(msg) # Run optimization - optimizer = Optimizer(benchmark, hparams, study_or_path, base_config) - optimizer.optimize(num_iter=args.num_iter, show_progress=True) + Optimizer( + speech_path=args.root, + reference_path=args.reference, + study_or_path=study_or_path, + batch_size=args.batch_size, + hparams=hparams, + base_config=base_config, + )(num_iter=args.num_iter, show_progress=True) if __name__ == "__main__": From 97a5b59deaef084f0c5d9b955ed936b5c777607d Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Mon, 8 Aug 2022 18:39:46 +0200 Subject: [PATCH 12/23] Add on-the-fly resampling --- src/diart/blocks/__init__.py | 2 +- src/diart/blocks/utils.py | 14 ++++++++++++++ src/diart/pipelines.py | 19 ++++++++++++++----- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py index a2d88c00..254b9291 100644 --- a/src/diart/blocks/__init__.py +++ b/src/diart/blocks/__init__.py @@ -13,4 +13,4 @@ OverlapAwareSpeakerEmbedding, ) from .segmentation import SpeakerSegmentation -from .utils import Binarize +from .utils import Binarize, Resample diff --git a/src/diart/blocks/utils.py b/src/diart/blocks/utils.py index 9a495ca8..a8057347 100644 --- a/src/diart/blocks/utils.py +++ b/src/diart/blocks/utils.py @@ -2,6 +2,9 @@ import numpy as np from pyannote.core import Annotation, Segment, SlidingWindowFeature +import torchaudio.transforms as T + +from ..features import TemporalFeatures, TemporalFeatureFormatter class Binarize: @@ -53,3 +56,14 @@ def __call__(self, segmentation: SlidingWindowFeature) -> Annotation: region = Segment(start_times[spk], timestamps[t + 1].middle) annotation[region, spk] = f"speaker{spk}" return annotation + + +class Resample: + def __init__(self, sample_rate: int, resample_rate: int): + self.resample = T.Resample(sample_rate, resample_rate) + self.formatter = TemporalFeatureFormatter() + + def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: + wav = self.formatter.cast(waveform) # shape (batch, samples, 1) + resampled_wav = self.resample(wav.transpose(-1, -2)).transpose(-1, -2) + return self.formatter.restore_type(resampled_wav) diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py index fab217c5..430ecba6 100644 --- a/src/diart/pipelines.py +++ b/src/diart/pipelines.py @@ -1,4 +1,6 @@ -from typing import Optional, List, Any +import logging +from typing import Optional, List, Any, Union +from typing_extensions import Literal import rx import rx.operators as ops @@ -18,7 +20,7 @@ def __init__( embedding: Optional[m.EmbeddingModel] = None, duration: Optional[float] = None, step: float = 0.5, - latency: Optional[float] = None, + latency: Optional[Union[float, Literal["max", "min"]]] = None, tau_active: float = 0.6, rho_update: float = 0.3, delta_new: float = 1, @@ -49,8 +51,10 @@ def __init__( # Latency defaults to the step duration self.step = step self.latency = latency - if self.latency is None: + if self.latency is None or self.latency == "min": self.latency = self.step + elif latency == "max": + self.latency = self.duration self.tau_active = tau_active self.rho_update = rho_update @@ -142,14 +146,19 @@ def __init__(self, config: Optional[PipelineConfig] = None, profile: bool = Fals assert config.step <= config.latency <= config.duration, msg def from_audio_source(self, source: src.AudioSource) -> rx.Observable: - msg = f"Audio source has sample rate {source.sample_rate}, expected {self.config.sample_rate}" - assert source.sample_rate == self.config.sample_rate, msg operators = [] # Regularize the stream to a specific chunk duration and step if not source.is_regular: operators.append(dops.regularize_audio_stream( self.config.duration, self.config.step, source.sample_rate )) + # Dynamic resampling if the audio source isn't compatible + if self.config.sample_rate != source.sample_rate: + msg = f"Audio source has sample rate {source.sample_rate}, " \ + f"but pipeline's is {self.config.sample_rate}. Will resample." + logging.warning(msg) + resample = blocks.Resample(source.sample_rate, self.config.sample_rate) + operators.append(ops.map(resample)) operators += [ # Extract segmentation and keep audio ops.map(lambda wav: (wav, self.segmentation(wav))), From f1aa182f6e6feae2d7d772cf78200eed36fdeaf7 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Tue, 16 Aug 2022 16:17:31 +0200 Subject: [PATCH 13/23] Add method to convert SpeakerMap into a dictionary. Bug fixes and docs --- src/diart/mapping.py | 8 +++++++- src/diart/sinks.py | 7 ++++--- src/diart/sources.py | 5 +++-- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/diart/mapping.py b/src/diart/mapping.py index 2795ba0b..3023da4d 100644 --- a/src/diart/mapping.py +++ b/src/diart/mapping.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable, Iterable, List, Optional, Text, Tuple, Union +from typing import Callable, Iterable, List, Optional, Text, Tuple, Union, Dict import numpy as np from pyannote.core.utils.distance import cdist @@ -226,6 +226,12 @@ def valid_assignments( source, target = np.array(source), np.array(target) return source, target + def to_dict(self, strict: bool = False) -> Dict[int, int]: + return {src: tgt for src, tgt in zip(*self.valid_assignments(strict))} + + def to_inverse_dict(self, strict: bool = False) -> Dict[int, int]: + return {tgt: src for src, tgt in zip(*self.valid_assignments(strict))} + def is_source_speaker_mapped(self, source_speaker: int) -> bool: return source_speaker in self.mapped_source_speakers diff --git a/src/diart/sinks.py b/src/diart/sinks.py index 3fcde11e..194031f0 100644 --- a/src/diart/sinks.py +++ b/src/diart/sinks.py @@ -7,6 +7,7 @@ from pyannote.metrics.diarization import DiarizationErrorRate from rx.core import Observer from typing_extensions import Literal +from traceback import print_exc class WindowClosedException(Exception): @@ -33,7 +34,7 @@ def on_next(self, value: Tuple[Annotation, SlidingWindowFeature]): def on_error(self, error: Exception): self.patch() - raise error + print_exc() def on_completed(self): self.patch() @@ -58,7 +59,7 @@ def on_next(self, value: Tuple[Annotation, SlidingWindowFeature]): def on_error(self, error: Exception): self.patch() - raise error + print_exc() def on_completed(self): self.patch() @@ -144,4 +145,4 @@ def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]): def on_error(self, error: Exception): if not isinstance(error, WindowClosedException): - raise error + print_exc() diff --git a/src/diart/sources.py b/src/diart/sources.py index aeca5cb7..de04916f 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -234,14 +234,15 @@ class WebSocketAudioSource(AudioSource): ---------- sample_rate: int Sample rate of the chunks emitted. - host: Text + host: Text | None The host to run the websocket server. Defaults to ``None`` (all interfaces). port: int The port to run the websocket server. Defaults to 7007. """ - def __init__(self, sample_rate: int, host: Text = None, port: int = 7007): + def __init__(self, sample_rate: int, host: Optional[Text] = None, port: int = 7007): name = host if host is not None and host else "localhost" uri = f"{name}:{port}" + # FIXME sample_rate is not being used, this can be confusing and lead to incompatibilities super().__init__(uri, sample_rate) self.host = host self.port = port From 6bbc2262f311a48bb193e6c48246d08bc28bb745 Mon Sep 17 00:00:00 2001 From: Khaled Zaouk Date: Wed, 17 Aug 2022 15:24:13 +0200 Subject: [PATCH 14/23] Fix bug with empty RTTMs (#81) * Update README.md * Update README.md * Fix bug with empty RTTMs * Add uri to empty annotation for consistency * Check non-empty annotation list in a more pythonic way Co-authored-by: Juan Coria --- src/diart/inference.py | 6 +++++- src/diart/sinks.py | 9 ++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/diart/inference.py b/src/diart/inference.py index e58240e9..107f9d66 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -68,7 +68,11 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, source: src.AudioSource) # Stream audio through the pipeline source.read() - return load_rttm(rttm_path)[source.uri] + annotations = load_rttm(rttm_path) + if source.uri in annotations: + return annotations[source.uri] + else: + return Annotation(uri=source.uri) class Benchmark: diff --git a/src/diart/sinks.py b/src/diart/sinks.py index e859aa77..7b764f69 100644 --- a/src/diart/sinks.py +++ b/src/diart/sinks.py @@ -24,9 +24,12 @@ def __init__(self, path: Union[Path, Text], patch_collar: float = 0.05): def patch_rttm(self): """Stitch same-speaker turns that are close to each other""" - annotation = list(load_rttm(self.path).values())[0] - with open(self.path, 'w') as file: - annotation.support(self.patch_collar).write_rttm(file) + + annotations = list(load_rttm(self.path).values()) + if annotations: + annotation = annotations[0] + with open(self.path, 'w') as file: + annotation.support(self.patch_collar).write_rttm(file) def on_next(self, value: Tuple[Annotation, SlidingWindowFeature]): with open(self.path, 'a') as file: From a4742f41023e8aa1d4d8a36833013af658cac599 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Wed, 17 Aug 2022 18:01:45 +0200 Subject: [PATCH 15/23] Add SetVolume block to change the volume of audio chunks --- src/diart/blocks/__init__.py | 2 +- src/diart/blocks/utils.py | 60 +++++++++++++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py index 254b9291..d665d022 100644 --- a/src/diart/blocks/__init__.py +++ b/src/diart/blocks/__init__.py @@ -13,4 +13,4 @@ OverlapAwareSpeakerEmbedding, ) from .segmentation import SpeakerSegmentation -from .utils import Binarize, Resample +from .utils import Binarize, Resample, SetVolume diff --git a/src/diart/blocks/utils.py b/src/diart/blocks/utils.py index a8057347..23dfb847 100644 --- a/src/diart/blocks/utils.py +++ b/src/diart/blocks/utils.py @@ -1,6 +1,7 @@ from typing import Text import numpy as np +import torch from pyannote.core import Annotation, Segment, SlidingWindowFeature import torchaudio.transforms as T @@ -59,11 +60,68 @@ def __call__(self, segmentation: SlidingWindowFeature) -> Annotation: class Resample: + """Dynamically resample audio chunks. + + Parameters + ---------- + sample_rate: int + Original sample rate of the input audio + resample_rate: int + Sample rate of the output + """ def __init__(self, sample_rate: int, resample_rate: int): self.resample = T.Resample(sample_rate, resample_rate) self.formatter = TemporalFeatureFormatter() def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: wav = self.formatter.cast(waveform) # shape (batch, samples, 1) - resampled_wav = self.resample(wav.transpose(-1, -2)).transpose(-1, -2) + with torch.no_grad(): + resampled_wav = self.resample(wav.transpose(-1, -2)).transpose(-1, -2) return self.formatter.restore_type(resampled_wav) + + +class SetVolume: + """Change the volume of an audio chunk. + + Notice that the output volume might be different to avoid saturation. + + Parameters + ---------- + volume_in_db: float + Target volume in dB. Must be positive. + """ + def __init__(self, volume_in_db: float): + msg = "Volume dB must be greater than 0" + assert volume_in_db > 0, msg + self.target_db = volume_in_db + self.formatter = TemporalFeatureFormatter() + + @staticmethod + def get_volumes(waveforms: torch.Tensor) -> torch.Tensor: + """Compute the volumes of a set of audio chunks. + + Parameters + ---------- + waveforms: torch.Tensor + Audio chunks. Shape (batch, samples, channels). + + Returns + ------- + volumes: torch.Tensor + Audio chunk volumes per channel. Shape (batch, 1, channels) + """ + return 10 * torch.log10(torch.mean(np.abs(waveforms) ** 2, dim=1, keepdim=True)) + + def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: + wav = self.formatter.cast(waveform) # shape (batch, samples, channels) + with torch.no_grad(): + # Compute current volume per chunk, shape (batch, 1, channels) + current_volumes = self.get_volumes(wav) + # Determine gain to reach the target volume + gains = 10 ** ((-self.target_db - current_volumes) / 20) + # Apply gain + wav = gains * wav + # If maximum value is greater than one, normalize chunk + maximums = torch.clamp(torch.amax(torch.abs(wav), dim=1, keepdim=True), 1) + wav = wav / maximums + return self.formatter.restore_type(wav) From 0ef3a11b34bb429a2ff4b869bf05780a3af59acb Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Wed, 17 Aug 2022 18:14:45 +0200 Subject: [PATCH 16/23] Fix inverted decibels in SetVolume --- src/diart/blocks/utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/diart/blocks/utils.py b/src/diart/blocks/utils.py index 23dfb847..9264014e 100644 --- a/src/diart/blocks/utils.py +++ b/src/diart/blocks/utils.py @@ -88,11 +88,9 @@ class SetVolume: Parameters ---------- volume_in_db: float - Target volume in dB. Must be positive. + Target volume in dB. """ def __init__(self, volume_in_db: float): - msg = "Volume dB must be greater than 0" - assert volume_in_db > 0, msg self.target_db = volume_in_db self.formatter = TemporalFeatureFormatter() @@ -110,7 +108,7 @@ def get_volumes(waveforms: torch.Tensor) -> torch.Tensor: volumes: torch.Tensor Audio chunk volumes per channel. Shape (batch, 1, channels) """ - return 10 * torch.log10(torch.mean(np.abs(waveforms) ** 2, dim=1, keepdim=True)) + return 10 * torch.log10(torch.mean(torch.abs(waveforms) ** 2, dim=1, keepdim=True)) def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: wav = self.formatter.cast(waveform) # shape (batch, samples, channels) @@ -118,7 +116,7 @@ def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: # Compute current volume per chunk, shape (batch, 1, channels) current_volumes = self.get_volumes(wav) # Determine gain to reach the target volume - gains = 10 ** ((-self.target_db - current_volumes) / 20) + gains = 10 ** ((self.target_db - current_volumes) / 20) # Apply gain wav = gains * wav # If maximum value is greater than one, normalize chunk From b66e05c69f4eadbdac4fe11a12a0fe9824229851 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Wed, 17 Aug 2022 18:30:25 +0200 Subject: [PATCH 17/23] Rename SetVolume to AdjustVolume --- src/diart/blocks/__init__.py | 2 +- src/diart/blocks/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py index d665d022..7f44869a 100644 --- a/src/diart/blocks/__init__.py +++ b/src/diart/blocks/__init__.py @@ -13,4 +13,4 @@ OverlapAwareSpeakerEmbedding, ) from .segmentation import SpeakerSegmentation -from .utils import Binarize, Resample, SetVolume +from .utils import Binarize, Resample, AdjustVolume diff --git a/src/diart/blocks/utils.py b/src/diart/blocks/utils.py index 9264014e..c8cdc443 100644 --- a/src/diart/blocks/utils.py +++ b/src/diart/blocks/utils.py @@ -80,7 +80,7 @@ def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: return self.formatter.restore_type(resampled_wav) -class SetVolume: +class AdjustVolume: """Change the volume of an audio chunk. Notice that the output volume might be different to avoid saturation. From ca15311f084f5b5e49dc69fd03468b0bff9f8640 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Tue, 23 Aug 2022 15:02:47 +0200 Subject: [PATCH 18/23] Add diart.stream arguments to change pyannote models --- README.md | 5 ++++- src/diart/argdoc.py | 2 ++ src/diart/stream.py | 7 +++++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 0684b04d..a0fa09f5 100644 --- a/README.md +++ b/README.md @@ -312,7 +312,7 @@ To obtain the best results, make sure to use the following hyper-parameters: `diart.benchmark` and `diart.inference.Benchmark` can run, evaluate and measure the real-time latency of the pipeline. For instance, for a DIHARD III configuration: ```shell -diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 +diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --segmentation pyannote/segmentation@Interspeech2021 ``` or using the inference API: @@ -320,8 +320,11 @@ or using the inference API: ```python from diart.inference import Benchmark from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig +from diart.models import SegmentationModel config = PipelineConfig( + # Set the model used in the paper + segmentation=SegmentationModel.from_pyannote("pyannote/segmentation@Interspeech2021"), step=0.5, latency=0.5, tau_active=0.555, diff --git a/src/diart/argdoc.py b/src/diart/argdoc.py index 0fb2a1a6..19784a59 100644 --- a/src/diart/argdoc.py +++ b/src/diart/argdoc.py @@ -1,3 +1,5 @@ +SEGMENTATION = "Segmentation model name from pyannote" +EMBEDDING = "Embedding model name from pyannote" STEP = "Sliding window step (in seconds)" LATENCY = "System latency (in seconds). STEP <= LATENCY <= CHUNK_DURATION" TAU = "Probability threshold to consider a speaker as active. 0 <= TAU <= 1" diff --git a/src/diart/stream.py b/src/diart/stream.py index 271164b9..05a911ce 100644 --- a/src/diart/stream.py +++ b/src/diart/stream.py @@ -8,11 +8,16 @@ from diart.inference import RealTimeInference from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig from diart.sinks import RTTMWriter +from diart.models import SegmentationModel, EmbeddingModel def run(): parser = argparse.ArgumentParser() parser.add_argument("source", type=str, help="Path to an audio file | 'microphone'") + parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, + help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") + parser.add_argument("--embedding", default="pyannote/embedding", type=str, + help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding") parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") @@ -28,6 +33,8 @@ def run(): help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file") args = parser.parse_args() args.device = torch.device("cpu") if args.cpu else None + args.segmentation = SegmentationModel.from_pyannote(args.segmentation) + args.embedding = EmbeddingModel.from_pyannote(args.embedding) # Define online speaker diarization pipeline config = PipelineConfig.from_namespace(args) From 9880003ee9aca99592cdc10696483ae6230b3a17 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Fri, 26 Aug 2022 17:27:08 +0200 Subject: [PATCH 19/23] Add model arguments in diart.benchmark and diart.tune. Other improvements --- requirements.txt | 2 +- setup.cfg | 2 +- src/diart/benchmark.py | 7 +++++++ src/diart/blocks/segmentation.py | 2 +- src/diart/features.py | 3 --- src/diart/sinks.py | 18 ++++++++++++++---- src/diart/tune.py | 7 +++++++ 7 files changed, 31 insertions(+), 10 deletions(-) diff --git a/requirements.txt b/requirements.txt index a72491e7..42f406c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ einops>=0.3.0 tqdm>=4.64.0 pandas>=1.4.2 torchaudio>=0.10,<1.0 -pyannote.core>=4.4 +pyannote.core>=4.5 pyannote.database>=4.1.1 pyannote.metrics>=3.2 optuna>=2.10 diff --git a/setup.cfg b/setup.cfg index c22a7aff..4f8830d6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,7 +29,7 @@ install_requires= tqdm>=4.64.0 pandas>=1.4.2 torchaudio>=0.10,<1.0 - pyannote.core>=4.4 + pyannote.core>=4.5 pyannote.database>=4.1.1 pyannote.metrics>=3.2 optuna>=2.10 diff --git a/src/diart/benchmark.py b/src/diart/benchmark.py index ab60be0d..f4689da3 100644 --- a/src/diart/benchmark.py +++ b/src/diart/benchmark.py @@ -4,12 +4,17 @@ import diart.argdoc as argdoc from diart.inference import Benchmark +from diart.models import SegmentationModel, EmbeddingModel from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig def run(): parser = argparse.ArgumentParser() parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)") + parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, + help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") + parser.add_argument("--embedding", default="pyannote/embedding", type=str, + help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding") parser.add_argument("--reference", type=str, help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files") parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") @@ -26,6 +31,8 @@ def run(): parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to no writing") args = parser.parse_args() args.device = torch.device("cpu") if args.cpu else None + args.segmentation = SegmentationModel.from_pyannote(args.segmentation) + args.embedding = EmbeddingModel.from_pyannote(args.embedding) benchmark = Benchmark( args.root, diff --git a/src/diart/blocks/segmentation.py b/src/diart/blocks/segmentation.py index 1b310944..3796441a 100644 --- a/src/diart/blocks/segmentation.py +++ b/src/diart/blocks/segmentation.py @@ -37,4 +37,4 @@ def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: with torch.no_grad(): wave = rearrange(self.formatter.cast(waveform), "batch sample channel -> batch channel sample") output = self.model(wave.to(self.device)).cpu() - return self.formatter.restore_type(output) \ No newline at end of file + return self.formatter.restore_type(output) diff --git a/src/diart/features.py b/src/diart/features.py index 7068682b..2489027a 100644 --- a/src/diart/features.py +++ b/src/diart/features.py @@ -78,9 +78,6 @@ def __init__(self): self.state: Optional[TemporalFeatureFormatterState] = None def set_state(self, features: TemporalFeatures): - if self.state is not None: - return - if isinstance(features, SlidingWindowFeature): msg = "Features sliding window duration and step must be equal" assert features.sliding_window.duration == features.sliding_window.step, msg diff --git a/src/diart/sinks.py b/src/diart/sinks.py index d235a45e..d84dc653 100644 --- a/src/diart/sinks.py +++ b/src/diart/sinks.py @@ -14,6 +14,15 @@ class WindowClosedException(Exception): pass +def _extract_annotation(value: Union[Tuple, Annotation]) -> Annotation: + if isinstance(value, tuple): + return value[0] + if isinstance(value, Annotation): + return value + msg = f"Expected tuple or Annotation, but got {type(value)}" + raise ValueError(msg) + + class RTTMWriter(Observer): def __init__(self, path: Union[Path, Text], patch_collar: float = 0.05): super().__init__() @@ -29,9 +38,10 @@ def patch(self): with open(self.path, 'w') as file: annotations[0].support(self.patch_collar).write_rttm(file) - def on_next(self, value: Tuple[Annotation, SlidingWindowFeature]): + def on_next(self, value: Union[Tuple, Annotation]): + annotation = _extract_annotation(value) with open(self.path, 'a') as file: - value[0].write_rttm(file) + annotation.write_rttm(file) def on_error(self, error: Exception): self.patch() @@ -51,8 +61,8 @@ def patch(self): """Stitch same-speaker turns that are close to each other""" self.annotation.support(self.patch_collar) - def on_next(self, value: Tuple[Annotation, SlidingWindowFeature]): - annotation, waveform = value + def on_next(self, value: Union[Tuple, Annotation]): + annotation = _extract_annotation(value) if self.annotation is None: self.annotation = annotation else: diff --git a/src/diart/tune.py b/src/diart/tune.py index 5dac1f79..978f94f5 100644 --- a/src/diart/tune.py +++ b/src/diart/tune.py @@ -6,6 +6,7 @@ from optuna.samplers import TPESampler import diart.argdoc as argdoc +from diart.models import SegmentationModel, EmbeddingModel from diart.optim import Optimizer, HyperParameter from diart.pipelines import PipelineConfig @@ -15,6 +16,10 @@ def run(): parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)") parser.add_argument("--reference", required=True, type=str, help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files") + parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, + help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") + parser.add_argument("--embedding", default="pyannote/embedding", type=str, + help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding") parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") @@ -34,6 +39,8 @@ def run(): parser.add_argument("--output", type=str, help="Working directory") args = parser.parse_args() args.device = torch.device("cpu") if args.cpu else None + args.segmentation = SegmentationModel.from_pyannote(args.segmentation) + args.embedding = EmbeddingModel.from_pyannote(args.embedding) # Create the base configuration for each trial base_config = PipelineConfig.from_namespace(args) From 00c493650f7615ec89b437eb9ee059c6b2b9233e Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Tue, 30 Aug 2022 21:19:06 +0200 Subject: [PATCH 20/23] Rename RTTMAccumulator to DiarizationPredictionAccumulator --- src/diart/inference.py | 10 +++++----- src/diart/sinks.py | 17 +++++++++++------ 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/diart/inference.py b/src/diart/inference.py index c3301d34..7156f554 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -11,7 +11,7 @@ import diart.operators as dops import diart.sources as src from diart.pipelines import OnlineSpeakerDiarization -from diart.sinks import RTTMAccumulator, RTTMWriter, RealTimePlot +from diart.sinks import DiarizationPredictionAccumulator, RTTMWriter, RealTimePlot class RealTimeInference: @@ -38,7 +38,7 @@ def __init__( self.pipeline = pipeline self.source = source self.do_plot = do_plot - self.accumulator = RTTMAccumulator() + self.accumulator = DiarizationPredictionAccumulator() self.stream = self.pipeline.from_audio_source(source).pipe( dops.progress(f"Streaming {source.uri}", total=source.length, leave=True), ops.do(self.accumulator), @@ -87,7 +87,7 @@ def __call__(self) -> Annotation: ) observable.subscribe() self.source.read() - return self.accumulator.annotation + return self.accumulator.get_prediction() class Benchmark: @@ -211,10 +211,10 @@ def __call__(self, pipeline: OnlineSpeakerDiarization) -> Optional[pd.DataFrame] ops.do(RTTMWriter(self.output_path / f"{source.uri}.rttm")) ) - accumulator = RTTMAccumulator() + accumulator = DiarizationPredictionAccumulator() observable.subscribe(accumulator) source.read() - predictions.append(accumulator.annotation) + predictions.append(accumulator.get_prediction()) # Run evaluation if self.reference_path is not None: diff --git a/src/diart/sinks.py b/src/diart/sinks.py index d84dc653..13f2dbec 100644 --- a/src/diart/sinks.py +++ b/src/diart/sinks.py @@ -51,22 +51,27 @@ def on_completed(self): self.patch() -class RTTMAccumulator(Observer): +class DiarizationPredictionAccumulator(Observer): def __init__(self, patch_collar: float = 0.05): super().__init__() self.patch_collar = patch_collar - self.annotation = None + self._annotation = None def patch(self): """Stitch same-speaker turns that are close to each other""" - self.annotation.support(self.patch_collar) + self._annotation.support(self.patch_collar) + + def get_prediction(self) -> Annotation: + # Patch again in case this is called before on_completed + self.patch() + return self._annotation def on_next(self, value: Union[Tuple, Annotation]): annotation = _extract_annotation(value) - if self.annotation is None: - self.annotation = annotation + if self._annotation is None: + self._annotation = annotation else: - self.annotation.update(annotation) + self._annotation.update(annotation) def on_error(self, error: Exception): self.patch() From fd237c4d07f3fafff4b56a05fa2f53c82009ad93 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Wed, 31 Aug 2022 11:25:01 +0200 Subject: [PATCH 21/23] Improve websocket section in README and clarify a TODO comment --- README.md | 4 ++-- src/diart/sources.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a0fa09f5..b39ac8e7 100644 --- a/README.md +++ b/README.md @@ -250,9 +250,9 @@ torch.Size([4, 512]) ## WebSockets -Diart is also compatible with the WebSocket protocol to serve your pipeline on the web. +Diart is also compatible with the WebSocket protocol to serve pipelines on the web. -In the following example we build a minimal server for a client to send audio and receive a prediction in RTTM format: +In the following example we build a minimal server that receives audio chunks and sends back predictions in RTTM format: ```python from diart.pipelines import OnlineSpeakerDiarization diff --git a/src/diart/sources.py b/src/diart/sources.py index de04916f..bdd1dab7 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -242,7 +242,8 @@ class WebSocketAudioSource(AudioSource): def __init__(self, sample_rate: int, host: Optional[Text] = None, port: int = 7007): name = host if host is not None and host else "localhost" uri = f"{name}:{port}" - # FIXME sample_rate is not being used, this can be confusing and lead to incompatibilities + # FIXME sample_rate is not being used, this can be confusing and lead to incompatibilities. + # I would prefer the client to send a JSON with data and sample rate, then resample if needed super().__init__(uri, sample_rate) self.host = host self.port = port From e4a11b06115104ee3888df13919c4ec8f8fc585f Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Wed, 31 Aug 2022 12:48:43 +0200 Subject: [PATCH 22/23] Export csv report in diart.benchmark when output is provided --- src/diart/benchmark.py | 12 ++++++++---- src/diart/inference.py | 4 ++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/diart/benchmark.py b/src/diart/benchmark.py index f4689da3..5cf5808b 100644 --- a/src/diart/benchmark.py +++ b/src/diart/benchmark.py @@ -1,4 +1,5 @@ import argparse +from pathlib import Path import torch @@ -10,12 +11,12 @@ def run(): parser = argparse.ArgumentParser() - parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)") + parser.add_argument("root", type=Path, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)") parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") parser.add_argument("--embedding", default="pyannote/embedding", type=str, help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding") - parser.add_argument("--reference", type=str, + parser.add_argument("--reference", type=Path, help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files") parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") @@ -28,7 +29,7 @@ def run(): parser.add_argument("--batch-size", default=32, type=int, help=f"{argdoc.BATCH_SIZE}. Defaults to 32") parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") - parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to no writing") + parser.add_argument("--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing") args = parser.parse_args() args.device = torch.device("cpu") if args.cpu else None args.segmentation = SegmentationModel.from_pyannote(args.segmentation) @@ -43,7 +44,10 @@ def run(): batch_size=args.batch_size, ) - benchmark(OnlineSpeakerDiarization(PipelineConfig.from_namespace(args), profile=True)) + pipeline = OnlineSpeakerDiarization(PipelineConfig.from_namespace(args), profile=True) + report = benchmark(pipeline) + if args.output is not None and report is not None: + report.to_csv(args.output / "benchmark_report.csv") if __name__ == "__main__": diff --git a/src/diart/inference.py b/src/diart/inference.py index 7156f554..15b703c2 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -135,6 +135,10 @@ def __init__( self.speech_path = Path(speech_path).expanduser() assert self.speech_path.is_dir(), "Speech path must be a directory" + # If there's no reference and no output, then benchmark has no output + msg = "Benchmark expected reference path, output path or both" + assert reference_path is not None or output_path is not None, msg + self.reference_path = reference_path if reference_path is not None: self.reference_path = Path(self.reference_path).expanduser() From b75dc9f33d258e3b43e21799c8e11fc342c988d7 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Wed, 31 Aug 2022 13:42:44 +0200 Subject: [PATCH 23/23] Change version to 0.5.0 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 4f8830d6..bc0ccc8b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name=diart -version=0.4.0 +version=0.5.0 author=Juan Manuel Coria description=Speaker diarization in real time long_description=file: README.md