Skip to content

Commit

Permalink
black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
kirilklein committed Oct 12, 2024
1 parent c998957 commit 61c4bd5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
16 changes: 11 additions & 5 deletions CausalEstimate/filter/propensity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd
from CausalEstimate.utils.utils import get_treated_ps, get_untreated_ps, filter_column


def filter_common_support(
df: pd.DataFrame,
ps_col: str = "propensity_score",
Expand All @@ -20,11 +21,16 @@ def filter_common_support(
Returns:
DataFrame after removing individuals without common support.
"""
common_min, common_max = get_common_support_range(df, treatment_col, ps_col, threshold)
common_min, common_max = get_common_support_range(
df, treatment_col, ps_col, threshold
)
filtered_df = filter_column(df, ps_col, common_min, common_max)
return filtered_df

def get_common_support_range(df: pd.DataFrame, treatment_col: str, ps_col: str, threshold: float = 0.05) -> tuple[float, float]:

def get_common_support_range(
df: pd.DataFrame, treatment_col: str, ps_col: str, threshold: float = 0.05
) -> tuple[float, float]:
"""
Calculate the common support range for propensity scores.
Expand All @@ -42,9 +48,9 @@ def get_common_support_range(df: pd.DataFrame, treatment_col: str, ps_col: str,
min_ps_treated, max_ps_treated = get_treated_ps(df, treatment_col, ps_col).quantile(
[threshold, 1 - threshold]
)
min_ps_control, max_ps_control = get_untreated_ps(df, treatment_col, ps_col).quantile(
[threshold, 1 - threshold]
)
min_ps_control, max_ps_control = get_untreated_ps(
df, treatment_col, ps_col
).quantile([threshold, 1 - threshold])
common_min = max(min_ps_treated, min_ps_control)
common_max = min(max_ps_treated, max_ps_control)
return common_min, common_max
1 change: 1 addition & 0 deletions CausalEstimate/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def get_treated(df: pd.DataFrame, treatment_col: str) -> pd.DataFrame:
def get_untreated(df: pd.DataFrame, treatment_col: str) -> pd.DataFrame:
return df[df[treatment_col] == 0]


def filter_column(df: pd.DataFrame, col: str, min: float, max: float) -> pd.DataFrame:
"""
Filters a DataFrame to keep only rows where a specified column is within a given range.
Expand Down

0 comments on commit 61c4bd5

Please sign in to comment.