Skip to content

Commit

Permalink
layered config tree typing (#473)
Browse files Browse the repository at this point in the history
Category: typing
JIRA issue: MIC-5376

Typing updates related to LayeredConfigTree errors.

Testing
All tests pass (including slow tests).
  • Loading branch information
hussain-jafari authored Nov 5, 2024
1 parent 8b56091 commit 49d9084
Show file tree
Hide file tree
Showing 22 changed files with 304 additions and 206 deletions.
6 changes: 3 additions & 3 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

base_dir = Path(pseudopeople.__file__).parent

about: dict = {}
about: dict[str, str] = {}
with (base_dir / "__about__.py").open() as f:
exec(f.read(), about)

Expand Down Expand Up @@ -91,7 +91,7 @@
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
exclude_patterns: list = []
exclude_patterns: list[str] = []

# The name of the Pygments (syntax highlighting) style to use.
pygments_style = "sphinx"
Expand Down Expand Up @@ -144,7 +144,7 @@

# -- Options for LaTeX output ---------------------------------------------

latex_elements: dict = {
latex_elements: dict[str, str] = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
Expand Down
20 changes: 0 additions & 20 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,6 @@ implicit_reexport = true
exclude = [
"build",
]
# ignore error codes (to allow for gradual typing)
# extract unique error codes thrown with `mypy . | grep -oP " [ \K[^\] ]+" | sort | uniq`
disable_error_code = [
# assignment is still being ignored because all of these issues stem from LayeredConfigTree
# usage and we want to make a schema object that is typed which should hopefully resolve most
# of these issues.
# TODO: remove "assignment" once typed object is implemented
"assignment",
"attr-defined",
# Call overload is still being ignored because of the same issue mentioned above
# for assignment
# TODO: remove "call-overload" once typed object is implemented
"call-overload",
"index",
# operator is still being ignored because of issues with checking containment in
# LayeredConfigTree objects which should be resolved by changes mentioned above
# TODO: remove "operator" once typed object is implemented
"operator",
"type-arg",
]

# handle mypy errors when 3rd party packages are not typed.
[[tool.mypy.overrides]]
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
base_dir = Path(__file__).parent
src_dir = base_dir / "src"

about: dict = {}
about: dict[str, str] = {}
with (src_dir / "pseudopeople" / "__about__.py").open() as f:
exec(f.read(), about)

Expand All @@ -38,7 +38,7 @@
"pyarrow",
"scipy",
"tqdm",
"layered_config_tree>=1.0.1",
"layered_config_tree>=2.1.0",
"loguru",
# type stubs
"pandas-stubs",
Expand Down
36 changes: 24 additions & 12 deletions src/pseudopeople/configuration/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

# Define non-baseline default items
# NOTE: default values are defined in entity_types.RowNoiseType and entity_types.ColumnNoiseType
DEFAULT_NOISE_VALUES: dict = {
DEFAULT_NOISE_VALUES: dict[str, Any] = {
DATASET_SCHEMAS.census.name: {
Keys.ROW_NOISE: {
NOISE_TYPES.do_not_respond.name: {
Expand Down Expand Up @@ -101,16 +101,19 @@ def get_configuration(
overrides = None
elif isinstance(overrides, (Path, str)):
with open(overrides, "r") as f:
overrides = yaml.safe_load(f)
overrides_dict: dict[str, Any] = yaml.safe_load(f)
is_no_noise = False
elif overrides is not None:
overrides_dict = overrides
is_no_noise = False
else:
is_no_noise = False
noising_configuration = _generate_configuration(is_no_noise)
if overrides is not None:
validate_overrides(overrides, noising_configuration)
validate_overrides(overrides_dict, noising_configuration)
add_overrides(
noising_configuration,
overrides, # type: ignore [arg-type]
overrides_dict,
dataset_schema,
filters,
)
Expand All @@ -129,7 +132,9 @@ def _generate_configuration(is_no_noise: bool) -> LayeredConfigTree:
baseline_dict = {}
# Loop through each dataset
for dataset_schema in DATASET_SCHEMAS:
dataset_dict = {}
# dataset_dict is extremely nested so typing it any deeper
# causes problems for typing further down
dataset_dict: dict[str, dict[str, dict]] = {} # type: ignore [type-arg]
row_noise_dict = {}
column_dict = {}

Expand Down Expand Up @@ -186,7 +191,7 @@ def get_noise_type_dict(noise_type: NoiseType, is_no_noise: bool) -> dict[str, f

def add_overrides(
noising_configuration: LayeredConfigTree,
overrides: dict,
overrides: dict[str, Any],
dataset_schema: DatasetSchema | None = None,
filters: Sequence[DataFilter] = (),
) -> None:
Expand All @@ -200,7 +205,9 @@ def add_overrides(
validate_noise_level_proportions(noising_configuration, dataset_schema, filters)


def _format_overrides(default_config: LayeredConfigTree, user_dict: dict) -> dict:
def _format_overrides(
default_config: LayeredConfigTree, user_dict: dict[str, Any]
) -> dict[str, Any]:
"""Formats the user's configuration file as necessary, so it can properly
update noising configuration to be used
"""
Expand All @@ -209,8 +216,8 @@ def _format_overrides(default_config: LayeredConfigTree, user_dict: dict) -> dic


def _format_misreport_age_perturbations(
default_config: LayeredConfigTree, user_dict: dict
) -> dict:
default_config: LayeredConfigTree, user_dict: dict[str, Any]
) -> dict[str, Any]:
# Format any age perturbation lists as a dictionary with uniform probabilities
for dataset_schema in user_dict:
user_perturbations = (
Expand All @@ -223,9 +230,14 @@ def _format_misreport_age_perturbations(
if not user_perturbations:
continue
formatted = {}
default_perturbations: dict[int, float] = default_config[dataset_schema][
Keys.COLUMN_NOISE
]["age"][NOISE_TYPES.misreport_age.name][Keys.POSSIBLE_AGE_DIFFERENCES]
default_perturbations: dict[int, float] = (
default_config.get_tree(dataset_schema)
.get_tree(Keys.COLUMN_NOISE)
.get_tree("age")
.get_tree(NOISE_TYPES.misreport_age.name)
.get(Keys.POSSIBLE_AGE_DIFFERENCES)
.to_dict()
)
# Replace default configuration with 0 probabilities
for perturbation in default_perturbations:
formatted[perturbation] = 0.0
Expand Down
5 changes: 3 additions & 2 deletions src/pseudopeople/configuration/interface.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from pathlib import Path
from typing import Any

from pseudopeople.configuration.generator import get_configuration


def get_config(overrides: Path | str | dict | None = None) -> dict:
def get_config(overrides: Path | str | dict[str, Any] | None = None) -> dict[str, Any]:
"""
Function that returns the pseudopeople configuration containing all
default values. To get the default probability of nonresponse in the
Expand Down Expand Up @@ -49,5 +50,5 @@ def get_config(overrides: Path | str | dict | None = None) -> dict:
An invalid configuration is passed with `overrides`.
"""
config: dict = get_configuration(overrides).to_dict()
config: dict[str, Any] = get_configuration(overrides).to_dict()
return config
94 changes: 74 additions & 20 deletions src/pseudopeople/configuration/noise_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class NoiseConfiguration:
def __init__(self, config: LayeredConfigTree):
self._config = config

def to_dict(self) -> dict:
config_dict: dict = self._config.to_dict()
def to_dict(self) -> dict[str, Any]:
config_dict: dict[str, Any] = self._config.to_dict()
return config_dict

def get_value(
Expand All @@ -32,10 +32,10 @@ def get_value(
noise_type: str,
parameter_name: str,
column_name: str | None = None,
) -> float | int | list | dict:
) -> float | int | list[float] | dict[int, float]:
config = self._config
try:
dataset_config = config[dataset]
dataset_config: LayeredConfigTree = config.get_tree(dataset)
except:
raise ValueError(
f"{dataset} was not found in configuration. "
Expand All @@ -48,20 +48,20 @@ def get_value(
raise ValueError(
f"You cannot provide both a row noise type ({noise_type}) and a column name ({column_name}) simultaneously."
)
config = dataset_config["row_noise"]
config = dataset_config.get_tree("row_noise")
# column noise
elif noise_type in COLUMN_NOISE_TYPES:
if not column_name:
raise ValueError(
f"You must provide a column name when using a column noise type ({noise_type} in your case)."
)
all_column_configs: LayeredConfigTree = dataset_config["column_noise"]
all_column_configs: LayeredConfigTree = dataset_config.get_tree("column_noise")
if column_name not in all_column_configs:
raise ValueError(
f"The column name {column_name} was not found in your config. "
f"Available columns are {list(all_column_configs.keys())}."
)
config = all_column_configs[column_name]
config = all_column_configs.get_tree(column_name)
# unknown noise type
else:
raise ValueError(
Expand All @@ -70,40 +70,94 @@ def get_value(
f"Available column noise types are {COLUMN_NOISE_TYPES}."
)
# get value
parameter_tree: LayeredConfigTree = config[noise_type]
parameter_tree: LayeredConfigTree = config.get_tree(noise_type)
if parameter_name not in parameter_tree:
raise ValueError(
f"The parameter {parameter_name} was not found for {noise_type} in the configuration. "
f"Available parameters are {list(parameter_tree.keys())}."
)
noise_value: int | float | LayeredConfigTree = parameter_tree[parameter_name]
converted_noise_value: int | float | dict = (
noise_value.to_dict()
if isinstance(noise_value, LayeredConfigTree)
else noise_value
)
return converted_noise_value
noise_value: int | float | LayeredConfigTree = parameter_tree.get(parameter_name)
if isinstance(noise_value, LayeredConfigTree):
# TODO: [MIC-5500] store dicts in LayeredConfigTree without converting to LayeredConfigTree
converted_noise_value: dict[int, float] = noise_value.to_dict() # type: ignore [assignment]
return converted_noise_value
else:
return noise_value

def get_row_probability(self, dataset: str, noise_type: str) -> int | float:
value: int | float = self.get_value(
dataset, noise_type, parameter_name="row_probability"
)
value = self.get_value(dataset, noise_type, parameter_name="row_probability")
if not isinstance(value, int) and not isinstance(value, float):
raise ValueError(
f"Row probabilities are expected to contain ints or floats. Your config returned {type(value)}."
)
return value

def get_cell_probability(
self, dataset: str, noise_type: str, column_name: str
) -> int | float:
value: int | float = self.get_value(
value = self.get_value(
dataset, noise_type, parameter_name="cell_probability", column_name=column_name
)
if not isinstance(value, int) and not isinstance(value, float):
raise ValueError(
f"Cell probabilities are expected to contain ints or floats. Your config returned {type(value)}."
)
return value

def get_token_probability(
self, dataset: str, noise_type: str, column_name: str
) -> int | float:
value: int | float = self.get_value(
value = self.get_value(
dataset, noise_type, parameter_name="token_probability", column_name=column_name
)
if not isinstance(value, int) and not isinstance(value, float):
raise ValueError(
f"Token probabilities are expected to contain ints or floats. Your config returned {type(value)}."
)
return value

def get_zipcode_digit_probabilities(self, dataset: str, column_name: str) -> list[float]:
values = self.get_value(
dataset,
"write_wrong_zipcode_digits",
parameter_name="digit_probabilities",
column_name=column_name,
)
if not isinstance(values, list) or not all(
isinstance(value, float) for value in values
):
raise ValueError(
f"Zipcode digit probabilities are expected to be a list of floats. Your config returned {type(values)}."
)
return values

def get_duplicate_with_guardian_probabilities(
self, dataset: str, parameter_name: str
) -> int | float:
if (
parameter_name != "row_probability_in_households_under_18"
and parameter_name != "row_probability_in_college_group_quarters_under_24"
):
raise ValueError(
f"Parameter name must be 'row_probability_in_households_under_18' or 'row_probability_in_college_group_quarters_under_24' when getting duplicate with guardian probabilities. You provided {parameter_name}."
)
value = self.get_value(dataset, "duplicate_with_guardian", parameter_name)
if not isinstance(value, int) and not isinstance(value, float):
raise ValueError(
f"Duplicate with guardian probabilities are expected to be ints or floats. Your config returned {type(value)}."
)
return value

def get_misreport_ages_probabilities(
self, dataset: str, column_name: str
) -> dict[int, float]:
value = self.get_value(
dataset, "misreport_age", Keys.POSSIBLE_AGE_DIFFERENCES, column_name
)
if not isinstance(value, dict):
raise ValueError(
f"Misreport age probabilities are expected to be a dict. Your config returned {type(value)}."
)
return value

def has_noise_type(
Expand Down
Loading

0 comments on commit 49d9084

Please sign in to comment.