-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* renamed API to interface * make sure all estimators are imported, in order for registry to work. * improved tests * fix IPW with stabilized weights
- Loading branch information
1 parent
878a6cb
commit 2c4bd53
Showing
13 changed files
with
168 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import importlib | ||
import os | ||
import pkgutil | ||
|
||
|
||
def import_all_estimators(): | ||
package_dir = os.path.join(os.path.dirname(__file__), "..", "estimators") | ||
for _, module_name, _ in pkgutil.iter_modules([package_dir]): | ||
importlib.import_module(f"CausalEstimate.estimators.{module_name}") |
2 changes: 1 addition & 1 deletion
2
CausalEstimate/api/registry.py → CausalEstimate/core/registry.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
ESTIMATOR_REGISTRY = {} | ||
|
||
|
||
def register_estimator(cls): | ||
def register_estimator(cls: object) -> object: | ||
ESTIMATOR_REGISTRY[cls.__name__] = cls | ||
return cls |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import unittest | ||
import numpy as np | ||
from CausalEstimate.estimators.functional.ipw import ( | ||
compute_ipw_ate, | ||
compute_ipw_ate_stabilized, | ||
compute_ipw_att, | ||
compute_ipw_risk_ratio, | ||
compute_ipw_risk_ratio_treated, | ||
) | ||
|
||
|
||
class TestIPWEstimators(unittest.TestCase): | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
# Simulate simple data for testing | ||
rng = np.random.default_rng(42) | ||
n = 1000 | ||
cls.A = rng.binomial(1, 0.5, size=n) # Treatment assignment | ||
cls.Y = rng.binomial(1, 0.3, size=n) # Outcome | ||
cls.ps = np.clip(rng.uniform(0.1, 0.9, size=n), 0.01, 0.99) # Propensity score | ||
|
||
def test_ipw_ate(self): | ||
ate = compute_ipw_ate(self.A, self.Y, self.ps) | ||
self.assertIsInstance(ate, float) | ||
self.assertTrue(-1 <= ate <= 1) # Check ATE is within reasonable range | ||
|
||
def test_ipw_ate_stabilized(self): | ||
ate_stabilized = compute_ipw_ate_stabilized(self.A, self.Y, self.ps) | ||
self.assertIsInstance(ate_stabilized, float) | ||
self.assertTrue(-1 <= ate_stabilized <= 1) # Check ATE with stabilized weights | ||
|
||
def test_ipw_att(self): | ||
att = compute_ipw_att(self.A, self.Y, self.ps) | ||
self.assertIsInstance(att, float) | ||
self.assertTrue(-1 <= att <= 1) # Check ATT is within reasonable range | ||
|
||
def test_ipw_risk_ratio(self): | ||
risk_ratio = compute_ipw_risk_ratio(self.A, self.Y, self.ps) | ||
self.assertIsInstance(risk_ratio, float) | ||
self.assertTrue(risk_ratio > 0) # Risk ratio should be positive | ||
|
||
def test_ipw_risk_ratio_treated(self): | ||
risk_ratio_treated = compute_ipw_risk_ratio_treated(self.A, self.Y, self.ps) | ||
self.assertIsInstance(risk_ratio_treated, float) | ||
self.assertTrue( | ||
risk_ratio_treated > 0 | ||
) # Risk ratio for treated should be positive | ||
|
||
def test_edge_case_ps_near_0_or_1(self): | ||
# Test with ps values close to 0 or 1 | ||
ps_edge = np.clip(self.ps, 0.01, 0.99) | ||
ate_edge = compute_ipw_ate(self.A, self.Y, ps_edge) | ||
self.assertIsInstance(ate_edge, float) | ||
self.assertTrue(-1 <= ate_edge <= 1) | ||
|
||
att_edge = compute_ipw_att(self.A, self.Y, ps_edge) | ||
self.assertIsInstance(att_edge, float) | ||
self.assertTrue(-1 <= att_edge <= 1) | ||
|
||
def test_mismatched_shapes(self): | ||
# Test with mismatched input shapes | ||
A = np.array([1, 0, 1]) | ||
Y = np.array([3, 1, 4]) | ||
ps = np.array([0.8, 0.6]) # Mismatched length | ||
|
||
with self.assertRaises(ValueError): | ||
compute_ipw_ate(A, Y, ps) | ||
|
||
def test_single_value_input(self): | ||
# Test with single value input | ||
A = np.array([1]) | ||
Y = np.array([1]) | ||
ps = np.array([0.5]) | ||
|
||
ate = compute_ipw_ate(A, Y, ps) | ||
self.assertIsInstance(ate, float) | ||
|
||
|
||
# Run the unittests | ||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters