Skip to content

Commit

Permalink
Refactor 2 (#36)
Browse files Browse the repository at this point in the history
Improve the filtering of data based on common support in propensity
scores. The most important changes include importing utility functions,
refactoring the `filter_common_support` function, and adding new utility
functions for better code modularity.

### Addition of new utility functions:

*
[`CausalEstimate/filter/propensity.py`](diffhunk://#diff-73e7b9da6dfdab6a8a1711b59df52d4bffac8ad5d5a5db07f7ddc00b69acfa89L23-R50):
Added `get_common_support_range` function to calculate the common
support range for propensity scores, improving the readability and
maintainability of the code.
*
[`CausalEstimate/utils/utils.py`](diffhunk://#diff-47ddc6f8fde028e7680ca9bc39bbe60f9539d833069d20a5d4ce9fac195f2396R18-R23):
Added `filter_column` function to filter a DataFrame based on a
specified column range, enhancing the modularity of the code.
  • Loading branch information
kirilklein authored Oct 12, 2024
1 parent 510a668 commit 6c0b65b
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 11 deletions.
40 changes: 29 additions & 11 deletions CausalEstimate/filter/propensity.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pandas as pd
from CausalEstimate.utils.utils import get_treated_ps, get_untreated_ps, filter_column


def filter_common_support(
Expand All @@ -20,19 +21,36 @@ def filter_common_support(
Returns:
DataFrame after removing individuals without common support.
"""
# Split the dataframe into treated and control groups
treated = df[df[treatment_col] == 1]
control = df[df[treatment_col] == 0]
# Get the range of propensity scores for treated and control groups
min_ps_treated, max_ps_treated = treated[ps_col].quantile(
[threshold, 1 - threshold]
common_min, common_max = get_common_support_range(
df, treatment_col, ps_col, threshold
)
min_ps_control, max_ps_control = control[ps_col].quantile(
filtered_df = filter_column(df, ps_col, common_min, common_max)
return filtered_df


def get_common_support_range(
df: pd.DataFrame, treatment_col: str, ps_col: str, threshold: float = 0.05
) -> tuple[float, float]:
"""
Calculate the common support range for propensity scores.
Parameters:
-----------
df : Input DataFrame with treatment and propensity score columns.
treatment_col : Name of the treatment status column.
ps_col : Name of the propensity score column.
threshold : Quantile threshold for trimming score distribution tails. Default is 0.05.
Returns:
--------
Lower and upper bounds of the common support range.
"""
min_ps_treated, max_ps_treated = get_treated_ps(df, treatment_col, ps_col).quantile(
[threshold, 1 - threshold]
)
# Define the common support range
min_ps_control, max_ps_control = get_untreated_ps(
df, treatment_col, ps_col
).quantile([threshold, 1 - threshold])
common_min = max(min_ps_treated, min_ps_control)
common_max = min(max_ps_treated, max_ps_control)
# Filter individuals to keep only those within the common support range
filtered_df = df[(df[ps_col] >= common_min) & (df[ps_col] <= common_max)]
return filtered_df
return common_min, common_max
7 changes: 7 additions & 0 deletions CausalEstimate/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,10 @@ def get_treated(df: pd.DataFrame, treatment_col: str) -> pd.DataFrame:

def get_untreated(df: pd.DataFrame, treatment_col: str) -> pd.DataFrame:
return df[df[treatment_col] == 0]


def filter_column(df: pd.DataFrame, col: str, min: float, max: float) -> pd.DataFrame:
"""
Filters a DataFrame to keep only rows where a specified column is within a given range.
"""
return df[(df[col] >= min) & (df[col] <= max)]
Empty file added tests/test_vis/__init__.py
Empty file.

0 comments on commit 6c0b65b

Please sign in to comment.