-
Notifications
You must be signed in to change notification settings - Fork 2
/
makeFigure-exp1-difficulty-Qvalues.R
210 lines (182 loc) · 10.6 KB
/
makeFigure-exp1-difficulty-Qvalues.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
rm(list=ls())
library(snowfall)
source ("dmc/dmc.R")
source('utils.R')
source('models.R')
samplesDir <- 'samples'
savePlot <- FALSE
getDataPpBPIC <- function(modelName, dataName, do.plot=FALSE, BPIConly=FALSE) {
model <- setupModel(modelName) # calls load_model(), which loads transform.dmc() and transform2.dmc()
dat <- loadData(dataName, removeBlock = NULL)[['dat']]
fn <- paste0('model-', modelName, '_data-', dataName)
# Load, generate posterior preds -------------------------------------------
samples <- loadSamples(fn, samplesDir)
data <- lapply(samples, function(x) x$data)
if(do.plot) plot.dmc(samples, hyper=TRUE, density=TRUE, layout=c(4,4))
if(!BPIConly) {
pp = h.post.predict.dmc(samples = samples, adapt=TRUE,save.simulation = TRUE, cores=30)
ppNoSim <- h.pp.summary(pp, samples=samples)
#### Append stimulus set info to data & model --------
nBins <- 10
pp2 <- lapply(1:length(pp), addStimSetInfo, input=pp, orig_dat=dat)
data2 <- lapply(1:length(data), addStimSetInfo, input=data, orig_dat=dat)
if(!sfIsRunning()) sfInit(parallel=TRUE, cpus =30); sfLibrary(moments)
pp3 <- sfLapply(pp2, calculateByBin)
data3 <- lapply(data2, calculateByBin)
bpics <- h.IC.dmc(samples)
return(list('pp3'=pp3, 'data3'=data3, 'BPIC'=bpics))
} else {
return(list('BPIC'=h.IC.dmc(samples)))
}
}
# Load BPICs & quantiles per bin ---------------------------------------------------------------
modelName <- 'alba-RL-mag' #'arw-RL-mag' #'ddm-RL-st0'
tmp <- getDataPpBPIC(modelName, 'exp1')
BPIC <- tmp$BPIC
excludePerfectAcc <- FALSE
if(excludePerfectAcc) {
tmp2 <- tmp
tmp2[['data3']] <- tmp[['data3']][sapply(tmp[['data3']], function(x) mean(x[x$bin==1&x$ease=='0.6','acc'])<1)]
tmp2[['pp3']] <- tmp[['pp3']][sapply(tmp[['data3']], function(x) mean(x[x$bin==1&x$ease=='0.6','acc'])<1)]
data3 <- tmp2[['data3']]
pp3 <- tmp2[['pp3']]
modelName <- paste0(modelName, '-exclperfacc')
} else {
data3 <- tmp[['data3']]
pp3 <- tmp[['pp3']]
}
#data3 <- tmp[['data3']]
#pp3 <- tmp[['pp3']]
q10RTsByEase <- list(getDescriptives(data3, dep.var='RT.10.', attr.name='qRTsCorrectByEase', id.var1='~bin*ease', id.var2=NULL),
getDescriptives(pp3, dep.var='RT.10.', attr.name='qRTsCorrectByEase', id.var1='~reps*bin*ease', id.var2=NULL))
q50RTsByEase <- list(getDescriptives(data3, dep.var='RT.50.', attr.name='qRTsCorrectByEase', id.var1='~bin*ease', id.var2=NULL),
getDescriptives(pp3, dep.var='RT.50.', attr.name='qRTsCorrectByEase', id.var1='~reps*bin*ease', id.var2=NULL))
q90RTsByEase <- list(getDescriptives(data3, dep.var='RT.90.', attr.name='qRTsCorrectByEase', id.var1='~bin*ease', id.var2=NULL),
getDescriptives(pp3, dep.var='RT.90.', attr.name='qRTsCorrectByEase', id.var1='~reps*bin*ease', id.var2=NULL))
q10RTsByEaseE <- list(getDescriptives(data3, dep.var='RT.10.', attr.name='qRTsErrorByEase', id.var1='~bin*ease', id.var2=NULL),
getDescriptives(pp3, dep.var='RT.10.', attr.name='qRTsErrorByEase', id.var1='~reps*bin*ease', id.var2=NULL))
q50RTsByEaseE <- list(getDescriptives(data3, dep.var='RT.50.', attr.name='qRTsErrorByEase', id.var1='~bin*ease', id.var2=NULL),
getDescriptives(pp3, dep.var='RT.50.', attr.name='qRTsErrorByEase', id.var1='~reps*bin*ease', id.var2=NULL))
q90RTsByEaseE <- list(getDescriptives(data3, dep.var='RT.90.', attr.name='qRTsErrorByEase', id.var1='~bin*ease', id.var2=NULL),
getDescriptives(pp3, dep.var='RT.90.', attr.name='qRTsErrorByEase', id.var1='~reps*bin*ease', id.var2=NULL))
meanAccByEase <- list(getDescriptives(data3, dep.var='acc', attr.name='AccByEase', id.var1='~bin*ease', id.var2=NULL),
getDescriptives(pp3, dep.var='acc', attr.name='AccByEase', id.var1='~reps*bin*ease', id.var2=NULL))
# Plot posterior predictives
if(savePlot) pdf(file=paste0('./figures/exp1_difficulty_', modelName, '-QQ-horizontal.pdf'), width=7, height=7/4*3)
par(oma=c(3,3,1,0), mar=c(0, 1, 1, 0) + 0.1, mfcol=c(3,4), mgp=c(2.75,.75,0), las=1, bty='l')
i <- 0
corrRTylim <- c(0.45, 1.1)
errRTylim <- c(0.45, 1.1)
data.cex=1.5
for(ease in unique(meanAccByEase[[1]]$ease)) {
i <- i+1
idxD = meanAccByEase[[1]]$ease == ease
idxM = meanAccByEase[[2]]$ease == ease
plotDataPPBins(data=meanAccByEase[[1]][idxD,], pp=meanAccByEase[[2]][idxM,],
xaxt='n', draw.legend = i==1, data.cex = data.cex,
dep.var='acc', ylab='', xlab = '', yaxt='n',
legend.pos='topleft', ylim=c(0.5, 0.95), hline.by=0.1)
axis(1, at=seq(2, 10, 2), labels=rep(NA, 5), lwd=2)
if(i == 1) {
mtext('Accuracy', side=2, cex=.66, line=3, las=0, font=1)
axis(2, at=seq(.5, .9, .1), lwd=1.5)
} else {
axis(2, at=seq(.5, .9, .1), labels=rep(NA, 5), lwd=1.5)
}
if(i == 1) title('0.6/0.4 (Hardest)')
if(i == 2) title('0.65/0.35')
if(i == 3) title('0.7/0.3')
if(i == 4) title('0.8/0.2 (Easiest)')
##
plotDataPPBins(data=q10RTsByEase[[1]][idxD,], pp=q10RTsByEase[[2]][idxM,], dep.var='RT.10.',
ylim=corrRTylim, xaxt='n', ylab='', yaxt='n', draw.legend = FALSE, data.cex = data.cex)
plotDataPPBins(data=q50RTsByEase[[1]][idxD,], pp=q50RTsByEase[[2]][idxM,], dep.var='RT.50.', plot.new = FALSE, draw.legend=FALSE, data.cex = data.cex)
plotDataPPBins(data=q90RTsByEase[[1]][idxD,], pp=q90RTsByEase[[2]][idxM,], dep.var='RT.90.', plot.new = FALSE, draw.legend=FALSE, data.cex = data.cex)
axis(1, at=seq(2, 10, 2), labels=NA, lwd=1.5)
if(i == 1) {
mtext('Correct RTs (s)', side=2, cex=.66, line=3, las=0, font=1)
axis(2, seq(.4, 1.2, .2), lwd=1.5)
} else {
axis(2, seq(.4, 1.2, .2), labels=NA, lwd=1.5)
}
##
plotDataPPBins(data=q10RTsByEaseE[[1]][idxD,], pp=q10RTsByEaseE[[2]][idxM,], dep.var='RT.10.', ylim=errRTylim, xaxt='n', ylab='', yaxt='n', draw.legend = FALSE, data.cex = data.cex)
plotDataPPBins(data=q50RTsByEaseE[[1]][idxD,], pp=q50RTsByEaseE[[2]][idxM,], dep.var='RT.50.', plot.new = FALSE, draw.legend=FALSE, data.cex = data.cex)
plotDataPPBins(data=q90RTsByEaseE[[1]][idxD,], pp=q90RTsByEaseE[[2]][idxM,], dep.var='RT.90.', plot.new = FALSE, draw.legend=FALSE, data.cex = data.cex)
if(i == 1) {
mtext('Error RTs (s)', side=2, cex=.66, line=3, las=0, font=1)
axis(2, seq(.4, 1.2, .2), lwd=1.5)
} else {
axis(2, seq(.4, 1.2, .2), labels=NA, lwd=1.5)
}
axis(1, at=seq(2, 10, 2), lwd=1.5)
mtext('Trial bin', side=1, cex=.66, line=2)
}
if(savePlot) dev.off()
# Q-values, drift rates ---------------------------------------------------
get.color <- function(ease) {
if(ease == "0.6") return(1)
if(ease == "0.4") return(2)
if(ease == "0.3") return(3)
if(ease == "0.2") return(4)
}
draw.polygon <- function(pp, dep.var, xaxis='bin', colorM='blue', plot.model.points=FALSE) {
lowerQ <- aggregate(as.formula(paste0(dep.var, '~bin')), pp, quantile, .025)
upperQ <- aggregate(as.formula(paste0(dep.var, '~bin')), pp, quantile, .975)
xs <- c(lowerQ[,xaxis], rev(lowerQ[,xaxis]))
ys <- c(lowerQ[,dep.var], rev(upperQ[,dep.var]))
polygon(xs, ys, col=rgb(col2rgb(colorM)[1]/255, col2rgb(colorM)[2]/255, col2rgb(colorM)[3]/255, alpha=.3), lty = NULL, border=NA)
if(plot.model.points) points(pp[,xaxis], pp[,dep.var], pch=20, col=colorM, cex=.01)
}
allQ1OverTimeM <- do.call(rbind, (lapply(1:length(pp3), function(x) {tmp <- attr(pp3[[x]], 'SR.r1OverBins'); tmp$s <- x; tmp})))
allQ2OverTimeM <- do.call(rbind, (lapply(1:length(pp3), function(x) {tmp <- attr(pp3[[x]], 'SR.r2OverBins'); tmp$s <- x; tmp})))
meanQ1OverTimeM <- aggregate(SR.r1~reps*bin*ease, allQ1OverTimeM, mean)
meanQ2OverTimeM <- aggregate(SR.r2~reps*bin*ease, allQ2OverTimeM, mean)
# differences
deltaQ <- allQ1OverTimeM
deltaQ$SR.r2 <- allQ2OverTimeM$SR.r2
deltaQ$deltaQ <- deltaQ$SR.r1 - deltaQ$SR.r2
# sums
deltaQ$sumQ <- deltaQ$SR.r1 + deltaQ$SR.r2
meanDeltaQOverTime <- aggregate(deltaQ~reps*bin*ease, deltaQ, mean)
meanSumQOverTime <- aggregate(sumQ~reps*bin*ease, deltaQ, mean)
# drift rates
allV1OverTimeM <- do.call(rbind, (lapply(1:length(pp3), function(x) {tmp <- attr(pp3[[x]], 'mean_v.r1OverBins'); tmp$s <- x; tmp})))
allV2OverTimeM <- do.call(rbind, (lapply(1:length(pp3), function(x) {tmp <- attr(pp3[[x]], 'mean_v.r2OverBins'); tmp$s <- x; tmp})))
meanV1OverTimeM <- aggregate(mean_v.r1~reps*bin*ease, allV1OverTimeM, mean)
meanV2OverTimeM <- aggregate(mean_v.r2~reps*bin*ease, allV2OverTimeM, mean)
# Plot --------------------------------------------------------------------
if(savePlot) pdf(file='./figures/q-values.pdf', width=7, height=2.5)
par(mfrow=c(1,4), las=1, bty='l', oma=c(0,1,1,0), mar=c(4, 3, 2, 0.5) + 0.1, mgp=c(2.25,.75,0))
# Q-values
plot(0,0, type='n', xlim=range(meanQ1OverTimeM$bin)+c(-.5, .5), ylim=c(0, .85), xlab='Trial bin', ylab='Q-values', main='A. Q-values')
abline(h=seq(0, 1, .1), col='grey')
abline(v=seq(0, 10, 2), col='grey')
for(ease in unique(meanQ1OverTimeM$ease)) {
draw.polygon(meanQ1OverTimeM[meanQ1OverTimeM$ease==ease,], dep.var='SR.r1', colorM=get.color(ease)) #ifelse(ease=="0.6", 1, ifelse(ease=="0.4", 2, ifelse(ease=="0.3", 3, 4))))
draw.polygon(meanQ2OverTimeM[meanQ2OverTimeM$ease==ease,], dep.var='SR.r2', colorM=get.color(ease))
}
# delta here
plot(0,0, type='n', xlim=range(meanDeltaQOverTime$bin)+c(-.5, .5), ylim=c(0, .85), xlab='Trial bin', ylab=expression(paste(Delta, 'Q-values')), main=expression(bold(paste('B. ', Delta, 'Q-values'))))
abline(h=seq(0, 1, .1), col='grey')
abline(v=seq(0, 10, 2), col='grey')
for(ease in unique(meanDeltaQOverTime$ease)) {
draw.polygon(meanDeltaQOverTime[meanDeltaQOverTime$ease==ease,], dep.var='deltaQ', colorM=get.color(ease))
}
# sum here
plot(0,0, type='n', xlim=range(meanDeltaQOverTime$bin)+c(-.5, .5), ylim=c(0, .85), xlab='Trial bin', ylab=expression(paste(Sigma, 'Q-values')), main=expression(bold(paste('C. ', Sigma, 'Q-values'))))
abline(h=seq(0, 1, .1), col='grey')
abline(v=seq(0, 10, 2), col='grey')
for(ease in unique(meanDeltaQOverTime$ease)) {
draw.polygon(meanSumQOverTime[meanSumQOverTime$ease==ease,], dep.var='sumQ', colorM=get.color(ease))
}
legend('bottomright', legend=c('0.8/0.2', '0.7/0.3', '0.65/0.35', '0.6/0.4'), bg='white', col=1:4, pch=15, title='Difficulty')
# drift rates
plot(0,0, type='n', xlim=range(meanV1OverTimeM$bin)+c(-.5, .5), ylim=c(1.0, 4.5), xlab='Trial bin', ylab='Drift rates', main='D. Drift rates')
abline(h=seq(0, 5, .25), col='grey')
abline(v=seq(0, 10, 2), col='grey')
for(ease in unique(allV1OverTimeM$ease)) {
draw.polygon(meanV1OverTimeM[meanV1OverTimeM$ease==ease,], dep.var='mean_v.r1', colorM=get.color(ease))
draw.polygon(meanV2OverTimeM[meanV2OverTimeM$ease==ease,], dep.var='mean_v.r2', colorM=get.color(ease))
}
if(savePlot) dev.off()