-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
310 additions
and
0 deletions.
There are no files selected for viewing
310 changes: 310 additions & 0 deletions
310
3_Epidemiology_Analysis/c_causal_inference/1_time-fixed-treatments/5_AIPTW.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,310 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Augmented Inverse Probability of Treatment Weights\n", | ||
"Augmented-IPTW (AIPTW) is a doubly robust estimator. Essentially, AIPTW combines the IPTW estimator and g-formula into a single estimate. Before continuing, I will briefly outline what a doubly-robust estimator is and why you would want to use one. In observational research with high-dimensional data, we (generally) are forced to use parametric models to adjust for many confounders. In this scenario, we assume that our parametric models are correctly specified. Our statistical model, $\\mathcal{M}$, must include the distribution that the data came from. \n", | ||
"\n", | ||
"With other estimators, like IPTW or g-formula, we have one chance to specify $\\mathcal{M}$ correctly. Doubly-robust estimators use a model to predict the treatment (like IPTW) and another model to predict the outcome (like g-formula). The estimator then combines the estimates, such that if either is correct, then our estimate will be consistent. Essentially, we get two chances to get the statistical model correct.\n", | ||
"\n", | ||
"A more in-depth description of doubly robust estimators is available in [this pre-print](https://statnav.files.wordpress.com/2017/10/doublerobustness-preprint.pdf)\n", | ||
"\n", | ||
"## AIPTW\n", | ||
"\n", | ||
"AIPTW takes the following form\n", | ||
"$$E[Y^a] = \\frac{1}{n} \\sum_i^n \\left(\\frac{Y \\times I(A=a)}{\\widehat{\\Pr}(A=a|L)} - \\frac{\\hat{E}[Y|A=a, L] \\times (I(A=a) - \\widehat{\\Pr}(A=a|L))}{1 - \\widehat{\\Pr}(A=a|L)}\\right)$$\n", | ||
"where $\\widehat{\\Pr}(A=a|L)$ comes from the IPTW model and $\\hat{E}[Y|A=a,L]$ comes from the g-formula. If we do some manipulations and assume an infinite sample size, we can end up with\n", | ||
"$$\\hat{E}^{IPW}[Y^a] \\times \\frac{\\Pr(A=a|L)}{\\widehat{\\Pr}(A=a|L, \\mathcal{M})} - \\hat{E}^{STD}[Y^a] \\times \\frac{\\Pr(A=a|L) - \\widehat{\\Pr}(A=a|L, \\mathcal{M})}{\\widehat{\\Pr}(A=a|L, \\mathcal{M})}$$\n", | ||
"from this form, we can see that if as long as one estimate is correct then AIPTW will be unbiased\n", | ||
"\n", | ||
"## An example\n", | ||
"To motivate our example, we will use a simulated data set included with *zEpid*. In the data set, we have a cohort of HIV-positive individuals. We are interested in the sample average treatment effect of antiretroviral therapy (ART) on all-cause mortality at 45-weeks. Based on substantive background knowledge, we believe that the treated and untreated population are exchangeable based gender, age, CD4 T-cell count, and detectable viral load. \n", | ||
"\n", | ||
"In this tutorial, we will focus on a complete case analysis. Therefore, we will drop the `cd4_wk45` column and all the missing data in `dead`. This will leave 517 observations with no missing data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"<class 'pandas.core.frame.DataFrame'>\n", | ||
"Int64Index: 517 entries, 0 to 546\n", | ||
"Data columns (total 12 columns):\n", | ||
"id 517 non-null int64\n", | ||
"male 517 non-null int64\n", | ||
"age0 517 non-null int64\n", | ||
"cd40 517 non-null int64\n", | ||
"dvl0 517 non-null int64\n", | ||
"art 517 non-null int64\n", | ||
"dead 517 non-null float64\n", | ||
"t 517 non-null float64\n", | ||
"age_rs1 517 non-null float64\n", | ||
"age_rs2 517 non-null float64\n", | ||
"cd4_rs1 517 non-null float64\n", | ||
"cd4_rs2 517 non-null float64\n", | ||
"dtypes: float64(6), int64(6)\n", | ||
"memory usage: 52.5 KB\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import pandas as pd\n", | ||
"\n", | ||
"from zepid import load_sample_data, spline\n", | ||
"from zepid.causal.doublyrobust import AIPTW\n", | ||
"\n", | ||
"df = load_sample_data(False)\n", | ||
"df[['age_rs1', 'age_rs2']] = spline(df, 'age0', n_knots=3, term=2, restricted=True)\n", | ||
"df[['cd4_rs1', 'cd4_rs2']] = spline(df, 'cd40', n_knots=3, term=2, restricted=True)\n", | ||
"\n", | ||
"dfcc = df.drop(columns=['cd4_wk45']).dropna()\n", | ||
"dfcc.info()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Our data is now ready to conduct a complete case analysis using TMLE. First, we initialize TMLE with our complete-case data (dfcc), the treatment (art), and the outcome (dead)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"aipw = AIPTW(dfcc, exposure='art', outcome='dead')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Treatment Model\n", | ||
"First, we will specify our treatment model. We believe the sufficient set for the treatment model is gender (`male`), age (`age0`), CD4 T-cell (`cd40`) and detectable viral load (`dvl0`). To relax the functional for assumptions, we will model age and CD4 using restricted quadratic splines" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"----------------------------------------------------------------\n", | ||
"MODEL: art ~ male + age0 + age_rs1 + age_rs2 + cd40 + cd4_rs1 + cd4_rs2 + dvl0\n", | ||
"-----------------------------------------------------------------\n", | ||
" Generalized Linear Model Regression Results \n", | ||
"==============================================================================\n", | ||
"Dep. Variable: art No. Observations: 517\n", | ||
"Model: GLM Df Residuals: 508\n", | ||
"Model Family: Binomial Df Model: 8\n", | ||
"Link Function: logit Scale: 1.0000\n", | ||
"Method: IRLS Log-Likelihood: -206.06\n", | ||
"Date: Mon, 11 Mar 2019 Deviance: 412.12\n", | ||
"Time: 08:24:05 Pearson chi2: 510.\n", | ||
"No. Iterations: 5 Covariance Type: nonrobust\n", | ||
"==============================================================================\n", | ||
" coef std err z P>|z| [0.025 0.975]\n", | ||
"------------------------------------------------------------------------------\n", | ||
"Intercept 1.4498 1.679 0.864 0.388 -1.841 4.741\n", | ||
"male -0.1159 0.321 -0.361 0.718 -0.745 0.513\n", | ||
"age0 -0.1026 0.059 -1.726 0.084 -0.219 0.014\n", | ||
"age_rs1 0.0048 0.003 1.706 0.088 -0.001 0.010\n", | ||
"age_rs2 -0.0077 0.006 -1.373 0.170 -0.019 0.003\n", | ||
"cd40 0.0041 0.004 0.964 0.335 -0.004 0.012\n", | ||
"cd4_rs1 -2.422e-05 1.2e-05 -2.014 0.044 -4.78e-05 -6.49e-07\n", | ||
"cd4_rs2 8.875e-05 4.55e-05 1.952 0.051 -3.81e-07 0.000\n", | ||
"dvl0 -0.0158 0.399 -0.040 0.968 -0.797 0.765\n", | ||
"==============================================================================\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"aipw.exposure_model('male + age0 + age_rs1 + age_rs2 + cd40 + cd4_rs1 + cd4_rs2 + dvl0')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"`AIPTW` uses a logistic regression model to estimate the probabilities of treatment and the corresponding summary of the model fit are printed to the console. \n", | ||
"\n", | ||
"### Outcome Model\n", | ||
"Now, we will estimate the outcome model. We will model the outcomes as ART (`art`), gender (`male`), age (`age0`), CD4 T-cell (`cd40`) and detectable viral load (`dvl0`). Again, we will model age and CD4 using restricted quadratic splines " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"----------------------------------------------------------------\n", | ||
"MODEL: dead ~ art + male + age0 + age_rs1 + age_rs2 + cd40 + cd4_rs1 + cd4_rs2 + dvl0\n", | ||
"-----------------------------------------------------------------\n", | ||
" Generalized Linear Model Regression Results \n", | ||
"==============================================================================\n", | ||
"Dep. Variable: dead No. Observations: 517\n", | ||
"Model: GLM Df Residuals: 507\n", | ||
"Model Family: Binomial Df Model: 9\n", | ||
"Link Function: logit Scale: 1.0000\n", | ||
"Method: IRLS Log-Likelihood: -202.85\n", | ||
"Date: Mon, 11 Mar 2019 Deviance: 405.71\n", | ||
"Time: 08:25:18 Pearson chi2: 535.\n", | ||
"No. Iterations: 6 Covariance Type: nonrobust\n", | ||
"==============================================================================\n", | ||
" coef std err z P>|z| [0.025 0.975]\n", | ||
"------------------------------------------------------------------------------\n", | ||
"Intercept -4.0961 2.713 -1.510 0.131 -9.413 1.220\n", | ||
"art -0.7274 0.392 -1.853 0.064 -1.497 0.042\n", | ||
"male -0.0774 0.334 -0.232 0.817 -0.732 0.577\n", | ||
"age0 0.1605 0.096 1.670 0.095 -0.028 0.349\n", | ||
"age_rs1 -0.0058 0.004 -1.481 0.139 -0.013 0.002\n", | ||
"age_rs2 0.0128 0.006 2.026 0.043 0.000 0.025\n", | ||
"cd40 -0.0123 0.004 -2.987 0.003 -0.020 -0.004\n", | ||
"cd4_rs1 1.872e-05 1.18e-05 1.584 0.113 -4.45e-06 4.19e-05\n", | ||
"cd4_rs2 -3.868e-05 4.59e-05 -0.842 0.400 -0.000 5.13e-05\n", | ||
"dvl0 -0.1261 0.398 -0.317 0.751 -0.906 0.653\n", | ||
"==============================================================================\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"aipw.outcome_model('art + male + age0 + age_rs1 + age_rs2 + cd40 + cd4_rs1 + cd4_rs2 + dvl0')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Again, logistic regression is used to predict the outcome data\n", | ||
"\n", | ||
"### Estimation\n", | ||
"To estimate the risk difference and risk ratio, we will now call the `fit()` function. After this, `AIPTW` gains the attributes `risk_difference` and `risk_ratio`. Additionally, results can be printed to the console using the `summary()` function" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"RD: -0.06857139263314216\n", | ||
"RR: 0.5844630051369846\n", | ||
"----------------------------------------------------------------------\n", | ||
"Risk Difference: -0.0686\n", | ||
"Risk Ratio: 0.5845\n", | ||
"----------------------------------------------------------------------\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"aipw.fit()\n", | ||
"\n", | ||
"print('RD:', aipw.risk_difference)\n", | ||
"print('RR:', aipw.risk_ratio)\n", | ||
"\n", | ||
"aipw.summary()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Interpreting the risk difference, we would conclude that had everyone in our cohort been treated with ART, the risk of all-cause mortality would have been 6.9% points lower than had no one been treated.\n", | ||
"\n", | ||
"### Confidence Intervals\n", | ||
"To obtain correct confidence intervals, we need to use a bootstrap procedure. Influence curves are an alternative, but not currently available" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"95% LCL: -0.12324970574432371\n", | ||
"95% UCL -0.005298084807482341\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"rd_results = []\n", | ||
"for i in range(1000):\n", | ||
" dfs = dfcc.sample(n=dfcc.shape[0],replace=True)\n", | ||
" s = AIPTW(dfs,exposure='art',outcome='dead')\n", | ||
" s.exposure_model('male + age0 + age_rs1 + age_rs2 + cd40 + cd4_rs1 + cd4_rs2 + dvl0',\n", | ||
" print_results=False)\n", | ||
" s.outcome_model('art + male + age0 + age_rs1 + age_rs2 + cd40 + cd4_rs1 + cd4_rs2 + dvl0',\n", | ||
" print_results=False)\n", | ||
" s.fit()\n", | ||
" rd_results.append(s.risk_difference)\n", | ||
"\n", | ||
"\n", | ||
"print('95% LCL:', np.percentile(rd_results, q=2.5))\n", | ||
"print('95% UCL', np.percentile(rd_results, q=97.5))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Under the counterfactual of everyone receiving treatment with ART, the risk of all-cause mortality was 6.9% points lower (95% CL: ) than the counterfactual where no one had been treated.\n", | ||
"\n", | ||
"# Conclusion\n", | ||
"In this tutorial, I introduced the concept of doubly-robust estimators and detailed augmented-IPTW. I demonstrated estimation with `AIPTW` using *zEpid* and how to obtain confidence intervals. Please view other tutorials for information on other functionality within *zEpid*\n", | ||
"\n", | ||
"## References\n", | ||
"Funk MJ, Westreich D, Wiesen C, Stürmer T, Brookhart MA, Davidian M. (2011). Doubly robust estimation of causal effects. *AJE*, 173(7), 761-767.\n", | ||
"\n", | ||
"Keil AP et al. (2018). Resolving an apparent paradox in doubly robust estimators. *AJE*, 187(4), 891-892.\n", | ||
"\n", | ||
"Robins JM, Rotnitzky A, Zhao LP. (1994). Estimation of regression coefficients when some regressors are not always observed. *JASA*, 89(427), 846-866." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |