-
Notifications
You must be signed in to change notification settings - Fork 0
/
excludeTrialsMonkeyLogic.m
141 lines (124 loc) · 5.23 KB
/
excludeTrialsMonkeyLogic.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
function [ taskDataValid ] = excludeTrialsMonkeyLogic( taskData, params )
%excludeStimuli removes stim based on
% - fix spot flashes within flashPre before or flashPost after (msec)
% - broken fixation within fixPre before or fixPost after (msec)
% - juice delivery within juicePre before or juicePost after (msec)
% - acceleration above thresh. within accelPre before or accelPost after
% - photodiode vs. digital trigger alignment worse than threshold (worse
% defined as difference from median value)
%
% params has fields:
% - fixPre, fixPost, flashPre, flashPost (in ms)
% and optional fields:
% - juicePre, juicePost (in ms)
% - accel1, accel2 (structs with fields data (1d timeseries, 1 ks/sec) and threshold);
% - maxEventTimeAdjustmentDeviation (in ms)
% - minStimDuration (ms)
% todo: exclude stimuli shorter than minStimDuration (for arrythmic runs)
if isfield(params,'needExcludeTrials') && ~params.needExcludeTrials
taskDataValid = taskData;
return
end
fixPre = params.fixPre;
fixPost = params.fixPost;
flashPre = params.flashPre;
flashPost = params.flashPost;
if isfield(params,'ephysDuration')
ephysDuration = params.ephysDuration;
else
ephysDuration = max(taskData.taskEventStartTimes);
disp('No ephysDuration supplied to excludeParams; using final task event start time as duration lower bound.');
end
if isfield(params, 'juicePre') && isfield(params, 'juicePost')
juicePre = params.juicePre;
juicePost = params.juicePost;
end
if isfield(params, 'accel1')
accel1 = params.accel1;
end
if isfield(params, 'accel2')
accel2 = params.accel2;
end
if isfield(params,'minStimDur')
minStimDur = params.minStimDur;
end
if isfield(params, 'maxEventTimeAdjustmentDeviation') && isfield(taskData,'eventTimeAdjustments')
maxEventTimeAdjustmentDeviation = params.maxEventTimeAdjustmentDeviation;
eventTimeAdjustmentDeviations = abs(taskData.eventTimeAdjustments - median(taskData.eventTimeAdjustments));
end
if isfield(params, 'ephysDataPre')
ephysDataPre = params.ephysDataPre;
else
ephysDataPre = fixPre;
end
if isfield(params, 'ephysDataPost')
ephysDataPost = params.ephysDataPost;
else
ephysDataPost = fixPost;
end
trialValid = zeros(length(taskData.taskEventIDs),1);
%initialize array of 0s, add numbers to the same Ntrials*1 array. Non-zeros
%will be excluded at the end.
%exclude failed trials
if isfield(params, 'excludeFailed') && params.excludeFailed
trialValid = taskData.errorArray + trialValid;
end
%exclude trials where too many frames are dropped;
if isfield(params, 'frameDropThreshold')
trialValid = trialValid + floor(taskData.stimFramesLost/params.frameDropThreshold);
end
%exclude the x trial(s) after a failed trial
if isfield(params, 'excludeAfterFailed') && ~(params.excludeAfterFailed == 0)
errorIndVector = find(trialValid);
ind2Add = zeros(length(errorIndVector),params.excludeAfterFailed);
for ii = 1:params.excludeAfterFailed
ind2Add(:,ii) = errorIndVector + ii;
end
ind2Add = unique(ind2Add);
ind2Add = ind2Add(ind2Add <= length(trialValid));
%Add to the indices calculated
trialValid(ind2Add) = trialValid(ind2Add) + 1;
end
trialValid = (trialValid == 0);
%
fprintf('Percent of trials excluded: %d%%\n', round((1-sum(trialValid)/length(trialValid))*100))
fprintf('Number of trials kept: %d\n', sum(trialValid))
trialValid = logical(trialValid);
taskDataValid = struct;
if isfield(taskData, 'eventData')
taskDataValid.eventData = taskData.eventData;
end
% Not per trial, can be passed as is.
taskDataValid.recDepth = taskData.recDepth;
taskDataValid.gridHole = taskData.gridHole;
taskDataValid.taskDataSummary = taskData.taskDataSummary;
taskDataValid.taskEventList = taskData.taskEventList;
taskDataValid.frameMotionData = taskData.frameMotionData;
taskDataValid.RFmap = taskData.RFmap;
taskDataValid.eyeCal = taskData.eyeCal;
taskDataValid.runTime = taskData.runTime;
taskDataValid.paradigm = taskData.paradigm;
taskDataValid.rewardParadigm = taskData.rewardParadigm;
taskDataValid.scaleFactor = taskData.scaleFactor;
% Per trial values, which need to be changed
taskDataValid.fixTime = taskData.fixTime(trialValid);
taskDataValid.taskEventIDs = taskData.taskEventIDs(trialValid);
taskDataValid.taskEventIDsMerged = taskData.taskEventIDsMerged(trialValid);
taskDataValid.stimFramesLost = taskData.stimFramesLost(trialValid);
taskDataValid.taskEventStartTimes = taskData.taskEventStartTimes(trialValid);
taskDataValid.taskEventEndTimes = taskData.taskEventEndTimes(trialValid);
taskDataValid.taskEventStartTimes = taskData.taskEventStartTimes(trialValid);
taskDataValid.juiceOnTimes = taskData.juiceOnTimes(trialValid);
taskDataValid.juiceOffTimes = taskData.juiceOffTimes(trialValid);
taskDataValid.taskEventFixDur = taskData.taskEventFixDur(trialValid);
taskDataValid.rewardTimePerTrial = taskData.rewardTimePerTrial(trialValid);
if params.DEBUG
figure();
hold on
plot(taskDataValid.taskEventStartTimes,ones(size(taskDataValid.taskEventStartTimes)),'color','red','marker','o', 'linestyle','none');
plot(taskData.juiceOnTimes, 3*ones(size(taskData.juiceOnTimes)),'color','red','marker','o', 'linestyle','none');
plot(taskData.juiceOffTimes, 3*ones(size(taskData.juiceOffTimes)),'color','green','marker','o', 'linestyle','none');
plot([0,0],[-20 20],'marker','none');
hold off
end
end