Skip to content

Commit

Permalink
save changes made to populate STAs
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhuokunDing committed Feb 16, 2024
1 parent e2a3d3c commit 0e0ebde
Show file tree
Hide file tree
Showing 14 changed files with 562 additions and 59 deletions.
2 changes: 1 addition & 1 deletion foundation/fnn/compute/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def _trial_traces(self, trial_ids, datatype):
# load trials
for trial_id in trial_ids:

traces = (recording.ResampledTraces & key & {"trial_id": trial_id}).fetch1("traces")
traces = (recording.ResampledTraces & key & {"trial_id": trial_id}).fetch1("traces") # [time, units]
yield transform(traces).astype(np.float32)[:, order]

@rowmethod
Expand Down
202 changes: 202 additions & 0 deletions foundation/fnn/compute/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,205 @@ def units(self):
correlations.append(cc(unit_targ, unit_pred))

return np.array(correlations)


@keys
class VisualResp:
"""Visual Mean Response Over Repeated Presentation and Corresponding Stimuli"""

# @property
# def keys(self):
# return [
# recording.ScanUnits,
# recording.TrialFilterSet,
# stimulus.VideoSet,
# utility.Resize,
# utility.Resolution,
# utility.Resample,
# utility.Offset,
# utility.Rate,
# utility.Burnin,
# ]

@property
def keys(self):
return [
fnn.Model,
recording.TrialFilterSet,
stimulus.VideoSet,
utility.Resize,
utility.Burnin,
utility.Bool.proj(perspective="bool"),
utility.Bool.proj(modulation="bool"),
]

@rowproperty
def video_response(self):
"""
Returns
-------
videos : np.ndarray [videos, time, height, width, channel]
responses : np.ndarray [videos, time, units]
"""
from foundation.recording.compute.visual import VisualTrials
from foundation.utility.response import Correlation
from foundation.stimulus.video import VideoSet
from foundation.fnn.model import Model
from foundation.fnn.data import Data
from foundation.utils.response import Trials, concatenate
from foundation.utils import cuda_enabled

# load model
model = (Model & self.item).model(device="cuda" if cuda_enabled() else "cpu")

# load data
data = (Data & self.item).link.compute

# trial set
trialset = {"trialset_id": data.trialset_id}

# videos
videos = (VideoSet & self.item).members
videos = videos.fetch("KEY", order_by=videos.primary_key)

# trials, targets, predictions
trials = []
targs = []
preds = []


with cache_rowproperty():

for video in tqdm(videos, desc="Videos"):

# trials
trial_ids = (VisualTrials & trialset & video & self.item).trial_ids

# no trials for video
if not trial_ids:
logger.warning(f"No trials found for video_id `{video['video_id']}`")
continue

# stimuli
stimuli = data.trial_stimuli(trial_ids)

# units
units = data.trial_units(trial_ids)

# perspectives
if self.item["perspective"]:
perspectives = data.trial_perspectives(trial_ids)
else:
perspectives = repeat(None)

# modulations
if self.item["modulation"]:
modulations = data.trial_modulations(trial_ids)
else:
modulations = repeat(None)

# video targets and predictions
_targs = []
_preds = []

for s, p, m, u in zip(stimuli, perspectives, modulations, units):

# generate prediction
r = model.generate_response(stimuli=s, perspectives=p, modulations=m)
r = np.stack(list(r), axis=0)

_targs.append(u)
_preds.append(r)

assert len(trial_ids) == len(_targs) == len(_preds)

trials.append(trial_ids)
targs.append(_targs)
preds.append(_preds)

# no trials at all
if not trials:
logger.warning(f"No trials found")
return

from foundation.stimulus.video import VideoSet, Video
from foundation.utils.video import Video as VideoGenerator, Frame
from foundation.stimulus.resize import ResizedVideo
from foundation.recording.compute.visual import VisualTrials
from foundation.utility.resample import Rate

# trial set
all_trial_filt = (recording.TrialFilterSet & "not members").proj()
trialset = dict(
trialset_id=(
recording.ScanTrials & (scan.Scan & self.item) & all_trial_filt
).fetch1("trialset_id")
) # all trials shown in the scan

# trace set
traceset = dict(traceset_id=(recording.ScanUnits & self.item).fetch1("traceset_id"))

# videos
video_ids = (VideoSet & self.item).members
video_ids = video_ids.fetch("KEY", order_by=video_ids.primary_key)

# unit order
unit_order = (
merge(recording.TraceSet.Member & traceset, recording.ScanUnitOrder & self.item)
).fetch("traceset_index", order_by="trace_order")
with cache_rowproperty():
# resampled frames of visual stimuli
frames = []
# visual responses
responses = []

for video_id in tqdm(video_ids, desc="Videos"):
# trial ids
trial_ids = (
VisualTrials & trialset & video_id & self.item
).trial_ids # filtered trials by video and TrialFilterSet

# no trials for video
if not trial_ids:
logger.warning(f"No trials found for video_id `{video_id['video_id']}`")
continue

