Skip to content

Commit

Permalink
redact function preserve nulls
Browse files Browse the repository at this point in the history
  • Loading branch information
Egor Goloshchapov committed Jul 26, 2022
1 parent 58cc6ae commit 8c308fb
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
3 changes: 2 additions & 1 deletion nestedfunctions/functions/redact.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def process(self, df: DataFrame) -> DataFrame:

def transform_primitive(self, primitive_value: Column, fieldType: AtomicType) -> Column:
try:
return column_name_with_dedicated_field_type(fieldType)
return F.when(F.isnull(primitive_value), primitive_value) \
.otherwise(column_name_with_dedicated_field_type(fieldType))
except KeyError:
raise Exception(
f'Unknown type {fieldType.simpleString()} for field {self.column_to_process}. '
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/functions/redact/fixtures/redact_sample.json
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
[
{
"root level string": "yo",
"root level with null": null,
"root-level-number": 1.123521322,
"customDimensions": [
{
"Metabolics Conditions": 13,
"null value in nested": null,
"value": "value1"
},
{
"Metabolics Conditions": 2,
"null value in nested": 5,
"value": "value2"
}
]
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/functions/redact/test_redact.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@


class RedactTest(SparkBaseTest):
def test_redact_root_level_preserve_null(self):
def parse_data(df: DataFrame) -> List[str]:
return [d[0] for d in df.select("root level with null").collect()]

df = self.parse_df()
self.assertEqual([None], parse_data(df))
processed = redact(df, field="root level with null")
self.assertEqual([None], parse_data(processed))

def test_redact_root_level(self):
def parse_data(df: DataFrame) -> List[str]:
Expand All @@ -37,6 +45,18 @@ def parse_data(df_to_extract: DataFrame) -> List[str]:
expected_value = SPARK_TYPE_TO_REDACT_VALUE[primitive_value_type]
self.assertEqual([expected_value, expected_value], flatten(parse_data(processed)))


def test_redact_nested_structure_null_value_preserved(self):
def parse_data(df_to_extract: DataFrame) -> List[str]:
return [d[0] for d in df_to_extract.select("customDimensions.null value in nested").collect()]

df = self.parse_df()
self.assertEqual([None, 5], flatten(parse_data(df)))
processed = redact(df, field="customDimensions.null value in nested")
primitive_value_type = df.schema["customDimensions"].dataType.elementType['null value in nested'].dataType
expected_value = SPARK_TYPE_TO_REDACT_VALUE[primitive_value_type]
self.assertEqual([None, expected_value], flatten(parse_data(processed)))

def test_redact_throws_exception_if_field_is_not_primitive(self):
df = self.parse_df()
with pytest.raises(Exception) as excinfo:
Expand Down

0 comments on commit 8c308fb

Please sign in to comment.