diff --git a/tests/test_wer.py b/tests/test_wer.py index 6017bc4..337b85f 100644 --- a/tests/test_wer.py +++ b/tests/test_wer.py @@ -107,6 +107,19 @@ def test_wer_example_4(self): self.assertEqual(wer(ref, hyp), expected_result) + def test_wer_empty_strings(self): + """ + Test the wer function with empty reference and hypothesis strings. + + This test evaluates the WER function with empty strings as input. + It verifies that the calculated WER is 0 for identical empty strings. + """ + ref = [""] + hyp = [""] + expected_result = 0.0 + + self.assertEqual(wer(ref, hyp), expected_result) + if __name__ == "__main__": # pragma: no cover unittest.main() diff --git a/werpy/metrics.pyx b/werpy/metrics.pyx index 3d06daf..7198204 100644 --- a/werpy/metrics.pyx +++ b/werpy/metrics.pyx @@ -50,7 +50,7 @@ cpdef np.ndarray calculations(object reference, object hypothesis): ) ld = ldm[m][n] - wer = ld / m + wer = ld / max(m, 1) # Avoid division by 0 insertions, deletions, substitutions = 0, 0, 0 inserted_words, deleted_words, substituted_words = [], [], [] diff --git a/werpy/wer.py b/werpy/wer.py index 4b873ab..d283a94 100644 --- a/werpy/wer.py +++ b/werpy/wer.py @@ -56,8 +56,9 @@ def wer(reference, hypothesis) -> float: transform_word_error_rate_breakdown = np.transpose( word_error_rate_breakdown.tolist() ) - wer_result = (np.sum(transform_word_error_rate_breakdown[1])) / ( - np.sum(transform_word_error_rate_breakdown[2]) + total_words = np.sum(transform_word_error_rate_breakdown[2]) + wer_result = np.sum(transform_word_error_rate_breakdown[1]) / max( + total_words, 1 ) else: wer_result = word_error_rate_breakdown[0]