trial_resps = []
trial_video_index = []
for trial_id in trial_ids:
# trial responses
trial_resps.append(
(
recording.ResampledTraces
& dict(trial_id=trial_id)
& traceset
& self.item
).fetch1("traces")
) # [time, units], units are ordered by traceset_index
# trial stimuli index
index = (
ResizedVideo
* recording.TrialVideo
* recording.ResampledTrial
& dict(trial_id=trial_id)
& self.item
).fetch1("index")
trial_video_index.append(index)
# videos.append(video[index].astype(np.uint8))

trial_resps = np.stack(
trial_resps, axis=0
)[
..., unit_order
] # [trials, time, units], units are ordered by unit_id, trials are ordered by trial start
responses.append(trial_resps.mean(axis=0)) # [time, units]
trial_video_index = np.stack(trial_video_index, axis=0)
assert np.all(trial_video_index == trial_video_index[0])
trial_video_index = trial_video_index[0]
resized_frames = (ResizedVideo & video_id & self.item).fetch1("video")
frames.append(resized_frames[trial_video_index].astype(np.uint8))
responses = np.stack(responses, axis=0) # [videos, time, units]
responses = responses[:, self.item["burnin"]:, :] # [videos, time, units]
frames = np.stack(frames, axis=0)[:, self.item["burnin"]:, :] # [videos, time, height, width, channel]
assert responses.shape[0] == frames.shape[0]
return frames, responses
93 changes: 61 additions & 32 deletions foundation/fnn/query/scan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from djutils import keys, merge
from foundation.virtual import utility, stimulus, scan, recording, fnn

import pandas as pd

@keys
class VisualScanRecording:
Expand All @@ -19,36 +19,65 @@ def units(self):
return units.proj("trace_order")


# @keys
# class VisualScanCorrelation:
# """Visual Scan Correlation"""
@keys
class VisualScanCorrelation:
"""Visual Scan Correlation"""

# @property
# def keys(self):
# return [
# (scan.Scan * fnn.Model) & fnn.Data.VisualScan,
# recording.TrialFilterSet,
# stimulus.VideoSet,
# utility.Burnin,
# utility.Bool.proj(perspective="bool"),
# utility.Bool.proj(modulation="bool"),
# ]
@property
def keys(self):
return [
(scan.Scan * fnn.Model) & fnn.Data.VisualScan,
recording.TrialFilterSet,
stimulus.VideoSet,
utility.Burnin,
utility.Bool.proj(perspective="bool"),
utility.Bool.proj(modulation="bool"),
]

# def cc_norm(self):
# spec = fnn.Data.VisualScan.proj(
# "spec_id", "trace_filterset_id", "pipe_version", "segmentation_method", "spike_method"
# )
# units = merge(
# self.key,
# spec,
# fnn.Spec.VisualSpec.proj("rate_id", offset_id="offset_id_unit", resample_id="resample_id_unit"),
# (fnn.VisualRecordingCorrelation & utility.Correlation.CCSignal).proj(..., trace_order="unit"),
# recording.ScanUnitOrder,
# recording.Trace.ScanUnit,
# recording.VisualMeasure & utility.Measure.CCMax,
# )
# return units.proj(
# cc_abs="correlation",
# cc_max="measure",
# cc_norm="correlation/measure",
# )
def cc_norm(self):
data_scan_spec = fnn.Data.VisualScan.proj(
"spec_id",
"trace_filterset_id",
"pipe_version",
"animal_id",
"session",
"scan_idx",
"segmentation_method",
"spike_method",
) * fnn.Spec.VisualSpec.proj(
"rate_id", offset_id="offset_id_unit", resample_id="resample_id_unit"
)
all_unit_trace_rel = (
self.key
* data_scan_spec # data_id -> specs + scan key
* recording.ScanUnitOrder # scan key + trace_filterset_id -> trace_ids
* recording.Trace.ScanUnit # trace_id -> unit key
)
all_units_df = all_unit_trace_rel.fetch(format="frame").reset_index()
# fetch cc_max
cc_max = (
(recording.VisualMeasure & utility.Measure.CCMax & all_unit_trace_rel)
.fetch(format="frame")
.reset_index()
.rename(columns={"measure": "cc_max"})
)
# fetch cc_abs
cc_abs_df = pd.DataFrame(
(
(
fnn.VisualRecordingCorrelation
& utility.Correlation.CCSignal
).proj(
..., trace_order="unit"
)
& all_unit_trace_rel
)
.fetch(as_dict=True) # this fetch is very slow
).reset_index().rename(columns={"correlation": "cc_abs"})
# compute cc_norm
cc_norm_df = (
all_units_df.merge(cc_abs_df, how="left", validate="one_to_one")
.merge(cc_max, how="left", validate="one_to_one")
.assign(cc_norm=lambda df: df.cc_abs / df.cc_max)
)
return cc_norm_df
Loading

0 comments on commit 0e0ebde

Please sign in to comment.