Skip to content

Commit

Permalink
allow RI and USA populations
Browse files Browse the repository at this point in the history
  • Loading branch information
hussain-jafari committed Nov 26, 2024
1 parent 923b0be commit 296d540
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
33 changes: 27 additions & 6 deletions tests/release/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import pytest
from memory_profiler import memory_usage # type: ignore

from pseudopeople.configuration import Keys, get_configuration
from pseudopeople.configuration import get_configuration
from pseudopeople.dataset import Dataset
from pseudopeople.interface import (
generate_american_community_survey,
generate_current_population_survey,
Expand All @@ -19,9 +20,7 @@
generate_taxes_w2_and_1099,
generate_women_infants_and_children,
)
from pseudopeople.noise_entities import NOISE_TYPES
from pseudopeople.schema_entities import COLUMNS, DATASET_SCHEMAS
from tests.integration.conftest import CELL_PROBABILITY
from pseudopeople.schema_entities import DATASET_SCHEMAS

DATASET_GENERATION_FUNCS: dict[str, Callable[..., Any]] = {
"census": generate_decennial_census,
Expand All @@ -42,6 +41,7 @@
"wic": "women_infants_and_children",
}

SEED = 0
DEFAULT_YEAR = None
DEFAULT_STATE = None
DEFAULT_POP = "sample"
Expand Down Expand Up @@ -100,7 +100,8 @@ def output_dir() -> Path:
# output_dir_name = (
# "/mnt/team/simulation_science/priv/engineering/pseudopeople_release_testing"
# )
output_dir_name = "/home/hjafari/ppl_testing"
# output_dir_name = "/home/hjafari/ppl_testing"
output_dir_name = "/ihme/homes/hjafari/ppl_testing"
if not output_dir_name:
raise ValueError("PSP_TEST_OUTPUT_DIR environment variable not set")
output_dir = Path(output_dir_name) / f"{time.strftime('%Y%m%d_%H%M%S')}"
Expand All @@ -109,7 +110,7 @@ def output_dir() -> Path:


@pytest.fixture(scope="session")
def dataset(
def data(
output_dir: Path, request: pytest.FixtureRequest, config: dict[str, Any]
) -> pd.DataFrame:
dataset_name, dataset_func, source, engine, state, year = _parse_dataset_params(request)
Expand All @@ -124,6 +125,26 @@ def dataset(
)


@pytest.fixture(scope="session")
def unnoised_dataset(
output_dir: Path, request: pytest.FixtureRequest, config: dict[str, Any]
) -> pd.DataFrame:
dataset_name, dataset_func, source, engine, state, year = _parse_dataset_params(request)
no_noise_config = get_configuration("no_noise")

if dataset_func == generate_social_security:
unnoised_data = dataset_func(
source=source, year=year, engine=engine, config=no_noise_config
)
else:
unnoised_data = dataset_func(
source=source, year=year, state=state, engine=engine, config=no_noise_config
)

dataset_schema = DATASET_SCHEMAS.get_dataset_schema(dataset_name)
return Dataset(dataset_schema, unnoised_data, SEED)


@pytest.fixture(scope="session")
def dataset_name(request: pytest.FixtureRequest) -> str:
dataset_arg = request.config.getoption("--dataset")
Expand Down
6 changes: 3 additions & 3 deletions tests/release/test_release.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_column_noising(
) -> None:
"""Tests that columns are noised as expected"""
original = _initialize_dataset_with_sample(dataset_name)
noised_data = request.getfixturevalue("dataset")
noised_data = request.getfixturevalue("data")
check_noised, check_original, shared_idx = _get_common_datasets(original, noised_data)

run_column_noising_tests(
Expand All @@ -86,7 +86,7 @@ def test_row_noising_omit_row_or_do_not_respond(
idx_cols = IDX_COLS.get(dataset_name)
original = get_unnoised_data(dataset_name)
original_data = original.data.set_index(idx_cols)
noised_data = request.getfixturevalue("dataset")
noised_data = request.getfixturevalue("data")
noised_data = noised_data.set_index(idx_cols)

run_omit_row_or_do_not_respond_tests(dataset_name, config, original_data, noised_data)
Expand All @@ -100,7 +100,7 @@ def test_unnoised_id_cols(dataset_name: str, request: FixtureRequest) -> None:
if dataset_name != DATASET_SCHEMAS.ssa.name:
unnoised_id_cols.append(COLUMNS.household_id.name)
original = _initialize_dataset_with_sample(dataset_name)
noised_data = request.getfixturevalue("dataset")
noised_data = request.getfixturevalue("data")
check_noised, check_original, _ = _get_common_datasets(original, noised_data)
assert (
(
Expand Down

0 comments on commit 296d540

Please sign in to comment.