Skip to content

Commit

Permalink
improve unittest: Move setup to a base class (only for tmle so far)
Browse files Browse the repository at this point in the history
  • Loading branch information
kirilklein committed Oct 17, 2024
1 parent dbabff8 commit e8d5073
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 57 deletions.
62 changes: 62 additions & 0 deletions tests/test_functional/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import unittest
from typing import List

import numpy as np
from scipy.special import expit

from CausalEstimate.simulation.binary_simulation import (
compute_ATE_theoretical_from_data,
compute_ATT_theoretical_from_data,
simulate_binary_data,
)


class TestEffectBase(unittest.TestCase):
n: int = 2000
alpha: List[float] = [0.1, 0.2, -0.3, 0]
beta: List[float] = [0.5, 0.8, -0.6, 0.3, 0]
seed: int = 42

@classmethod
def setUpClass(cls):
# Simulate realistic data for testing
rng = np.random.default_rng(cls.seed)
# Covariates
data = simulate_binary_data(
cls.n, alpha=cls.alpha, beta=cls.beta, seed=cls.seed
)

# Predicted outcomes
X = data[["X1", "X2"]].values
A = data["A"].values
Y = data["Y"].values
ps = expit(
cls.alpha[0] + cls.alpha[1] * X[:, 0] + cls.alpha[2] * X[:, 1]
) + 0.01 * rng.normal(size=cls.n)
Y1_hat = expit(
cls.beta[0]
+ cls.beta[1] * 1
+ cls.beta[2] * X[:, 0]
+ cls.beta[3] * X[:, 1]
) + 0.01 * rng.normal(size=cls.n)
Y0_hat = expit(
cls.beta[0] + cls.beta[2] * X[:, 0] + cls.beta[3] * X[:, 1]
) + 0.01 * rng.normal(size=cls.n)
Yhat = expit(
cls.beta[0]
+ cls.beta[1] * A
+ cls.beta[2] * X[:, 0]
+ cls.beta[3] * X[:, 1]
) + 0.01 * rng.normal(size=cls.n)

cls.A = A
cls.Y = Y
cls.ps = ps
cls.Y1_hat = Y1_hat
cls.Y0_hat = Y0_hat
cls.Yhat = Yhat

true_ate = compute_ATE_theoretical_from_data(data, beta=cls.beta)
true_att = compute_ATT_theoretical_from_data(data, beta=cls.beta)
cls.true_ate = true_ate
cls.true_att = true_att
34 changes: 18 additions & 16 deletions tests/test_functional/test_aipw.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,22 @@ def setUpClass(cls):
# Covariates
alpha = [0.1, 0.2, -0.3, 0]
beta = [0.5, 0.8, -0.6, 0.3, 0]
data = simulate_binary_data(
n, alpha=alpha, beta=beta, seed=42
)
true_ate = compute_ATE_theoretical_from_data(
data, beta=beta
)
data = simulate_binary_data(n, alpha=alpha, beta=beta, seed=42)
true_ate = compute_ATE_theoretical_from_data(data, beta=beta)

# Predicted outcomes
X = data[["X1", "X2"]].values
A = data["A"].values
Y = data["Y"].values
ps = expit(alpha[0] + alpha[1] * X[:, 0] + alpha[2] * X[:, 1]) + 0.01 * rng.normal(size=n)
ps = expit(
alpha[0] + alpha[1] * X[:, 0] + alpha[2] * X[:, 1]
) + 0.01 * rng.normal(size=n)
Y1_hat = expit(
beta[0] + beta[1] * 1 + beta[2] * X[:, 0] + beta[3] * X[:, 1]
) + 0.01 * rng.normal(size=n)
Y0_hat = expit(beta[0] + beta[2] * X[:, 0] + beta[3] * X[:, 1]) + 0.01 * rng.normal(size=n)
Y0_hat = expit(
beta[0] + beta[2] * X[:, 0] + beta[3] * X[:, 1]
) + 0.01 * rng.normal(size=n)

cls.A = A
cls.Y = Y
Expand Down Expand Up @@ -73,6 +73,7 @@ def test_edge_case_ps_close_to_zero_or_one(self):
# Check if the estimate is still close to the true ATE
self.assertAlmostEqual(ate_aipw, self.true_ate, delta=0.15)


