Skip to content

Commit

Permalink
Merge pull request #337 from ksneab7/bug_fix_for_column_header_issue
Browse files Browse the repository at this point in the history
fix for column name exclusion bug
  • Loading branch information
taylorfturner authored Sep 27, 2023
2 parents 3ac052d + a2ab09b commit df7c401
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 64 deletions.
14 changes: 14 additions & 0 deletions synthetic_data/distinct_generators/null_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Contains a random Null generator."""
import numpy as np


def null_generation(num_rows: int = 1) -> np.array:
"""
Randomly generates an array of integers between the given min and max values.
:param num_rows: the number of rows in np array generated
:type num_rows: int, optional
:return: np array of null values
"""
return np.array([None] * num_rows)
77 changes: 40 additions & 37 deletions synthetic_data/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from synthetic_data.distinct_generators.datetime_generator import random_datetimes
from synthetic_data.distinct_generators.float_generator import random_floats
from synthetic_data.distinct_generators.int_generator import random_integers
from synthetic_data.distinct_generators.null_generator import null_generation
from synthetic_data.distinct_generators.text_generator import random_text
from synthetic_data.graph_synthetic_data import GraphDataGenerator
from synthetic_data.synthetic_data import make_data_from_report
Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(
"datetime": random_datetimes,
"string": random_text,
"text": random_text,
"null_generator": null_generation,
}

@classmethod
Expand Down Expand Up @@ -105,49 +107,46 @@ def _generate_uncorrelated_column_data(self, num_samples):
col_ = copy.deepcopy(col)

generator_name = col_.get("data_type", None)

if not generator_name:
logging.warning(
f"Generator of type {generator_name} is not implemented."
)
continue
column_header = col_.get("column_name", None)

col_["rng"] = self.rng
col_["num_rows"] = num_samples
if generator_name:
if generator_name in ["string", "text"]:
if col_.get("categorical", False):
generator_name = "categorical"
total = 0
for count in col["statistics"]["categorical_count"].values():
total += count

if generator_name in ["string", "text"]:
if col_.get("categorical", False):
generator_name = "categorical"
total = 0
for count in col["statistics"]["categorical_count"].values():
total += count

probabilities = []
for count in col["statistics"]["categorical_count"].values():
probabilities.append(count / total)
probabilities = []
for count in col["statistics"]["categorical_count"].values():
probabilities.append(count / total)

col_["probabilities"] = probabilities
col_["categories"] = col_["statistics"].get("categories", None)
col_["probabilities"] = probabilities
col_["categories"] = col_["statistics"].get("categories", None)

col_["vocab"] = col_["statistics"].get("vocab", None)
col_["vocab"] = col_["statistics"].get("vocab", None)

col_["min"] = col_["statistics"].get("min", None)
col_["max"] = col_["statistics"].get("max", None)
col_["min"] = col_["statistics"].get("min", None)
col_["max"] = col_["statistics"].get("max", None)

# edge cases for extracting data from profiler report.
if generator_name == "datetime":
col_["format"] = col_["statistics"].get("format", None)
col_["min"] = pd.to_datetime(
col_["statistics"].get("min", None), format=col_["format"][0]
)
col_["max"] = pd.to_datetime(
col_["statistics"].get("max", None), format=col_["format"][0]
)
# edge cases for extracting data from profiler report.
if generator_name == "datetime":
col_["format"] = col_["statistics"].get("format", None)
col_["min"] = pd.to_datetime(
col_["statistics"].get("min", None), format=col_["format"][0]
)
col_["max"] = pd.to_datetime(
col_["statistics"].get("max", None), format=col_["format"][0]
)

if generator_name == "float":
col_["precision"] = int(
col_["statistics"].get("precision", None).get("max", None)
)
if generator_name == "float":
col_["precision"] = int(
col_["statistics"].get("precision", None).get("max", None)
)
elif not generator_name:
generator_name = "null_generator"

generator_func = self.gen_funcs.get(generator_name, None)
params_gen_funcs = inspect.signature(generator_func)
Expand All @@ -157,7 +156,9 @@ def _generate_uncorrelated_column_data(self, num_samples):
param_build[param[0]] = col_[param[0]]

generated_data = generator_func(**param_build)
if col_["order"] in sorting_types:
if (not generator_name == "null_generator") and col_[
"order"
] in sorting_types:
dataset.append(
self.get_ordered_column(
generated_data,
Expand All @@ -166,7 +167,9 @@ def _generate_uncorrelated_column_data(self, num_samples):
)
)
else:
if col_["order"] is not None:
if (not generator_name == "null_generator") and col_[
"order"
] is not None:
logging.warning(
f"""{generator_name} is passed with sorting type of {col_["order"]}.
Ascending and descending are the only supported options.
Expand All @@ -178,7 +181,7 @@ def _generate_uncorrelated_column_data(self, num_samples):
else:
dataset.append(generated_data)

column_names.append(generator_name)
column_names.append(column_header)

return self.convert_data_to_df(dataset, column_names=column_names)

Expand Down
84 changes: 57 additions & 27 deletions tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,25 @@ def test_synthesize_uncorrelated_output(self):
np.testing.assert_array_equal(
actual_synthetic_data.columns.values,
np.array(
["datetime", "categorical", "int", "string", "float"], dtype="object"
[
"datetime",
"host",
"src",
"proto",
"type",
"srcport",
"destport",
"srcip",
"locale",
"localeabbr",
"postalcode",
"latitude",
"longitude",
"owner",
"comment",
"int_col",
],
dtype="object",
),
)

Expand Down Expand Up @@ -301,13 +319,13 @@ def test_generate_uncorrelated_column_data(
else:
self.assertEqual(call_args_list[key], expected_calls[j][key])

@mock.patch("synthetic_data.generators.logging.warning")
@mock.patch("dataprofiler.profilers.StructuredProfiler.report")
def test_get_ordered_column_integration(self, mock_report, mock_warning):
def test_get_ordered_column_integration(self, mock_report):
mock_report.return_value = {
"data_stats": [
{
"data_type": "int",
"column_name": "test_column_1",
"order": "ascending",
"statistics": {
"min": 1.0,
Expand All @@ -316,6 +334,7 @@ def test_get_ordered_column_integration(self, mock_report, mock_warning):
},
{
"data_type": "string",
"column_name": "test_column_2",
"categorical": False,
"order": "ascending",
"statistics": {
Expand All @@ -326,6 +345,7 @@ def test_get_ordered_column_integration(self, mock_report, mock_warning):
},
{
"data_type": "string",
"column_name": "test_column_3",
"categorical": True,
"order": "ascending",
"statistics": {
Expand All @@ -342,11 +362,13 @@ def test_get_ordered_column_integration(self, mock_report, mock_warning):
},
{
"data_type": "float",
"column_name": "test_column_4",
"order": "ascending",
"statistics": {"min": 2.11234, "max": 8.0, "precision": {"max": 6}},
},
{
"data_type": "datetime",
"column_name": "test_column_5",
"order": "ascending",
"statistics": {
"format": ["%Y-%m-%d"],
Expand All @@ -355,6 +377,7 @@ def test_get_ordered_column_integration(self, mock_report, mock_warning):
},
},
{
"column_name": "test_column_6",
"data_type": None,
},
]
Expand All @@ -363,36 +386,43 @@ def test_get_ordered_column_integration(self, mock_report, mock_warning):
self.assertFalse(generator.is_correlated)

expected_array = [
[1, "arif", "blue", 2.246061, "2003-06-02"],
[1, "daips", "blue", 2.628393, "2003-10-08"],
[1, "dree", "orange", 2.642511, "2006-02-17"],
[1, "drqs", "orange", 2.807119, "2006-11-18"],
[1, "dwdaa", "orange", 3.009102, "2008-12-07"],
[2, "fswfe", "orange", 3.061853, "2009-12-03"],
[2, "fwqe", "orange", 3.677692, "2013-02-24"],
[2, "ipdpd", "orange", 3.887541, "2013-08-18"],
[3, "pdis", "red", 4.24257, "2014-02-19"],
[3, "peii", "red", 4.355663, "2014-04-29"],
[3, "pepie", "red", 4.739156, "2017-12-13"],
[3, "qrdq", "red", 4.831716, "2018-02-03"],
[3, "qrps", "yellow", 5.062321, "2019-05-13"],
[3, "rrqp", "yellow", 5.82323, "2020-01-09"],
[4, "sasr", "yellow", 6.212038, "2021-12-29"],
[4, "sspwe", "yellow", 6.231978, "2022-01-25"],
[4, "sssi", "yellow", 6.365346, "2023-03-20"],
[4, "wpfsi", "yellow", 7.461754, "2023-10-23"],
[4, "wqfed", "yellow", 7.775666, "2026-02-04"],
[4, "wsde", "yellow", 7.818521, "2027-06-13"],
[1, "arif", "blue", 2.246061, "2003-06-02", None],
[1, "daips", "blue", 2.628393, "2003-10-08", None],
[1, "dree", "orange", 2.642511, "2006-02-17", None],
[1, "drqs", "orange", 2.807119, "2006-11-18", None],
[1, "dwdaa", "orange", 3.009102, "2008-12-07", None],
[2, "fswfe", "orange", 3.061853, "2009-12-03", None],
[2, "fwqe", "orange", 3.677692, "2013-02-24", None],
[2, "ipdpd", "orange", 3.887541, "2013-08-18", None],
[3, "pdis", "red", 4.24257, "2014-02-19", None],
[3, "peii", "red", 4.355663, "2014-04-29", None],
[3, "pepie", "red", 4.739156, "2017-12-13", None],
[3, "qrdq", "red", 4.831716, "2018-02-03", None],
[3, "qrps", "yellow", 5.062321, "2019-05-13", None],
[3, "rrqp", "yellow", 5.82323, "2020-01-09", None],
[4, "sasr", "yellow", 6.212038, "2021-12-29", None],
[4, "sspwe", "yellow", 6.231978, "2022-01-25", None],
[4, "sssi", "yellow", 6.365346, "2023-03-20", None],
[4, "wpfsi", "yellow", 7.461754, "2023-10-23", None],
[4, "wqfed", "yellow", 7.775666, "2026-02-04", None],
[4, "wsde", "yellow", 7.818521, "2027-06-13", None],
]
expected_column_names = [
"test_column_1",
"test_column_2",
"test_column_3",
"test_column_4",
"test_column_5",
"test_column_6",
]
categories = ["int", "string", "categorical", "float", "datetime"]

expected_data = [dict(zip(categories, item)) for item in expected_array]
expected_data = [
dict(zip(expected_column_names, item)) for item in expected_array
]
expected_df = pd.DataFrame(expected_data)

actual_df = generator.synthesize(20)

self.assertEqual(mock_warning.call_count, 1)
mock_warning.assert_called_with(f"Generator of type None is not implemented.")
pd.testing.assert_frame_equal(expected_df, actual_df)

@mock.patch("dataprofiler.profilers.StructuredProfiler.report")
Expand Down

0 comments on commit df7c401

Please sign in to comment.