-
Notifications
You must be signed in to change notification settings - Fork 2
/
makeFigure-exp1-modelComparison-Appendix.R
147 lines (130 loc) · 6.82 KB
/
makeFigure-exp1-modelComparison-Appendix.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
####### NB This is the same code as makeFigure-exp1-modelComparison.R, just with different models
rm(list=ls())
library(snowfall)
source ("dmc/dmc.R")
source('utils.R')
source('models.R')
samplesDir <- 'samples'
savePlot <- FALSE
getDataPpBPIC <- function(modelName, dataName, do.plot=FALSE, BPIConly=FALSE) {
model <- setupModel(modelName) # calls load_model(), which loads transform.dmc() and transform2.dmc()
dat <- loadData(dataName, removeBlock = NULL)[['dat']]
fn <- paste0('model-', modelName, '_data-', dataName)
# Load, generate posterior preds -------------------------------------------
samples <- loadSamples(fn, samplesDir)
data <- lapply(samples, function(x) x$data)
if(do.plot) plot.dmc(samples, hyper=TRUE, density=TRUE, layout=c(4,4))
if(!BPIConly) {
pp = h.post.predict.dmc(samples = samples, adapt=TRUE,save.simulation = TRUE, cores=30)
ppNoSim <- h.pp.summary(pp, samples=samples)
#### Append stimulus set info to data & model --------
nBins <- 10
pp2 <- lapply(1:length(pp), addStimSetInfo, input=pp, orig_dat=dat)
data2 <- lapply(1:length(data), addStimSetInfo, input=data, orig_dat=dat)
if(!sfIsRunning()) sfInit(parallel=TRUE, cpus =30); sfLibrary(moments)
pp3 <- sfLapply(pp2, calculateByBin)
data3 <- lapply(data2, calculateByBin)
bpics <- h.IC.dmc(samples)
return(list('pp3'=pp3, 'data3'=data3, 'BPIC'=bpics))
} else {
return(list('BPIC'=h.IC.dmc(samples)))
}
}
getqRTs <- function(data3, pp3) {
q10RTsOverTime <- list(getDescriptives(data3, dep.var='RT.10.', attr.name='qRTsCorrect', id.var1='~bin*s', id.var2="~bin"),
getDescriptives(pp3, dep.var='RT.10.', attr.name='qRTsCorrect'))
q50RTsOverTime <- list(getDescriptives(data3, dep.var='RT.50.', attr.name='qRTsCorrect', id.var1='~bin*s', id.var2="~bin"),
getDescriptives(pp3, dep.var='RT.50.', attr.name='qRTsCorrect'))
q90RTsOverTime <- list(getDescriptives(data3, dep.var='RT.90.', attr.name='qRTsCorrect', id.var1='~bin*s', id.var2="~bin"),
getDescriptives(pp3, dep.var='RT.90.', attr.name='qRTsCorrect'))
q10RTsOverTimeE <- list(getDescriptives(data3, dep.var='RT.10.', attr.name='qRTsError', id.var1='~bin*s', id.var2="~bin"),
getDescriptives(pp3, dep.var='RT.10.', attr.name='qRTsError'))
q50RTsOverTimeE <- list(getDescriptives(data3, dep.var='RT.50.', attr.name='qRTsError', id.var1='~bin*s', id.var2="~bin"),
getDescriptives(pp3, dep.var='RT.50.', attr.name='qRTsError'))
q90RTsOverTimeE <- list(getDescriptives(data3, dep.var='RT.90.', attr.name='qRTsError', id.var1='~bin*s', id.var2="~bin"),
getDescriptives(pp3, dep.var='RT.90.', attr.name='qRTsError'))
meanAccOverTime <- list(getDescriptives(data3, dep.var='acc', attr.name='AccOverBins', id.var1='~bin*s', id.var2="~bin"),
getDescriptives(pp3, dep.var='acc', attr.name='AccOverBins'))
return(list('q10RTsOverTime'=q10RTsOverTime,
'q50RTsOverTime'=q50RTsOverTime,
'q90RTsOverTime'=q90RTsOverTime,
'q10RTsOverTimeE'=q10RTsOverTimeE,
'q50RTsOverTimeE'=q50RTsOverTimeE,
'q90RTsOverTimeE'=q90RTsOverTimeE,
'meanAccOverTime'=meanAccOverTime))
}
# Load BPICs & quantiles per bin ---------------------------------------------------------------
# DDM A1
tmp <- getDataPpBPIC('ddm-RL-nonlinear', 'exp1')
BPICDDMA1 <- tmp$BPIC
qRTsDDMA1 <- getqRTs(tmp[['data3']], tmp[['pp3']])
# DDM A2
tmp <- getDataPpBPIC('ddm-RL-svsz', 'exp1')
BPICDDMA2 <- tmp$BPIC
qRTsDDMA2 <- getqRTs(tmp[['data3']], tmp[['pp3']])
# DDM A3
tmp <- getDataPpBPIC('ddm-RL-st0', 'exp1')
BPICDDMA3 <- tmp$BPIC
qRTsDDMA3 <- getqRTs(tmp[['data3']], tmp[['pp3']])
# DDM A4
tmp <- getDataPpBPIC('ddm-RL-nonlinear-svszst0', 'exp1')
BPICDDMA4 <- tmp$BPIC
qRTsDDMA4 <- getqRTs(tmp[['data3']], tmp[['pp3']])
# Combine accuracy per bin, quantile per bin
allqRTs <- list(qRTsDDMA1, qRTsDDMA2, qRTsDDMA3, qRTsDDMA4)
# Main text model comparison for experiment 1
allBPICs <- cbind(BPICDDMA1[,2], BPICDDMA2[,2], BPICDDMA3[,2], BPICDDMA4[,2])
apply(allBPICs, 2, sum)
# Plot posterior predictives
if(savePlot) pdf('./figures/modelcomparison-exp1-RLDDMs.pdf', width=7, height=7/4*3)
par(oma=c(3,4,2,0), mar=c(0, 0, 1, 0.5) + 0.1, mfcol=c(3,4), mgp=c(2.75,.75,0), las=1, bty='l')
#par(oma=c(3,4,1,0), mar=c(0, 0, 1, 0.5) + 0.1, mfcol=c(3,4), mgp=c(2.75,.75,0), las=1, bty='l')
i <- 0
data.cex=1.5
corrRTylim <- errRTylim <- c(.45, 1.1)
for(qRTs in allqRTs) {
i <- i+1
plotDataPPBins(data=qRTs$meanAccOverTime[[1]], pp=qRTs$meanAccOverTime[[2]],
xaxt='n', draw.legend = i==1, data.cex = data.cex,
dep.var='acc', ylab='', xlab = '', yaxt='n',
legend.pos='topleft', ylim=c(0.5, 0.85), hline.by=0.1)
axis(1, at=seq(2, 10, 2), labels=rep(NA, 5), lwd=2)
if(i == 1) {
mtext('Accuracy', side=2, cex=.66, line=3, las=0, font=1)
axis(2, at=seq(.5, .9, .1), lwd=1.5)
} else {
axis(2, at=seq(.5, .9, .1), labels=rep(NA, 5), lwd=1.5)
}
par(xpd=NA)
if(i == 1) title('RL-DDM A1', line=1.2)
if(i == 2) title('RL-DDM A2', line=1.2)
if(i == 3) title('RL-DDM A3', line=1.2)
if(i == 4) title('RL-DDM A4', line=1.2)
mtext(paste0('BPIC = ', round(apply(allBPICs, 2, sum)[i])), cex=.66)
par(xpd=FALSE)
##
plotDataPPBins(data=qRTs$q10RTsOverTime[[1]], pp=qRTs$q10RTsOverTime[[2]], dep.var='RT.10.',
ylim=corrRTylim, xaxt='n', ylab='', yaxt='n', draw.legend = FALSE, data.cex = data.cex)
plotDataPPBins(data=qRTs$q50RTsOverTime[[1]], pp=qRTs$q50RTsOverTime[[2]], dep.var='RT.50.', plot.new = FALSE, draw.legend=FALSE, data.cex = data.cex)
plotDataPPBins(data=qRTs$q90RTsOverTime[[1]], pp=qRTs$q90RTsOverTime[[2]], dep.var='RT.90.', plot.new = FALSE, draw.legend=FALSE, data.cex = data.cex)
axis(1, at=seq(2, 10, 2), labels=NA, lwd=1.5)
if(i == 1) {
mtext('Correct RTs (s)', side=2, cex=.66, line=3, las=0, font=1)
axis(2, seq(.4, 1.2, .2), lwd=1.5)
} else {
axis(2, seq(.4, 1.2, .2), labels=NA, lwd=1.5)
}
##
plotDataPPBins(data=qRTs$q10RTsOverTimeE[[1]], pp=qRTs$q10RTsOverTimeE[[2]], dep.var='RT.10.', ylim=errRTylim, xaxt='n', ylab='', yaxt='n', draw.legend = FALSE, data.cex = data.cex)
plotDataPPBins(data=qRTs$q50RTsOverTimeE[[1]], pp=qRTs$q50RTsOverTimeE[[2]], dep.var='RT.50.', plot.new = FALSE, draw.legend=FALSE, data.cex = data.cex)
plotDataPPBins(data=qRTs$q90RTsOverTimeE[[1]], pp=qRTs$q90RTsOverTimeE[[2]], dep.var='RT.90.', plot.new = FALSE, draw.legend=FALSE, data.cex = data.cex)
if(i == 1) {
mtext('Error RTs (s)', side=2, cex=.66, line=3, las=0, font=1)
axis(2, seq(.4, 1.2, .2), lwd=1.5)
} else {
axis(2, seq(.4, 1.2, .2), labels=NA, lwd=1.5)
}
axis(1, at=seq(2, 10, 2), lwd=1.5)
mtext('Trial bin', side=1, cex=.66, line=2)
}
if(savePlot) dev.off()