class TestComputeAIPWATT(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand All @@ -82,22 +83,22 @@ def setUpClass(cls):
# Covariates
alpha = [0.1, 0.2, -0.3, 0]
beta = [0.5, 0.8, -0.6, 0.3, 0]
data = simulate_binary_data(
n, alpha=alpha, beta=beta, seed=42
)
true_att = compute_ATT_theoretical_from_data(
data, beta=beta
)
data = simulate_binary_data(n, alpha=alpha, beta=beta, seed=42)
true_att = compute_ATT_theoretical_from_data(data, beta=beta)

# Predicted outcomes
X = data[["X1", "X2"]].values
A = data["A"].values
Y = data["Y"].values
ps = expit(alpha[0] + alpha[1] * X[:, 0] + alpha[2] * X[:, 1]) + 0.01 * rng.normal(size=n)
ps = expit(
alpha[0] + alpha[1] * X[:, 0] + alpha[2] * X[:, 1]
) + 0.01 * rng.normal(size=n)
Y1_hat = expit(
beta[0] + beta[1] * 1 + beta[2] * X[:, 0] + beta[3] * X[:, 1]
) + 0.01 * rng.normal(size=n)
Y0_hat = expit(beta[0] + beta[2] * X[:, 0] + beta[3] * X[:, 1]) + 0.01 * rng.normal(size=n)
Y0_hat = expit(
beta[0] + beta[2] * X[:, 0] + beta[3] * X[:, 1]
) + 0.01 * rng.normal(size=n)

cls.A = A
cls.Y = Y
Expand All @@ -113,6 +114,7 @@ def test_aipw_att_computation(self):
# Check if the AIPW estimate is close to the true ATT
self.assertAlmostEqual(att_aipw, self.true_att, delta=0.1)


# Run the unittests
if __name__ == "__main__":
unittest.main()
43 changes: 2 additions & 41 deletions tests/test_functional/test_tmle.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,15 @@
import unittest

import numpy as np
from scipy.special import expit

from CausalEstimate.estimators.functional.tmle import (
compute_tmle_ate,
estimate_fluctuation_parameter,
update_ate_estimate,
)
from CausalEstimate.simulation.binary_simulation import (
simulate_binary_data,
compute_ATE_theoretical_from_data,
)


class TestTMLEFunctions(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Simulate realistic data for testing
rng = np.random.default_rng(42)
n = 2000
# Covariates
data = simulate_binary_data(
n, alpha=[0.1, 0.2, -0.3, 0], beta=[0.5, 0.8, -0.6, 0.3, 0], seed=42
)
true_ate = compute_ATE_theoretical_from_data(
data, beta=[0.5, 0.8, -0.6, 0.3, 0]
)

# Predicted outcomes
X = data[["X1", "X2"]].values
A = data["A"].values
Y = data["Y"].values
ps = expit(0.1 + 0.2 * X[:, 0] - 0.3 * X[:, 1]) + 0.01 * rng.normal(size=n)
Y1_hat = expit(
0.5 + 0.8 * 1 + -0.6 * X[:, 0] + 0.3 * X[:, 1]
) + 0.01 * rng.normal(size=n)
Y0_hat = expit(0.5 + -0.6 * X[:, 0] + 0.3 * X[:, 1]) + 0.01 * rng.normal(size=n)
Yhat = expit(
0.5 + 0.8 * A + -0.6 * X[:, 0] + 0.3 * X[:, 1]
) + 0.01 * rng.normal(size=n)
from tests.test_functional.base import TestEffectBase

cls.A = A
cls.Y = Y
cls.ps = ps
cls.Y1_hat = Y1_hat
cls.Y0_hat = Y0_hat
cls.Yhat = Yhat
cls.true_ate = true_ate

class TestTMLEFunctions(TestEffectBase):
def test_estimate_fluctuation_parameter(self):
epsilon = estimate_fluctuation_parameter(self.A, self.Y, self.ps, self.Yhat)
self.assertIsInstance(epsilon, float)
Expand Down

0 comments on commit e8d5073

Please sign in to comment.