-
Notifications
You must be signed in to change notification settings - Fork 1
/
vignette.Rmd
94 lines (75 loc) · 2.83 KB
/
vignette.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
---
title: "Vignette"
output: html_document
date: '2023-07-26'
---
```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE)
```
```{r}
# install package if needed
if(!require(causalHAL)) {
devtools::install_github("Larsvanderlaan/causalHAL")
}
```
# ADMLE for ATE using adaptive partially linear regression models
## Generate example dataset
```{r}
# Dataset used for simulations
get_data <- function(n, pos_const, muIsHard = TRUE) {
# n: sample size
# pos_const: used to control treatment overlap
# Whether outcome regression is hard or simple.
# covariate dimension
d <- 4
W <- replicate(d, runif(n, -1, 1))
colnames(W) <- paste0("W", 1:d)
# propensity score
pi0 <- plogis(pos_const * ( W[,1] + sin(4*W[,1]) + W[,2] + cos(4*W[,2]) + W[,3] + sin(4*W[,3]) + W[,4] + cos(4*W[,4]) ))
# treatment
A <- rbinom(n, 1, pi0)
# control outcome regression
if(muIsHard) {
mu0 <- sin(4*W[,1]) + sin(4*W[,2]) + sin(4*W[,3])+ sin(4*W[,4]) + cos(4*W[,2])
} else {
mu0 <- W[,1] + abs(W[,2]) + W[,3] + abs(W[,4])
}
# CATE
tau <- 1 + W[,1] + abs(W[,2]) + cos(4*W[,3]) + W[,4]
# outcome
Y <- rnorm(n, mu0 + A * tau, 0.5)
return(list(W=W, A = A, Y = Y, ATE = 1.31, pi = pi0, mu0 = mu0, tau = tau ))
}
```
## Run ADMLE using HAL and glmnet
```{r}
library(causalHAL)
seed <- rnorm(1)
data <- get_data(1000, 1, TRUE)
print(paste0("True ATE: ", data$ATE))
# get nuisance functions for R-learner
# User-supplied estimate of propensity score pi = P(A=1|W)
pi.hat <- data$pi
# User-supplied estimate of treatment-marginalized outcome regression m = E(Y|W)
m.hat <- data$mu0 * pi.hat + (data$mu0 + data$tau) * (1-pi.hat)
# ADMLE for ATE using partially linear model with HAL.
# Fits additive piece-wise linear spline model for CATE with 50 knot points per covariate using highly adaptive lasso (see tlverse/hal9001 github R package)
set.seed(seed)
ADMLE_fit <- fit_cate_hal_partially_linear(data$W, data$A, data$Y,
m.hat = m.hat,
pi.hat = pi.hat,
smoothness_orders_cate = 1, num_knots_cate = c(50), max_degree_cate = 1)
# Provides estimates and CI for ATE
inference_ate(ADMLE_fit)
# Same analysis but using glmnet implementation with hal9001-basis design matrix.
# May not reproduce estimates exactly but should be close.
# For those not familiar with hal9001 package, the below code may be easier to play around with.
basis_list <- hal9001::enumerate_basis(data$W, smoothness_orders = 1, num_knots = 50, max_degree = 1)
tau_basis <- hal9001::make_design_matrix(data$W, basis_list)
set.seed(seed)
ADMLE_fit <- fit_cate_lasso_partially_linear(tau_basis, data$A, data$Y,
m.hat = m.hat,
pi.hat = pi.hat, standardize = FALSE)
# Provides estimates and CI for ATE
inference_ate(ADMLE_fit)
```