-
Notifications
You must be signed in to change notification settings - Fork 2
/
makeFigure-SAT-modelComparison-simple.R
251 lines (222 loc) · 12.4 KB
/
makeFigure-SAT-modelComparison-simple.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
## load data
rm(list=ls())
library(snowfall)
source ("dmc/dmc.R")
source('utils.R')
source('models.R')
samplesDir <- 'samples'
savePlot <- FALSE
calculateByBin <- function(df) {
df$acc <- as.integer(df$R)==2
attr(df, 'qRTs') <- do.call(data.frame, aggregate(RT~reps*bin, df, quantile, probs=seq(.1, .9, .4))) #cbind(quants[quants$acc==1,c('reps', 'bin')], quants[quants$acc==1,'RT'][,1])
attr(df, 'qRTsCorrect') <- do.call(data.frame, aggregate(RT~reps*bin, df[df$acc==1,], quantile, probs=seq(.1, .9, .4))) #cbind(quants[quants$acc==1,c('reps', 'bin')], quants[quants$acc==1,'RT'][,1])
attr(df, 'qRTsError') <- do.call(data.frame, aggregate(RT~reps*bin, df[df$acc==0,], quantile, probs=seq(.1, .9, .4))) #cbind(quants[quants$acc==1,c('reps', 'bin')], quants[quants$acc==1,'RT'][,1])
attr(df, 'qRTsCorrectByCue') <- do.call(data.frame, aggregate(RT~reps*bin*cue, df[df$acc==1,], quantile, probs=seq(.1, .9, .4))) #cbind(quants[quants$acc==1,c('reps', 'bin')], quants[quants$acc==1,'RT'][,1])
attr(df, 'qRTsErrorByCue') <- do.call(data.frame, aggregate(RT~reps*bin*cue, df[df$acc==0,], quantile, probs=seq(.1, .9, .4))) #cbind(quants[quants$acc==1,c('reps', 'bin')], quants[quants$acc==1,'RT'][,1])
attr(df, 'RTsOverBins') <- aggregate(RT~reps*bin, df, mean)
attr(df, 'AccOverBins') <- aggregate(acc~reps*bin, df, mean)
attr(df, 'SkewOverBins') <- aggregate(RT~reps*bin, df, skewness)
attr(df, 'RTsByChoiceByEase') <- aggregate(RT~reps*bin*ease*R, df, mean)
attr(df, 'RTsByEase') <- aggregate(RT~reps*ease*bin, df, mean)
attr(df, 'AccByEase') <- aggregate(acc~reps*ease*bin, df, mean)
attr(df, 'SkewByEase') <- aggregate(RT~reps*ease*bin, df, skewness)
attr(df, 'RTsByCue') <- aggregate(RT~reps*cue*bin, df, mean)
attr(df, 'AccByCue') <- aggregate(acc~reps*cue*bin, df, mean)
attr(df, 'SkewByCue') <- aggregate(RT~reps*cue*bin, df, skewness)
if(!is.null(attr(df, 'adapt'))) {
adapted <- attr(df, 'adapt')
adapted$ease <- df$ease
adapted$bin <- df$bin
for(colName in colnames(adapted)) {
form <- as.formula(paste0(colName, '~reps*ease*bin'))
attr(df, paste0(colName, 'OverBins')) <- aggregate(form, adapted, mean)
}
}
df
}
getDataPpBPIC <- function(modelName, dataName, do.plot=FALSE, BPIConly=FALSE) {
model <- setupModel(modelName) # calls load_model(), which loads transform.dmc() and transform2.dmc()
dat <- loadData(dataName)[['dat']]
# data <- loadData(dataName)[['data']]
fn <- paste0('model-', modelName, '_data-', dataName)
# Load, generate posterior preds -------------------------------------------
samples <- loadSamples(fn, samplesDir)
data <- lapply(samples, function(x) x$data)
if(do.plot) plot.dmc(samples, hyper=TRUE, density=TRUE, layout=c(4,4))
if(!BPIConly) {
pp = h.post.predict.dmc(samples = samples, adapt=TRUE,save.simulation = TRUE, cores=30)
#ppNoSim <- h.pp.summary(pp, samples=samples)
#### Append stimulus set info to data & model --------
nBins <- 10
if(!'ease' %in% colnames(dat)) dat$ease <- 1
pp2 <- lapply(1:length(pp), addStimSetInfo, input=pp, orig_dat=dat)
data2 <- lapply(1:length(data), addStimSetInfo, input=data, orig_dat=dat)
if(!sfIsRunning()) sfInit(parallel=TRUE, cpus =30); sfLibrary(moments)
pp3 <- sfLapply(pp2, calculateByBin)
data3 <- lapply(data2, calculateByBin)
bpics <- h.IC.dmc(samples)
return(list('pp3'=pp3, 'data3'=data3, 'BPIC'=bpics))
} else {
return(list('BPIC'=h.IC.dmc(samples)))
}
}
getqRTsByCue <- function(data3, pp3) {
q10RTsByCue <- list(getDescriptives(data3, dep.var='RT.10.', attr.name='qRTsCorrectByCue', id.var1='~bin*cue', id.var2=NULL),
getDescriptives(pp3, dep.var='RT.10.', attr.name='qRTsCorrectByCue', id.var1='~reps*bin*cue', id.var2=NULL))
q50RTsByCue <- list(getDescriptives(data3, dep.var='RT.50.', attr.name='qRTsCorrectByCue', id.var1='~bin*cue', id.var2=NULL),
getDescriptives(pp3, dep.var='RT.50.', attr.name='qRTsCorrectByCue', id.var1='~reps*bin*cue', id.var2=NULL))
q90RTsByCue <- list(getDescriptives(data3, dep.var='RT.90.', attr.name='qRTsCorrectByCue', id.var1='~bin*cue', id.var2=NULL),
getDescriptives(pp3, dep.var='RT.90.', attr.name='qRTsCorrectByCue', id.var1='~reps*bin*cue', id.var2=NULL))
q10RTsByCueE <- list(getDescriptives(data3, dep.var='RT.10.', attr.name='qRTsErrorByCue', id.var1='~bin*cue', id.var2=NULL),
getDescriptives(pp3, dep.var='RT.10.', attr.name='qRTsErrorByCue', id.var1='~reps*bin*cue', id.var2=NULL))
q50RTsByCueE <- list(getDescriptives(data3, dep.var='RT.50.', attr.name='qRTsErrorByCue', id.var1='~bin*cue', id.var2=NULL),
getDescriptives(pp3, dep.var='RT.50.', attr.name='qRTsErrorByCue', id.var1='~reps*bin*cue', id.var2=NULL))
q90RTsByCueE <- list(getDescriptives(data3, dep.var='RT.90.', attr.name='qRTsErrorByCue', id.var1='~bin*cue', id.var2=NULL),
getDescriptives(pp3, dep.var='RT.90.', attr.name='qRTsErrorByCue', id.var1='~reps*bin*cue', id.var2=NULL))
meanAccByCue <- list(getDescriptives(data3, dep.var='acc', attr.name='AccByCue', id.var1='~bin*cue', id.var2=NULL),
getDescriptives(pp3, dep.var='acc', attr.name='AccByCue', id.var1='~reps*bin*cue', id.var2=NULL))
return(list('q10RTsByCue'=q10RTsByCue,
'q50RTsByCue'=q50RTsByCue,
'q90RTsByCue'=q90RTsByCue,
'q10RTsByCueE'=q10RTsByCueE,
'q50RTsByCueE'=q50RTsByCueE,
'q90RTsByCueE'=q90RTsByCueE,
'meanAccByCue'=meanAccByCue))
}
# Load quantiles of winning model & RL-fARD, BPICs ---------------------------------------------------
modelName <- 'ddm-RL-SAT-a-st0' #'arw-RL-mag-SAT-V02'
dataName <- 'exp3'
fn <- paste0('model-', modelName, '_data-', dataName)
tmp <- getDataPpBPIC(modelName, dataName)
qRTs <- getqRTsByCue(tmp[['data3']], tmp[['pp3']])
# Combine -----------------------------------------------------------------
allqRTs <- list(qRTs)
# Plot single --------------------------------------------------------------------
layoutM <- matrix(1:7, nrow=5, byrow=TRUE)
layoutM[c(1, 5),1:2] <- 1
#layoutM[c(1, 5),4:5] <- 9
layoutM[2:4,1:2] <- 2:7 #matrix(c(2:13), nrow=3, byrow=TRUE)
#layoutM[2:4,4:5] <- 10:15 #matrix(c(2:13), nrow=3, byrow=TRUE)
#layoutM[,3] <- 8
layoutM
if(savePlot) pdf('./figures/exp3-SAT-RLDDMst0-2.pdf', width=3.5, height=7/4*3)
layout(layoutM, heights = c(0.01, .8, 1, 1, 0.01), widths=c(1,1)) #,.1,1,1))
par(oma=c(3,4,2,0), mar=c(0, 0, 1, 0.5) + 0.1,
mgp=c(2.75,.75,0), las=1, bty='l')
i <- 0
data.cex=1.5
corrRTylim <- errRTylim <- c(.35,1.1)
for(qRTs in allqRTs) {
plot.new()
# if(i == 0) mtext(fn, side=3, cex=.66*1.2, font=2, line=1)
if(i == 0) { mtext('RL-DDM A3', side=3, cex=.66*1.2, font=2, line=2); mtext(paste0('BPIC = ', round(sum(tmp$BPIC[,2]))), line=1, cex=.66) }
#if(i == 2) {plot.new(); mtext('RL-fARD', side=3, cex=.66*1.2, font=2, line=1)}
for(cue in c('SPD', 'ACC')) {
i <- i+1
idxD <- qRTs$meanAccByCue[[1]]$cue==cue
idxM <- qRTs$meanAccByCue[[2]]$cue==cue
plotDataPPBins(data=qRTs$meanAccByCue[[1]][idxD,], pp=qRTs$meanAccByCue[[2]][idxM,],
xaxt='n', draw.legend = i==1, data.cex = data.cex,
dep.var='acc', ylab='', xlab = '', yaxt='n',
legend.pos='bottomright', ylim=c(0.5, 0.9), hline.by=0.1)
axis(1, at=seq(2, 10, 2), labels=rep(NA, 5), lwd=2)
if(i == 1) {
mtext('Accuracy', side=2, cex=.66, line=3, las=0, font=1)
axis(2, at=seq(.5, .9, .1), lwd=1.5)
} else {
axis(2, at=seq(.5, .9, .1), labels=rep(NA, 5), lwd=1.5)
}
if(i == 1) title('Speed')
if(i == 2) title('Accuracy')
if(i == 3) title('Speed')
if(i == 4) title('Accuracy')
##
plotDataPPBins(data=qRTs$q10RTsByCue[[1]][idxD,], pp=qRTs$q10RTsByCue[[2]][idxM,], dep.var='RT.10.',
ylim=corrRTylim, xaxt='n', ylab='', yaxt='n', draw.legend = FALSE, data.cex = data.cex)
plotDataPPBins(data=qRTs$q50RTsByCue[[1]][idxD,], pp=qRTs$q50RTsByCue[[2]][idxM,], dep.var='RT.50.', plot.new = FALSE, draw.legend=FALSE, data.cex = data.cex)
plotDataPPBins(data=qRTs$q90RTsByCue[[1]][idxD,], pp=qRTs$q90RTsByCue[[2]][idxM,], dep.var='RT.90.', plot.new = FALSE, draw.legend=FALSE, data.cex = data.cex)
axis(1, at=seq(2, 10, 2), labels=NA, lwd=1.5)
if(i == 1) {
mtext('Correct RTs (s)', side=2, cex=.66, line=3, las=0, font=1)
axis(2, seq(.4, 1.2, .2), lwd=1.5)
} else {
axis(2, seq(.4, 1.2, .2), labels=NA, lwd=1.5)
}
##
plotDataPPBins(data=qRTs$q10RTsByCueE[[1]][idxD,], pp=qRTs$q10RTsByCueE[[2]][idxM,], dep.var='RT.10.', ylim=errRTylim, xaxt='n', ylab='', yaxt='n', draw.legend = FALSE, data.cex = data.cex)
plotDataPPBins(data=qRTs$q50RTsByCueE[[1]][idxD,], pp=qRTs$q50RTsByCueE[[2]][idxM,], dep.var='RT.50.', plot.new = FALSE, draw.legend=FALSE, data.cex = data.cex)
plotDataPPBins(data=qRTs$q90RTsByCueE[[1]][idxD,], pp=qRTs$q90RTsByCueE[[2]][idxM,], dep.var='RT.90.', plot.new = FALSE, draw.legend=FALSE, data.cex = data.cex)
if(i == 1) {
mtext('Error RTs (s)', side=2, cex=.66, line=3, las=0, font=1)
axis(2, seq(.4, 1.2, .2), lwd=1.5)
} else {
axis(2, seq(.4, 1.2, .2), labels=NA, lwd=1.5)
}
axis(1, at=seq(2, 10, 2), lwd=1.5)
mtext('Trial bin', side=1, cex=.66, line=2)
}
}
dev.off()
# Plot double --------------------------------------------------------------------
layoutM <- matrix(1:25, nrow=5, byrow=TRUE)
layoutM[c(1, 5),1:2] <- 1
layoutM[c(1, 5),4:5] <- 9
layoutM[2:4,1:2] <- 2:7 #matrix(c(2:13), nrow=3, byrow=TRUE)
layoutM[2:4,4:5] <- 10:15 #matrix(c(2:13), nrow=3, byrow=TRUE)
layoutM[,3] <- 8
layoutM
layout(layoutM, heights = c(0.01, .8, 1, 1, 0.01), widths=c(1,1)) #,.1,1,1))
par(oma=c(3,4,2,0), mar=c(0, 0, 1, 0.5) + 0.1,
mgp=c(2.75,.75,0), las=1, bty='l')
i <- 0
data.cex=1.5
corrRTylim <- errRTylim <- c(.35,1.1)
for(qRTs in allqRTs) {
plot.new()
if(i == 0) mtext(fn, side=3, cex=.66*1.2, font=2, line=1)
if(i == 2) {plot.new(); mtext('RL-fARD', side=3, cex=.66*1.2, font=2, line=1)}
for(cue in c('SPD', 'ACC')) {
i <- i+1
idxD <- qRTs$meanAccByCue[[1]]$cue==cue
idxM <- qRTs$meanAccByCue[[2]]$cue==cue
plotDataPPBins(data=qRTs$meanAccByCue[[1]][idxD,], pp=qRTs$meanAccByCue[[2]][idxM,],
xaxt='n', draw.legend = i==1, data.cex = data.cex,
dep.var='acc', ylab='', xlab = '', yaxt='n',
legend.pos='bottomright', ylim=c(0.5, 0.9), hline.by=0.1)
axis(1, at=seq(2, 10, 2), labels=rep(NA, 5), lwd=2)
if(i == 1) {
mtext('Accuracy', side=2, cex=.66, line=3, las=0, font=1)
axis(2, at=seq(.5, .9, .1), lwd=1.5)
} else {
axis(2, at=seq(.5, .9, .1), labels=rep(NA, 5), lwd=1.5)
}
if(i == 1) title('Speed')
if(i == 2) title('Accuracy')
if(i == 3) title('Speed')
if(i == 4) title('Accuracy')
##
plotDataPPBins(data=qRTs$q10RTsByCue[[1]][idxD,], pp=qRTs$q10RTsByCue[[2]][idxM,], dep.var='RT.10.',
ylim=corrRTylim, xaxt='n', ylab='', yaxt='n', draw.legend = FALSE, data.cex = data.cex)
plotDataPPBins(data=qRTs$q50RTsByCue[[1]][idxD,], pp=qRTs$q50RTsByCue[[2]][idxM,], dep.var='RT.50.', plot.new = FALSE, draw.legend=FALSE, data.cex = data.cex)
plotDataPPBins(data=qRTs$q90RTsByCue[[1]][idxD,], pp=qRTs$q90RTsByCue[[2]][idxM,], dep.var='RT.90.', plot.new = FALSE, draw.legend=FALSE, data.cex = data.cex)
axis(1, at=seq(2, 10, 2), labels=NA, lwd=1.5)
if(i == 1) {
mtext('Correct RTs (s)', side=2, cex=.66, line=3, las=0, font=1)
axis(2, seq(.4, 1.2, .2), lwd=1.5)
} else {
axis(2, seq(.4, 1.2, .2), labels=NA, lwd=1.5)
}
##
plotDataPPBins(data=qRTs$q10RTsByCueE[[1]][idxD,], pp=qRTs$q10RTsByCueE[[2]][idxM,], dep.var='RT.10.', ylim=errRTylim, xaxt='n', ylab='', yaxt='n', draw.legend = FALSE, data.cex = data.cex)
plotDataPPBins(data=qRTs$q50RTsByCueE[[1]][idxD,], pp=qRTs$q50RTsByCueE[[2]][idxM,], dep.var='RT.50.', plot.new = FALSE, draw.legend=FALSE, data.cex = data.cex)
plotDataPPBins(data=qRTs$q90RTsByCueE[[1]][idxD,], pp=qRTs$q90RTsByCueE[[2]][idxM,], dep.var='RT.90.', plot.new = FALSE, draw.legend=FALSE, data.cex = data.cex)
if(i == 1) {
mtext('Error RTs (s)', side=2, cex=.66, line=3, las=0, font=1)
axis(2, seq(.4, 1.2, .2), lwd=1.5)
} else {
axis(2, seq(.4, 1.2, .2), labels=NA, lwd=1.5)
}
axis(1, at=seq(2, 10, 2), lwd=1.5)
mtext('Trial bin', side=1, cex=.66, line=2)
}
}