Skip to content

Commit

Permalink
Rewrite labeler test
Browse files Browse the repository at this point in the history
  • Loading branch information
gliptak authored Jun 10, 2024
1 parent 56b8b2c commit e118ee8
Showing 1 changed file with 19 additions and 24 deletions.
43 changes: 19 additions & 24 deletions dataprofiler/tests/labelers/test_labeler_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import unittest
from unittest import mock
import tempfile

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -235,9 +236,7 @@ def test_verbose(self):
self.assertIn("f1-score ", log_output)
self.assertIn("F1 Score: ", log_output)

@mock.patch("dataprofiler.labelers.labeler_utils.classification_report")
@mock.patch("pandas.DataFrame")
def test_save_conf_mat(self, mock_dataframe, mock_report):
def test_save_conf_mat(self):

# ideally mock out the actual contents written to file, but
# would be difficult to get this completely worked out.
Expand All @@ -248,29 +247,25 @@ def test_save_conf_mat(self, mock_dataframe, mock_report):
[0, 1, 2],
]
)
expected_row_col_names = dict(
columns=["pred:PAD", "pred:UNKNOWN", "pred:OTHER"],
index=["true:PAD", "true:UNKNOWN", "true:OTHER"],
)
mock_instance_df = mock.Mock(spec=pd.DataFrame)()
mock_dataframe.return_value = mock_instance_df
expected_columns=["pred:PAD", "pred:UNKNOWN", "pred:OTHER"]
expected_index=["true:PAD", "true:UNKNOWN", "true:OTHER"]

# still omit bc confusion mat should include all despite omit
f1, f1_report = labeler_utils.evaluate_accuracy(
self.y_pred,
self.y_true,
self.num_labels,
self.reverse_label_mapping,
omitted_labels=["PAD"],
verbose=False,
confusion_matrix_file="test.csv",
)

self.assertTrue((mock_dataframe.call_args[0][0] == expected_conf_mat).all())
self.assertDictEqual(expected_row_col_names, mock_dataframe.call_args[1])

mock_instance_df.to_csv.assert_called()
with tempfile.NamedTemporaryFile() as tmpFile:
# still omit bc confusion mat should include all despite omit
f1, f1_report = labeler_utils.evaluate_accuracy(
self.y_pred,
self.y_true,
self.num_labels,
self.reverse_label_mapping,
omitted_labels=["PAD"],
verbose=False,
confusion_matrix_file=tmpFile.name,
)

df1 = pd.read_csv(tmpFile.name, index_col=0)
self.assertListEqual(list(df1.columns), expected_columns)
self.assertListEqual(list(df1.index), expected_index)
np.testing.assert_array_equal(df1.values, expected_conf_mat)

class TestTFFunctions(unittest.TestCase):
def test_get_tf_layer_index_from_name(self):
Expand Down

0 comments on commit e118ee8

Please sign in to comment.