Skip to content

Commit

Permalink
Merge pull request #409 from claritychallenge/cad2-fix-evaluation
Browse files Browse the repository at this point in the history
Fix Task1 evaluation
  • Loading branch information
groadabike authored Sep 10, 2024
2 parents bebd5fc + 0e73be3 commit 4cca4ab
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 18 deletions.
61 changes: 46 additions & 15 deletions recipes/cad2/task1/baseline/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import hashlib
import json
import logging
from pathlib import Path
Expand All @@ -25,6 +26,13 @@
logger = logging.getLogger(__name__)


def set_song_seed(song: str) -> None:
"""Set a seed that is unique for the given song"""
song_encoded = hashlib.md5(song.encode("utf-8")).hexdigest()
song_md5 = int(song_encoded, 16) % (10**8)
np.random.seed(song_md5)


def make_scene_listener_list(scenes_listeners: dict, small_test: bool = False) -> list:
"""Make the list of scene-listener pairing to process
Expand Down Expand Up @@ -58,7 +66,7 @@ def compute_intelligibility(
save_intermediate: bool = False,
path_intermediate: str | Path | None = None,
equiv_0db_spl: float = 100,
) -> tuple[float, float]:
) -> tuple[float, float, dict]:
"""
Compute the Intelligibility score for the enhanced signal
using the Whisper model.
Expand All @@ -79,6 +87,9 @@ def compute_intelligibility(
Returns:
The intelligibility score for the left and right channels
"""

lyrics = {}

if path_intermediate is None:
path_intermediate = Path.cwd()
if isinstance(path_intermediate, str):
Expand All @@ -90,6 +101,7 @@ def compute_intelligibility(
)

reference = segment_metadata["text"]
lyrics["reference"] = reference

# Compute left ear
ear.set_audiogram(listener.audiogram_left)
Expand All @@ -101,8 +113,10 @@ def compute_intelligibility(
44100,
sample_rate,
)
hipothesis = scorer.transcribe(left_path, fp16=False)["text"]
left_results = compute_measures(reference, hipothesis)
hypothesis = scorer.transcribe(left_path.as_posix(), fp16=False)["text"]
lyrics["hypothesis_left"] = hypothesis

left_results = compute_measures(reference, hypothesis)

# Compute right ear
ear.set_audiogram(listener.audiogram_right)
Expand All @@ -114,8 +128,10 @@ def compute_intelligibility(
44100,
sample_rate,
)
hipothesis = scorer.transcribe(right_path, fp16=False)["text"]
right_results = compute_measures(reference, hipothesis)
hypothesis = scorer.transcribe(right_path.as_posix(), fp16=False)["text"]
lyrics["hypothesis_right"] = hypothesis

right_results = compute_measures(reference, hypothesis)

# Compute the average score for both ears
total_words = (
Expand All @@ -136,7 +152,11 @@ def compute_intelligibility(
Path(left_path).unlink()
Path(right_path).unlink()

return left_results["hits"] / total_words, right_results["hits"] / total_words
return (
left_results["hits"] / total_words,
right_results["hits"] / total_words,
lyrics,
)


def compute_quality(
Expand Down Expand Up @@ -203,14 +223,14 @@ def load_reference_signal(


def normalise_luft(
signal: np.ndarray, sample_rate: float, target_luft=-40
signal: np.ndarray, sample_rate: float, target_luft: float = -40.0
) -> np.ndarray:
"""
Normalise the signal to a target loudness level.
Args:
signal: input signal to normalise
sample_rate: sample rate of the signal
target_luft: target loudness level in LUFS
target_luft: target loudness level in LUFS.
Returns:
np.ndarray: normalised signal
Expand Down Expand Up @@ -254,6 +274,9 @@ def run_compute_scores(config: DictConfig) -> None:
"scene",
"song",
"listener",
"lyrics",
"hypothesis_left",
"hypothesis_right",
"haaqi_left",
"haaqi_right",
"haaqi_avg",
Expand Down Expand Up @@ -302,6 +325,10 @@ def run_compute_scores(config: DictConfig) -> None:

scene_id, listener_id = scene_listener_ids

# Set the random seed for the scene
if config.evaluate.set_random_seed:
set_song_seed(scene_id)

# Load scene details
scene = scenes[scene_id]
listener = listener_dict[listener_id]
Expand Down Expand Up @@ -363,7 +390,7 @@ def run_compute_scores(config: DictConfig) -> None:

# Compute the HAAQI and Whisper scores
haaqi_scores = compute_quality(reference, enhanced_signal, listener, config)
whisper_scores = compute_intelligibility(
whisper_left, whisper_right, lyrics_text = compute_intelligibility(
enhanced_signal=enhanced_signal,
segment_metadata=songs[scene["segment_id"]],
scorer=intelligibility_scorer,
Expand All @@ -375,20 +402,24 @@ def run_compute_scores(config: DictConfig) -> None:
equiv_0db_spl=config.evaluate.equiv_0db_spl,
)

max_whisper = np.max([whisper_left, whisper_right])
mean_haaqi = np.mean(haaqi_scores)
results_file.add_result(
{
"scene": scene_id,
"song": songs[scene["segment_id"]]["track_name"],
"listener": listener_id,
"lyrics": lyrics_text["reference"],
"hypothesis_left": lyrics_text["hypothesis_left"],
"hypothesis_right": lyrics_text["hypothesis_right"],
"haaqi_left": haaqi_scores[0],
"haaqi_right": haaqi_scores[1],
"haaqi_avg": np.mean(haaqi_scores),
"whisper_left": whisper_scores[0],
"whisper_rigth": whisper_scores[1],
"whisper_be": np.max(whisper_scores),
"haaqi_avg": mean_haaqi,
"whisper_left": whisper_left,
"whisper_rigth": whisper_right,
"whisper_be": max_whisper,
"alpha": alpha,
"score": alpha * np.max(whisper_scores)
+ (1 - alpha) * np.mean(haaqi_scores),
"score": alpha * max_whisper + (1 - alpha) * mean_haaqi,
}
)

Expand Down
4 changes: 1 addition & 3 deletions recipes/cad2/task2/baseline/config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# Zenodo download path: path to the zenodo download folder.
# root: root path of the dataset. This path will contain the audio and metadata folders
path:
root: /media/gerardoroadabike/Extreme SSD1/Challenges/CAD2/cadenza_data/cad2/task2
root: ?? # Set to the root of the dataset
metadata_dir: ${path.root}/metadata
music_dir: ${path.root}/audio
gains_file: ${path.metadata_dir}/gains.json
Expand Down

0 comments on commit 4cca4ab

Please sign in to comment.