Skip to content

Commit

Permalink
Merge pull request #410 from claritychallenge/cad2-fix-in-main
Browse files Browse the repository at this point in the history
Cad2 fix in main
  • Loading branch information
jonbarker68 authored Sep 11, 2024
2 parents dade982 + 73ea121 commit a526a3a
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 18 deletions.
2 changes: 1 addition & 1 deletion recipes/cad2/task1/baseline/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ path:
scene_listeners_file: ${path.metadata_dir}/scene_listeners.valid.json
exp_folder: ./exp_${separator.causality} # folder to store enhanced signals and final results

input_sample_rate: 44100 # sample rate of the input mixture
input_sample_rate: 44100 # sample rate of the input mixture
remix_sample_rate: 44100 # sample rate for the output remixed signal
HAAQI_sample_rate: 24000 # sample rate for computing HAAQI score

Expand Down
46 changes: 32 additions & 14 deletions recipes/cad2/task1/baseline/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def compute_intelligibility(
save_intermediate: bool = False,
path_intermediate: str | Path | None = None,
equiv_0db_spl: float = 100,
) -> tuple[float, float]:
) -> tuple[float, float, dict]:
"""
Compute the Intelligibility score for the enhanced signal
using the Whisper model.
Expand All @@ -79,6 +79,9 @@ def compute_intelligibility(
Returns:
The intelligibility score for the left and right channels
"""

lyrics = {}

if path_intermediate is None:
path_intermediate = Path.cwd()
if isinstance(path_intermediate, str):
Expand All @@ -90,6 +93,7 @@ def compute_intelligibility(
)

reference = segment_metadata["text"]
lyrics["reference"] = reference

# Compute left ear
ear.set_audiogram(listener.audiogram_left)
Expand All @@ -101,8 +105,10 @@ def compute_intelligibility(
44100,
sample_rate,
)
hipothesis = scorer.transcribe(left_path, fp16=False)["text"]
left_results = compute_measures(reference, hipothesis)
hypothesis = scorer.transcribe(left_path.as_posix(), fp16=False)["text"]
lyrics["hypothesis_left"] = hypothesis

left_results = compute_measures(reference, hypothesis)

# Compute right ear
ear.set_audiogram(listener.audiogram_right)
Expand All @@ -114,8 +120,10 @@ def compute_intelligibility(
44100,
sample_rate,
)
hipothesis = scorer.transcribe(right_path, fp16=False)["text"]
right_results = compute_measures(reference, hipothesis)
hypothesis = scorer.transcribe(right_path.as_posix(), fp16=False)["text"]
lyrics["hypothesis_right"] = hypothesis

right_results = compute_measures(reference, hypothesis)

# Compute the average score for both ears
total_words = (
Expand All @@ -136,7 +144,11 @@ def compute_intelligibility(
Path(left_path).unlink()
Path(right_path).unlink()

return left_results["hits"] / total_words, right_results["hits"] / total_words
return (
left_results["hits"] / total_words,
right_results["hits"] / total_words,
lyrics,
)


def compute_quality(
Expand Down Expand Up @@ -203,14 +215,14 @@ def load_reference_signal(


def normalise_luft(
signal: np.ndarray, sample_rate: float, target_luft=-40
signal: np.ndarray, sample_rate: float, target_luft: float = -40.0
) -> np.ndarray:
"""
Normalise the signal to a target loudness level.
Args:
signal: input signal to normalise
sample_rate: sample rate of the signal
target_luft: target loudness level in LUFS
target_luft: target loudness level in LUFS.
Returns:
np.ndarray: normalised signal
Expand Down Expand Up @@ -254,6 +266,9 @@ def run_compute_scores(config: DictConfig) -> None:
"scene",
"song",
"listener",
"lyrics",
"hypothesis_left",
"hypothesis_right",
"haaqi_left",
"haaqi_right",
"haaqi_avg",
Expand Down Expand Up @@ -363,7 +378,7 @@ def run_compute_scores(config: DictConfig) -> None:

# Compute the HAAQI and Whisper scores
haaqi_scores = compute_quality(reference, enhanced_signal, listener, config)
whisper_scores = compute_intelligibility(
whisper_left, whisper_right, lyrics_text = compute_intelligibility(
enhanced_signal=enhanced_signal,
segment_metadata=songs[scene["segment_id"]],
scorer=intelligibility_scorer,
Expand All @@ -375,20 +390,23 @@ def run_compute_scores(config: DictConfig) -> None:
equiv_0db_spl=config.evaluate.equiv_0db_spl,
)

max_whisper = np.max([whisper_left, whisper_right])
results_file.add_result(
{
"scene": scene_id,
"song": songs[scene["segment_id"]]["track_name"],
"listener": listener_id,
"lyrics": lyrics_text["reference"],
"hypothesis_left": lyrics_text["hypothesis_left"],
"hypothesis_right": lyrics_text["hypothesis_right"],
"haaqi_left": haaqi_scores[0],
"haaqi_right": haaqi_scores[1],
"haaqi_avg": np.mean(haaqi_scores),
"whisper_left": whisper_scores[0],
"whisper_rigth": whisper_scores[1],
"whisper_be": np.max(whisper_scores),
"whisper_left": whisper_left,
"whisper_rigth": whisper_right,
"whisper_be": max_whisper,
"alpha": alpha,
"score": alpha * np.max(whisper_scores)
+ (1 - alpha) * np.mean(haaqi_scores),
"score": alpha * max_whisper + (1 - alpha) * np.mean(haaqi_scores),
}
)

Expand Down
4 changes: 1 addition & 3 deletions recipes/cad2/task2/baseline/config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# Zenodo download path: path to the zenodo download folder.
# root: root path of the dataset. This path will contain the audio and metadata folders
path:
root: /media/gerardoroadabike/Extreme SSD1/Challenges/CAD2/cadenza_data/cad2/task2
root: ??? # Set to the root of the dataset
metadata_dir: ${path.root}/metadata
music_dir: ${path.root}/audio
gains_file: ${path.metadata_dir}/gains.json
Expand Down

0 comments on commit a526a3a

Please sign in to comment.