diff --git a/pyannote/audio/__init__.py b/pyannote/audio/__init__.py index 462a15d77..827cb1473 100644 --- a/pyannote/audio/__init__.py +++ b/pyannote/audio/__init__.py @@ -27,8 +27,9 @@ from .core.inference import Inference +from .core.streaming_inference import StreamingInference from .core.io import Audio from .core.model import Model from .core.pipeline import Pipeline -__all__ = ["Audio", "Model", "Inference", "Pipeline"] +__all__ = ["Audio", "Model", "Inference", "Pipeline", "StreamingInference"] diff --git a/pyannote/audio/core/inference.py b/pyannote/audio/core/inference.py index 0c3e9b212..df33a90e6 100644 --- a/pyannote/audio/core/inference.py +++ b/pyannote/audio/core/inference.py @@ -225,6 +225,7 @@ def infer(self, chunks: torch.Tensor) -> Union[np.ndarray, Tuple[np.ndarray]]: def __convert(output: torch.Tensor, conversion: nn.Module, **kwargs): return conversion(output).cpu().numpy() + return map_with_specifications( self.model.specifications, __convert, outputs, self.conversion ) @@ -549,7 +550,7 @@ def aggregate( aggregated_scores : SlidingWindowFeature Aggregated scores. Shape is (num_frames, num_classes) """ - + print("aggregate") num_chunks, num_frames_per_chunk, num_classes = scores.data.shape chunks = scores.sliding_window @@ -596,6 +597,7 @@ def aggregate( ) + 1 ) + aggregated_output: np.ndarray = np.zeros( (num_frames, num_classes), dtype=np.float32 ) @@ -611,7 +613,6 @@ def aggregate( aggregated_mask: np.ndarray = np.zeros( (num_frames, num_classes), dtype=np.float32 ) - # loop on the scores of sliding chunks for (chunk, score), (_, mask) in zip(scores, masks): # chunk ~ Segment @@ -620,6 +621,7 @@ def aggregate( start_frame = frames.closest_frame(chunk.start + 0.5 * frames.duration) + aggregated_output[start_frame : start_frame + num_frames_per_chunk] += ( score * mask * hamming_window * warm_up_window ) @@ -644,6 +646,7 @@ def aggregate( return SlidingWindowFeature(average, frames) + @staticmethod def trim( scores: SlidingWindowFeature, diff --git a/pyannote/audio/core/streaming_inference.py b/pyannote/audio/core/streaming_inference.py new file mode 100644 index 000000000..1a8bfcdcf --- /dev/null +++ b/pyannote/audio/core/streaming_inference.py @@ -0,0 +1,911 @@ +# MIT License +# +# Copyright (c) 2020- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import math +import warnings +from pathlib import Path +from typing import Callable, List, Optional, Text, Tuple, Union +from pyannote.audio import Inference +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature +from pytorch_lightning.utilities.memory import is_oom_error + +from pyannote.audio.core.io import AudioFile +from pyannote.audio.core.model import Model, Specifications +from pyannote.audio.core.task import Resolution +from pyannote.audio.utils.multi_task import map_with_specifications +from pyannote.audio.utils.permutation import mae_cost_func, permutate +from pyannote.audio.utils.powerset import Powerset +from pyannote.audio.utils.reproducibility import fix_reproducibility + + +class BaseInference: + pass + + +class StreamingInference(BaseInference): + """Inference + + Parameters + ---------- + model : Model + Model. Will be automatically set to eval() mode and moved to `device` when provided. + window : {"sliding", "whole"}, optional + Use a "sliding" window and aggregate the corresponding outputs (default) + or just one (potentially long) window covering the "whole" file or chunk. + duration : float, optional + Chunk duration, in seconds. Defaults to duration used for training the model. + Has no effect when `window` is "whole". + step : float, optional + Step between consecutive chunks, in seconds. Defaults to warm-up duration when + greater than 0s, otherwise 10% of duration. Has no effect when `window` is "whole". + pre_aggregation_hook : callable, optional + When a callable is provided, it is applied to the model output, just before aggregation. + Takes a (num_chunks, num_frames, dimension) numpy array as input and returns a modified + (num_chunks, num_frames, other_dimension) numpy array passed to overlap-add aggregation. + skip_aggregation : bool, optional + Do not aggregate outputs when using "sliding" window. Defaults to False. + skip_conversion: bool, optional + In case a task has been trained with `powerset` mode, output is automatically + converted to `multi-label`, unless `skip_conversion` is set to True. + batch_size : int, optional + Batch size. Larger values (should) make inference faster. Defaults to 32. + device : torch.device, optional + Device used for inference. Defaults to `model.device`. + In case `device` and `model.device` are different, model is sent to device. + use_auth_token : str, optional + When loading a private huggingface.co model, set `use_auth_token` + to True or to a string containing your hugginface.co authentication + token that can be obtained by running `huggingface-cli login` + """ + + def __init__( + self, + model: Union[Model, Text, Path], + window: Text = "sliding", + duration: float = None, + step: float = None, + pre_aggregation_hook: Callable[[np.ndarray], np.ndarray] = None, + skip_aggregation: bool = False, + skip_conversion: bool = False, + device: torch.device = None, + batch_size: int = 32, + use_auth_token: Union[Text, None] = None, + ): + # ~~~~ model ~~~~~ + + self.model = ( + model + if isinstance(model, Model) + else Model.from_pretrained( + model, + map_location=device, + strict=False, + use_auth_token=use_auth_token, + ) + ) + + if device is None: + device = self.model.device + self.device = device + + self.model.eval() + self.model.to(self.device) + + specifications = self.model.specifications + + # ~~~~ sliding window ~~~~~ + + if window not in ["sliding", "whole"]: + raise ValueError('`window` must be "sliding" or "whole".') + + if window == "whole" and any( + s.resolution == Resolution.FRAME for s in specifications + ): + warnings.warn( + 'Using "whole" `window` inference with a frame-based model might lead to bad results ' + 'and huge memory consumption: it is recommended to set `window` to "sliding".' + ) + self.window = window + + training_duration = next(iter(specifications)).duration + duration = duration or training_duration + if training_duration != duration: + warnings.warn( + f"Model was trained with {training_duration:g}s chunks, and you requested " + f"{duration:g}s chunks for inference: this might lead to suboptimal results." + ) + self.duration = duration + + # ~~~~ powerset to multilabel conversion ~~~~ + + self.skip_conversion = skip_conversion + + conversion = list() + for s in specifications: + if s.powerset and not skip_conversion: + c = Powerset(len(s.classes), s.powerset_max_classes) + else: + c = nn.Identity() + conversion.append(c.to(self.device)) + + if isinstance(specifications, Specifications): + self.conversion = conversion[0] + else: + self.conversion = nn.ModuleList(conversion) + + # ~~~~ overlap-add aggregation ~~~~~ + + self.skip_aggregation = skip_aggregation + self.pre_aggregation_hook = pre_aggregation_hook + + self.warm_up = next(iter(specifications)).warm_up + # Use that many seconds on the left- and rightmost parts of each chunk + # to warm up the model. While the model does process those left- and right-most + # parts, only the remaining central part of each chunk is used for aggregating + # scores during inference. + + # step between consecutive chunks + step = step or ( + 0.1 * self.duration if self.warm_up[0] == 0.0 else self.warm_up[0] + ) + + if step > self.duration: + raise ValueError( + f"Step between consecutive chunks is set to {step:g}s, while chunks are " + f"only {self.duration:g}s long, leading to gaps between consecutive chunks. " + f"Either decrease step or increase duration." + ) + self.step = step + + self.batch_size = batch_size + + def to(self, device: torch.device) -> "Inference": + """Send internal model to `device`""" + + if not isinstance(device, torch.device): + raise TypeError( + f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`" + ) + + self.model.to(device) + self.conversion.to(device) + self.device = device + return self + + def infer(self, chunks: torch.Tensor) -> Union[np.ndarray, Tuple[np.ndarray]]: + """Forward pass + + Takes care of sending chunks to right device and outputs back to CPU + + Parameters + ---------- + chunks : (batch_size, num_channels, num_samples) torch.Tensor + Batch of audio chunks. + + Returns + ------- + outputs : (tuple of) (batch_size, ...) np.ndarray + Model output. + """ + + with torch.inference_mode(): + try: + outputs = self.model(chunks.to(self.device)) + except RuntimeError as exception: + if is_oom_error(exception): + raise MemoryError( + f"batch_size ({self.batch_size: d}) is probably too large. " + f"Try with a smaller value until memory error disappears." + ) + else: + raise exception + + def __convert(output: torch.Tensor, conversion: nn.Module, **kwargs): + return conversion(output).cpu().numpy() + + return map_with_specifications( + self.model.specifications, __convert, outputs, self.conversion + ) + + def slide( + self, + waveform: torch.Tensor, + sample_rate: int, + hook: Optional[Callable], + ) -> Union[SlidingWindowFeature, Tuple[SlidingWindowFeature]]: + """Slide model on a waveform + + Parameters + ---------- + waveform: (num_channels, num_samples) torch.Tensor + Waveform. + sample_rate : int + Sample rate. + hook: Optional[Callable] + When a callable is provided, it is called everytime a batch is + processed with two keyword arguments: + - `completed`: the number of chunks that have been processed so far + - `total`: the total number of chunks + + Returns + ------- + output : (tuple of) SlidingWindowFeature + Model output. Shape is (num_chunks, dimension) for chunk-level tasks, + and (num_frames, dimension) for frame-level tasks. + """ + + window_size: int = self.model.audio.get_num_samples(self.duration) + step_size: int = round(self.step * sample_rate) + _, num_samples = waveform.shape + + def __frames( + example_output, specifications: Optional[Specifications] = None + ) -> SlidingWindow: + if specifications.resolution == Resolution.CHUNK: + return SlidingWindow(start=0.0, duration=self.duration, step=self.step) + return example_output.frames + + frames: Union[SlidingWindow, Tuple[SlidingWindow]] = map_with_specifications( + self.model.specifications, + __frames, + self.model.example_output, + ) + + # prepare complete chunks + if num_samples >= window_size: + chunks: torch.Tensor = rearrange( + waveform.unfold(1, window_size, step_size), + "channel chunk frame -> chunk channel frame", + ) + num_chunks, _, _ = chunks.shape + else: + num_chunks = 0 + + # prepare last incomplete chunk + has_last_chunk = (num_samples < window_size) or ( + num_samples - window_size + ) % step_size > 0 + if has_last_chunk: + # pad last chunk with zeros + last_chunk: torch.Tensor = waveform[:, num_chunks * step_size :] + _, last_window_size = last_chunk.shape + last_pad = window_size - last_window_size + last_chunk = F.pad(last_chunk, (0, last_pad)) + + def __empty_list(**kwargs): + return list() + + outputs: Union[ + List[np.ndarray], Tuple[List[np.ndarray]] + ] = map_with_specifications(self.model.specifications, __empty_list) + + if hook is not None: + hook(completed=0, total=num_chunks + has_last_chunk) + + def __append_batch(output, batch_output, **kwargs) -> None: + output.append(batch_output) + return + + # slide over audio chunks in batch + for c in np.arange(0, num_chunks, self.batch_size): + batch: torch.Tensor = chunks[c : c + self.batch_size] + + batch_outputs: Union[np.ndarray, Tuple[np.ndarray]] = self.infer(batch) + + _ = map_with_specifications( + self.model.specifications, __append_batch, outputs, batch_outputs + ) + + if hook is not None: + hook(completed=c + self.batch_size, total=num_chunks + has_last_chunk) + + # process orphan last chunk + if has_last_chunk: + last_outputs = self.infer(last_chunk[None]) + + _ = map_with_specifications( + self.model.specifications, __append_batch, outputs, last_outputs + ) + + if hook is not None: + hook( + completed=num_chunks + has_last_chunk, + total=num_chunks + has_last_chunk, + ) + + def __vstack(output: List[np.ndarray], **kwargs) -> np.ndarray: + return np.vstack(output) + + outputs: Union[np.ndarray, Tuple[np.ndarray]] = map_with_specifications( + self.model.specifications, __vstack, outputs + ) + + def __aggregate( + outputs: np.ndarray, + frames: SlidingWindow, + specifications: Optional[Specifications] = None, + ) -> SlidingWindowFeature: + # skip aggregation when requested, + # or when model outputs just one vector per chunk + # or when model is permutation-invariant (and not post-processed) + if ( + self.skip_aggregation + or specifications.resolution == Resolution.CHUNK + or ( + specifications.permutation_invariant + and self.pre_aggregation_hook is None + ) + ): + frames = SlidingWindow( + start=0.0, duration=self.duration, step=self.step + ) + return SlidingWindowFeature(outputs, frames) + + if self.pre_aggregation_hook is not None: + outputs = self.pre_aggregation_hook(outputs) + + aggregated = self.concatenate_end_chunk( + SlidingWindowFeature( + outputs, + SlidingWindow(start=0.0, duration=self.duration, step=self.step), + ), + frames=frames, + warm_up=self.warm_up, + hamming=True, + missing=0.0, + ) + + # remove padding that was added to last chunk + if has_last_chunk: + aggregated.data = aggregated.crop( + Segment(0.0, num_samples / sample_rate), mode="loose" + ) + + return aggregated + + return map_with_specifications( + self.model.specifications, __aggregate, outputs, frames + ) + + def __call__( + self, file: AudioFile, hook: Optional[Callable] = None + ) -> Union[ + Tuple[Union[SlidingWindowFeature, np.ndarray]], + Union[SlidingWindowFeature, np.ndarray], + ]: + """Run inference on a whole file + + Parameters + ---------- + file : AudioFile + Audio file. + hook : callable, optional + When a callable is provided, it is called everytime a batch is processed + with two keyword arguments: + - `completed`: the number of chunks that have been processed so far + - `total`: the total number of chunks + + Returns + ------- + output : (tuple of) SlidingWindowFeature or np.ndarray + Model output, as `SlidingWindowFeature` if `window` is set to "sliding" + and `np.ndarray` if is set to "whole". + + """ + + fix_reproducibility(self.device) + + waveform, sample_rate = self.model.audio(file) + + if self.window == "sliding": + return self.slide(waveform, sample_rate, hook=hook) + + outputs: Union[np.ndarray, Tuple[np.ndarray]] = self.infer(waveform[None]) + + def __first_sample(outputs: np.ndarray, **kwargs) -> np.ndarray: + return outputs[0] + + return map_with_specifications( + self.model.specifications, __first_sample, outputs + ) + + def crop( + self, + file: AudioFile, + chunk: Union[Segment, List[Segment]], + duration: Optional[float] = None, + hook: Optional[Callable] = None, + ) -> Union[ + Tuple[Union[SlidingWindowFeature, np.ndarray]], + Union[SlidingWindowFeature, np.ndarray], + ]: + """Run inference on a chunk or a list of chunks + + Parameters + ---------- + file : AudioFile + Audio file. + chunk : Segment or list of Segment + Apply model on this chunk. When a list of chunks is provided and + window is set to "sliding", this is equivalent to calling crop on + the smallest chunk that contains all chunks. In case window is set + to "whole", this is equivalent to concatenating each chunk into one + (artifical) chunk before processing it. + duration : float, optional + Enforce chunk duration (in seconds). This is a hack to avoid rounding + errors that may result in a different number of audio samples for two + chunks of the same duration. + hook : callable, optional + When a callable is provided, it is called everytime a batch is processed + with two keyword arguments: + - `completed`: the number of chunks that have been processed so far + - `total`: the total number of chunks + + Returns + ------- + output : (tuple of) SlidingWindowFeature or np.ndarray + Model output, as `SlidingWindowFeature` if `window` is set to "sliding" + and `np.ndarray` if is set to "whole". + + Notes + ----- + If model needs to be warmed up, remember to extend the requested chunk with the + corresponding amount of time so that it is actually warmed up when processing the + chunk of interest: + >>> chunk_of_interest = Segment(10, 15) + >>> extended_chunk = Segment(10 - warm_up, 15 + warm_up) + >>> inference.crop(file, extended_chunk).crop(chunk_of_interest, returns_data=False) + """ + + fix_reproducibility(self.device) + + if self.window == "sliding": + if not isinstance(chunk, Segment): + start = min(c.start for c in chunk) + end = max(c.end for c in chunk) + chunk = Segment(start=start, end=end) + + waveform, sample_rate = self.model.audio.crop( + file, chunk, duration=duration + ) + outputs: Union[ + SlidingWindowFeature, Tuple[SlidingWindowFeature] + ] = self.slide(waveform, sample_rate, hook=hook) + + def __shift(output: SlidingWindowFeature, **kwargs) -> SlidingWindowFeature: + frames = output.sliding_window + shifted_frames = SlidingWindow( + start=chunk.start, duration=frames.duration, step=frames.step + ) + return SlidingWindowFeature(output.data, shifted_frames) + + return map_with_specifications(self.model.specifications, __shift, outputs) + + if isinstance(chunk, Segment): + waveform, sample_rate = self.model.audio.crop( + file, chunk, duration=duration + ) + else: + waveform = torch.cat( + [self.model.audio.crop(file, c)[0] for c in chunk], dim=1 + ) + + outputs: Union[np.ndarray, Tuple[np.ndarray]] = self.infer(waveform[None]) + + def __first_sample(outputs: np.ndarray, **kwargs) -> np.ndarray: + return outputs[0] + + return map_with_specifications( + self.model.specifications, __first_sample, outputs + ) + + @staticmethod + def aggregate( + scores: SlidingWindowFeature, + frames: SlidingWindow = None, + warm_up: Tuple[float, float] = (0.0, 0.0), + epsilon: float = 1e-12, + hamming: bool = False, + missing: float = np.NaN, + skip_average: bool = False, + ) -> SlidingWindowFeature: + """Aggregation + + Parameters + ---------- + scores : SlidingWindowFeature + Raw (unaggregated) scores. Shape is (num_chunks, num_frames_per_chunk, num_classes). + frames : SlidingWindow, optional + Frames resolution. Defaults to estimate it automatically based on `scores` shape + and chunk size. Providing the exact frame resolution (when known) leads to better + temporal precision. + warm_up : (float, float) tuple, optional + Left/right warm up duration (in seconds). + missing : float, optional + Value used to replace missing (ie all NaNs) values. + skip_average : bool, optional + Skip final averaging step. + + Returns + ------- + aggregated_scores : SlidingWindowFeature + Aggregated scores. Shape is (num_frames, num_classes) + """ + + num_chunks, num_frames_per_chunk, num_classes = scores.data.shape + + chunks = scores.sliding_window + if frames is None: + duration = step = chunks.duration / num_frames_per_chunk + frames = SlidingWindow(start=chunks.start, duration=duration, step=step) + else: + frames = SlidingWindow( + start=chunks.start, + duration=frames.duration, + step=frames.step, + ) + + masks = 1 - np.isnan(scores) + scores.data = np.nan_to_num(scores.data, copy=True, nan=0.0) + + # Hamming window used for overlap-add aggregation + hamming_window = ( + np.hamming(num_frames_per_chunk).reshape(-1, 1) + if hamming + else np.ones((num_frames_per_chunk, 1)) + ) + + # anything before warm_up_left (and after num_frames_per_chunk - warm_up_right) + # will not be used in the final aggregation + + # warm-up windows used for overlap-add aggregation + warm_up_window = np.ones((num_frames_per_chunk, 1)) + # anything before warm_up_left will not contribute to aggregation + warm_up_left = round( + warm_up[0] / scores.sliding_window.duration * num_frames_per_chunk + ) + warm_up_window[:warm_up_left] = epsilon + # anything after num_frames_per_chunk - warm_up_right either + warm_up_right = round( + warm_up[1] / scores.sliding_window.duration * num_frames_per_chunk + ) + warm_up_window[num_frames_per_chunk - warm_up_right :] = epsilon + + # aggregated_output[i] will be used to store the sum of all predictions + # for frame #i + num_frames = ( + frames.closest_frame( + scores.sliding_window.start + + scores.sliding_window.duration + + (num_chunks - 1) * scores.sliding_window.step + ) + + 1 + ) + + aggregated_output: np.ndarray = np.zeros( + (num_frames, num_classes), dtype=np.float32 + ) + + # overlapping_chunk_count[i] will be used to store the number of chunks + # that contributed to frame #i + overlapping_chunk_count: np.ndarray = np.zeros( + (num_frames, num_classes), dtype=np.float32 + ) + + # aggregated_mask[i] will be used to indicate whether + # at least one non-NAN frame contributed to frame #i + aggregated_mask: np.ndarray = np.zeros( + (num_frames, num_classes), dtype=np.float32 + ) + # loop on the scores of sliding chunks + for (chunk, score), (_, mask) in zip(scores, masks): + # chunk ~ Segment + # score ~ (num_frames_per_chunk, num_classes)-shaped np.ndarray + # mask ~ (num_frames_per_chunk, num_classes)-shaped np.ndarray + start_frame = frames.closest_frame(chunk.start) + + + aggregated_output[start_frame : start_frame + num_frames_per_chunk] += ( + score * mask * hamming_window * warm_up_window + ) + + overlapping_chunk_count[ + start_frame : start_frame + num_frames_per_chunk + ] += (mask * hamming_window * warm_up_window) + + aggregated_mask[ + start_frame : start_frame + num_frames_per_chunk + ] = np.maximum( + aggregated_mask[start_frame : start_frame + num_frames_per_chunk], + mask, + ) + + if skip_average: + average = aggregated_output + else: + average = aggregated_output / np.maximum(overlapping_chunk_count, epsilon) + + average[aggregated_mask == 0.0] = missing + + return SlidingWindowFeature(average, frames) + + @staticmethod + def concatenate_end_chunk( + scores: SlidingWindowFeature, + frames: SlidingWindow = None, + warm_up: Tuple[float, float] = (0.0, 0.0), + epsilon: float = 1e-12, + hamming: bool = False, + missing: float = np.NaN, + skip_average: bool = False, + ) -> SlidingWindowFeature: + """Aggregation + + Parameters + ---------- + scores : SlidingWindowFeature + Raw (unaggregated) scores. Shape is (num_chunks, num_frames_per_chunk, num_classes). + frames : SlidingWindow, optional + Frames resolution. Defaults to estimate it automatically based on `scores` shape + and chunk size. Providing the exact frame resolution (when known) leads to better + temporal precision. + warm_up : (float, float) tuple, optional + Left/right warm up duration (in seconds). + missing : float, optional + Value used to replace missing (ie all NaNs) values. + skip_average : bool, optional + Skip final averaging step. + + Returns + ------- + aggregated_scores : SlidingWindowFeature + Aggregated scores. Shape is (num_frames, num_classes) + """ + print("concatenate") + num_chunks, num_frames_per_chunk, num_classes = scores.data.shape + + chunks = scores.sliding_window + if frames is None: + duration = step = chunks.duration / num_frames_per_chunk + frames = SlidingWindow(start=chunks.start, duration=duration, step=step) + else: + frames = SlidingWindow( + start=chunks.start, + duration=frames.duration, + step=frames.step, + ) + masks = 1 - np.isnan(scores) + scores.data = np.nan_to_num(scores.data, copy=True, nan=0.0) + + # aggregated_output[i] will be used to store the sum of all predictions + # for frame #i + num_frames = ( + frames.closest_frame( + scores.sliding_window.start + + scores.sliding_window.duration + + (num_chunks - 1) * scores.sliding_window.step + ) + + 1 + ) + step_frames = frames.closest_frame(scores.sliding_window.step) + aggregated_output: np.ndarray = np.zeros( + (num_frames, num_classes), dtype=np.float32 + ) + aggregated_output[0 : num_frames_per_chunk-step_frames] = scores[0][:num_frames_per_chunk-step_frames] + end = scores.sliding_window.duration - scores.sliding_window.step + + # data = scores.data + # print(data.shape) + # data=data[1:] + # scores = scores[1:] + # loop on the scores of sliding chunks + for (chunk, score) in scores: + # chunk ~ Segment + # score ~ (num_frames_per_chunk, num_classes)-shaped np.ndarray + start_frame = frames.closest_frame(end) + aggregated_output[start_frame : start_frame + step_frames] = score[num_frames_per_chunk-step_frames:] + end = chunk.end + + return SlidingWindowFeature(aggregated_output, frames) + + @staticmethod + def trim( + scores: SlidingWindowFeature, + warm_up: Tuple[float, float] = (0.1, 0.1), + ) -> SlidingWindowFeature: + """Trim left and right warm-up regions + + Parameters + ---------- + scores : SlidingWindowFeature + (num_chunks, num_frames, num_classes)-shaped scores. + warm_up : (float, float) tuple + Left/right warm up ratio of chunk duration. + Defaults to (0.1, 0.1), i.e. 10% on both sides. + + Returns + ------- + trimmed : SlidingWindowFeature + (num_chunks, trimmed_num_frames, num_speakers)-shaped scores + """ + + assert ( + scores.data.ndim == 3 + ), "Inference.trim expects (num_chunks, num_frames, num_classes)-shaped `scores`" + _, num_frames, _ = scores.data.shape + + chunks = scores.sliding_window + + num_frames_left = round(num_frames * warm_up[0]) + num_frames_right = round(num_frames * warm_up[1]) + + num_frames_step = round(num_frames * chunks.step / chunks.duration) + if num_frames - num_frames_left - num_frames_right < num_frames_step: + warnings.warn( + f"Total `warm_up` is so large ({sum(warm_up) * 100:g}% of each chunk) " + f"that resulting trimmed scores does not cover a whole step ({chunks.step:g}s)" + ) + new_data = scores.data[:, num_frames_left : num_frames - num_frames_right] + + new_chunks = SlidingWindow( + start=chunks.start + warm_up[0] * chunks.duration, + step=chunks.step, + duration=(1 - warm_up[0] - warm_up[1]) * chunks.duration, + ) + + return SlidingWindowFeature(new_data, new_chunks) + + @staticmethod + def stitch( + activations: SlidingWindowFeature, + frames: SlidingWindow = None, + lookahead: Optional[Tuple[int, int]] = None, + cost_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None, + match_func: Callable[[np.ndarray, np.ndarray, float], bool] = None, + ) -> SlidingWindowFeature: + """ + + Parameters + ---------- + activations : SlidingWindowFeature + (num_chunks, num_frames, num_classes)-shaped scores. + frames : SlidingWindow, optional + Frames resolution. Defaults to estimate it automatically based on `activations` + shape and chunk size. Providing the exact frame resolution (when known) leads to better + temporal precision. + lookahead : (int, int) tuple + Number of past and future adjacent chunks to use for stitching. + Defaults to (k, k) with k = chunk_duration / chunk_step - 1 + cost_func : callable + Cost function used to find the optimal mapping between two chunks. + Expects two (num_frames, num_classes) torch.tensor as input + and returns cost as a (num_classes, ) torch.tensor + Defaults to mean absolute error (utils.permutations.mae_cost_func) + match_func : callable + Function used to decide whether two speakers mapped by the optimal + mapping actually are a match. + Expects two (num_frames, ) np.ndarray and the cost (from cost_func) + and returns a boolean. Defaults to always returning True. + """ + + num_chunks, num_frames, num_classes = activations.data.shape + + chunks: SlidingWindow = activations.sliding_window + + if frames is None: + duration = step = chunks.duration / num_frames + frames = SlidingWindow(start=chunks.start, duration=duration, step=step) + else: + frames = SlidingWindow( + start=chunks.start, + duration=frames.duration, + step=frames.step, + ) + + max_lookahead = math.floor(chunks.duration / chunks.step - 1) + if lookahead is None: + lookahead = 2 * (max_lookahead,) + + assert all(L <= max_lookahead for L in lookahead) + + if cost_func is None: + cost_func = mae_cost_func + + if match_func is None: + + def always_match(this: np.ndarray, that: np.ndarray, cost: float): + return True + + match_func = always_match + + stitches = [] + for C, (chunk, activation) in enumerate(activations): + local_stitch = np.NAN * np.zeros( + (sum(lookahead) + 1, num_frames, num_classes) + ) + + for c in range( + max(0, C - lookahead[0]), min(num_chunks, C + lookahead[1] + 1) + ): + # extract common temporal support + shift = round((C - c) * num_frames * chunks.step / chunks.duration) + + if shift < 0: + shift = -shift + this_activations = activation[shift:] + that_activations = activations[c, : num_frames - shift] + else: + this_activations = activation[: num_frames - shift] + that_activations = activations[c, shift:] + + # find the optimal one-to-one mapping + _, (permutation,), (cost,) = permutate( + this_activations[np.newaxis], + that_activations, + cost_func=cost_func, + return_cost=True, + ) + + for this, that in enumerate(permutation): + # only stitch under certain condiditions + matching = (c == C) or ( + match_func( + this_activations[:, this], + that_activations[:, that], + cost[this, that], + ) + ) + + if matching: + local_stitch[c - C + lookahead[0], :, this] = activations[ + c, :, that + ] + + # TODO: do not lookahead further once a mismatch is found + + stitched_chunks = SlidingWindow( + start=chunk.start - lookahead[0] * chunks.step, + duration=chunks.duration, + step=chunks.step, + ) + + local_stitch = Inference.aggregate( + SlidingWindowFeature(local_stitch, stitched_chunks), + frames=frames, + hamming=True, + ) + + stitches.append(local_stitch.data) + + stitches = np.stack(stitches) + stitched_chunks = SlidingWindow( + start=chunks.start - lookahead[0] * chunks.step, + duration=chunks.duration + sum(lookahead) * chunks.step, + step=chunks.step, + ) + + return SlidingWindowFeature(stitches, stitched_chunks) diff --git a/pyannote/audio/models/blocks/sincnet.py b/pyannote/audio/models/blocks/sincnet.py index b46549bb3..fea5b5e78 100644 --- a/pyannote/audio/models/blocks/sincnet.py +++ b/pyannote/audio/models/blocks/sincnet.py @@ -38,7 +38,7 @@ class SincNet(nn.Module): - def __init__(self, sample_rate: int = 16000, stride: int = 1): + def __init__(self, sample_rate: int = 16000, stride: int = 1, streaming: bool = False): super().__init__() if sample_rate != 16000: @@ -48,12 +48,14 @@ def __init__(self, sample_rate: int = 16000, stride: int = 1): self.sample_rate = sample_rate self.stride = stride - - self.wav_norm1d = nn.InstanceNorm1d(1, affine=True) + self.streaming = streaming + if self.streaming == False: + self.wav_norm1d = nn.InstanceNorm1d(1, affine=True) self.conv1d = nn.ModuleList() self.pool1d = nn.ModuleList() - self.norm1d = nn.ModuleList() + if self.streaming == False: + self.norm1d = nn.ModuleList() self.conv1d.append( Encoder( @@ -68,15 +70,18 @@ def __init__(self, sample_rate: int = 16000, stride: int = 1): ) ) self.pool1d.append(nn.MaxPool1d(3, stride=3, padding=0, dilation=1)) - self.norm1d.append(nn.InstanceNorm1d(80, affine=True)) + if self.streaming == False: + self.norm1d.append(nn.InstanceNorm1d(80, affine=True)) self.conv1d.append(nn.Conv1d(80, 60, 5, stride=1)) self.pool1d.append(nn.MaxPool1d(3, stride=3, padding=0, dilation=1)) - self.norm1d.append(nn.InstanceNorm1d(60, affine=True)) + if self.streaming == False: + self.norm1d.append(nn.InstanceNorm1d(60, affine=True)) self.conv1d.append(nn.Conv1d(60, 60, 5, stride=1)) self.pool1d.append(nn.MaxPool1d(3, stride=3, padding=0, dilation=1)) - self.norm1d.append(nn.InstanceNorm1d(60, affine=True)) + if self.streaming == False: + self.norm1d.append(nn.InstanceNorm1d(60, affine=True)) @lru_cache def num_frames(self, num_samples: int) -> int: @@ -165,18 +170,32 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor: ---------- waveforms : (batch, channel, sample) """ + if self.streaming == False: + outputs = self.wav_norm1d(waveforms) + for c, (conv1d, pool1d, norm1d) in enumerate( + zip(self.conv1d, self.pool1d, self.norm1d) + ): + + outputs = conv1d(outputs) + + # https://github.com/mravanelli/SincNet/issues/4 + if c == 0: + outputs = torch.abs(outputs) + + outputs = F.leaky_relu(norm1d(pool1d(outputs))) + return outputs - outputs = self.wav_norm1d(waveforms) + else: + outputs = waveforms + for c, (conv1d, pool1d) in enumerate( + zip(self.conv1d, self.pool1d) + ): - for c, (conv1d, pool1d, norm1d) in enumerate( - zip(self.conv1d, self.pool1d, self.norm1d) - ): - outputs = conv1d(outputs) + outputs = conv1d(outputs) - # https://github.com/mravanelli/SincNet/issues/4 - if c == 0: - outputs = torch.abs(outputs) + if c == 0: + outputs = torch.abs(outputs) - outputs = F.leaky_relu(norm1d(pool1d(outputs))) + outputs = F.leaky_relu(pool1d(outputs)) - return outputs + return outputs diff --git a/pyannote/audio/models/segmentation/MultilatencyPyanNet.py b/pyannote/audio/models/segmentation/MultilatencyPyanNet.py new file mode 100644 index 000000000..c436ffa1c --- /dev/null +++ b/pyannote/audio/models/segmentation/MultilatencyPyanNet.py @@ -0,0 +1,292 @@ +# MIT License +# +# Copyright (c) 2020 CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from functools import lru_cache +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from pyannote.core.utils.generators import pairwise + +from pyannote.audio.core.model import Model +from pyannote.audio.core.task import Task +from pyannote.audio.models.blocks.sincnet import SincNet +from pyannote.audio.utils.params import merge_dict + +from typing import Optional +from dataclasses import dataclass +from pyannote.audio.utils.multi_task import map_with_specifications +from pyannote.audio.core.task import ( + Problem, + Resolution, + Specifications, + Task, + UnknownSpecificationsError, +) +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from pyannote.core.utils.generators import pairwise +from functools import cached_property +from typing import Any, Dict, List, Optional, Text, Tuple, Union + +from pyannote.audio.core.model import Model +from pyannote.core import SlidingWindow + +from pyannote.audio.core.task import Task +from pyannote.audio.models.blocks.sincnet import SincNet +from pyannote.audio.utils.params import merge_dict + +@dataclass +class Output: + num_frames: int + dimension: int + frames: SlidingWindow + +class MultilatencyPyanNet(Model): + """PyanNet segmentation model + + SincNet > LSTM > Feed forward > Classifier + + Parameters + ---------- + sample_rate : int, optional + Audio sample rate. Defaults to 16kHz (16000). + num_channels : int, optional + Number of channels. Defaults to mono (1). + sincnet : dict, optional + Keyword arugments passed to the SincNet block. + Defaults to {"stride": 1}. + lstm : dict, optional + Keyword arguments passed to the LSTM layer. + Defaults to {"hidden_size": 128, "num_layers": 2, "bidirectional": True}, + i.e. two bidirectional layers with 128 units each. + Set "monolithic" to False to split monolithic multi-layer LSTM into multiple mono-layer LSTMs. + This may proove useful for probing LSTM internals. + linear : dict, optional + Keyword arugments used to initialize linear layers + Defaults to {"hidden_size": 128, "num_layers": 2}, + i.e. two linear layers with 128 units each. + """ + + SINCNET_DEFAULTS = {"stride": 10} + LSTM_DEFAULTS = { + "hidden_size": 128, + "num_layers": 2, + "bidirectional": False, + "monolithic": True, + "dropout": 0.0, + } + LINEAR_DEFAULTS = {"hidden_size": 128, "num_layers": 2} + + def __init__( + self, + sincnet: dict = None, + lstm: dict = None, + linear: dict = None, + sample_rate: int = 16000, + num_channels: int = 1, + latency_index: int = -1, + task: Optional[Task] = None, + latency_list: Optional[List[float]] = None, + ): + super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) + + self.latency_index = latency_index + sincnet = merge_dict(self.SINCNET_DEFAULTS, sincnet) + sincnet["sample_rate"] = sample_rate + lstm = merge_dict(self.LSTM_DEFAULTS, lstm) + lstm["batch_first"] = True + linear = merge_dict(self.LINEAR_DEFAULTS, linear) + self.save_hyperparameters("sincnet", "lstm", "linear") + self.hparams.latency_list = latency_list or self.task.latency_list + if self.task is not None: + self.latency_list = self.task.latency_list + else: + self.latency_list = self.hparams.latency_list + + self.sincnet = SincNet(**self.hparams.sincnet, streaming=True) + monolithic = lstm["monolithic"] + if monolithic: + multi_layer_lstm = dict(lstm) + del multi_layer_lstm["monolithic"] + self.lstm = nn.LSTM(60, **multi_layer_lstm) + + else: + num_layers = lstm["num_layers"] + if num_layers > 1: + self.dropout = nn.Dropout(p=lstm["dropout"]) + + one_layer_lstm = dict(lstm) + one_layer_lstm["num_layers"] = 1 + one_layer_lstm["dropout"] = 0.0 + del one_layer_lstm["monolithic"] + + self.lstm = nn.ModuleList( + [ + nn.LSTM( + 60 + if i == 0 + else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), + **one_layer_lstm + ) + for i in range(num_layers) + ] + ) + + if linear["num_layers"] < 1: + return + + lstm_out_features: int = self.hparams.lstm["hidden_size"] * ( + 2 if self.hparams.lstm["bidirectional"] else 1 + ) + + self.linear = nn.ModuleList( + [ + nn.Linear(in_features, out_features) + for in_features, out_features in pairwise( + [ + lstm_out_features, + ] + + [self.hparams.linear["hidden_size"]] + * self.hparams.linear["num_layers"] + ) + ]) + + @property + def dimension(self) -> int: + """Dimension of output""" + if isinstance(self.specifications, tuple): + raise ValueError("PyanNet does not support multi-tasking.") + + if self.specifications.powerset: + return self.specifications.num_powerset_classes + else: + return len(self.specifications.classes) + + @lru_cache + def num_frames(self, num_samples: int) -> int: + """Compute number of output frames for a given number of input samples + + Parameters + ---------- + num_samples : int + Number of input samples + + Returns + ------- + num_frames : int + Number of output frames + """ + + return self.sincnet.num_frames(num_samples) + + def receptive_field_size(self, num_frames: int = 1) -> int: + """Compute size of receptive field + + Parameters + ---------- + num_frames : int, optional + Number of frames in the output signal + + Returns + ------- + receptive_field_size : int + Receptive field size. + """ + return self.sincnet.receptive_field_size(num_frames=num_frames) + + def receptive_field_center(self, frame: int = 0) -> int: + """Compute center of receptive field + + Parameters + ---------- + frame : int, optional + Frame index + + Returns + ------- + receptive_field_center : int + Index of receptive field center. + """ + + return self.sincnet.receptive_field_center(frame=frame) + + + def build(self): + if self.hparams.linear["num_layers"] > 0: + in_features = self.hparams.linear["hidden_size"] + else: + in_features = self.hparams.lstm["hidden_size"] * ( + 2 if self.hparams.lstm["bidirectional"] else 1 + ) + + if isinstance(self.specifications, tuple): + raise ValueError("PyanNet does not support multi-tasking.") + + if self.specifications.powerset: + out_features = self.specifications.num_powerset_classes + else: + out_features = len(self.specifications.classes) + + self.classifier = nn.Linear(in_features, out_features * len(self.latency_list)) + self.activation = self.default_activation() + + def forward(self, waveforms: torch.Tensor) -> torch.Tensor: + """Pass forward + + Parameters + ---------- + waveforms : (batch, channel, sample) + + Returns + ------- + scores : (batch, frame, classes) + """ + outputs = self.sincnet(waveforms) + if self.hparams.lstm["monolithic"]: + outputs, _ = self.lstm( + rearrange(outputs, "batch feature frame -> batch frame feature") + ) + else: + outputs = rearrange(outputs, "batch feature frame -> batch frame feature") + for i, lstm in enumerate(self.lstm): + outputs, _ = lstm(outputs) + if i + 1 < self.hparams.lstm["num_layers"]: + outputs = self.dropout(outputs) + + if self.hparams.linear["num_layers"] > 0: + for linear in self.linear: + outputs = F.leaky_relu(linear(outputs)) + # tensor of size (batch_size, num_frames, num_speakers * K) where K is the number of latencies + predictions = self.activation(self.classifier(outputs)) + num_classes_powerset = predictions.size(2) //len(self.latency_list) + + if self.latency_index == -1: + # return all latencies + return predictions + + # return only the corresponding latency + return predictions[:,:, self.latency_index * num_classes_powerset : self.latency_index*num_classes_powerset + num_classes_powerset] \ No newline at end of file diff --git a/pyannote/audio/models/segmentation/__init__.py b/pyannote/audio/models/segmentation/__init__.py index 9f6f5f6e3..e549f0353 100644 --- a/pyannote/audio/models/segmentation/__init__.py +++ b/pyannote/audio/models/segmentation/__init__.py @@ -22,5 +22,8 @@ from .PyanNet import PyanNet from .SSeRiouSS import SSeRiouSS +from .MultilatencyPyanNet import MultilatencyPyanNet -__all__ = ["PyanNet", "SSeRiouSS"] + + +__all__ = ["PyanNet", "SSeRiouSS", "MultilatencyPyanNet"] diff --git a/pyannote/audio/tasks/__init__.py b/pyannote/audio/tasks/__init__.py index 6cbba258f..814b3e5ce 100644 --- a/pyannote/audio/tasks/__init__.py +++ b/pyannote/audio/tasks/__init__.py @@ -22,6 +22,8 @@ from .segmentation.multilabel import MultiLabelSegmentation # isort:skip from .segmentation.speaker_diarization import SpeakerDiarization # isort:skip +from .segmentation.multilatency_speaker_diarization import MultilatencySpeakerDiarization # isort:skip + from .segmentation.voice_activity_detection import VoiceActivityDetection # isort:skip from .segmentation.overlapped_speech_detection import ( # isort:skip OverlappedSpeechDetection, @@ -41,4 +43,5 @@ "MultiLabelSegmentation", "SpeakerEmbedding", "Segmentation", + "MultilatencySpeakerDiarization", ] diff --git a/pyannote/audio/tasks/segmentation/multilatency_speaker_diarization.py b/pyannote/audio/tasks/segmentation/multilatency_speaker_diarization.py new file mode 100644 index 000000000..57fe6c2d6 --- /dev/null +++ b/pyannote/audio/tasks/segmentation/multilatency_speaker_diarization.py @@ -0,0 +1,985 @@ +# MIT License +# +# Copyright (c) 2020- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import sys +import math +import warnings +from collections import Counter +from typing import Dict, Literal, Sequence, Text, Tuple, Union, List, Optional +import torch.nn.functional as F + +import numpy as np +import torch +import torch.nn.functional +from matplotlib import pyplot as plt +from pyannote.core import Segment, SlidingWindowFeature +from pyannote.database.protocol import SpeakerDiarizationProtocol +from pyannote.database.protocol.protocol import Scope, Subset +from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger +from rich.progress import track +from torch_audiomentations.core.transforms_interface import BaseWaveformTransform +from torchmetrics import Metric + +from pyannote.audio.core.task import Problem, Resolution, Specifications, Task +from pyannote.audio.tasks.segmentation.mixins import SegmentationTask +from pyannote.audio.torchmetrics import ( + DiarizationErrorRate, + FalseAlarmRate, + MissedDetectionRate, + OptimalDiarizationErrorRate, + OptimalDiarizationErrorRateThreshold, + OptimalFalseAlarmRate, + OptimalMissedDetectionRate, + OptimalSpeakerConfusionRate, + SpeakerConfusionRate, +) +from pyannote.audio.utils.loss import binary_cross_entropy, mse_loss, interpolate +# from pyannote.audio.utils.loss import nll_loss + +from pyannote.audio.utils.permutation import permutate +from pyannote.audio.utils.powerset import Powerset + +Subsets = list(Subset.__args__) +Scopes = list(Scope.__args__) + +def nll_loss( + prediction: torch.Tensor, + target: torch.Tensor, + class_weight: Optional[torch.Tensor] = None, + weight: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Frame-weighted negative log-likelihood loss + + Parameters + ---------- + prediction : torch.Tensor + Prediction with shape (batch_size, num_frames, num_classes). + target : torch.Tensor + Target with shape (batch_size, num_frames) + class_weight : (num_classes, ) torch.Tensor, optional + Class weight with shape (num_classes, ) + weight : (batch_size, num_frames, 1) torch.Tensor, optional + Frame weight with shape (batch_size, num_frames, 1). + + Returns + ------- + loss : torch.Tensor + """ + + num_classes = prediction.shape[2] + + losses = F.nll_loss( + prediction.reshape(-1, num_classes), + # (batch_size x num_frames, num_classes) + target.view(-1), + # (batch_size x num_frames, ) + weight=class_weight, + # (num_classes, ) + reduction="none", + ).view(target.shape) + # (batch_size, num_frames) + + if weight is None: + return torch.mean(losses) + + else: + # interpolate weight + weight = interpolate(target, weight=weight).squeeze(dim=2) + # (batch_size, num_frames) + + return torch.sum(losses * weight) / torch.sum(weight) + + + +class MultilatencySpeakerDiarization(SegmentationTask, Task): + """Speaker diarization + Parameters + ---------- + protocol : SpeakerDiarizationProtocol + pyannote.database protocol + duration : float, optional + Chunks duration. Defaults to 2s. + max_speakers_per_chunk : int, optional + Maximum number of speakers per chunk (must be at least 2). + Defaults to estimating it from the training set. + max_speakers_per_frame : int, optional + Maximum number of (overlapping) speakers per frame. + Setting this value to 1 or more enables `powerset multi-class` training. + Default behavior is to use `multi-label` training. + weigh_by_cardinality: bool, optional + Weigh each powerset classes by the size of the corresponding speaker set. + In other words, {0, 1} powerset class weight is 2x bigger than that of {0} + or {1} powerset classes. Note that empty (non-speech) powerset class is + assigned the same weight as mono-speaker classes. Defaults to False (i.e. use + same weight for every class). Has no effect with `multi-label` training. + warm_up : float or (float, float), optional + Use that many seconds on the left- and rightmost parts of each chunk + to warm up the model. While the model does process those left- and right-most + parts, only the remaining central part of each chunk is used for computing the + loss during training, and for aggregating scores during inference. + Defaults to 0. (i.e. no warm-up). + balance: Sequence[Text], optional + When provided, training samples are sampled uniformly with respect to these keys. + For instance, setting `balance` to ["database","subset"] will make sure that each + database & subset combination will be equally represented in the training samples. + weight: str, optional + When provided, use this key as frame-wise weight in loss function. + batch_size : int, optional + Number of training samples per batch. Defaults to 32. + num_workers : int, optional + Number of workers used for generating training samples. + Defaults to multiprocessing.cpu_count() // 2. + pin_memory : bool, optional + If True, data loaders will copy tensors into CUDA pinned + memory before returning them. See pytorch documentation + for more details. Defaults to False. + augmentation : BaseWaveformTransform, optional + torch_audiomentations waveform transform, used by dataloader + during training. + vad_loss : {"bce", "mse"}, optional + Add voice activity detection loss. + Cannot be used in conjunction with `max_speakers_per_frame`. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). + References + ---------- + Hervé Bredin and Antoine Laurent + "End-To-End Speaker Segmentation for Overlap-Aware Resegmentation." + Proc. Interspeech 2021 + Zhihao Du, Shiliang Zhang, Siqi Zheng, and Zhijie Yan + "Speaker Embedding-aware Neural Diarization: an Efficient Framework for Overlapping + Speech Diarization in Meeting Scenarios" + https://arxiv.org/abs/2203.09767 + """ + + def __init__( + self, + protocol: SpeakerDiarizationProtocol, + cache: Optional[Union[str, None]] = None, + duration: float = 2.0, + max_speakers_per_chunk: Optional[int] = None, + max_speakers_per_frame: Optional[int] = None, + weigh_by_cardinality: bool = False, + warm_up: Union[float, Tuple[float, float]] = 0.0, + balance: Optional[Sequence[Text]] = None, + weight: Optional[Text] = None, + batch_size: int = 32, + num_workers: Optional[int] = None, + pin_memory: bool = False, + augmentation: Optional[BaseWaveformTransform] = None, + vad_loss: Literal["bce", "mse"] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + max_num_speakers: Optional[ + int + ] = None, # deprecated in favor of `max_speakers_per_chunk`` + loss: Literal["bce", "mse"] = None, # deprecated + latency_list: List[float] = [0.0], + + ): + super().__init__( + protocol, + duration=duration, + warm_up=warm_up, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + augmentation=augmentation, + metric=metric, + cache=cache, + ) + + if not isinstance(protocol, SpeakerDiarizationProtocol): + raise ValueError( + "SpeakerDiarization task requires a SpeakerDiarizationProtocol." + ) + + # deprecation warnings + if max_speakers_per_chunk is None and max_num_speakers is not None: + max_speakers_per_chunk = max_num_speakers + warnings.warn( + "`max_num_speakers` has been deprecated in favor of `max_speakers_per_chunk`." + ) + if loss is not None: + warnings.warn("`loss` has been deprecated and has no effect.") + + # parameter validation + if max_speakers_per_frame is not None: + if max_speakers_per_frame < 1: + raise ValueError( + f"`max_speakers_per_frame` must be 1 or more (you used {max_speakers_per_frame})." + ) + if vad_loss is not None: + raise ValueError( + "`vad_loss` cannot be used jointly with `max_speakers_per_frame`" + ) + + self.max_speakers_per_chunk = max_speakers_per_chunk + self.max_speakers_per_frame = max_speakers_per_frame + self.weigh_by_cardinality = weigh_by_cardinality + self.balance = balance + self.weight = weight + self.vad_loss = vad_loss + self.latency_list=latency_list + + + def setup(self, stage=None): + super().setup(stage) + + # estimate maximum number of speakers per chunk when not provided + if self.max_speakers_per_chunk is None: + training = self.prepared_data["audio-metadata"]["subset"] == Subsets.index( + "train" + ) + + num_unique_speakers = [] + progress_description = f"Estimating maximum number of speakers per {self.duration:g}s chunk in the training set" + for file_id in track( + np.where(training)[0], description=progress_description + ): + annotations = self.prepared_data["annotations-segments"][ + np.where( + self.prepared_data["annotations-segments"]["file_id"] == file_id + )[0] + ] + annotated_regions = self.prepared_data["annotations-regions"][ + np.where( + self.prepared_data["annotations-regions"]["file_id"] == file_id + )[0] + ] + for region in annotated_regions: + # find annotations within current region + region_start = region["start"] + region_end = region["start"] + region["duration"] + region_annotations = annotations[ + np.where( + (annotations["start"] >= region_start) + * (annotations["end"] <= region_end) + )[0] + ] + + for window_start in np.arange( + region_start, region_end - self.duration, 0.25 * self.duration + ): + window_end = window_start + self.duration + window_annotations = region_annotations[ + np.where( + (region_annotations["start"] <= window_end) + * (region_annotations["end"] >= window_start) + )[0] + ] + num_unique_speakers.append( + len(np.unique(window_annotations["file_label_idx"])) + ) + + # because there might a few outliers, estimate the upper bound for the + # number of speakers as the 97th percentile + + num_speakers, counts = zip(*list(Counter(num_unique_speakers).items())) + num_speakers, counts = np.array(num_speakers), np.array(counts) + + sorting_indices = np.argsort(num_speakers) + num_speakers = num_speakers[sorting_indices] + counts = counts[sorting_indices] + + ratios = np.cumsum(counts) / np.sum(counts) + + for k, ratio in zip(num_speakers, ratios): + if k == 0: + print(f" - {ratio:7.2%} of all chunks contain no speech at all.") + elif k == 1: + print(f" - {ratio:7.2%} contain 1 speaker or less") + else: + print(f" - {ratio:7.2%} contain {k} speakers or less") + + self.max_speakers_per_chunk = max( + 2, + num_speakers[np.where(ratios > 0.97)[0][0]], + ) + + print( + f"Setting `max_speakers_per_chunk` to {self.max_speakers_per_chunk}. " + f"You can override this value (or avoid this estimation step) by passing `max_speakers_per_chunk={self.max_speakers_per_chunk}` to the task constructor." + ) + + if ( + self.max_speakers_per_frame is not None + and self.max_speakers_per_frame > self.max_speakers_per_chunk + ): + raise ValueError( + f"`max_speakers_per_frame` ({self.max_speakers_per_frame}) must be smaller " + f"than `max_speakers_per_chunk` ({self.max_speakers_per_chunk})" + ) + + # now that we know about the number of speakers upper bound + # we can set task specifications + self.specifications = Specifications( + problem=Problem.MULTI_LABEL_CLASSIFICATION + if self.max_speakers_per_frame is None + else Problem.MONO_LABEL_CLASSIFICATION, + resolution=Resolution.FRAME, + duration=self.duration, + min_duration=self.min_duration, + warm_up=self.warm_up, + classes=[f"speaker#{i+1}" for i in range(self.max_speakers_per_chunk)], + powerset_max_classes=self.max_speakers_per_frame, + permutation_invariant=True, + ) + + def setup_loss_func(self): + if self.specifications.powerset: + self.model.powerset = Powerset( + len(self.specifications.classes), + self.specifications.powerset_max_classes, + ) + + def prepare_chunk(self, file_id: int, start_time: float, duration: float): + """Prepare chunk + + Parameters + ---------- + file_id : int + File index + start_time : float + Chunk start time + duration : float + Chunk duration. + + Returns + ------- + sample : dict + Dictionary containing the chunk data with the following keys: + - `X`: waveform + - `y`: target as a SlidingWindowFeature instance where y.labels is + in meta.scope space. + - `meta`: + - `scope`: target scope (0: file, 1: database, 2: global) + - `database`: database index + - `file`: file index + """ + + file = self.get_file(file_id) + + # get label scope + label_scope = Scopes[self.prepared_data["audio-metadata"][file_id]["scope"]] + label_scope_key = f"{label_scope}_label_idx" + + # + chunk = Segment(start_time, start_time + duration) + + sample = dict() + sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) + + # gather all annotations of current file + annotations = self.prepared_data["annotations-segments"][ + self.prepared_data["annotations-segments"]["file_id"] == file_id + ] + + # gather all annotations with non-empty intersection with current chunk + chunk_annotations = annotations[ + (annotations["start"] < chunk.end) & (annotations["end"] > chunk.start) + ] + + # discretize chunk annotations at model output resolution + step = self.model.receptive_field.step + half = 0.5 * self.model.receptive_field.duration + + start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - half + start_idx = np.maximum(0, np.round(start / step)).astype(int) + + end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - half + end_idx = np.round(end / step).astype(int) + + # get list and number of labels for current scope + labels = list(np.unique(chunk_annotations[label_scope_key])) + num_labels = len(labels) + + if num_labels > self.max_speakers_per_chunk: + pass + + # initial frame-level targets + num_frames = self.model.num_frames( + round(duration * self.model.hparams.sample_rate) + ) + y = np.zeros((num_frames, num_labels), dtype=np.uint8) + + # map labels to indices + mapping = {label: idx for idx, label in enumerate(labels)} + + for start, end, label in zip( + start_idx, end_idx, chunk_annotations[label_scope_key] + ): + mapped_label = mapping[label] + y[start : end + 1, mapped_label] = 1 + + sample["y"] = SlidingWindowFeature(y, self.model.receptive_field, labels=labels) + + metadata = self.prepared_data["audio-metadata"][file_id] + sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} + sample["meta"]["file"] = file_id + + return sample + + def collate_y(self, batch) -> torch.Tensor: + """ + + Parameters + ---------- + batch : list + List of samples to collate. + "y" field is expected to be a SlidingWindowFeature. + + Returns + ------- + y : torch.Tensor + Collated target tensor of shape (num_frames, self.max_speakers_per_chunk) + If one chunk has more than `self.max_speakers_per_chunk` speakers, we keep + the max_speakers_per_chunk most talkative ones. If it has less, we pad with + zeros (artificial inactive speakers). + """ + + collated_y = [] + for b in batch: + y = b["y"].data + num_speakers = len(b["y"].labels) + if num_speakers > self.max_speakers_per_chunk: + # sort speakers in descending talkativeness order + indices = np.argsort(-np.sum(y, axis=0), axis=0) + # keep only the most talkative speakers + y = y[:, indices[: self.max_speakers_per_chunk]] + + # TODO: we should also sort the speaker labels in the same way + + elif num_speakers < self.max_speakers_per_chunk: + # create inactive speakers by zero padding + y = np.pad( + y, + ((0, 0), (0, self.max_speakers_per_chunk - num_speakers)), + mode="constant", + ) + + else: + # we have exactly the right number of speakers + pass + + collated_y.append(y) + + return torch.from_numpy(np.stack(collated_y)) + + + def segmentation_loss( + self, + permutated_prediction: torch.Tensor, + target: torch.Tensor, + weight: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Permutation-invariant segmentation loss + + Parameters + ---------- + permutated_prediction : (batch_size, num_frames, num_classes) torch.Tensor + Permutated speaker activity predictions. + target : (batch_size, num_frames, num_speakers) torch.Tensor + Speaker activity. + weight : (batch_size, num_frames, 1) torch.Tensor, optional + Frames weight. + + Returns + ------- + seg_loss : torch.Tensor + Permutation-invariant segmentation loss + """ + + if self.specifications.powerset: + # `clamp_min` is needed to set non-speech weight to 1. + class_weight = ( + torch.clamp_min(self.model.powerset.cardinality, 1.0) + if self.weigh_by_cardinality + else None + ) + seg_loss = nll_loss(permutated_prediction, torch.argmax(target, dim=-1)) + else: + seg_loss = binary_cross_entropy( + permutated_prediction, target.float(), weight=weight + ) + + return seg_loss + + + def voice_activity_detection_loss( + self, + permutated_prediction: torch.Tensor, + target: torch.Tensor, + weight: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Voice activity detection loss + + Parameters + ---------- + permutated_prediction : (batch_size, num_frames, num_classes) torch.Tensor + Speaker activity predictions. + target : (batch_size, num_frames, num_speakers) torch.Tensor + Speaker activity. + weight : (batch_size, num_frames, 1) torch.Tensor, optional + Frames weight. + + Returns + ------- + vad_loss : torch.Tensor + Voice activity detection loss. + """ + + vad_prediction, _ = torch.max(permutated_prediction, dim=2, keepdim=True) + # (batch_size, num_frames, 1) + + vad_target, _ = torch.max(target.float(), dim=2, keepdim=False) + # (batch_size, num_frames) + + if self.vad_loss == "bce": + loss = binary_cross_entropy(vad_prediction, vad_target, weight=weight) + + elif self.vad_loss == "mse": + loss = mse_loss(vad_prediction, vad_target, weight=weight) + + return loss + + def training_step(self, batch, batch_idx: int): + """Compute permutation-invariant segmentation loss + Parameters + ---------- + batch : (usually) dict of torch.Tensor + Current batch. + batch_idx: int + Batch index. + Returns + ------- + loss : {str: torch.tensor} + {"loss": loss} + """ + + # target + target = batch["y"] + # (batch_size, num_frames, num_speakers) + + waveform = batch["X"] + # (batch_size, num_channels, num_samples) + + # drop samples that contain too many speakers + num_speakers: torch.Tensor = torch.sum(torch.any(target, dim=1), dim=1) + keep: torch.Tensor = num_speakers <= self.max_speakers_per_chunk + target = target[keep] + waveform = waveform[keep] + + # corner case + if not keep.any(): + return None + + # forward pass + # tensor of size (batch_size, num_frames, num_speakers * K) where K is the number of latencies + predictions = self.model(waveform) + num_classes_powerset = predictions.size(2) //len(self.latency_list) + seg_loss = 0 + for k in range(len(self.latency_list)): + # select onle latency at a time + prediction = predictions[:,:,k*num_classes_powerset:k*num_classes_powerset+num_classes_powerset] + batch_size, num_frames, _ = prediction.shape + # (batch_size, num_frames, num_classes) + + # frames weight + weight_key = getattr(self, "weight", None) + weight = batch.get( + weight_key, + torch.ones(batch_size, num_frames, 1, device=self.model.device), + ) + # (batch_size, num_frames, 1) + + # warm-up + warm_up_left = round(self.warm_up[0] / self.duration * num_frames) + weight[:, :warm_up_left] = 0.0 + warm_up_right = round(self.warm_up[1] / self.duration * num_frames) + weight[:, num_frames - warm_up_right :] = 0.0 + + # shift prediction and target + if self.latency_list[k] >= 0: + delay = int(np.floor(num_frames * (self.latency_list[k]) / self.duration)) # round down + prediction = prediction[:, delay:, :] + target = target[:, :num_frames-delay, :] + else: + delay = int(np.floor(num_frames * (-1.0 * self.latency_list[k]) / self.duration)) # round down + prediction = prediction[:, :num_frames-delay, :] + target = target[:, delay:, :] + + #compute loss (all losses are added, there are K losses) + if self.specifications.powerset: + multilabel = self.model.powerset.to_multilabel(prediction) + permutated_target, _ = permutate(multilabel, target) + permutated_target_powerset = self.model.powerset.to_powerset( + permutated_target.float() + ) + seg_loss += self.segmentation_loss( + prediction, permutated_target_powerset, weight=weight + ) + + else: + permutated_prediction, _ = permutate(target, prediction) + seg_loss += self.segmentation_loss( + permutated_prediction, target, weight=weight + ) + + + self.model.log( + "loss/train/segmentation", + seg_loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + if self.vad_loss is None: + vad_loss = 0.0 + + else: + # TODO: vad_loss probably does not make sense in powerset mode + # because first class (empty set of labels) does exactly this... + if self.specifications.powerset: + vad_loss = self.voice_activity_detection_loss( + prediction, permutated_target_powerset, weight=weight + ) + + else: + vad_loss = self.voice_activity_detection_loss( + permutated_prediction, target, weight=weight + ) + + self.model.log( + "loss/train/vad", + vad_loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + loss = seg_loss + vad_loss + + # skip batch if something went wrong for some reason + if torch.isnan(loss): + return None + + self.model.log( + "loss/train", + loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + return {"loss": loss} + + def default_metric( + self, + ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: + """Returns diarization error rate and its components""" + + if self.specifications.powerset: + return { + "DiarizationErrorRate": DiarizationErrorRate(0.5), + "DiarizationErrorRate/Confusion": SpeakerConfusionRate(0.5), + "DiarizationErrorRate/Miss": MissedDetectionRate(0.5), + "DiarizationErrorRate/FalseAlarm": FalseAlarmRate(0.5), + } + + return { + "DiarizationErrorRate": OptimalDiarizationErrorRate(), + "DiarizationErrorRate/Threshold": OptimalDiarizationErrorRateThreshold(), + "DiarizationErrorRate/Confusion": OptimalSpeakerConfusionRate(), + "DiarizationErrorRate/Miss": OptimalMissedDetectionRate(), + "DiarizationErrorRate/FalseAlarm": OptimalFalseAlarmRate(), + } + + # TODO: no need to compute gradient in this method + def validation_step(self, batch, batch_idx: int): + """Compute validation loss and metric + Parameters + ---------- + batch : dict of torch.Tensor + Current batch. + batch_idx: int + Batch index. + """ + + # target + target = batch["y"] + # (batch_size, num_frames, num_speakers) + + waveform = batch["X"] + # (batch_size, num_channels, num_samples) + + # TODO: should we handle validation samples with too many speakers + # waveform = waveform[keep] + # target = target[keep] + + # forward pass + # tensor of size (batch_size, num_frames, num_speakers * K) where K is the number of latencies + predictions = self.model(waveform) + losses=[] + num_classes_powerset = predictions.size(2) //len(self.latency_list) + for k in range(len(self.latency_list)): + # select one latency + prediction = predictions[:,:,k*num_classes_powerset:k*num_classes_powerset+num_classes_powerset] + batch_size, num_frames, _ = prediction.shape + + # frames weight + weight_key = getattr(self, "weight", None) + weight = batch.get( + weight_key, + torch.ones(batch_size, num_frames, 1, device=self.model.device), + ) + # (batch_size, num_frames, 1) + + # warm-up + warm_up_left = round(self.warm_up[0] / self.duration * num_frames) + weight[:, :warm_up_left] = 0.0 + warm_up_right = round(self.warm_up[1] / self.duration * num_frames) + weight[:, num_frames - warm_up_right :] = 0.0 + + # shift prediction and target + if self.latency_list[k] >= 0: + delay = int(np.floor(num_frames * (self.latency_list[k]) / self.duration)) # round down + prediction = prediction[:, delay:, :] + reference = target[:, :num_frames-delay, :] + else: + delay = int(np.floor(num_frames * (-1.0 * self.latency_list[k]) / self.duration)) # round down + prediction = prediction[:, :num_frames-delay, :] + reference = target[:, delay:, :] + + if self.specifications.powerset: + multilabel = self.model.powerset.to_multilabel(prediction) + permutated_target, _ = permutate(multilabel, reference) + + # FIXME: handle case where target have too many speakers? + # since we don't need + permutated_target_powerset = self.model.powerset.to_powerset( + permutated_target.float() + ) + losses.append(self.segmentation_loss( + prediction, permutated_target_powerset, weight=weight + )) + + else: + permutated_prediction, _ = permutate(reference, prediction) + losses.append(self.segmentation_loss( + permutated_prediction, reference, weight=weight + )) + + # with the following line, the validation DER is calculated on the first latency prediction + multilabel = self.model.powerset.to_multilabel(predictions[:,:,:num_classes_powerset]) + + + seg_loss = torch.sum(torch.tensor(losses)) + + self.model.log( + "loss/val/segmentation", + seg_loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + if self.vad_loss is None: + vad_loss = 0.0 + + else: + # TODO: vad_loss probably does not make sense in powerset mode + # because first class (empty set of labels) does exactly this... + if self.specifications.powerset: + vad_loss = self.voice_activity_detection_loss( + prediction, permutated_target_powerset, weight=weight + ) + + else: + vad_loss = self.voice_activity_detection_loss( + permutated_prediction, target, weight=weight + ) + + self.model.log( + "loss/val/vad", + vad_loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + loss = seg_loss + vad_loss + + self.model.log( + "loss/val", + loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + if self.specifications.powerset: + self.model.validation_metric( + torch.transpose( + multilabel[:, warm_up_left : num_frames - warm_up_right], 1, 2 + ), + torch.transpose( + target[:, warm_up_left : num_frames - warm_up_right], 1, 2 + ), + ) + else: + self.model.validation_metric( + torch.transpose( + prediction[:, warm_up_left : num_frames - warm_up_right], 1, 2 + ), + torch.transpose( + target[:, warm_up_left : num_frames - warm_up_right], 1, 2 + ), + ) + + self.model.log_dict( + self.model.validation_metric, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + # log first batch visualization every 2^n epochs. + if ( + self.model.current_epoch == 0 + or math.log2(self.model.current_epoch) % 1 > 0 + or batch_idx > 0 + ): + return + + # visualize first 9 validation samples of first batch in Tensorboard/MLflow + + if self.specifications.powerset: + y = permutated_target.float().cpu().numpy() + y_pred = multilabel.cpu().numpy() + else: + y = target.float().cpu().numpy() + y_pred = permutated_prediction.cpu().numpy() + + # prepare 3 x 3 grid (or smaller if batch size is smaller) + num_samples = min(self.batch_size, 9) + nrows = math.ceil(math.sqrt(num_samples)) + ncols = math.ceil(num_samples / nrows) + fig, axes = plt.subplots( + nrows=2 * nrows, ncols=ncols, figsize=(8, 5), squeeze=False + ) + + # reshape target so that there is one line per class when plotting it + y[y == 0] = np.NaN + if len(y.shape) == 2: + y = y[:, :, np.newaxis] + y *= np.arange(y.shape[2]) + + # plot each sample + for sample_idx in range(num_samples): + # find where in the grid it should be plotted + row_idx = sample_idx // nrows + col_idx = sample_idx % ncols + + # plot target + ax_ref = axes[row_idx * 2 + 0, col_idx] + sample_y = y[sample_idx] + ax_ref.plot(sample_y) + ax_ref.set_xlim(0, len(sample_y)) + ax_ref.set_ylim(-1, sample_y.shape[1]) + ax_ref.get_xaxis().set_visible(False) + ax_ref.get_yaxis().set_visible(False) + + # plot predictions + ax_hyp = axes[row_idx * 2 + 1, col_idx] + sample_y_pred = y_pred[sample_idx] + ax_hyp.axvspan(0, warm_up_left, color="k", alpha=0.5, lw=0) + ax_hyp.axvspan( + num_frames - warm_up_right, num_frames, color="k", alpha=0.5, lw=0 + ) + ax_hyp.plot(sample_y_pred) + ax_hyp.set_ylim(-0.1, 1.1) + ax_hyp.set_xlim(0, len(sample_y)) + ax_hyp.get_xaxis().set_visible(False) + + plt.tight_layout() + + for logger in self.model.loggers: + if isinstance(logger, TensorBoardLogger): + logger.experiment.add_figure("samples", fig, self.model.current_epoch) + elif isinstance(logger, MLFlowLogger): + logger.experiment.log_figure( + run_id=logger.run_id, + figure=fig, + artifact_file=f"samples_epoch{self.model.current_epoch}.png", + ) + + plt.close(fig) + + +def main(protocol: str, subset: str = "test", model: str = "pyannote/segmentation"): + """Evaluate a segmentation model""" + + from pyannote.database import FileFinder, get_protocol + from rich.progress import Progress + + from pyannote.audio import Inference + from pyannote.audio.pipelines.utils import get_devices + from pyannote.audio.utils.metric import DiscreteDiarizationErrorRate + from pyannote.audio.utils.signal import binarize + + (device,) = get_devices(needs=1) + metric = DiscreteDiarizationErrorRate() + protocol = get_protocol(protocol, preprocessors={"audio": FileFinder()}) + files = list(getattr(protocol, subset)()) + + with Progress() as progress: + main_task = progress.add_task(protocol.name, total=len(files)) + file_task = progress.add_task("Processing", total=1.0) + + def progress_hook(completed: Optional[int] = None, total: Optional[int] = None): + progress.update(file_task, completed=completed / total) + + inference = Inference(model, device=device) + + for file in files: + progress.update(file_task, description=file["uri"]) + reference = file["annotation"] + hypothesis = binarize(inference(file, hook=progress_hook)) + uem = file["annotated"] + _ = metric(reference, hypothesis, uem=uem) + progress.advance(main_task) + + _ = metric.report(display=True) + + +if __name__ == "__main__": + import typer + + typer.run(main) \ No newline at end of file