diff --git a/foundation/fnn/compute/data.py b/foundation/fnn/compute/data.py index 82803c5..cdd648d 100644 --- a/foundation/fnn/compute/data.py +++ b/foundation/fnn/compute/data.py @@ -332,7 +332,7 @@ def _trial_traces(self, trial_ids, datatype): # load trials for trial_id in trial_ids: - traces = (recording.ResampledTraces & key & {"trial_id": trial_id}).fetch1("traces") + traces = (recording.ResampledTraces & key & {"trial_id": trial_id}).fetch1("traces") # [time, units] yield transform(traces).astype(np.float32)[:, order] @rowmethod diff --git a/foundation/fnn/compute/visual.py b/foundation/fnn/compute/visual.py index 63ec6a1..9d8e176 100644 --- a/foundation/fnn/compute/visual.py +++ b/foundation/fnn/compute/visual.py @@ -138,3 +138,205 @@ def units(self): correlations.append(cc(unit_targ, unit_pred)) return np.array(correlations) + + +@keys +class VisualResp: + """Visual Mean Response Over Repeated Presentation and Corresponding Stimuli""" + + # @property + # def keys(self): + # return [ + # recording.ScanUnits, + # recording.TrialFilterSet, + # stimulus.VideoSet, + # utility.Resize, + # utility.Resolution, + # utility.Resample, + # utility.Offset, + # utility.Rate, + # utility.Burnin, + # ] + + @property + def keys(self): + return [ + fnn.Model, + recording.TrialFilterSet, + stimulus.VideoSet, + utility.Resize, + utility.Burnin, + utility.Bool.proj(perspective="bool"), + utility.Bool.proj(modulation="bool"), + ] + + @rowproperty + def video_response(self): + """ + Returns + ------- + videos : np.ndarray [videos, time, height, width, channel] + responses : np.ndarray [videos, time, units] + """ + from foundation.recording.compute.visual import VisualTrials + from foundation.utility.response import Correlation + from foundation.stimulus.video import VideoSet + from foundation.fnn.model import Model + from foundation.fnn.data import Data + from foundation.utils.response import Trials, concatenate + from foundation.utils import cuda_enabled + + # load model + model = (Model & self.item).model(device="cuda" if cuda_enabled() else "cpu") + + # load data + data = (Data & self.item).link.compute + + # trial set + trialset = {"trialset_id": data.trialset_id} + + # videos + videos = (VideoSet & self.item).members + videos = videos.fetch("KEY", order_by=videos.primary_key) + + # trials, targets, predictions + trials = [] + targs = [] + preds = [] + + + with cache_rowproperty(): + + for video in tqdm(videos, desc="Videos"): + + # trials + trial_ids = (VisualTrials & trialset & video & self.item).trial_ids + + # no trials for video + if not trial_ids: + logger.warning(f"No trials found for video_id `{video['video_id']}`") + continue + + # stimuli + stimuli = data.trial_stimuli(trial_ids) + + # units + units = data.trial_units(trial_ids) + + # perspectives + if self.item["perspective"]: + perspectives = data.trial_perspectives(trial_ids) + else: + perspectives = repeat(None) + + # modulations + if self.item["modulation"]: + modulations = data.trial_modulations(trial_ids) + else: + modulations = repeat(None) + + # video targets and predictions + _targs = [] + _preds = [] + + for s, p, m, u in zip(stimuli, perspectives, modulations, units): + + # generate prediction + r = model.generate_response(stimuli=s, perspectives=p, modulations=m) + r = np.stack(list(r), axis=0) + + _targs.append(u) + _preds.append(r) + + assert len(trial_ids) == len(_targs) == len(_preds) + + trials.append(trial_ids) + targs.append(_targs) + preds.append(_preds) + + # no trials at all + if not trials: + logger.warning(f"No trials found") + return + + from foundation.stimulus.video import VideoSet, Video + from foundation.utils.video import Video as VideoGenerator, Frame + from foundation.stimulus.resize import ResizedVideo + from foundation.recording.compute.visual import VisualTrials + from foundation.utility.resample import Rate + + # trial set + all_trial_filt = (recording.TrialFilterSet & "not members").proj() + trialset = dict( + trialset_id=( + recording.ScanTrials & (scan.Scan & self.item) & all_trial_filt + ).fetch1("trialset_id") + ) # all trials shown in the scan + + # trace set + traceset = dict(traceset_id=(recording.ScanUnits & self.item).fetch1("traceset_id")) + + # videos + video_ids = (VideoSet & self.item).members + video_ids = video_ids.fetch("KEY", order_by=video_ids.primary_key) + + # unit order + unit_order = ( + merge(recording.TraceSet.Member & traceset, recording.ScanUnitOrder & self.item) + ).fetch("traceset_index", order_by="trace_order") + with cache_rowproperty(): + # resampled frames of visual stimuli + frames = [] + # visual responses + responses = [] + + for video_id in tqdm(video_ids, desc="Videos"): + # trial ids + trial_ids = ( + VisualTrials & trialset & video_id & self.item + ).trial_ids # filtered trials by video and TrialFilterSet + + # no trials for video + if not trial_ids: + logger.warning(f"No trials found for video_id `{video_id['video_id']}`") + continue + + trial_resps = [] + trial_video_index = [] + for trial_id in trial_ids: + # trial responses + trial_resps.append( + ( + recording.ResampledTraces + & dict(trial_id=trial_id) + & traceset + & self.item + ).fetch1("traces") + ) # [time, units], units are ordered by traceset_index + # trial stimuli index + index = ( + ResizedVideo + * recording.TrialVideo + * recording.ResampledTrial + & dict(trial_id=trial_id) + & self.item + ).fetch1("index") + trial_video_index.append(index) + # videos.append(video[index].astype(np.uint8)) + + trial_resps = np.stack( + trial_resps, axis=0 + )[ + ..., unit_order + ] # [trials, time, units], units are ordered by unit_id, trials are ordered by trial start + responses.append(trial_resps.mean(axis=0)) # [time, units] + trial_video_index = np.stack(trial_video_index, axis=0) + assert np.all(trial_video_index == trial_video_index[0]) + trial_video_index = trial_video_index[0] + resized_frames = (ResizedVideo & video_id & self.item).fetch1("video") + frames.append(resized_frames[trial_video_index].astype(np.uint8)) + responses = np.stack(responses, axis=0) # [videos, time, units] + responses = responses[:, self.item["burnin"]:, :] # [videos, time, units] + frames = np.stack(frames, axis=0)[:, self.item["burnin"]:, :] # [videos, time, height, width, channel] + assert responses.shape[0] == frames.shape[0] + return frames, responses diff --git a/foundation/fnn/query/scan.py b/foundation/fnn/query/scan.py index 573fffd..0667a67 100644 --- a/foundation/fnn/query/scan.py +++ b/foundation/fnn/query/scan.py @@ -1,6 +1,6 @@ from djutils import keys, merge from foundation.virtual import utility, stimulus, scan, recording, fnn - +import pandas as pd @keys class VisualScanRecording: @@ -19,36 +19,65 @@ def units(self): return units.proj("trace_order") -# @keys -# class VisualScanCorrelation: -# """Visual Scan Correlation""" +@keys +class VisualScanCorrelation: + """Visual Scan Correlation""" -# @property -# def keys(self): -# return [ -# (scan.Scan * fnn.Model) & fnn.Data.VisualScan, -# recording.TrialFilterSet, -# stimulus.VideoSet, -# utility.Burnin, -# utility.Bool.proj(perspective="bool"), -# utility.Bool.proj(modulation="bool"), -# ] + @property + def keys(self): + return [ + (scan.Scan * fnn.Model) & fnn.Data.VisualScan, + recording.TrialFilterSet, + stimulus.VideoSet, + utility.Burnin, + utility.Bool.proj(perspective="bool"), + utility.Bool.proj(modulation="bool"), + ] -# def cc_norm(self): -# spec = fnn.Data.VisualScan.proj( -# "spec_id", "trace_filterset_id", "pipe_version", "segmentation_method", "spike_method" -# ) -# units = merge( -# self.key, -# spec, -# fnn.Spec.VisualSpec.proj("rate_id", offset_id="offset_id_unit", resample_id="resample_id_unit"), -# (fnn.VisualRecordingCorrelation & utility.Correlation.CCSignal).proj(..., trace_order="unit"), -# recording.ScanUnitOrder, -# recording.Trace.ScanUnit, -# recording.VisualMeasure & utility.Measure.CCMax, -# ) -# return units.proj( -# cc_abs="correlation", -# cc_max="measure", -# cc_norm="correlation/measure", -# ) + def cc_norm(self): + data_scan_spec = fnn.Data.VisualScan.proj( + "spec_id", + "trace_filterset_id", + "pipe_version", + "animal_id", + "session", + "scan_idx", + "segmentation_method", + "spike_method", + ) * fnn.Spec.VisualSpec.proj( + "rate_id", offset_id="offset_id_unit", resample_id="resample_id_unit" + ) + all_unit_trace_rel = ( + self.key + * data_scan_spec # data_id -> specs + scan key + * recording.ScanUnitOrder # scan key + trace_filterset_id -> trace_ids + * recording.Trace.ScanUnit # trace_id -> unit key + ) + all_units_df = all_unit_trace_rel.fetch(format="frame").reset_index() + # fetch cc_max + cc_max = ( + (recording.VisualMeasure & utility.Measure.CCMax & all_unit_trace_rel) + .fetch(format="frame") + .reset_index() + .rename(columns={"measure": "cc_max"}) + ) + # fetch cc_abs + cc_abs_df = pd.DataFrame( + ( + ( + fnn.VisualRecordingCorrelation + & utility.Correlation.CCSignal + ).proj( + ..., trace_order="unit" + ) + & all_unit_trace_rel + ) + .fetch(as_dict=True) # this fetch is very slow + ).reset_index().rename(columns={"correlation": "cc_abs"}) + # compute cc_norm + cc_norm_df = ( + all_units_df.merge(cc_abs_df, how="left", validate="one_to_one") + .merge(cc_max, how="left", validate="one_to_one") + .assign(cc_norm=lambda df: df.cc_abs / df.cc_max) + ) + return cc_norm_df diff --git a/foundation/recording/compute/visual.py b/foundation/recording/compute/visual.py index 46eafe4..65cd5e3 100644 --- a/foundation/recording/compute/visual.py +++ b/foundation/recording/compute/visual.py @@ -1,7 +1,8 @@ import numpy as np -from djutils import keys, merge, rowproperty, cache_rowproperty, MissingError +import pandas as pd +from djutils import keys, merge, rowproperty, cache_rowproperty, MissingError, U from foundation.utils import tqdm, logger -from foundation.virtual import utility, stimulus, recording +from foundation.virtual import utility, stimulus, recording, scan @keys @@ -24,7 +25,13 @@ def trial_ids(self): Tuple[str] tuple of keys (foundation.recording.trial.Trial) -- ordered by trial start time """ - from foundation.recording.trial import Trial, TrialSet, TrialVideo, TrialBounds, TrialFilterSet + from foundation.recording.trial import ( + Trial, + TrialSet, + TrialVideo, + TrialBounds, + TrialFilterSet, + ) # all trials trials = Trial & (TrialSet & self.item).members @@ -70,25 +77,27 @@ def measure(self): from foundation.utils.response import Trials, concatenate # trial set - trialset = (recording.TraceTrials & self.item).fetch1() + trialset = (recording.TraceTrials & self.item).fetch1() # all trials # videos videos = (VideoSet & self.item).members videos = videos.fetch("KEY", order_by=videos.primary_key) with cache_rowproperty(): - # visual responses responses = [] for video in tqdm(videos, desc="Videos"): - # trial ids - trial_ids = (VisualTrials & trialset & video & self.item).trial_ids + trial_ids = ( + VisualTrials & trialset & video & self.item + ).trial_ids # filter trials by TrialFilterSet # no trials for video if not trial_ids: - logger.warning(f"No trials found for video_id `{video['video_id']}`") + logger.warning( + f"No trials found for video_id `{video['video_id']}`" + ) continue # trial responses @@ -108,3 +117,120 @@ def measure(self): # response measure return (Measure & self.item).link.measure(responses) + + +@keys +class VisualTrialResp: + """Visual Mean Response Over Repeated Presentation and Corresponding Stimuli""" + + @property + def keys(self): + return [ + recording.ScanUnits, + recording.TrialFilterSet, + stimulus.VideoSet, + utility.Resize, + utility.Resolution, + utility.Resample, + utility.Offset, + utility.Rate, + utility.Burnin, + ] + + @rowproperty + def video_response(self): + """ + Returns + ------- + videos : list of np.ndarray [time, height, width, channel] per video + responses : list of np.ndarray [trials, time, units] per video + """ + from foundation.stimulus.video import VideoSet, Video + from foundation.utils.video import Video as VideoGenerator, Frame + from foundation.stimulus.resize import ResizedVideo + from foundation.recording.compute.visual import VisualTrials + from foundation.utility.resample import Rate + + # trial set + all_trial_filt = (recording.TrialFilterSet & "not members").proj() + trialset = dict( + trialset_id=( + recording.ScanTrials & (scan.Scan & self.item) & all_trial_filt + ).fetch1("trialset_id") + ) # all trials shown in the scan + + # trace set + traceset = dict( + traceset_id=(recording.ScanUnits & self.item).fetch1("traceset_id") + ) + + # videos + video_ids = (VideoSet & self.item).members + video_ids = video_ids.fetch("KEY", order_by=video_ids.primary_key) + + # unit order + unit_order = ( + merge( + recording.TraceSet.Member & traceset, + recording.ScanUnitOrder & self.item, + ) + ).fetch("traceset_index", order_by="trace_order") + with cache_rowproperty(): + # resampled frames of visual stimuli + frames = [] + # visual responses + responses = [] + + for video_id in tqdm(video_ids, desc="Videos"): + # trial ids + trial_ids = ( + VisualTrials & trialset & video_id & self.item + ).trial_ids # filtered trials by video and TrialFilterSet + + # no trials for video + if not trial_ids: + logger.warning( + f"No trials found for video_id `{video_id['video_id']}`" + ) + continue + + trial_resps = [] + trial_video_index = [] + for trial_id in trial_ids: + # trial responses + trial_resps.append( + ( + recording.ResampledTraces + & dict(trial_id=trial_id) + & traceset + & self.item + ).fetch1("traces") + ) # [time, units], units are ordered by traceset_index + # trial stimuli index + index = ( + ResizedVideo * recording.TrialVideo * recording.ResampledTrial + & dict(trial_id=trial_id) + & self.item + ).fetch1("index") + trial_video_index.append(index) + # videos.append(video[index].astype(np.uint8)) + + trial_resps = np.stack( + trial_resps, axis=0 + )[ + ..., unit_order + ] # [trials, time, units], units are ordered by unit_id, trials are ordered by trial start + # responses.append(trial_resps.mean(axis=0)) # [time, units] + trial_resps = trial_resps[:, self.item["burnin"] :, :] + responses.append(trial_resps) # [trials, time, units] + trial_video_index = np.stack(trial_video_index, axis=0) + assert np.all(trial_video_index == trial_video_index[0]) + trial_video_index = trial_video_index[0] + resized_frames = (ResizedVideo & video_id & self.item).fetch1("video") + frames.append( + resized_frames[trial_video_index].astype(np.uint8)[ + self.item["burnin"] :, ... + ] + ) + assert len(responses) == len(frames) + return frames, responses diff --git a/foundation/recording/dot.py b/foundation/recording/dot.py new file mode 100644 index 0000000..e3a1a08 --- /dev/null +++ b/foundation/recording/dot.py @@ -0,0 +1,28 @@ +import numpy as np +from foundation.virtual import utility, recording, stimulus +from foundation.schemas import recording as schema + + + +@schema.computed +class DotResponses: + definition = """ + -> recording.ScanUnits + -> recording.TrialFilterSet + -> stimulus.VideoSet + -> utility.Offset + --- + dots : blob@external # [dots] + responses : blob@external # [dots, traces], traces are ordered by unit_id + finite : bool # all values finite + """ + + def make(self, key): + from foundation.recording.compute.visual import DotTrialResp + dots, responses = (DotTrialResp & key).dot_response + + # trace values finite + finite = np.isfinite(responses).all() + + # insert + self.insert1(dict(key, dots=dots, responses=responses, finite=bool(finite))) diff --git a/foundation/recording/fill/scan.py b/foundation/recording/fill/scan.py index 900f43a..650ac3b 100644 --- a/foundation/recording/fill/scan.py +++ b/foundation/recording/fill/scan.py @@ -1,6 +1,6 @@ -from djutils import keys, merge +from djutils import keys, merge, U from foundation.virtual.bridge import pipe_fuse, pipe_shared -from foundation.virtual import scan, recording +from foundation.virtual import scan, recording, utility, stimulus class _VisualScanRecording: diff --git a/foundation/recording/fill/visual.py b/foundation/recording/fill/visual.py new file mode 100644 index 0000000..7300bda --- /dev/null +++ b/foundation/recording/fill/visual.py @@ -0,0 +1,39 @@ +from djutils import keys, rowproperty, merge, U +from foundation.virtual import recording, stimulus, utility + +@keys +class ScanVideoType(): + """Fill videoset for a scan with a specific video type""" + @property + def keys(self): + return [ + recording.ScanRecording, + recording.TrialFilterSet, + U('video_type') & stimulus.Video, + ] + + @rowproperty + def videoset(self): + # filter trials + from foundation.recording.trial import ( + Trial, + TrialSet, + TrialVideo, + TrialFilterSet, + ) + from foundation.stimulus.video import VideoSet + # fill videoset_id + all_trialset_id = dict( + trialset_id=(recording.ScanRecording & self.item).fetch1("trialset_id") + ) + # all trials + trials = Trial & (TrialSet & all_trialset_id).members + # filtered trials + trials = (TrialFilterSet & self.item).filter(trials) + # video_ids of the same video_type that are shown in the scan + videos = stimulus.Video & ( + merge(trials, TrialVideo, stimulus.Video) & self.item + ).proj() + # trial ids, ordered by trial start + videoset = VideoSet.fill(videos, prompt=False) + return videoset \ No newline at end of file diff --git a/foundation/recording/trial.py b/foundation/recording/trial.py index d65cc01..688a001 100644 --- a/foundation/recording/trial.py +++ b/foundation/recording/trial.py @@ -142,8 +142,6 @@ def filter(self, trials): key = merge(trials, self, Trial.ScanTrial, scan.PupilNans) & "nans < max_nans" return trials & key.proj() - - # -- Trial Filter -- diff --git a/foundation/tuning/compute/direction.py b/foundation/tuning/compute/direction.py new file mode 100644 index 0000000..e69de29 diff --git a/foundation/tuning/compute/dot.py b/foundation/tuning/compute/dot.py index b977fa9..8a8406e 100644 --- a/foundation/tuning/compute/dot.py +++ b/foundation/tuning/compute/dot.py @@ -1,5 +1,5 @@ from djutils import keys, rowproperty, cache_rowproperty, merge -from foundation.virtual import recording, stimulus, utility, tuning +from foundation.virtual import recording, stimulus, utility import numpy as np import torch import pandas as pd @@ -71,7 +71,7 @@ def dots_responses(self): from foundation.stimulus.compute.video import SquareDotType from foundation.recording.compute.visual import VisualTrials from foundation.recording.trace import TraceSet, Trace - from foundation.recording import scan + from foundation.scan.experiment import Scan from foundation.utility.resample import Offset, Rate from foundation.utility.response import Burnin @@ -79,7 +79,7 @@ def dots_responses(self): all_trial_filt = (recording.TrialFilterSet & "not members").proj() trialset = dict( trialset_id=( - recording.ScanTrials & (scan.Scan & self.item) & all_trial_filt + recording.ScanTrials & (Scan & self.item) & all_trial_filt ).fetch1("trialset_id") ) # all trials shown in the scan diff --git a/foundation/tuning/direction.py b/foundation/tuning/direction.py new file mode 100644 index 0000000..60bac2d --- /dev/null +++ b/foundation/tuning/direction.py @@ -0,0 +1,78 @@ +from djutils import merge, rowproperty, rowmethod +from foundation.virtual import fnn, recording +from foundation.schemas import tuning as schema + + +# ------------------------------------ DirResp ------------------------------------ + +# -- DirResp Interface -- + + +class DirRespType: + """Tuning Direction""" + + @rowproperty + def compute(self): + """ + Returns + ------- + directions : pd.DataFrame, rows are stimulus.compute.video.direction + responses : np.ndarray [directions, units] + """ + raise NotImplementedError() + + +# -- DirResp Types -- + + +@schema.lookup +class RecordingDir(DirRespType): + definition = """ + -> recording.ScanUnits + -> recording.TrialFilterSet + -> stimulus.VideoSet + -> utility.Resize + -> utility.Resolution + -> utility.Resample + -> utility.Offset + -> utility.Rate + -> utility.Burnin + -> utility.Offset.proj(tuning_offset_id="offset_id") + """ + + @rowproperty + def compute(self): + from foundation.tuning.compute.direction import RecordingDir + return RecordingDir & self + + +# -- DirResp -- + + +@schema.link +class DirResp: + links = [RecordingDir] + name = "dir_resp" + comment = "responses to directional stimuli" + + +# -- Computed Directional Tuning -- +@schema.computed +class DirTuning: + definition = """ + -> DirResp + --- + direction : longblob # [directions] + tuning : longblob # [directions, traces] + """ + + @rowmethod + def compute(self): + from foundation.tuning.compute.direction import DirTuning + return DirTuning & self + + +# ------------------------------------ BiVonMisesFit ------------------------------------ +@schema.computed +class BiVonMisesFit: + pass \ No newline at end of file diff --git a/foundation/tuning/dot.py b/foundation/tuning/dot.py index 89c2aa8..36407fd 100644 --- a/foundation/tuning/dot.py +++ b/foundation/tuning/dot.py @@ -10,12 +10,12 @@ class DotOnOff: on : bool # True for on, False for off """ -# ---------------------------- DotResponse ---------------------------- +# ---------------------------- DotResp ---------------------------- -# -- DotResponse Interface -- +# -- DotResp Interface -- -class DotResponseType: +class DotRespType: """Tuning Dot""" @rowproperty @@ -29,11 +29,11 @@ def compute(self): raise NotImplementedError() -# -- DotResponse Types -- +# -- DotResp Types -- @schema.lookup -class RecordingDot(DotResponseType): +class RecordingDot(DotRespType): definition = """ -> recording.ScanUnits -> recording.TrialFilterSet @@ -53,21 +53,21 @@ def compute(self): return RecordingDot & self -# -- DotResponse -- +# -- DotResp -- @schema.link -class DotResponse: +class DotResp: links = [RecordingDot] - name = "dot_response" - comment = "dot response" + name = "dot_resp" + comment = "responses to single dot stimuli" # -- Computed Dot STA -- @schema.computed class DotSta: definition = """ - -> DotResponse + -> DotResp -> DotOnOff --- sta : longblob # [height, width, traces], traces are ordered by unit_id diff --git a/foundation/utility/tuning.py b/foundation/utility/tuning.py new file mode 100644 index 0000000..d80d85b --- /dev/null +++ b/foundation/utility/tuning.py @@ -0,0 +1,3 @@ +import numpy as np +from djutils import rowproperty +from foundation.schemas import utility as schema diff --git a/foundation/virtual/__init__.py b/foundation/virtual/__init__.py index f7d5306..ab3fd03 100644 --- a/foundation/virtual/__init__.py +++ b/foundation/virtual/__init__.py @@ -5,4 +5,4 @@ scan = create_virtual_module("scan", "foundation_scan") recording = create_virtual_module("recording", "foundation_recording") fnn = create_virtual_module("fnn", "foundation_fnn") -tuning = create_virtual_module("tuning", "foundation_tuning") \ No newline at end of file +# tuning = create_virtual_module("tuning", "foundation_tuning") \ No newline at end of file