-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #76 from sinzlab/revert-74-master
Revert "[WIP] Remove project specific code, Add measures and scores."
- Loading branch information
Showing
13 changed files
with
2,525 additions
and
480 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,251 @@ | ||
import torch | ||
import torch.utils.data as utils | ||
import numpy as np | ||
import pickle | ||
|
||
# These function provide compatibility with the previous data loading logic of monkey V1 Data. | ||
# Individual sessions are no longer identified by a session key for different readouts, | ||
# but all sessions will be in a single loader. This provides backwards compatibility for | ||
# the Divisive Normalization model of Max Burg, and allows for direct comparison to the new way of dataloading as | ||
# a proof of principle for these kinds of models. | ||
|
||
def csrf_v1_legacy(datapath, image_path, batch_size, seed, train_frac=0.8, | ||
subsample=1, crop=65, time_bins_sum=tuple(range(12))): | ||
v1_data = CSRF_V1_Data(raw_data_path=datapath, image_path=image_path, seed=seed, | ||
train_frac=train_frac, subsample=subsample, crop=crop, | ||
time_bins_sum=time_bins_sum) | ||
|
||
images, responses, valid_responses = v1_data.train() | ||
train_loader = get_loader_csrf_V1_legacy(images, responses, 1 * valid_responses, batch_size=batch_size) | ||
|
||
images, responses, valid_responses = v1_data.val() | ||
val_loader = get_loader_csrf_V1_legacy(images, responses, 1 * valid_responses, batch_size=batch_size) | ||
|
||
images, responses, valid_responses = v1_data.test() | ||
test_loader = get_loader_csrf_V1_legacy(images, responses, 1 * valid_responses, batch_size=batch_size, shuffle=False) | ||
|
||
data_loader = dict(train_loader=train_loader, val_loader=val_loader, test_loader=test_loader) | ||
|
||
return data_loader | ||
|
||
|
||
# begin of helper functions | ||
|
||
def get_loader_csrf_V1_legacy(images, responses, valid_responses, batch_size=None, shuffle=True, retina_warp=False): | ||
# Expected Dimension of the Image Tensor is Images x Channels x size_x x size_y | ||
# In some CSRF files, Channels are at Dim4, the image tensor is thus reshaped accordingly | ||
if images.shape[1] > 3: | ||
images = images.transpose((0, 3, 1, 2)) | ||
|
||
if retina_warp: | ||
images = np.array(list(map(warp_image, images[:, 0])))[:, None] | ||
|
||
images = torch.tensor(images).to(torch.float).cuda() | ||
|
||
responses = torch.tensor(responses).cuda().to(torch.float) | ||
valid_responses = torch.tensor(valid_responses).cuda().to(torch.float) | ||
dataset = utils.TensorDataset(images, responses, valid_responses) | ||
data_loader = utils.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) | ||
|
||
return data_loader | ||
|
||
|
||
class CSRF_V1_Data: | ||
"""For use with George's and Kelli's csrf data set.""" | ||
|
||
def __init__(self, raw_data_path, image_path=None, seed=None, train_frac=0.8, | ||
subsample=1, crop=65, time_bins_sum=tuple(range(7))): | ||
""" | ||
Args: | ||
raw_data_path: Path pointing to a pickle file that contains the experimental data. | ||
Not all pickle files of the CSRF dataset contain the image data. | ||
If the images are missing, an image_path argument should be provided. | ||
image_path: Path pointing to a pickle file which should contain the image data | ||
(training and testing images). | ||
seed: Random seed for train val data set split (does not affect order of stimuli... in train val split themselves) | ||
train_frac: Fraction of experiments training data used for model training. | ||
Remaining data serves as validation set | ||
Float Value between 0 and 1 | ||
subsample: Integer value to downsample the input. | ||
Example usage: subsample=1 keeps original resolution | ||
subsample=2 cuts the resolution in half | ||
crop: Integer value to crop stimuli from each side (left, right, bottom, top), before subsampling | ||
time_bins_sum: a tuple which specifies which times bins are included in the analysis. | ||
there are 13 bins (0 to 12), which correspond to 10ms bins from 40 to 160 ms | ||
after stimulus presentation | ||
Exmple usage: (0,1,2,3) will only include the first four time bins into the analysis | ||
""" | ||
# unpacking pickle data | ||
with open(raw_data_path, "rb") as pkl: | ||
raw_data = pickle.load(pkl) | ||
|
||
self._subject_ids = raw_data["subject_ids"] | ||
self._session_ids = raw_data["session_ids"] | ||
self._session_unit_response_link = raw_data["session_unit_response_link"] | ||
self._repetitions_test = raw_data["repetitions_test"] | ||
responses_train = raw_data["responses_train"].astype(np.float32) | ||
self._responses_test = raw_data["responses_test"].astype(np.float32) | ||
|
||
real_responses = np.logical_not(np.isnan(responses_train)) | ||
self._real_responses_test = np.logical_not(np.isnan(self.responses_test)) | ||
|
||
images_test = raw_data['images_test'] | ||
if 'test_image_locator' in raw_data: | ||
test_image_locator = raw_data["test_image_locator"] | ||
|
||
# if an image path is provided, load the images from the corresponding pickle file | ||
if image_path: | ||
with open(image_path, "rb") as pkl: | ||
raw_data = pickle.load(pkl) | ||
|
||
_, h, w = raw_data['images_train'].shape[:3] | ||
images_train = raw_data['images_train'][:, crop:h - crop:subsample, crop:w - crop:subsample] | ||
images_test = raw_data['images_test'][:, crop:h - crop:subsample, crop:w - crop:subsample] | ||
|
||
# z-score all images by mean, and sigma of all images | ||
all_images = np.append(images_train, images_test, axis=0) | ||
img_mean = np.mean(all_images) | ||
img_std = np.std(all_images) | ||
images_train = (images_train - img_mean) / img_std | ||
self._images_test = (images_test - img_mean) / img_std | ||
if 'test_image_locator' in raw_data: | ||
self._images_test = self._images_test[test_image_locator - 1, ::] | ||
# split into train and val set, images randomly assigned | ||
train_split, val_split = self.get_validation_split(real_responses, train_frac, seed) | ||
self._images_train = images_train[train_split] | ||
self._responses_train = responses_train[train_split] | ||
self._real_responses_train = real_responses[train_split] | ||
|
||
self._images_val = images_train[val_split] | ||
self._responses_val = responses_train[val_split] | ||
self._real_responses_val = real_responses[val_split] | ||
|
||
if seed: | ||
np.random.seed(seed) | ||
|
||
self._train_perm = np.random.permutation(self._images_train.shape[0]) | ||
self._val_perm = np.random.permutation(self._images_val.shape[0]) | ||
|
||
if time_bins_sum is not None: # then average over given time bins | ||
self._responses_train = np.sum(self._responses_train[:, :, time_bins_sum], axis=-1) | ||
self._responses_test = np.sum(self._responses_test[:, :, time_bins_sum], axis=-1) | ||
self._responses_val = np.sum(self._responses_val[:, :, time_bins_sum], axis=-1) | ||
|
||
# In real responses: If an entry for any time is False, real_responses is False for all times. | ||
self._real_responses_train = np.all(self._real_responses_train[:, :, time_bins_sum], axis=-1) | ||
self._real_responses_test = np.all(self._real_responses_test[:, :, time_bins_sum], axis=-1) | ||
self._real_responses_val = np.all(self._real_responses_val[:, :, time_bins_sum], axis=-1) | ||
|
||
# in responses, change nan to zero. Then: Use real responses vector for all valid responses | ||
self._responses_train[np.isnan(self._responses_train)] = 0 | ||
self._responses_val[np.isnan(self._responses_val)] = 0 | ||
self._responses_test[np.isnan(self._responses_test)] = 0 | ||
|
||
self._minibatch_idx = 0 | ||
|
||
# getters | ||
@property | ||
def images_train(self): | ||
""" | ||
Returns: | ||
train images in current order (changes every time a new epoch starts) | ||
""" | ||
return np.expand_dims(self._images_train[self._train_perm], -1) | ||
|
||
@property | ||
def responses_train(self): | ||
""" | ||
Returns: | ||
train responses in current order (changes every time a new epoch starts) | ||
""" | ||
return self._responses_train[self._train_perm] | ||
|
||
# legacy property | ||
@property | ||
def real_resps_train(self): | ||
return self._real_responses_train[self._train_perm] | ||
|
||
@property | ||
def real_responses_train(self): | ||
return self._real_responses_train[self._train_perm] | ||
|
||
@property | ||
def images_val(self): | ||
return np.expand_dims(self._images_val, -1) | ||
|
||
@property | ||
def responses_val(self): | ||
return self._responses_val | ||
|
||
@property | ||
def images_test(self): | ||
return np.expand_dims(self._images_test, -1) | ||
|
||
@property | ||
def responses_test(self): | ||
return self._responses_test | ||
|
||
@property | ||
def image_dimensions(self): | ||
return self.images_train.shape[1:3] | ||
|
||
@property | ||
def num_neurons(self): | ||
return self.responses_train.shape[1] | ||
|
||
# methods | ||
def next_epoch(self): | ||
""" | ||
Gets new random index permutation for train set, reset minibatch index. | ||
""" | ||
self._minibatch_idx = 0 | ||
self._train_perm = np.random.permutation(self._train_perm) | ||
|
||
def get_validation_split(self, real_responses_train, train_frac=0.8, seed=None): | ||
""" | ||
Splits the Training Data into the trainset and validation set. | ||
The Validation set should recruit itself from the images that most neurons have seen. | ||
:return: returns permuted indeces for the training and validation set | ||
""" | ||
if seed: | ||
np.random.seed(seed) | ||
|
||
num_images = real_responses_train.shape[0] | ||
Neurons_per_image = np.sum(real_responses_train, axis=1)[:, 0] | ||
Neurons_per_image_sort_idx = np.argsort(Neurons_per_image) | ||
|
||
top_images = Neurons_per_image_sort_idx[-int(np.floor(train_frac / 2 * num_images)):] | ||
val_images_idx = np.random.choice(top_images, int(len(top_images) / 2), replace=False) | ||
|
||
train_idx_filter = np.logical_not(np.isin(Neurons_per_image_sort_idx, val_images_idx)) | ||
train_images_idx = np.random.permutation(Neurons_per_image_sort_idx[train_idx_filter]) | ||
|
||
return train_images_idx, val_images_idx | ||
|
||
# Methods for compatibility with Santiago's code base. | ||
def train(self): | ||
""" | ||
For compatibility with Santiago's code base. | ||
Returns: | ||
images_train, responses_train, real_respsonses_train | ||
""" | ||
|
||
return self.images_train, self.responses_train, self.real_responses_train | ||
|
||
def val(self): | ||
""" | ||
For compatibility with Santiago's code base. | ||
Returns: | ||
images_val, responses_val, real_respsonses_val | ||
""" | ||
|
||
return self.images_val, self.responses_val, self._real_responses_val | ||
|
||
def test(self): | ||
""" | ||
For compatibility with Santiago's code base. | ||
Returns: | ||
images_test, responses_test, real_responses_test | ||
""" | ||
|
||
return self.images_test, self.responses_test, self._real_responses_test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
from collections import OrderedDict | ||
from itertools import zip_longest | ||
import numpy as np | ||
|
||
import torch | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data.sampler import SubsetRandomSampler | ||
|
||
from mlutils.data.datasets import StaticImageSet | ||
from mlutils.data.transforms import Subsample, ToTensor, NeuroNormalizer, AddBehaviorAsChannels | ||
from mlutils.data.samplers import SubsetSequentialSampler | ||
|
||
from ..utility.nn_helpers import set_random_seed | ||
|
||
|
||
def mouse_static_loader(path, batch_size, seed=None, area='V1', layer='L2/3', | ||
tier=None, neuron_ids=None, get_key=False, cuda=True, normalize=True, include_behavior=False, | ||
exclude=None, select_input_channel=None, toy_data=False, **kwargs): | ||
""" | ||
returns a single data | ||
Args: | ||
path (list): list of path(s) for the dataset(s) | ||
batch_size (int): batch size. | ||
seed (int, optional): random seed for images. Defaults to None. | ||
area (str, optional): the visual area. Defaults to 'V1'. | ||
layer (str, optional): the layer from visual area. Defaults to 'L2/3'. | ||
tier (str, optional): tier is a placeholder to specify which set of images to pick for train, val, and test loader. Defaults to None. | ||
neuron_ids (list, optional): select neurons by their ids. neuron_ids and path should be of same length. Defaults to None. | ||
get_key (bool, optional): whether to retun the data key, along with the dataloaders. Defaults to False. | ||
cuda (bool, optional): whether to place the data on gpu or not. Defaults to True. | ||
Returns: | ||
if get_key is False returns a dictionary of dataloaders for one dataset, where the keys are 'train', 'validation', and 'test'. | ||
if get_key is True it also the data_key (as the first output) followed by the dalaoder dictionary. | ||
""" | ||
|
||
dat = StaticImageSet(path, 'images', 'responses', 'behavior') if include_behavior else StaticImageSet(path, 'images', 'responses') | ||
|
||
assert (include_behavior and select_input_channel) is False, "Selecting an Input Channel and Adding Behavior can not both be true" | ||
|
||
if toy_data: | ||
dat.transforms = [ToTensor(cuda)] | ||
else: | ||
# specify condition(s) for sampling neurons. If you want to sample specific neurons define conditions that would effect idx | ||
neuron_ids = neuron_ids if neuron_ids else dat.neurons.unit_ids | ||
conds = ((dat.neurons.area == area) & | ||
(dat.neurons.layer == layer) & | ||
(np.isin(dat.neurons.unit_ids, neuron_ids))) | ||
|
||
idx = np.where(conds)[0] | ||
dat.transforms = [Subsample(idx), ToTensor(cuda)] | ||
if normalize: | ||
dat.transforms.insert(1, NeuroNormalizer(dat, exclude=exclude)) | ||
|
||
if include_behavior: | ||
dat.transforms.insert(0, AddBehaviorAsChannels()) | ||
|
||
if select_input_channel is not None: | ||
dat.transforms.insert(0, SelectInputChannel(select_input_channel)) | ||
|
||
# subsample images | ||
dataloaders = {} | ||
keys = [tier] if tier else ['train', 'validation', 'test'] | ||
for tier in keys: | ||
|
||
if seed is not None: | ||
set_random_seed(seed) | ||
# torch.manual_seed(img_seed) | ||
|
||
# sample images | ||
subset_idx = np.where(dat.tiers == tier)[0] | ||
sampler = SubsetRandomSampler(subset_idx) if tier == 'train' else SubsetSequentialSampler(subset_idx) | ||
|
||
dataloaders[tier] = DataLoader(dat, sampler=sampler, batch_size=batch_size) | ||
|
||
# create the data_key for a specific data path | ||
data_key = path.split('static')[-1].split('.')[0].replace('preproc', '') | ||
return (data_key, dataloaders) if get_key else dataloaders | ||
|
||
|
||
def mouse_static_loaders(paths, batch_size, seed=None, area='V1', layer='L2/3', tier=None, | ||
neuron_ids=None, cuda=True, normalize=False, include_behavior=False, | ||
exclude=None, select_input_channel=None, toy_data=False, **kwargs): | ||
""" | ||
Returns a dictionary of dataloaders (i.e., trainloaders, valloaders, and testloaders) for >= 1 dataset(s). | ||
Args: | ||
paths (list): list of path(s) for the dataset(s) | ||
batch_size (int): batch size. | ||
seed (int, optional): random seed for images. Defaults to None. | ||
area (str, optional): the visual area. Defaults to 'V1'. | ||
layer (str, optional): the layer from visual area. Defaults to 'L2/3'. | ||
tier (str, optional): tier is a placeholder to specify which set of images to pick for train, val, and test loader. Defaults to None. | ||
neuron_ids ([type], optional): select neurons by their ids. Defaults to None. | ||
cuda (bool, optional): whether to place the data on gpu or not. Defaults to True. | ||
Returns: | ||
dict: dictionary of dictionaries where the first level keys are 'train', 'validation', and 'test', and second level keys are data_keys. | ||
""" | ||
|
||
neuron_ids = neuron_ids if neuron_ids is not None else [] | ||
|
||
dls = OrderedDict({}) | ||
keys = [tier] if tier else ['train', 'validation', 'test'] | ||
for key in keys: | ||
dls[key] = OrderedDict({}) | ||
|
||
for path, neuron_id in zip_longest(paths, neuron_ids, fillvalue=None): | ||
data_key, loaders = mouse_static_loader(path, batch_size, seed=seed, | ||
area=area, layer=layer, cuda=cuda, | ||
tier=tier, get_key=True, neuron_ids=neuron_id, | ||
normalize=normalize, include_behavior=include_behavior, | ||
exclude=exclude, select_input_channel=select_input_channel, | ||
toy_data=toy_data) | ||
for k in dls: | ||
dls[k][data_key] = loaders[k] | ||
|
||
return dls |
Oops, something went wrong.