diff --git a/test/detect/test_nudenet.py b/test/detect/test_nudenet.py index 5bf43bbf2c2..1a8ff6ed964 100644 --- a/test/detect/test_nudenet.py +++ b/test/detect/test_nudenet.py @@ -68,8 +68,18 @@ def nude_girl_detection(): class TestDetectNudeNet: def test_detect_with_nudenet_file(self, nude_girl_file, nude_girl_detection): detection = detect_with_nudenet(nude_girl_file) - assert detection == pytest.approx(nude_girl_detection) + assert [label for _, label, _ in detection] == \ + [label for _, label, _ in nude_girl_detection] + for (actual_box, _, _), (expected_box, _, _) in zip(detection, nude_girl_detection): + assert actual_box == pytest.approx(expected_box) + assert [score for _, _, score in detection] == \ + pytest.approx([score for _, _, score in nude_girl_detection], abs=1e-4) def test_detect_with_nudenet_image(self, nude_girl_image, nude_girl_detection): detection = detect_with_nudenet(nude_girl_image) - assert detection == pytest.approx(nude_girl_detection) + assert [label for _, label, _ in detection] == \ + [label for _, label, _ in nude_girl_detection] + for (actual_box, _, _), (expected_box, _, _) in zip(detection, nude_girl_detection): + assert actual_box == pytest.approx(expected_box) + assert [score for _, _, score in detection] == \ + pytest.approx([score for _, _, score in nude_girl_detection], abs=1e-4)