Skip to content

Commit

Permalink
Fix API (#26)
Browse files Browse the repository at this point in the history
* renamed API to interface
* make sure all estimators are imported, in order for registry to work.
* improved tests
* fix IPW with stabilized weights
  • Loading branch information
kirilklein authored Oct 4, 2024
1 parent 878a6cb commit 2c4bd53
Show file tree
Hide file tree
Showing 13 changed files with 168 additions and 34 deletions.
9 changes: 9 additions & 0 deletions CausalEstimate/core/imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import importlib
import os
import pkgutil


def import_all_estimators():
package_dir = os.path.join(os.path.dirname(__file__), "..", "estimators")
for _, module_name, _ in pkgutil.iter_modules([package_dir]):
importlib.import_module(f"CausalEstimate.estimators.{module_name}")
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
ESTIMATOR_REGISTRY = {}


def register_estimator(cls):
def register_estimator(cls: object) -> object:
ESTIMATOR_REGISTRY[cls.__name__] = cls
return cls
10 changes: 0 additions & 10 deletions CausalEstimate/estimators/__init__.py

This file was deleted.

2 changes: 1 addition & 1 deletion CausalEstimate/estimators/aipw.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pandas as pd

from CausalEstimate.api.registry import register_estimator
from CausalEstimate.core.registry import register_estimator
from CausalEstimate.estimators.base import BaseEstimator
from CausalEstimate.estimators.functional.aipw import compute_aipw_ate
from CausalEstimate.utils.checks import check_inputs
Expand Down
4 changes: 2 additions & 2 deletions CausalEstimate/estimators/functional/ipw.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def compute_ipw_ate_stabilized(A, Y, ps):
A: treatment assignment, Y: outcome, ps: propensity score
"""
W = compute_stabilized_ate_weights(A, ps)
Y1_weighed = W * A * Y
Y0_weighed = W * (1 - A) * Y
Y1_weighed = (W * A * Y).mean()
Y0_weighed = (W * (1 - A) * Y).mean()
return Y1_weighed - Y0_weighed


Expand Down
2 changes: 1 addition & 1 deletion CausalEstimate/estimators/ipw.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pandas as pd

from CausalEstimate.api.registry import register_estimator
from CausalEstimate.core.registry import register_estimator
from CausalEstimate.estimators.base import BaseEstimator
from CausalEstimate.estimators.functional.ipw import (
compute_ipw_ate,
Expand Down
2 changes: 1 addition & 1 deletion CausalEstimate/estimators/matching.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from CausalEstimate.api.registry import register_estimator
from CausalEstimate.core.registry import register_estimator
from CausalEstimate.estimators.base import BaseEstimator
from CausalEstimate.estimators.functional.matching import compute_matching_ate
from CausalEstimate.matching.matching import match_optimal
Expand Down
2 changes: 1 addition & 1 deletion CausalEstimate/estimators/tmle.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pandas as pd

from CausalEstimate.api.registry import register_estimator
from CausalEstimate.core.registry import register_estimator
from CausalEstimate.estimators.base import BaseEstimator
from CausalEstimate.estimators.functional.tmle import compute_tmle_ate
from CausalEstimate.utils.checks import check_inputs
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import List, Union

import numpy as np
import pandas as pd
from typing import Union, List
from CausalEstimate.api.registry import ESTIMATOR_REGISTRY

from CausalEstimate.core.imports import import_all_estimators
from CausalEstimate.core.registry import ESTIMATOR_REGISTRY

# !TODO: Write test for all functions

Expand All @@ -21,7 +24,7 @@ def __init__(
"""
if methods is None:
methods = ["AIPW"] # Default to AIPW if no method is provided.

import_all_estimators()
# Allow single method or list of methods
self.methods = methods if isinstance(methods, list) else [methods]
self.effect_type = effect_type
Expand Down
56 changes: 55 additions & 1 deletion tests/test_functional/test_aipw.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,50 @@
import unittest
import numpy as np
from scipy.special import expit
from CausalEstimate.estimators.functional.aipw import compute_aipw_ate
from CausalEstimate.simulation.binary_simulation import (
simulate_binary_data,
compute_ATE_theoretical_from_data,
)


class TestComputeAIPWATE(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Simulate realistic data for testing
rng = np.random.default_rng(42)
n = 2000
# Covariates
data = simulate_binary_data(
n, alpha=[0.1, 0.2, -0.3, 0], beta=[0.5, 0.8, -0.6, 0.3, 0], seed=42
)
true_ate = compute_ATE_theoretical_from_data(
data, beta=[0.5, 0.8, -0.6, 0.3, 0]
)

# Predicted outcomes
X = data[["X1", "X2"]].values
A = data["A"].values
Y = data["Y"].values
ps = expit(0.1 + 0.2 * X[:, 0] - 0.3 * X[:, 1]) + 0.01 * rng.normal(size=n)
Y1_hat = expit(
0.5 + 0.8 * 1 + -0.6 * X[:, 0] + 0.3 * X[:, 1]
) + 0.01 * rng.normal(size=n)
Y0_hat = expit(0.5 + -0.6 * X[:, 0] + 0.3 * X[:, 1]) + 0.01 * rng.normal(size=n)

cls.A = A
cls.Y = Y
cls.ps = ps
cls.Y1_hat = Y1_hat
cls.Y0_hat = Y0_hat
cls.true_ate = true_ate

def test_aipw_ate_computation(self):
# Test the computation of AIPW ATE
ate_aipw = compute_aipw_ate(self.A, self.Y, self.ps, self.Y0_hat, self.Y1_hat)
self.assertIsInstance(ate_aipw, float)
# Check if the AIPW estimate is close to the true ATE
self.assertAlmostEqual(ate_aipw, self.true_ate, delta=0.1)

def test_invalid_input_shapes(self):
# Test for mismatched input shapes
Expand All @@ -17,6 +58,19 @@ def test_invalid_input_shapes(self):
with self.assertRaises(ValueError):
compute_aipw_ate(A, Y, ps, Y0_hat, Y1_hat)

def test_edge_case_ps_close_to_zero_or_one(self):
# Test with ps very close to 0 or 1
ps_edge = self.ps.copy()
ps_edge[ps_edge < 0.01] = 0.01
ps_edge[ps_edge > 0.99] = 0.99

# Compute the AIPW estimate with the edge case propensity scores
ate_aipw = compute_aipw_ate(self.A, self.Y, ps_edge, self.Y0_hat, self.Y1_hat)
self.assertIsInstance(ate_aipw, float)
# Check if the estimate is still close to the true ATE
self.assertAlmostEqual(ate_aipw, self.true_ate, delta=0.15)


# Run the unittests
unittest.main(argv=[""], exit=False)
if __name__ == "__main__":
unittest.main()
82 changes: 82 additions & 0 deletions tests/test_functional/test_ipw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import unittest
import numpy as np
from CausalEstimate.estimators.functional.ipw import (
compute_ipw_ate,
compute_ipw_ate_stabilized,
compute_ipw_att,
compute_ipw_risk_ratio,
compute_ipw_risk_ratio_treated,
)


class TestIPWEstimators(unittest.TestCase):

@classmethod
def setUpClass(cls):
# Simulate simple data for testing
rng = np.random.default_rng(42)
n = 1000
cls.A = rng.binomial(1, 0.5, size=n) # Treatment assignment
cls.Y = rng.binomial(1, 0.3, size=n) # Outcome
cls.ps = np.clip(rng.uniform(0.1, 0.9, size=n), 0.01, 0.99) # Propensity score

def test_ipw_ate(self):
ate = compute_ipw_ate(self.A, self.Y, self.ps)
self.assertIsInstance(ate, float)
self.assertTrue(-1 <= ate <= 1) # Check ATE is within reasonable range

def test_ipw_ate_stabilized(self):
ate_stabilized = compute_ipw_ate_stabilized(self.A, self.Y, self.ps)
self.assertIsInstance(ate_stabilized, float)
self.assertTrue(-1 <= ate_stabilized <= 1) # Check ATE with stabilized weights

def test_ipw_att(self):
att = compute_ipw_att(self.A, self.Y, self.ps)
self.assertIsInstance(att, float)
self.assertTrue(-1 <= att <= 1) # Check ATT is within reasonable range

def test_ipw_risk_ratio(self):
risk_ratio = compute_ipw_risk_ratio(self.A, self.Y, self.ps)
self.assertIsInstance(risk_ratio, float)
self.assertTrue(risk_ratio > 0) # Risk ratio should be positive

def test_ipw_risk_ratio_treated(self):
risk_ratio_treated = compute_ipw_risk_ratio_treated(self.A, self.Y, self.ps)
self.assertIsInstance(risk_ratio_treated, float)
self.assertTrue(
risk_ratio_treated > 0
) # Risk ratio for treated should be positive

def test_edge_case_ps_near_0_or_1(self):
# Test with ps values close to 0 or 1
ps_edge = np.clip(self.ps, 0.01, 0.99)
ate_edge = compute_ipw_ate(self.A, self.Y, ps_edge)
self.assertIsInstance(ate_edge, float)
self.assertTrue(-1 <= ate_edge <= 1)

att_edge = compute_ipw_att(self.A, self.Y, ps_edge)
self.assertIsInstance(att_edge, float)
self.assertTrue(-1 <= att_edge <= 1)

def test_mismatched_shapes(self):
# Test with mismatched input shapes
A = np.array([1, 0, 1])
Y = np.array([3, 1, 4])
ps = np.array([0.8, 0.6]) # Mismatched length

with self.assertRaises(ValueError):
compute_ipw_ate(A, Y, ps)

def test_single_value_input(self):
# Test with single value input
A = np.array([1])
Y = np.array([1])
ps = np.array([0.5])

ate = compute_ipw_ate(A, Y, ps)
self.assertIsInstance(ate, float)


# Run the unittests
if __name__ == "__main__":
unittest.main()
20 changes: 8 additions & 12 deletions tests/test_functional/test_tmle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ class TestTMLEFunctions(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Simulate realistic data for testing
np.random.seed(42)
n = 1000
rng = np.random.default_rng(42)
n = 2000
# Covariates
data = simulate_binary_data(
n, alpha=[0.1, 0.2, -0.3, 0], beta=[0.5, 0.8, -0.6, 0.3, 0]
n, alpha=[0.1, 0.2, -0.3, 0], beta=[0.5, 0.8, -0.6, 0.3, 0], seed=42
)
true_ate = compute_ATE_theoretical_from_data(
data, beta=[0.5, 0.8, -0.6, 0.3, 0]
Expand All @@ -32,18 +32,14 @@ def setUpClass(cls):
X = data[["X1", "X2"]].values
A = data["A"].values
Y = data["Y"].values
ps = expit(0.1 + 0.2 * X[:, 0] - 0.3 * X[:, 1]) + 0.01 * np.random.normal(
size=n
)
ps = expit(0.1 + 0.2 * X[:, 0] - 0.3 * X[:, 1]) + 0.01 * rng.normal(size=n)
Y1_hat = expit(
0.5 + 0.8 * 1 + -0.6 * X[:, 0] + 0.3 * X[:, 1]
) + 0.01 * np.random.normal(size=n)
Y0_hat = expit(0.5 + -0.6 * X[:, 0] + 0.3 * X[:, 1]) + 0.01 * np.random.normal(
size=n
)
) + 0.01 * rng.normal(size=n)
Y0_hat = expit(0.5 + -0.6 * X[:, 0] + 0.3 * X[:, 1]) + 0.01 * rng.normal(size=n)
Yhat = expit(
0.5 + 0.8 * A + -0.6 * X[:, 0] + 0.3 * X[:, 1]
) + 0.01 * np.random.normal(size=n)
) + 0.01 * rng.normal(size=n)

cls.A = A
cls.Y = Y
Expand Down Expand Up @@ -83,7 +79,7 @@ def test_compute_tmle_ate_edge_cases(self):
self.A, self.Y, ps_edge, self.Y0_hat, self.Y1_hat, self.Yhat
)
self.assertIsInstance(ate_tmle, float)
self.assertAlmostEqual(ate_tmle, self.true_ate, delta=0.1)
self.assertAlmostEqual(ate_tmle, self.true_ate, delta=0.15)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import unittest
import pandas as pd
import numpy as np
from CausalEstimate.api.estimator import Estimator
from CausalEstimate.interface.estimator import Estimator
from CausalEstimate.estimators.aipw import AIPW
from CausalEstimate.estimators.tmle import TMLE

Expand Down

0 comments on commit 2c4bd53

Please sign in to comment.