From fa95d409ef67589f075d8193a8ca5c71b56aa89d Mon Sep 17 00:00:00 2001 From: Konstantin Willeke Date: Wed, 17 Jun 2020 16:12:33 +0200 Subject: [PATCH] Revert "[WIP] Remove project specific code, Add measures and scores." --- nnfabrik/datasets/csrf_legacy_loaders.py | 251 +++++++ nnfabrik/datasets/mouse.py | 120 ++++ nnfabrik/datasets/movies.py | 71 ++ nnfabrik/datasets/sysident_v1_dataset.py | 533 ++++++++++++++ nnfabrik/measures/__init__.py | 0 nnfabrik/measures/measure_helpers.py | 26 - nnfabrik/measures/measures.py | 396 ----------- nnfabrik/models/gaussian_readout_models.py | 784 +++++++++++++++++++++ nnfabrik/models/pretrained_models.py | 92 +++ nnfabrik/models/v1_models.py | 241 +++++++ nnfabrik/template.py | 13 +- nnfabrik/training/trainers.py | 350 +++++++++ notebooks/nnfabrik_monkey_demo.ipynb | 128 ++-- 13 files changed, 2525 insertions(+), 480 deletions(-) create mode 100644 nnfabrik/datasets/csrf_legacy_loaders.py create mode 100644 nnfabrik/datasets/mouse.py create mode 100644 nnfabrik/datasets/movies.py create mode 100644 nnfabrik/datasets/sysident_v1_dataset.py delete mode 100644 nnfabrik/measures/__init__.py delete mode 100644 nnfabrik/measures/measure_helpers.py delete mode 100644 nnfabrik/measures/measures.py create mode 100644 nnfabrik/models/gaussian_readout_models.py create mode 100644 nnfabrik/models/pretrained_models.py create mode 100644 nnfabrik/models/v1_models.py create mode 100644 nnfabrik/training/trainers.py diff --git a/nnfabrik/datasets/csrf_legacy_loaders.py b/nnfabrik/datasets/csrf_legacy_loaders.py new file mode 100644 index 00000000..e3ac069c --- /dev/null +++ b/nnfabrik/datasets/csrf_legacy_loaders.py @@ -0,0 +1,251 @@ +import torch +import torch.utils.data as utils +import numpy as np +import pickle + +# These function provide compatibility with the previous data loading logic of monkey V1 Data. +# Individual sessions are no longer identified by a session key for different readouts, +# but all sessions will be in a single loader. This provides backwards compatibility for +# the Divisive Normalization model of Max Burg, and allows for direct comparison to the new way of dataloading as +# a proof of principle for these kinds of models. + +def csrf_v1_legacy(datapath, image_path, batch_size, seed, train_frac=0.8, + subsample=1, crop=65, time_bins_sum=tuple(range(12))): + v1_data = CSRF_V1_Data(raw_data_path=datapath, image_path=image_path, seed=seed, + train_frac=train_frac, subsample=subsample, crop=crop, + time_bins_sum=time_bins_sum) + + images, responses, valid_responses = v1_data.train() + train_loader = get_loader_csrf_V1_legacy(images, responses, 1 * valid_responses, batch_size=batch_size) + + images, responses, valid_responses = v1_data.val() + val_loader = get_loader_csrf_V1_legacy(images, responses, 1 * valid_responses, batch_size=batch_size) + + images, responses, valid_responses = v1_data.test() + test_loader = get_loader_csrf_V1_legacy(images, responses, 1 * valid_responses, batch_size=batch_size, shuffle=False) + + data_loader = dict(train_loader=train_loader, val_loader=val_loader, test_loader=test_loader) + + return data_loader + + +# begin of helper functions + +def get_loader_csrf_V1_legacy(images, responses, valid_responses, batch_size=None, shuffle=True, retina_warp=False): + # Expected Dimension of the Image Tensor is Images x Channels x size_x x size_y + # In some CSRF files, Channels are at Dim4, the image tensor is thus reshaped accordingly + if images.shape[1] > 3: + images = images.transpose((0, 3, 1, 2)) + + if retina_warp: + images = np.array(list(map(warp_image, images[:, 0])))[:, None] + + images = torch.tensor(images).to(torch.float).cuda() + + responses = torch.tensor(responses).cuda().to(torch.float) + valid_responses = torch.tensor(valid_responses).cuda().to(torch.float) + dataset = utils.TensorDataset(images, responses, valid_responses) + data_loader = utils.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) + + return data_loader + + +class CSRF_V1_Data: + """For use with George's and Kelli's csrf data set.""" + + def __init__(self, raw_data_path, image_path=None, seed=None, train_frac=0.8, + subsample=1, crop=65, time_bins_sum=tuple(range(7))): + """ + Args: + raw_data_path: Path pointing to a pickle file that contains the experimental data. + Not all pickle files of the CSRF dataset contain the image data. + If the images are missing, an image_path argument should be provided. + image_path: Path pointing to a pickle file which should contain the image data + (training and testing images). + seed: Random seed for train val data set split (does not affect order of stimuli... in train val split themselves) + train_frac: Fraction of experiments training data used for model training. + Remaining data serves as validation set + Float Value between 0 and 1 + subsample: Integer value to downsample the input. + Example usage: subsample=1 keeps original resolution + subsample=2 cuts the resolution in half + crop: Integer value to crop stimuli from each side (left, right, bottom, top), before subsampling + time_bins_sum: a tuple which specifies which times bins are included in the analysis. + there are 13 bins (0 to 12), which correspond to 10ms bins from 40 to 160 ms + after stimulus presentation + Exmple usage: (0,1,2,3) will only include the first four time bins into the analysis + """ + # unpacking pickle data + with open(raw_data_path, "rb") as pkl: + raw_data = pickle.load(pkl) + + self._subject_ids = raw_data["subject_ids"] + self._session_ids = raw_data["session_ids"] + self._session_unit_response_link = raw_data["session_unit_response_link"] + self._repetitions_test = raw_data["repetitions_test"] + responses_train = raw_data["responses_train"].astype(np.float32) + self._responses_test = raw_data["responses_test"].astype(np.float32) + + real_responses = np.logical_not(np.isnan(responses_train)) + self._real_responses_test = np.logical_not(np.isnan(self.responses_test)) + + images_test = raw_data['images_test'] + if 'test_image_locator' in raw_data: + test_image_locator = raw_data["test_image_locator"] + + # if an image path is provided, load the images from the corresponding pickle file + if image_path: + with open(image_path, "rb") as pkl: + raw_data = pickle.load(pkl) + + _, h, w = raw_data['images_train'].shape[:3] + images_train = raw_data['images_train'][:, crop:h - crop:subsample, crop:w - crop:subsample] + images_test = raw_data['images_test'][:, crop:h - crop:subsample, crop:w - crop:subsample] + + # z-score all images by mean, and sigma of all images + all_images = np.append(images_train, images_test, axis=0) + img_mean = np.mean(all_images) + img_std = np.std(all_images) + images_train = (images_train - img_mean) / img_std + self._images_test = (images_test - img_mean) / img_std + if 'test_image_locator' in raw_data: + self._images_test = self._images_test[test_image_locator - 1, ::] + # split into train and val set, images randomly assigned + train_split, val_split = self.get_validation_split(real_responses, train_frac, seed) + self._images_train = images_train[train_split] + self._responses_train = responses_train[train_split] + self._real_responses_train = real_responses[train_split] + + self._images_val = images_train[val_split] + self._responses_val = responses_train[val_split] + self._real_responses_val = real_responses[val_split] + + if seed: + np.random.seed(seed) + + self._train_perm = np.random.permutation(self._images_train.shape[0]) + self._val_perm = np.random.permutation(self._images_val.shape[0]) + + if time_bins_sum is not None: # then average over given time bins + self._responses_train = np.sum(self._responses_train[:, :, time_bins_sum], axis=-1) + self._responses_test = np.sum(self._responses_test[:, :, time_bins_sum], axis=-1) + self._responses_val = np.sum(self._responses_val[:, :, time_bins_sum], axis=-1) + + # In real responses: If an entry for any time is False, real_responses is False for all times. + self._real_responses_train = np.all(self._real_responses_train[:, :, time_bins_sum], axis=-1) + self._real_responses_test = np.all(self._real_responses_test[:, :, time_bins_sum], axis=-1) + self._real_responses_val = np.all(self._real_responses_val[:, :, time_bins_sum], axis=-1) + + # in responses, change nan to zero. Then: Use real responses vector for all valid responses + self._responses_train[np.isnan(self._responses_train)] = 0 + self._responses_val[np.isnan(self._responses_val)] = 0 + self._responses_test[np.isnan(self._responses_test)] = 0 + + self._minibatch_idx = 0 + + # getters + @property + def images_train(self): + """ + Returns: + train images in current order (changes every time a new epoch starts) + """ + return np.expand_dims(self._images_train[self._train_perm], -1) + + @property + def responses_train(self): + """ + Returns: + train responses in current order (changes every time a new epoch starts) + """ + return self._responses_train[self._train_perm] + + # legacy property + @property + def real_resps_train(self): + return self._real_responses_train[self._train_perm] + + @property + def real_responses_train(self): + return self._real_responses_train[self._train_perm] + + @property + def images_val(self): + return np.expand_dims(self._images_val, -1) + + @property + def responses_val(self): + return self._responses_val + + @property + def images_test(self): + return np.expand_dims(self._images_test, -1) + + @property + def responses_test(self): + return self._responses_test + + @property + def image_dimensions(self): + return self.images_train.shape[1:3] + + @property + def num_neurons(self): + return self.responses_train.shape[1] + + # methods + def next_epoch(self): + """ + Gets new random index permutation for train set, reset minibatch index. + """ + self._minibatch_idx = 0 + self._train_perm = np.random.permutation(self._train_perm) + + def get_validation_split(self, real_responses_train, train_frac=0.8, seed=None): + """ + Splits the Training Data into the trainset and validation set. + The Validation set should recruit itself from the images that most neurons have seen. + :return: returns permuted indeces for the training and validation set + """ + if seed: + np.random.seed(seed) + + num_images = real_responses_train.shape[0] + Neurons_per_image = np.sum(real_responses_train, axis=1)[:, 0] + Neurons_per_image_sort_idx = np.argsort(Neurons_per_image) + + top_images = Neurons_per_image_sort_idx[-int(np.floor(train_frac / 2 * num_images)):] + val_images_idx = np.random.choice(top_images, int(len(top_images) / 2), replace=False) + + train_idx_filter = np.logical_not(np.isin(Neurons_per_image_sort_idx, val_images_idx)) + train_images_idx = np.random.permutation(Neurons_per_image_sort_idx[train_idx_filter]) + + return train_images_idx, val_images_idx + + # Methods for compatibility with Santiago's code base. + def train(self): + """ + For compatibility with Santiago's code base. + Returns: + images_train, responses_train, real_respsonses_train + """ + + return self.images_train, self.responses_train, self.real_responses_train + + def val(self): + """ + For compatibility with Santiago's code base. + Returns: + images_val, responses_val, real_respsonses_val + """ + + return self.images_val, self.responses_val, self._real_responses_val + + def test(self): + """ + For compatibility with Santiago's code base. + Returns: + images_test, responses_test, real_responses_test + """ + + return self.images_test, self.responses_test, self._real_responses_test \ No newline at end of file diff --git a/nnfabrik/datasets/mouse.py b/nnfabrik/datasets/mouse.py new file mode 100644 index 00000000..410b436a --- /dev/null +++ b/nnfabrik/datasets/mouse.py @@ -0,0 +1,120 @@ +from collections import OrderedDict +from itertools import zip_longest +import numpy as np + +import torch +from torch.utils.data import DataLoader +from torch.utils.data.sampler import SubsetRandomSampler + +from mlutils.data.datasets import StaticImageSet +from mlutils.data.transforms import Subsample, ToTensor, NeuroNormalizer, AddBehaviorAsChannels +from mlutils.data.samplers import SubsetSequentialSampler + +from ..utility.nn_helpers import set_random_seed + + +def mouse_static_loader(path, batch_size, seed=None, area='V1', layer='L2/3', + tier=None, neuron_ids=None, get_key=False, cuda=True, normalize=True, include_behavior=False, + exclude=None, select_input_channel=None, toy_data=False, **kwargs): + """ + returns a single data + + Args: + path (list): list of path(s) for the dataset(s) + batch_size (int): batch size. + seed (int, optional): random seed for images. Defaults to None. + area (str, optional): the visual area. Defaults to 'V1'. + layer (str, optional): the layer from visual area. Defaults to 'L2/3'. + tier (str, optional): tier is a placeholder to specify which set of images to pick for train, val, and test loader. Defaults to None. + neuron_ids (list, optional): select neurons by their ids. neuron_ids and path should be of same length. Defaults to None. + get_key (bool, optional): whether to retun the data key, along with the dataloaders. Defaults to False. + cuda (bool, optional): whether to place the data on gpu or not. Defaults to True. + + Returns: + if get_key is False returns a dictionary of dataloaders for one dataset, where the keys are 'train', 'validation', and 'test'. + if get_key is True it also the data_key (as the first output) followed by the dalaoder dictionary. + + """ + + dat = StaticImageSet(path, 'images', 'responses', 'behavior') if include_behavior else StaticImageSet(path, 'images', 'responses') + + assert (include_behavior and select_input_channel) is False, "Selecting an Input Channel and Adding Behavior can not both be true" + + if toy_data: + dat.transforms = [ToTensor(cuda)] + else: + # specify condition(s) for sampling neurons. If you want to sample specific neurons define conditions that would effect idx + neuron_ids = neuron_ids if neuron_ids else dat.neurons.unit_ids + conds = ((dat.neurons.area == area) & + (dat.neurons.layer == layer) & + (np.isin(dat.neurons.unit_ids, neuron_ids))) + + idx = np.where(conds)[0] + dat.transforms = [Subsample(idx), ToTensor(cuda)] + if normalize: + dat.transforms.insert(1, NeuroNormalizer(dat, exclude=exclude)) + + if include_behavior: + dat.transforms.insert(0, AddBehaviorAsChannels()) + + if select_input_channel is not None: + dat.transforms.insert(0, SelectInputChannel(select_input_channel)) + + # subsample images + dataloaders = {} + keys = [tier] if tier else ['train', 'validation', 'test'] + for tier in keys: + + if seed is not None: + set_random_seed(seed) + # torch.manual_seed(img_seed) + + # sample images + subset_idx = np.where(dat.tiers == tier)[0] + sampler = SubsetRandomSampler(subset_idx) if tier == 'train' else SubsetSequentialSampler(subset_idx) + + dataloaders[tier] = DataLoader(dat, sampler=sampler, batch_size=batch_size) + + # create the data_key for a specific data path + data_key = path.split('static')[-1].split('.')[0].replace('preproc', '') + return (data_key, dataloaders) if get_key else dataloaders + + +def mouse_static_loaders(paths, batch_size, seed=None, area='V1', layer='L2/3', tier=None, + neuron_ids=None, cuda=True, normalize=False, include_behavior=False, + exclude=None, select_input_channel=None, toy_data=False, **kwargs): + """ + Returns a dictionary of dataloaders (i.e., trainloaders, valloaders, and testloaders) for >= 1 dataset(s). + + Args: + paths (list): list of path(s) for the dataset(s) + batch_size (int): batch size. + seed (int, optional): random seed for images. Defaults to None. + area (str, optional): the visual area. Defaults to 'V1'. + layer (str, optional): the layer from visual area. Defaults to 'L2/3'. + tier (str, optional): tier is a placeholder to specify which set of images to pick for train, val, and test loader. Defaults to None. + neuron_ids ([type], optional): select neurons by their ids. Defaults to None. + cuda (bool, optional): whether to place the data on gpu or not. Defaults to True. + + Returns: + dict: dictionary of dictionaries where the first level keys are 'train', 'validation', and 'test', and second level keys are data_keys. + """ + + neuron_ids = neuron_ids if neuron_ids is not None else [] + + dls = OrderedDict({}) + keys = [tier] if tier else ['train', 'validation', 'test'] + for key in keys: + dls[key] = OrderedDict({}) + + for path, neuron_id in zip_longest(paths, neuron_ids, fillvalue=None): + data_key, loaders = mouse_static_loader(path, batch_size, seed=seed, + area=area, layer=layer, cuda=cuda, + tier=tier, get_key=True, neuron_ids=neuron_id, + normalize=normalize, include_behavior=include_behavior, + exclude=exclude, select_input_channel=select_input_channel, + toy_data=toy_data) + for k in dls: + dls[k][data_key] = loaders[k] + + return dls diff --git a/nnfabrik/datasets/movies.py b/nnfabrik/datasets/movies.py new file mode 100644 index 00000000..e93e2d1d --- /dev/null +++ b/nnfabrik/datasets/movies.py @@ -0,0 +1,71 @@ +# Mouse Movie Datasets +import torch +from mlutils.data.datasets import MovieSet +from mlutils.data.transforms import Subsequence, Subsample, Normalizer, ToTensor +from torch.utils.data.sampler import SubsetRandomSampler +from torch.utils.data import DataLoader +import numpy as np + + +def load_movie_dataset( + data_path, batch_size, stats_source="all", seq_len=30, area="V1", layer="L2/3", normalize=False, tier="train" +): + + field_names = ["inputs", "behavior", "eye_position", "responses"] + + # load the dataset + dataset = MovieSet(data_path, *field_names) + + # configure the statistics source + dataset.stats_source = stats_source + + transforms = [] + + # configure the sequence length + transforms.append(Subsequence(seq_len)) + + # whether to add normalizer + if normalize: + transforms.append(Normalizer(dataset)) + + transforms.append(ToTensor(cuda=True)) + + # subselect to areas & layer + areas = dataset.neurons.area + layers = dataset.neurons.layer + idx = np.where((areas == area) & (layers == layer))[0] + + # place the area & layer subsampler at the very beginning + transforms.insert(-1, Subsample(idx)) + + dataset.transforms = transforms + + idx = np.where(dataset.tiers == tier)[0] + sampler = SubsetRandomSampler(idx) + + # create and return the data loader + return DataLoader(dataset, sampler=sampler, batch_size=batch_size) + + +def load_movie_set( + data_path, batch_size, stats_source="all", seq_len=30, area="V1", layer="L2/3", normalize=False, tiers_map=None +): + if tiers_map is None: + tiers_map = {"train_loader": "train", "val_loader": "validation", "test_loader": "test"} + + data_loaders = {} + + for key, tier in tiers_map.items(): + print("Packaging data loader for {tier}".format(tier=tier)) + data_loaders[key] = load_movie_dataset( + data_path, + batch_size, + stats_source=stats_source, + seq_len=seq_len, + area=area, + layer=layer, + normalize=normalize, + tier=tier, + ) + + return data_loaders diff --git a/nnfabrik/datasets/sysident_v1_dataset.py b/nnfabrik/datasets/sysident_v1_dataset.py new file mode 100644 index 00000000..68e9dfa9 --- /dev/null +++ b/nnfabrik/datasets/sysident_v1_dataset.py @@ -0,0 +1,533 @@ +import torch +import torch.utils.data as utils +import numpy as np +import pickle +#from retina.retina import warp_image +from collections import namedtuple, Iterable +import os +from mlutils.data.samplers import RepeatsBatchSampler + + +class ImageCache: + """ + A simple cache which loads images into memory given a path to the directory where the images are stored. + Images need to be present as 2D .npy arrays + """ + + def __init__(self, path=None, subsample=1, crop=0, img_mean=None, img_std=None, filename_precision=6): + """ + + path: str - pointing to the directory, where the individual .npy files are present + subsample: int - amount of downsampling + crop: the expected input is a list of tuples, the specify the exact cropping from all four sides + i.e. [(crop_left, crop_right), (crop_top, crop_down)] + img_mean: - mean luminance across all images + img_std: - std of the luminance across all images + leading_zeros: - amount leading zeros of the files in the specified folder + """ + self.cache = {} + self.path = path + self.subsample = subsample + self.crop = crop + self.img_mean = img_mean + self.img_std = img_std + self.leading_zeros = filename_precision + + def __len__(self): + return len([file for file in os.listdir(self.path) if file.endswith('.npy')]) + + def __contains__(self, key): + return key in self.cache + + def __getitem__(self, item): + return [self[i] for i in item] if isinstance(item, Iterable) else self.update(item) + + def update(self, key): + if key in self.cache: + return self.cache[key] + else: + filename = os.path.join(self.path, str(key).zfill(self.leading_zeros) + '.npy') + image = np.load(filename) + transformed_image = self.transform_image(image) + self.cache[key] = transformed_image + return transformed_image + + def transform_image(self, image): + """ + applies transformations to the image: downsampling and cropping, z-scoring, and dimension expansion. + """ + h, w = image.shape + image = image[self.crop[0][0]:h - self.crop[0][1]:self.subsample, self.crop[1][0]:w - self.crop[1][1]:self.subsample] + image = (image - self.img_mean) / self.img_std + image = image[None,] + return torch.tensor(image).to(torch.float) + + @property + def cache_size(self): + return len(self.cache) + + +class CachedTensorDataset(utils.Dataset): + """ + Dataset wrapping tensors. + + Each sample will be retrieved by indexing tensors along the first dimension. + + Arguments: + *tensors (Tensor): tensors that have the same size of the first dimension. + """ + + def __init__(self, *tensors, names=('inputs', 'targets'), image_cache=None): + if not all(tensors[0].size(0) == tensor.size(0) for tensor in tensors): + raise ValueError('The tensors of the dataset have unequal lenghts. The first dim of all tensors has to match exactly.') + if not len(tensors) == len(names): + raise ValueError('Number of tensors and names provided have to match. If there are more than two tensors,' + 'names have to be passed to the TensorDataset') + self.tensors = tensors + self.input_position = names.index("inputs") + self.DataPoint = namedtuple('DataPoint', names) + self.image_cache = image_cache + + def __getitem__(self, index): + """ + retrieves the inputs (= tensors[0]) from the image cache. If the image ID is not present in the cache, + the cache is updated to load the corresponding image into memory. + """ + if type(index) == int: + key = self.tensors[0][index].item() + else: + key = self.tensors[0][index].numpy().astype(np.int32) + + tensors_expanded = [tensor[index] if pos != self.input_position else torch.stack(list(self.image_cache[key])) + for pos, tensor in enumerate(self.tensors)] + + return self.DataPoint(*tensors_expanded) + + def __len__(self): + return self.tensors[0].size(0) + + +def get_cached_loader(image_ids, responses, batch_size, shuffle=True, image_cache=None, repeat_condition=None): + """ + + Args: + image_ids: an array of image IDs + responses: Numpy Array, Dimensions: N_images x Neurons + batch_size: int - batch size for the dataloader + shuffle: Boolean, shuffles image in the dataloader if True + image_cache: a cache object which stores the images + + Returns: a PyTorch DataLoader object + """ + + image_ids = torch.tensor(image_ids.astype(np.int32)) + responses = torch.tensor(responses).to(torch.float) + dataset = CachedTensorDataset(image_ids, responses, image_cache=image_cache) + sampler = RepeatsBatchSampler(torch.tensor(repeat_condition.astype(np.int32))) if repeat_condition is not None else None + + return utils.DataLoader(dataset, + batch_size=batch_size, + shuffle=shuffle, + sampler=sampler) + +def monkey_static_loader(dataset, + neuronal_data_files, + imagepath, + cached_images_path, + batch_size=64, + seed=None, + train_frac=0.8, + subsample=1, + crop=96, + time_bins_sum=12, + avg=False): + """ + Function that returns cached dataloaders for the Center Surround Visual Field Experiments. + Data recorded by George and Kelli at BCM, Houston. + + creates a nested dictionary of dataloaders in the format + {'train' : dict_of_loaders, + 'validation' : dict_of_loaders, + 'test' : dict_of_loaders, } + + in each dict_of_loaders, there will be one dataloader per data-key (refers to a unique session ID) + with the format: + {'data-key1': torch.utils.data.DataLoader, + 'data-key2': torch.utils.data.DataLoader, ... } + + required inputs is a list of datafiles specified as a full path, together with a full path + to a file that contains all the actually images + + Args: + dataset: a string, identifying the Dataset: + 'V1', 'CSRF_V1', 'CSRF_V4' + datafiles: a list paths that point to pickle files + imagepath: a path that points to the image files + batch_size: int - batch size of the dataloaders + seed: int - random seed, to calculate the random split + train_frac: ratio of train/validation images + subsample: int - downsampling factor + crop: int or tuple - crops x pixels from each side. Example: Input image of 100x100, crop=10 => Resulting img = 80x80. + if crop is tuple, the expected input is a list of tuples, the specify the exact cropping from all four sides + i.e. [(crop_left, crop_right), (crop_top, crop_bottom)] + time_bins_sum: sums the responses over x time bins. + avg: Boolean - Sums oder Averages the responses across bins. + + Returns: nested dictionary of dataloaders + """ + + # initialize dataloaders as empty dict + dataloaders = {'train': {}, 'validation': {}, 'test': {}} + + if not isinstance(time_bins_sum, Iterable): + time_bins_sum = tuple(range(time_bins_sum)) + + if imagepath: + with open(imagepath, "rb") as pkl: + images = pickle.load(pkl) + + images = images[:, :, :, None] + _, h, w = images.shape[:3] + + if isinstance(crop, int): + crop = [(crop, crop), (crop, crop)] + + images_cropped = images[:, crop[0][0]:h - crop[0][1]:subsample, crop[1][0]:w - crop[1][1]:subsample, :] + img_mean = np.mean(images_cropped) + img_std = np.std(images_cropped) + + # set up parameters for the different dataset types + if dataset == 'V1': + # for the "Amadeus V1" Dataset, recorded by Santiago Cadena, there was no specified test set. + # instead, the last 20% of the dataset were classified as test set. To make sure that the test set + # of this dataset will always stay identical, the `train_test_split` value is hardcoded here. + train_test_split = 0.8 + image_id_offset = 1 + else: + train_test_split = 1 + image_id_offset = 0 + + all_train_ids, all_validation_ids = get_validation_split(n_images=images.shape[0] * train_test_split, + train_frac=train_frac, + seed=seed) + + # Initialize the Image Cache class + cache = ImageCache(path=cached_images_path, subsample=subsample, crop=crop, img_mean=img_mean, img_std=img_std) + + # cycling through all datafiles to fill the dataloaders with an entry per session + for i, datapath in enumerate(neuronal_data_files): + + with open(datapath, "rb") as pkl: + raw_data = pickle.load(pkl) + + subject_ids = raw_data["subject_id"] + data_key = str(raw_data["session_id"]) + responses_train = raw_data["training_responses"].astype(np.float32) + responses_test = raw_data["testing_responses"].astype(np.float32) + training_image_ids = raw_data["training_image_ids"] - image_id_offset + testing_image_ids = raw_data["testing_image_ids"] - image_id_offset + + if dataset != 'V1': + responses_test = responses_test.transpose((2, 0, 1)) + responses_train = responses_train.transpose((2, 0, 1)) + + if time_bins_sum is not None: # then average over given time bins + responses_train = (np.mean if avg else np.sum)(responses_train[:, :, time_bins_sum], axis=-1) + responses_test = (np.mean if avg else np.sum)(responses_test[:, :, time_bins_sum], axis=-1) + + train_idx = np.isin(training_image_ids, all_train_ids) + val_idx = np.isin(training_image_ids, all_validation_ids) + + responses_val = responses_train[val_idx] + responses_train = responses_train[train_idx] + + validation_image_ids = training_image_ids[val_idx] + training_image_ids = training_image_ids[train_idx] + + train_loader = get_cached_loader(training_image_ids, responses_train, batch_size=batch_size, image_cache=cache) + val_loader = get_cached_loader(validation_image_ids, responses_val, batch_size=batch_size, image_cache=cache) + test_loader = get_cached_loader(testing_image_ids, responses_test, batch_size=1, shuffle=False, + image_cache=cache, repeat_condition=testing_image_ids) + + dataloaders["train"][data_key] = train_loader + dataloaders["validation"][data_key] = val_loader + dataloaders["test"][data_key] = test_loader + + return dataloaders + + + +class NamedTensorDataset(utils.Dataset): + """ + Dataset wrapping tensors. + + Each sample will be retrieved by indexing tensors along the first dimension. + + Arguments: + *tensors (Tensor): tensors that have the same size of the first dimension. + """ + def __init__(self, *tensors, names=('inputs','targets')): + if not all(tensors[0].size(0) == tensor.size(0) for tensor in tensors): + raise ValueError( + 'The tensors of the dataset have unequal lenghts. The first dim of all tensors has to match exactly.') + if not len(tensors) == len(names): + raise ValueError('Number of tensors and names provided have to match. If there are more than two tensors,' + 'names have to be passed to the TensorDataset') + self.tensors = tensors + self.DataPoint = namedtuple('DataPoint', names) + + def __getitem__(self, index): + return self.DataPoint(*[tensor[index] for tensor in self.tensors]) + + def __len__(self): + return self.tensors[0].size(0) + + +def csrf_v1(datafiles, imagepath, batch_size, seed, + train_frac=0.8, subsample=1, crop=65, + time_bins_sum=tuple(range(12)), avg=False, + crop_h=None, crop_w=None): + """ + Function that returns the dataloaders for the Center Surround Visual Field V1 Experiment. + Data recorded by George and Kelli at BCM, Houston. + + creates a nested dictionary of dataloaders in the format + {'train' : dict_of_loaders, + 'validation' : dict_of_loaders, + 'test' : dict_of_loaders, } + + in each dict_of_loaders, there will be one dataloader per data-key (refers to a unique session ID) + with the format: + {'data-key1': torch.utils.data.DataLoader, + 'data-key2': torch.utils.data.DataLoader, ... } + + required inputs is a list of datafiles specified as a full path, together with a full path + to a file that contains all the actually images + + Args: + datafiles: a list paths that point to pickle files + imagepath: a path that points to the image files + batch_size: int - batch size of the dataloaders + seed: int - random seed, to calculate the random split + train_frac: ratio of train/validation images + subsample: int - downsampling factor + crop: int - crops x pixels from each side. Example: Input image of 100x100, crop=10 => Resulting img = 80x80 + time_bins_sum: sums the responses over x time bins. + avg: Boolean - Sums oder Averages the responses across bins. + + Returns: nested dictionary of dataloaders + """ + + # + if not isinstance(time_bins_sum, Iterable): + time_bins_sum = tuple(range(time_bins_sum)) + + # initialize dataloaders as empty dict + dataloaders = {'train': {}, 'validation': {}, 'test': {}} + + if imagepath: + with open(imagepath, "rb") as pkl: + images = pickle.load(pkl) + + images = images[:, :, :, None] + _, h, w = images.shape[:3] + + if crop_h is None and crop_w is None: + images_cropped = images[:, crop:h - crop:subsample, crop:w - crop:subsample, :] + else: + images_cropped = images[:, crop_h[0]:h - crop_h[1]:subsample, crop_w[0]:w - crop_w[1]:subsample, :] + + img_mean = np.mean(images_cropped) + img_std = np.std(images_cropped) + + all_train_ids, all_validation_ids = get_validation_split(n_images=images.shape[0], + train_frac=train_frac, + seed=seed) + + # cycling through all datafiles to fill the dataloaders with an entry per session + for i, datapath in enumerate(datafiles): + + #Extract Session ID from the pickle filename + + with open(datapath, "rb") as pkl: + raw_data = pickle.load(pkl) + + # additional information related to session and animal. Has to find its way into datajoint + subject_ids = raw_data["subject_id"] + data_key = str(raw_data["session_id"]) + repetitions_test = raw_data["testing_repetitions"] + + responses_train = raw_data["training_responses"].astype(np.float32) + responses_test = raw_data["testing_responses"].astype(np.float32) + training_image_ids = raw_data["training_image_ids"] + testing_image_ids = raw_data["testing_image_ids"] + + responses_test = responses_test.transpose((2, 0, 1)) + responses_train = responses_train.transpose((2, 0, 1)) + + # images_train = images[training_image_ids, crop:h - crop:subsample, crop:w - crop:subsample] + # images_test = images[testing_image_ids, crop:h - crop:subsample, crop:w - crop:subsample] + + images_train = images_cropped[training_image_ids] + images_test = images_cropped[testing_image_ids] + + images_train = (images_train - img_mean) / img_std + images_test = (images_test - img_mean) / img_std + + if time_bins_sum is not None: # then average over given time bins + responses_train = (np.mean if avg else np.sum)(responses_train[:, :, time_bins_sum], axis=-1) + responses_test = (np.mean if avg else np.sum)(responses_test[:, :, time_bins_sum], axis=-1) + + train_idx = np.isin(training_image_ids, all_train_ids) + val_idx = np.isin(training_image_ids, all_validation_ids) + + images_val = images_train[val_idx] + images_train = images_train[train_idx] + responses_val = responses_train[val_idx] + responses_train = responses_train[train_idx] + + train_loader = get_loader_csrf_v1(images_train, responses_train, batch_size=batch_size) + val_loader = get_loader_csrf_v1(images_val, responses_val, batch_size=batch_size) + test_loader = get_loader_csrf_v1(images_test, responses_test, batch_size=batch_size, shuffle=False) + + dataloaders["train"][data_key] = train_loader + dataloaders["validation"][data_key] = val_loader + dataloaders["test"][data_key] = test_loader + + return dataloaders + + +def sysident_v1(datafiles, imagepath, batch_size, seed, + train_frac=0.8, subsample=2, crop=30): + """ + Function that returns the dataloaders for the SysIdent V1 Experiment. + Data recorded by Stantiago Cadena at BCM, Houston. + + creates a nested dictionary of dataloaders in the format + {'train' : dict_of_loaders, + 'val' : dict_of_loaders, + 'test' : dict_of_loaders, } + + in each dict_of_loaders, there will be one dataloader per data-key (refers to a unique session ID) + with the format: + {'data-key1': torch.utils.data.DataLoader, + 'data-key2': torch.utils.data.DataLoader, ... } + + Args: + datafiles: a list paths that point to pickle files + imagepath: a path that points to the image files + batch_size: int - batch size of the dataloaders + seed: int - random seed, to calculate the random split + train_frac: ratio of train/validation images + subsample: int - downsampling factor + crop: int - crops x pixels from each side. Example: Input image of 100x100, crop=10 => Resulting img = 80x80 + + Returns: nested dictionary of dataloaders + """ + + # initialize dataloaders as empty dict + dataloaders = {'train': {}, 'validation': {}, 'test': {}} + + if imagepath: + with open(imagepath, "rb") as pkl: + images = pickle.load(pkl) + + images = images[:, :, :, None] + _, h, w = images.shape[:3] + images_cropped = images[:, crop:h - crop:subsample, crop:w - crop:subsample, :] + img_mean = np.mean(images_cropped) + img_std = np.std(images_cropped) + + # hard Coded Parameter used in the amadeus.pickle file + n_train_images = int(images.shape[0]*0.8) + + all_train_ids, all_validation_ids = get_validation_split(n_images=n_train_images, + train_frac=train_frac, + seed=seed) + # cycling through all datafiles to fill the dataloaders with an entry per session + for i, datapath in enumerate(datafiles): + + with open(datapath, "rb") as pkl: + raw_data = pickle.load(pkl) + + # additional information related to session and animal. Has to find its way into datajoint + subject_ids = raw_data["subject_id"] + session_id = raw_data["session_id"] + + data_key = str(session_id) + + responses_train = raw_data["training_responses"].astype(np.float32) + responses_test = raw_data["testing_responses"].astype(np.float32) + + # for proper indexing, IDs have to start from zero + training_image_ids = raw_data["training_image_ids"] - 1 + testing_image_ids = raw_data["testing_image_ids"] - 1 + + images_train = images[training_image_ids, crop:h - crop:subsample, crop:w - crop:subsample] + images_test = images[testing_image_ids, crop:h - crop:subsample, crop:w - crop:subsample] + images_train = (images_train - img_mean) / img_std + images_test = (images_test - img_mean) / img_std + + train_idx = np.isin(training_image_ids, all_train_ids) + val_idx = np.isin(training_image_ids, all_validation_ids) + + images_val = images_train[val_idx] + images_train = images_train[train_idx] + responses_val = responses_train[val_idx] + responses_train = responses_train[train_idx] + + train_loader = get_loader_csrf_v1(images_train, responses_train, batch_size=batch_size) + val_loader = get_loader_csrf_v1(images_val, responses_val, batch_size=batch_size) + test_loader = get_loader_csrf_v1(images_test, responses_test, batch_size=batch_size, shuffle=False) + + dataloaders["train"][data_key] = train_loader + dataloaders["validation"][data_key] = val_loader + dataloaders["test"][data_key] = test_loader + + return dataloaders + + +def get_validation_split(n_images, train_frac, seed): + """ + Splits the total number of images into train and test set. + This ensures that in every session, the same train and validation images are being used. + + Args: + n_images: Total number of images. These will be plit into train and validation set + train_frac: fraction of images used for the training set + seed: random seed + + Returns: Two arrays, containing image IDs of the whole imageset, split into train and validation + + """ + if seed: np.random.seed(seed) + train_idx, val_idx = np.split(np.random.permutation(int(n_images)), [int(n_images*train_frac)]) + assert not np.any(np.isin(train_idx, val_idx)), "train_set and val_set are overlapping sets" + + return train_idx, val_idx + + +def get_loader_csrf_v1(images, responses, batch_size, shuffle=True): + """ + Args: + images: Numpy Array of Images, Dimensions: N x C x W x H + responses: Numpy Array, Dimensions: N_images x Neurons + batch_size: int - batch size for the dataloader + shuffle: Boolean, shuffles image in the dataloader if True + + Returns: a PyTorch DataLoader object + + """ + + # Expected Dimension of the Image Tensor is Images x Channels x size_x x size_y + # In some CSRF files, Channels are at Dim4, the image tensor is thus reshaped accordingly + if images.shape[1] > 3: + images = images.transpose((0, 3, 1, 2)) + + images = torch.tensor(images).to(torch.float) + responses = torch.tensor(responses).to(torch.float) + + dataset = NamedTensorDataset(images, responses) + data_loader = utils.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) + + return data_loader \ No newline at end of file diff --git a/nnfabrik/measures/__init__.py b/nnfabrik/measures/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/nnfabrik/measures/measure_helpers.py b/nnfabrik/measures/measure_helpers.py deleted file mode 100644 index 7bc99a44..00000000 --- a/nnfabrik/measures/measure_helpers.py +++ /dev/null @@ -1,26 +0,0 @@ -import warnings -import numpy as np -import types - - -def get_subset_of_repeats(outputs, repeat_limit, randomize=True): - """ - Args: - outputs (array or list): repeated responses/targets to the same input. with the shape [inputs, ] [reps, neurons] - or array(inputs, reps, neurons) - repeat_limit (int): how many reps are selected - randomize (cool): if True, takes a random selection of repetitions. if false, takes the first n repetitions. - - Returns: limited_outputs (list): same shape as inputs, but with reduced number of repetitions - - """ - limited_output = [] - for repetitions in outputs: - n_repeats = repetitions.shape[0] - limited_output.append(repetitions[:repeat_limit, ] if not randomize else repetitions[ - np.random.choice(n_repeats, repeat_limit if repeat_limit < n_repeats else n_repeats, replace=False)]) - return limited_output - - -def is_ensemble_function(model): - return (isinstance(model, types.FunctionType)) \ No newline at end of file diff --git a/nnfabrik/measures/measures.py b/nnfabrik/measures/measures.py deleted file mode 100644 index f5c38125..00000000 --- a/nnfabrik/measures/measures.py +++ /dev/null @@ -1,396 +0,0 @@ -import warnings -import numpy as np -import torch -from mlutils.measures import corr -from mlutils.training import eval_state, device_state -import types -import contextlib -import warnings -from .measure_helpers import get_subset_of_repeats, is_ensemble_function - - -def model_predictions_repeats(model, dataloader, data_key, device='cuda', broadcast_to_target=False): - """ - Computes model predictions for a dataloader that yields batches with identical inputs along the first dimension. - Unique inputs will be forwarded only once through the model - Returns: - target: ground truth, i.e. neuronal firing rates of the neurons as a list: [num_images][num_reaps, num_neurons] - output: responses as predicted by the network for the unique images. If broadcast_to_target, returns repeated - outputs of shape [num_images][num_reaps, num_neurons] else (default) returns unique outputs of shape [num_images, num_neurons] - """ - - target = [] - unique_images = torch.empty(0) - for images, responses in dataloader: - if len(images.shape) == 5: - images = images.squeeze(dim=0) - responses = responses.squeeze(dim=0) - - assert torch.all(torch.eq(images[-1,], images[0,],)), "All images in the batch should be equal" - unique_images = torch.cat((unique_images, images[0:1, ]), dim=0) - target.append(responses.detach().cpu().numpy()) - - # Forward unique images once: - with eval_state(model) if not is_ensemble_function(model) else contextlib.nullcontext(): - with device_state(model, device) if not is_ensemble_function(model) else contextlib.nullcontext(): - output = model(unique_images.to(device), data_key=data_key).detach().cpu() - - output = output.numpy() - - if broadcast_to_target: - output = [np.broadcast_to(x, target[idx].shape) for idx, x in enumerate(output)] - - return target, output - - -def model_predictions(model, dataloader, data_key, device='cpu'): - """ - computes model predictions for a given dataloader and a model - Returns: - target: ground truth, i.e. neuronal firing rates of the neurons - output: responses as predicted by the network - """ - - target, output = torch.empty(0), torch.empty(0) - for images, responses in dataloader: - if len(images.shape) == 5: - images = images.squeeze(dim=0) - responses = responses.squeeze(dim=0) - with torch.no_grad(): - with device_state(model, device) if not is_ensemble_function(model) else contextlib.nullcontext(): - output = torch.cat((output, (model(images.to(device), data_key=data_key).detach().cpu())), dim=0) - target = torch.cat((target, responses.detach().cpu()), dim=0) - - return target.numpy(), output.numpy() - - -def get_avg_correlations(model, dataloaders, device='cpu', as_dict=False, per_neuron=True, **kwargs): - """ - Returns correlation between model outputs and average responses over repeated trials - - """ - if 'test' in dataloaders: - dataloaders = dataloaders['test'] - - correlations = {} - for k, loader in dataloaders.items(): - - # Compute correlation with average targets - target, output = model_predictions_repeats(dataloader=loader, model=model, data_key=k, device=device, broadcast_to_target=False) - target_mean = np.array([t.mean(axis=0) for t in target]) - correlations[k] = corr(target_mean, output, axis=0) - - # Check for nans - if np.any(np.isnan(correlations[k])): - warnings.warn('{}% NaNs , NaNs will be set to Zero.'.format(np.isnan(correlations[k]).mean() * 100)) - correlations[k][np.isnan(correlations[k])] = 0 - - if not as_dict: - correlations = np.hstack([v for v in correlations.values()]) if per_neuron else np.mean(np.hstack([v for v in correlations.values()])) - return correlations - - -def get_correlations(model, dataloaders, device='cpu', as_dict=False, per_neuron=True, **kwargs): - correlations = {} - with eval_state(model) if not is_ensemble_function(model) else contextlib.nullcontext(): - for k, v in dataloaders.items(): - target, output = model_predictions(dataloader=v, model=model, data_key=k, device=device) - correlations[k] = corr(target, output, axis=0) - - if np.any(np.isnan(correlations[k])): - warnings.warn('{}% NaNs , NaNs will be set to Zero.'.format(np.isnan(correlations[k]).mean() * 100)) - correlations[k][np.isnan(correlations[k])] = 0 - - if not as_dict: - correlations = np.hstack([v for v in correlations.values()]) if per_neuron else np.mean(np.hstack([v for v in correlations.values()])) - return correlations - - -def get_poisson_loss(model, dataloaders, device='cpu', as_dict=False, avg=False, per_neuron=True, eps=1e-12): - poisson_loss = {} - with eval_state(model) if not is_ensemble_function(model) else contextlib.nullcontext(): - for k, v in dataloaders.items(): - target, output = model_predictions(dataloader=v, model=model, data_key=k, device=device) - loss = output - target * np.log(output + eps) - poisson_loss[k] = np.mean(loss, axis=0) if avg else np.sum(loss, axis=0) - if as_dict: - return poisson_loss - else: - if per_neuron: - return np.hstack([v for v in poisson_loss.values()]) - else: - return np.mean(np.hstack([v for v in poisson_loss.values()])) if avg else np.sum(np.hstack([v for v in poisson_loss.values()])) - - -def get_repeats(dataloader, min_repeats=2): - # save the responses of all neuron to the repeats of an image as an element in a list - repeated_inputs = [] - repeated_outputs = [] - for inputs, outputs in dataloader: - if len(inputs.shape) == 5: - inputs = np.squeeze(inputs.cpu().numpy(), axis=0) - outputs = np.squeeze(outputs.cpu().numpy(), axis=0) - else: - inputs = inputs.cpu().numpy() - outputs = outputs.cpu().numpy() - r, n = outputs.shape # number of frame repeats, number of neurons - if r < min_repeats: # minimum number of frame repeats to be considered for oracle, free choice - continue - assert np.all(np.abs(np.diff(inputs, axis=0)) == 0), "Images of oracle trials do not match" - repeated_inputs.append(inputs) - repeated_outputs.append(outputs) - return np.array(repeated_inputs), np.array(repeated_outputs) - - -def get_oracles(dataloaders, as_dict=False, per_neuron=True): - oracles = {} - for k, v in dataloaders.items(): - _, outputs = get_repeats(v) - oracles[k] = compute_oracle_corr(np.array(outputs)) - if not as_dict: - oracles = np.hstack([v for v in oracles.values()]) if per_neuron else np.mean(np.hstack([v for v in oracles.values()])) - return oracles - - -def get_oracles_corrected(dataloaders, as_dict=False, per_neuron=True): - oracles = {} - for k, v in dataloaders.items(): - _, outputs = get_repeats(v) - oracles[k] = compute_oracle_corr_corrected(np.array(outputs)) - if not as_dict: - oracles = np.hstack([v for v in oracles.values()]) if per_neuron else np.mean(np.hstack([v for v in oracles.values()])) - return oracles - - -def compute_oracle_corr_corrected(repeated_outputs): - """ - - Args: - repeated_outputs (list or array): array(images, repeats, responses), or a list of lists of repeats per image. - - Returns: the oracle correlations per neuron - - """ - if len(repeated_outputs.shape) == 3: - var_noise = repeated_outputs.var(axis=1).mean(0) - var_mean = repeated_outputs.mean(axis=1).var(0) - else: - var_noise, var_mean = [], [] - for repeat in repeated_outputs: - var_noise.append(repeat.var(axis=0)) - var_mean.append(repeat.mean(axis=0)) - var_noise = np.mean(np.array(var_noise), axis=0) - var_mean = np.var(np.array(var_mean), axis=0) - return var_mean / np.sqrt(var_mean * (var_mean + var_noise)) - - -def compute_oracle_corr(repeated_outputs): - if len(repeated_outputs.shape) == 3: - _, r, n = repeated_outputs.shape - oracles = (repeated_outputs.mean(axis=1, keepdims=True) - repeated_outputs / r) * r / (r - 1) - if np.any(np.isnan(oracles)): - warnings.warn('{}% NaNs when calculating the oracle. NaNs will be set to Zero.'.format(np.isnan(oracles).mean() * 100)) - oracles[np.isnan(oracles)] = 0 - return corr(oracles.reshape(-1, n), repeated_outputs.reshape(-1, n), axis=0) - else: - oracles = [] - for outputs in repeated_outputs: - r, n = outputs.shape - # compute the mean over repeats, for each neuron - mu = outputs.mean(axis=0, keepdims=True) - # compute oracle predictor - oracle = (mu - outputs / r) * r / (r - 1) - - if np.any(np.isnan(oracle)): - warnings.warn('{}% NaNs when calculating the oracle. NaNs will be set to Zero.'.format( - np.isnan(oracle).mean() * 100)) - oracle[np.isnan(oracle)] = 0 - - oracles.append(oracle) - return corr(np.vstack(repeated_outputs), np.vstack(oracles), axis=0) - - -def get_fraction_oracles(model, dataloaders, device='cpu', corrected=False): - dataloaders = dataloaders["test"] if "test" in dataloaders else dataloaders - if corrected: - oracles = get_oracles_corrected(dataloaders=dataloaders, as_dict=False, per_neuron=True) - else: - oracles = get_oracles(dataloaders=dataloaders, as_dict=False, per_neuron=True) - test_correlation = get_correlations(model=model, dataloaders=dataloaders, device=device, as_dict=False, per_neuron=True) - oracle_performance, _, _, _ = np.linalg.lstsq(np.hstack(oracles)[:, np.newaxis], np.hstack(test_correlation)) - return oracle_performance[0] - - -def get_explainable_var(dataloaders, as_dict=False, per_neuron=True, repeat_limit=None, randomize=True): - dataloaders = dataloaders["test"] if "test" in dataloaders else dataloaders - explainable_var = {} - for k ,v in dataloaders.items(): - _, outputs = get_repeats(v) - if repeat_limit is not None: - outputs = get_subset_of_repeats(outputs=outputs, repeat_limit=repeat_limit, randomize=randomize) - explainable_var[k] = compute_explainable_var(outputs) - if not as_dict: - explainable_var = np.hstack([v for v in explainable_var.values()]) if per_neuron else np.mean(np.hstack([v for v in explainable_var.values()])) - return explainable_var - - -def compute_explainable_var(outputs, eps=1e-9): - ImgVariance = [] - TotalVar = np.var(np.vstack(outputs), axis=0, ddof=1) - for out in outputs: - ImgVariance.append(np.var(out, axis=0, ddof=1)) - ImgVariance = np.vstack(ImgVariance) - NoiseVar = np.mean(ImgVariance, axis=0) - explainable_var = (TotalVar - NoiseVar) / (TotalVar + eps) - return explainable_var - - -def get_FEV(model, dataloaders, device='cpu', as_dict=False, per_neuron=True, threshold=None): - """ - Computes the fraction of explainable variance explained (FEVe) per Neuron, given a model and a dictionary of dataloaders. - The dataloaders will have to return batches of identical images, with the corresponing neuronal responses. - - Args: - model (object): PyTorch module - dataloaders (dict): Dictionary of dataloaders, with keys corresponding to "data_keys" in the model - device (str): 'cuda' or 'gpu - as_dict (bool): Returns the scores as a dictionary ('data_keys': values) if set to True. - per_neuron (bool): Returns the grand average if set to True. - threshold (float): for the avg feve, excludes neurons with a explainable variance below threshold - - Returns: - FEV (dict, or np.array, or float): Fraction of explainable varianced explained. Per Neuron or as grand average. - """ - dataloaders = dataloaders["test"] if "test" in dataloaders else dataloaders - FEV = {} - with eval_state(model) if not is_ensemble_function(model) else contextlib.nullcontext(): - for data_key, dataloader in dataloaders.items(): - targets, outputs = model_predictions_repeats(model=model, - dataloader=dataloader, - data_key=data_key, - device=device, - broadcast_to_target=True) - if threshold is None: - FEV[data_key] = compute_FEV(targets=targets, outputs=outputs) - else: - fev, feve = compute_FEV(targets=targets, outputs=outputs, return_exp_var=True) - FEV[data_key] = feve[fev>threshold] - if not as_dict: - FEV = np.hstack([v for v in FEV.values()]) if per_neuron else np.mean(np.hstack([v for v in FEV.values()])) - return FEV - - -def compute_FEV(targets, outputs, return_exp_var=False): - """ - - Args: - targets (list): Neuronal responses (ground truth) to image repeats. Dimensions: [num_images] np.array(num_reaps, num_neurons) - outputs (list): Model predictions to the repeated images, with an identical shape as the targets - return_exp_var (bool): returns the fraction of explainable variance per neuron if set to True - - Returns: - FEVe (np.array): the fraction of explainable variance explained per neuron - --- optional: FEV (np.array): the fraction - - """ - ImgVariance = [] - PredVariance = [] - for i, _ in enumerate(targets): - PredVariance.append((targets[i] - outputs[i]) ** 2) - ImgVariance.append(np.var(targets[i], axis=0, ddof=1)) - PredVariance = np.vstack(PredVariance) - ImgVariance = np.vstack(ImgVariance) - - TotalVar = np.var(np.vstack(targets), axis=0, ddof=1) - NoiseVar = np.mean(ImgVariance, axis=0) - FEV = (TotalVar - NoiseVar) / TotalVar - - PredVar = np.mean(PredVariance, axis=0) - FEVe = 1 - (PredVar - NoiseVar) / (TotalVar - NoiseVar) - return [FEV, FEVe] if return_exp_var else FEVe - - -def get_model_rf_size(model_config): - layers = model_config["layers"] - input_kern = model_config["input_kern"] - hidden_kern = model_config["hidden_kern"] - dil = model_config["hidden_dilation"] - rf_size = input_kern + ((hidden_kern-1) * dil)*(layers - 1) - return rf_size - - -def get_predictions(model, dataloaders, device='cpu', as_dict=False, per_neuron=True, test_data=True, **kwargs): - predictions = {} - with eval_state(model) if not isinstance(model, types.FunctionType) else contextlib.nullcontext(): - for k, v in dataloaders.items(): - if test_data: - _, output = model_predictions_repeats(dataloader=v, model=model, data_key=k, device=device) - else: - _, output = model_predictions(dataloader=v, model=model, data_key=k, device=device) - predictions[k] = output.T - - if not as_dict: - predictions = [v for v in predictions.values()] - return predictions - - -def get_targets(model, dataloaders, device='cpu', as_dict=True, per_neuron=True, test_data=True, **kwargs): - responses = {} - with eval_state(model) if not isinstance(model, types.FunctionType) else contextlib.nullcontext(): - for k, v in dataloaders.items(): - if test_data: - targets, _ = model_predictions_repeats(dataloader=v, model=model, data_key=k, device=device) - targets_per_neuron = [] - for i in range(targets[0].shape[1]): - neuronal_responses = [] - for repeats in targets: - neuronal_responses.append(repeats[:,i]) - targets_per_neuron.append(neuronal_responses) - responses[k] = targets_per_neuron - else: - targets, _ = model_predictions(dataloader=v, model=model, data_key=k, device=device) - responses[k] = targets.T - - if not as_dict: - responses = [v for v in responses.values()] - return responses - - -def get_avg_firing(dataloaders, as_dict=False, per_neuron=True): - """ - Returns average firing rate across the whole dataset - """ - - avg_firing = {} - for k, dataloader in dataloaders.items(): - target = torch.empty(0) - for images, responses in dataloader: - if len(images.shape) == 5: - responses = responses.squeeze(dim=0) - target = torch.cat((target, responses.detach().cpu()), dim=0) - avg_firing[k] = target.mean(0).numpy() - - if not as_dict: - avg_firing = np.hstack([v for v in avg_firing.values()]) if per_neuron else np.mean( - np.hstack([v for v in avg_firing.values()])) - return avg_firing - - -def get_fano_factor(dataloaders, as_dict=False, per_neuron=True): - """ - Returns average firing rate across the whole dataset - """ - - fano_factor = {} - for k, dataloader in dataloaders.items(): - target = torch.empty(0) - for images, responses in dataloader: - if len(images.shape) == 5: - responses = responses.squeeze(dim=0) - target = torch.cat((target, responses.detach().cpu()), dim=0) - fano_factor[k] = (target.var(0) / target.mean(0)).numpy() - - if not as_dict: - fano_factor = np.hstack([v for v in fano_factor.values()]) if per_neuron else np.mean( - np.hstack([v for v in fano_factor.values()])) - return fano_factor \ No newline at end of file diff --git a/nnfabrik/models/gaussian_readout_models.py b/nnfabrik/models/gaussian_readout_models.py new file mode 100644 index 00000000..20a3b625 --- /dev/null +++ b/nnfabrik/models/gaussian_readout_models.py @@ -0,0 +1,784 @@ +from collections import OrderedDict, Iterable +import numpy as np +import torch +import warnings +from torch import nn as nn +from torch.nn import Parameter +from torch.nn import functional as F +from torch.nn import ModuleDict +from mlutils.constraints import positive +from mlutils.layers.cores import DepthSeparableConv2d, Core2d, Stacked2dCore +from ..utility.nn_helpers import get_io_dims, get_module_output, set_random_seed, get_dims_for_loader_dict +from mlutils import regularizers +from mlutils.layers.readouts import PointPooled2d +from mlutils.layers.legacy import Gaussian2d +from .pretrained_models import TransferLearningCore + +# Squeeze and Excitation Block +class SQ_EX_Block(nn.Module): + def __init__(self, in_ch, reduction=16): + super(SQ_EX_Block, self).__init__() + self.se = nn.Sequential( + GlobalAvgPool(), + nn.Linear(in_ch, in_ch // reduction), + nn.ReLU(inplace=True), + nn.Linear(in_ch // reduction, in_ch), + nn.Sigmoid() + ) + + def forward(self, x): + se_weight = self.se(x).unsqueeze(-1).unsqueeze(-1) + return x.mul(se_weight) + + +class GlobalAvgPool(nn.Module): + def __init__(self): + super(GlobalAvgPool, self).__init__() + + def forward(self, x): + return x.view(*(x.shape[:-2]), -1).mean(-1) + + +class SE2dCore(Core2d, nn.Module): + def __init__( + self, + input_channels, + hidden_channels, + input_kern, + hidden_kern, + layers=3, + gamma_input=0.0, + skip=0, + final_nonlinearity=True, + bias=False, + momentum=0.1, + pad_input=True, + batch_norm=True, + hidden_dilation=1, + laplace_padding=None, + input_regularizer="LaplaceL2norm", + stack=None, + se_reduction=32, + n_se_blocks=1, + depth_separable=False, + ): + """ + Args: + input_channels: Integer, number of input channels as in + hidden_channels: Number of hidden channels (i.e feature maps) in each hidden layer + input_kern: kernel size of the first layer (i.e. the input layer) + hidden_kern: kernel size of each hidden layer's kernel + layers: number of layers + gamma_input: regularizer factor for the input weights (default: LaplaceL2, see mlutils.regularizers) + skip: Adds a skip connection + final_nonlinearity: Boolean, if true, appends an ELU layer after the last BatchNorm (if BN=True) + bias: Adds a bias layer. Note: bias and batch_norm can not both be true + momentum: BN momentum + pad_input: Boolean, if True, applies zero padding to all convolutions + batch_norm: Boolean, if True appends a BN layer after each convolutional layer + hidden_dilation: If set to > 1, will apply dilated convs for all hidden layers + laplace_padding: Padding size for the laplace convolution. If padding = None, it defaults to half of + the kernel size (recommended). Setting Padding to 0 is not recommended and leads to artefacts, + zero is the default however to recreate backwards compatibility. + normalize_laplace_regularizer: Boolean, if set to True, will use the LaplaceL2norm function from + mlutils.regularizers, which returns the regularizer as |laplace(filters)| / |filters| + input_regularizer: String that must match one of the regularizers in ..regularizers + stack: Int or iterable. Selects which layers of the core should be stacked for the readout. + default value will stack all layers on top of each other. + stack = -1 will only select the last layer as the readout layer + stack = 0 will only readout from the first layer + se_reduction: Int. Reduction of Channels for Global Pooling of the Squeeze and Excitation Block. + """ + + super().__init__() + + assert not bias or not batch_norm, "bias and batch_norm should not both be true" + + regularizer_config = ( + dict(padding=laplace_padding, kernel=input_kern) + if input_regularizer == "GaussianLaplaceL2" + else dict(padding=laplace_padding) + ) + self._input_weights_regularizer = regularizers.__dict__[input_regularizer](**regularizer_config) + + self.layers = layers + self.gamma_input = gamma_input + self.input_channels = input_channels + self.hidden_channels = hidden_channels + self.skip = skip + self.features = nn.Sequential() + self.n_se_blocks = n_se_blocks + if stack is None: + self.stack = range(self.layers) + else: + self.stack = [*range(self.layers)[stack:]] if isinstance(stack, int) else stack + + # --- first layer + layer = OrderedDict() + layer["conv"] = nn.Conv2d( + input_channels, hidden_channels, input_kern, padding=input_kern // 2 if pad_input else 0, bias=bias + ) + if batch_norm: + layer["norm"] = nn.BatchNorm2d(hidden_channels, momentum=momentum) + if layers > 1 or final_nonlinearity: + layer["nonlin"] = nn.ELU(inplace=True) + self.features.add_module("layer0", nn.Sequential(layer)) + + if not isinstance(hidden_kern, Iterable): + hidden_kern = [hidden_kern] * (self.layers - 1) + + # --- other layers + for l in range(1, self.layers): + layer = OrderedDict() + hidden_padding = ((hidden_kern[l - 1] - 1) * hidden_dilation + 1) // 2 + if depth_separable: + layer["ds_conv"] = DepthSeparableConv2d(hidden_channels, hidden_channels, kernel_size=hidden_kern[l - 1], + dilation=hidden_dilation, padding=hidden_padding, bias=False, + stride=1) + else: + layer["conv"] = nn.Conv2d( + hidden_channels if not skip > 1 else min(skip, l) * hidden_channels, + hidden_channels, + hidden_kern[l - 1], + padding=hidden_padding, + bias=bias, + dilation=hidden_dilation, + ) + if batch_norm: + layer["norm"] = nn.BatchNorm2d(hidden_channels, momentum=momentum) + + if final_nonlinearity or l < self.layers - 1: + layer["nonlin"] = nn.ELU(inplace=True) + + if (self.layers - l) <= self.n_se_blocks: + layer["seg_ex_block"] = SQ_EX_Block(in_ch=hidden_channels, reduction=se_reduction) + + self.features.add_module("layer{}".format(l), nn.Sequential(layer)) + + self.apply(self.init_conv) + + def forward(self, input_): + ret = [] + for l, feat in enumerate(self.features): + do_skip = l >= 1 and self.skip > 1 + input_ = feat(input_ if not do_skip else torch.cat(ret[-min(self.skip, l) :], dim=1)) + if l in self.stack: + ret.append(input_) + return torch.cat(ret, dim=1) + + def laplace(self): + return self._input_weights_regularizer(self.features[0].conv.weight) + + def regularizer(self): + return self.gamma_input * self.laplace() + + @property + def outchannels(self): + return len(self.features) * self.hidden_channels + + +class DepthSeparableCore(Core2d, nn.Module): + def __init__( + self, + input_channels, + hidden_channels, + input_kern, + hidden_kern, + layers=3, + gamma_input=0.0, + skip=0, + final_nonlinearity=True, + bias=False, + momentum=0.1, + pad_input=True, + batch_norm=True, + hidden_dilation=1, + laplace_padding=None, + input_regularizer="LaplaceL2norm", + stack=None, + ): + """ + Args: + input_channels: Integer, number of input channels as in + hidden_channels: Number of hidden channels (i.e feature maps) in each hidden layer + input_kern: kernel size of the first layer (i.e. the input layer) + hidden_kern: kernel size of each hidden layer's kernel + layers: number of layers + gamma_input: regularizer factor for the input weights (default: LaplaceL2, see mlutils.regularizers) + skip: Adds a skip connection + final_nonlinearity: Boolean, if true, appends an ELU layer after the last BatchNorm (if BN=True) + bias: Adds a bias layer. Note: bias and batch_norm can not both be true + momentum: BN momentum + pad_input: Boolean, if True, applies zero padding to all convolutions + batch_norm: Boolean, if True appends a BN layer after each convolutional layer + hidden_dilation: If set to > 1, will apply dilated convs for all hidden layers + laplace_padding: Padding size for the laplace convolution. If padding = None, it defaults to half of + the kernel size (recommended). Setting Padding to 0 is not recommended and leads to artefacts, + zero is the default however to recreate backwards compatibility. + normalize_laplace_regularizer: Boolean, if set to True, will use the LaplaceL2norm function from + mlutils.regularizers, which returns the regularizer as |laplace(filters)| / |filters| + input_regularizer: String that must match one of the regularizers in ..regularizers + stack: Int or iterable. Selects which layers of the core should be stacked for the readout. + default value will stack all layers on top of each other. + stack = -1 will only select the last layer as the readout layer + stack = 0 will only readout from the first layer + """ + + super().__init__() + + assert not bias or not batch_norm, "bias and batch_norm should not both be true" + + regularizer_config = ( + dict(padding=laplace_padding, kernel=input_kern) + if input_regularizer == "GaussianLaplaceL2" + else dict(padding=laplace_padding) + ) + self._input_weights_regularizer = regularizers.__dict__[input_regularizer](**regularizer_config) + + self.layers = layers + self.gamma_input = gamma_input + self.input_channels = input_channels + self.hidden_channels = hidden_channels + self.skip = skip + self.features = nn.Sequential() + if stack is None: + self.stack = range(self.layers) + else: + self.stack = [*range(self.layers)[stack:]] if isinstance(stack, int) else stack + + # --- first layer + layer = OrderedDict() + layer["conv"] = nn.Conv2d( + input_channels, hidden_channels, input_kern, padding=input_kern // 2 if pad_input else 0, bias=bias + ) + if batch_norm: + layer["norm"] = nn.BatchNorm2d(hidden_channels, momentum=momentum) + if layers > 1 or final_nonlinearity: + layer["nonlin"] = nn.ELU(inplace=True) + self.features.add_module("layer0", nn.Sequential(layer)) + + # def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True): + + if not isinstance(hidden_kern, Iterable): + hidden_kern = [hidden_kern] * (self.layers - 1) + + # --- other layers + for l in range(1, self.layers): + layer = OrderedDict() + hidden_padding = ((hidden_kern[l - 1] - 1) * hidden_dilation + 1) // 2 + layer["ds_conv"] = DepthSeparableConv2d(hidden_channels, hidden_channels, kernel_size=hidden_kern[l-1], dilation=hidden_dilation, padding=hidden_padding, bias=False, stride=1) + if batch_norm: + layer["norm"] = nn.BatchNorm2d(hidden_channels, momentum=momentum) + if final_nonlinearity or l < self.layers - 1: + layer["nonlin"] = nn.ELU(inplace=True) + self.features.add_module("layer{}".format(l), nn.Sequential(layer)) + + self.apply(self.init_conv) + + def forward(self, input_): + ret = [] + for l, feat in enumerate(self.features): + do_skip = l >= 1 and self.skip > 1 + input_ = feat(input_ if not do_skip else torch.cat(ret[-min(self.skip, l) :], dim=1)) + if l in self.stack: + ret.append(input_) + return torch.cat(ret, dim=1) + + def laplace(self): + return self._input_weights_regularizer(self.features[0].conv.weight) + + def regularizer(self): + return self.gamma_input * self.laplace() + + @property + def outchannels(self): + return len(self.features) * self.hidden_channels + + +class MultiplePointPooled2d(nn.ModuleDict): + def __init__(self, core, in_shape_dict, n_neurons_dict, pool_steps, pool_kern, bias, init_range, gamma_readout): + # super init to get the _module attribute + super(MultiplePointPooled2d, self).__init__() + for k in n_neurons_dict: + in_shape = get_module_output(core, in_shape_dict[k])[1:] + n_neurons = n_neurons_dict[k] + self.add_module(k, PointPooled2d( + in_shape, + n_neurons, + pool_steps=pool_steps, + pool_kern=pool_kern, + bias=bias, + init_range=init_range) + ) + + self.gamma_readout = gamma_readout + + def forward(self, *args, data_key=None, **kwargs): + if data_key is None and len(self) == 1: + data_key = list(self.keys())[0] + return self[data_key](*args, **kwargs) + + + def regularizer(self, data_key): + return self[data_key].feature_l1(average=False) * self.gamma_readout + + +class MultipleGaussian2d(nn.ModuleDict): + def __init__(self, core, in_shape_dict, n_neurons_dict, init_mu_range, init_sigma_range, bias, gamma_readout): + # super init to get the _module attribute + super(MultipleGaussian2d, self).__init__() + for k in n_neurons_dict: + in_shape = get_module_output(core, in_shape_dict[k])[1:] + n_neurons = n_neurons_dict[k] + self.add_module(k, Gaussian2d( + in_shape=in_shape, + outdims=n_neurons, + init_mu_range=init_mu_range, + init_sigma_range=init_sigma_range, + bias=bias) + ) + + self.gamma_readout = gamma_readout + + def forward(self, *args, data_key=None, **kwargs): + if data_key is None and len(self) == 1: + data_key = list(self.keys())[0] + return self[data_key](*args, **kwargs) + + def regularizer(self, data_key): + return self[data_key].feature_l1(average=False) * self.gamma_readout + + +def se_core_gauss_readout(dataloaders, seed, hidden_channels=32, input_kern=13, # core args + hidden_kern=3, layers=3, gamma_input=15.5, + skip=0, final_nonlinearity=True, momentum=0.9, + pad_input=False, batch_norm=True, hidden_dilation=1, + laplace_padding=None, input_regularizer='LaplaceL2norm', + init_mu_range=0.2, init_sigma_range=0.5, readout_bias=True, # readout args, + gamma_readout=4, elu_offset=0, stack=None, se_reduction=32, n_se_blocks=1, + depth_separable=False, + ): + """ + Model class of a stacked2dCore (from mlutils) and a pointpooled (spatial transformer) readout + + Args: + dataloaders: a dictionary of dataloaders, one loader per session + in the format {'data_key': dataloader object, .. } + seed: random seed + elu_offset: Offset for the output non-linearity [F.elu(x + self.offset)] + + all other args: See Documentation of Stacked2dCore in mlutils.layers.cores and + PointPooled2D in mlutils.layers.readouts + + Returns: An initialized model which consists of model.core and model.readout + """ + + if "train" in dataloaders.keys(): + dataloaders = dataloaders["train"] + + # Obtain the named tuple fields from the first entry of the first dataloader in the dictionary + in_name, out_name = next(iter(list(dataloaders.values())[0]))._fields + + session_shape_dict = get_dims_for_loader_dict(dataloaders) + n_neurons_dict = {k: v[out_name][1] for k, v in session_shape_dict.items()} + in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} + input_channels = [v[in_name][1] for v in session_shape_dict.values()] + + class Encoder(nn.Module): + + def __init__(self, core, readout, elu_offset): + super().__init__() + self.core = core + self.readout = readout + self.offset = elu_offset + + def forward(self, x, data_key=None, **kwargs): + x = self.core(x) + + sample = kwargs["sample"] if 'sample' in kwargs else None + x = self.readout(x, data_key=data_key, sample=sample) + return F.elu(x + self.offset) + 1 + + def regularizer(self, data_key): + return self.core.regularizer() + self.readout.regularizer(data_key=data_key) + + set_random_seed(seed) + + # get a stacked2D core from mlutils + core = SE2dCore(input_channels=input_channels[0], + hidden_channels=hidden_channels, + input_kern=input_kern, + hidden_kern=hidden_kern, + layers=layers, + gamma_input=gamma_input, + skip=skip, + final_nonlinearity=final_nonlinearity, + bias=False, + momentum=momentum, + pad_input=pad_input, + batch_norm=batch_norm, + hidden_dilation=hidden_dilation, + laplace_padding=laplace_padding, + input_regularizer=input_regularizer, + stack=stack, + se_reduction=se_reduction, + n_se_blocks=n_se_blocks, + depth_separable=depth_separable) + + readout = MultipleGaussian2d(core, in_shape_dict=in_shapes_dict, + n_neurons_dict=n_neurons_dict, + init_mu_range=init_mu_range, + bias=readout_bias, + init_sigma_range=init_sigma_range, + gamma_readout=gamma_readout) + + # initializing readout bias to mean response + if readout_bias: + for k in dataloaders: + readout[k].bias.data = dataloaders[k].dataset[:][1].mean(0) + + model = Encoder(core, readout, elu_offset) + + return model + + +def ds_core_gauss_readout(dataloaders, seed, hidden_channels=32, input_kern=13, # core args + hidden_kern=3, layers=3, gamma_input=0.1, + skip=0, final_nonlinearity=True, momentum=0.9, + pad_input=False, batch_norm=True, hidden_dilation=1, + laplace_padding=None, input_regularizer='LaplaceL2norm', + init_mu_range=0.2, init_sigma_range=0.5, readout_bias=True, # readout args, + gamma_readout=4, elu_offset=0, stack=None, + ): + """ + Model class of a stacked2dCore (from mlutils) and a pointpooled (spatial transformer) readout + + Args: + dataloaders: a dictionary of dataloaders, one loader per session + in the format {'data_key': dataloader object, .. } + seed: random seed + elu_offset: Offset for the output non-linearity [F.elu(x + self.offset)] + + all other args: See Documentation of Stacked2dCore in mlutils.layers.cores and + PointPooled2D in mlutils.layers.readouts + + Returns: An initialized model which consists of model.core and model.readout + """ + + if "train" in dataloaders.keys(): + dataloaders = dataloaders["train"] + + # Obtain the named tuple fields from the first entry of the first dataloader in the dictionary + in_name, out_name = next(iter(list(dataloaders.values())[0]))._fields + + session_shape_dict = get_dims_for_loader_dict(dataloaders) + n_neurons_dict = {k: v[out_name][1] for k, v in session_shape_dict.items()} + in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} + input_channels = [v[in_name][1] for v in session_shape_dict.values()] + + class Encoder(nn.Module): + + def __init__(self, core, readout, elu_offset): + super().__init__() + self.core = core + self.readout = readout + self.offset = elu_offset + + def forward(self, x, data_key=None, **kwargs): + x = self.core(x) + + sample = kwargs["sample"] if 'sample' in kwargs else None + x = self.readout(x, data_key=data_key, sample=sample) + return F.elu(x + self.offset) + 1 + + def regularizer(self, data_key): + return self.core.regularizer() + self.readout.regularizer(data_key=data_key) + + set_random_seed(seed) + + # get a stacked2D core from mlutils + core = DepthSeparableCore(input_channels=input_channels[0], + hidden_channels=hidden_channels, + input_kern=input_kern, + hidden_kern=hidden_kern, + layers=layers, + gamma_input=gamma_input, + skip=skip, + final_nonlinearity=final_nonlinearity, + bias=False, + momentum=momentum, + pad_input=pad_input, + batch_norm=batch_norm, + hidden_dilation=hidden_dilation, + laplace_padding=laplace_padding, + input_regularizer=input_regularizer, + stack=stack) + + readout = MultipleGaussian2d(core, in_shape_dict=in_shapes_dict, + n_neurons_dict=n_neurons_dict, + init_mu_range=init_mu_range, + bias=readout_bias, + init_sigma_range=init_sigma_range, + gamma_readout=gamma_readout) + + # initializing readout bias to mean response + if readout_bias: + for k in dataloaders: + readout[k].bias.data = dataloaders[k].dataset[:][1].mean(0) + + model = Encoder(core, readout, elu_offset) + + return model + + +def ds_core_point_readout(dataloaders, seed, hidden_channels=32, input_kern=13, # core args + hidden_kern=3, layers=3, gamma_input=0.1, + skip=0, final_nonlinearity=True, core_bias=False, momentum=0.9, + pad_input=False, batch_norm=True, hidden_dilation=1, + laplace_padding=None, input_regularizer='LaplaceL2norm', + pool_steps=2, pool_kern=3, readout_bias=True, # readout args, + init_range=0.2, gamma_readout=0.1, elu_offset=0, stack=None, + ): + """ + Model class of a stacked2dCore (from mlutils) and a pointpooled (spatial transformer) readout + + Args: + dataloaders: a dictionary of dataloaders, one loader per session + in the format {'data_key': dataloader object, .. } + seed: random seed + elu_offset: Offset for the output non-linearity [F.elu(x + self.offset)] + + all other args: See Documentation of Stacked2dCore in mlutils.layers.cores and + PointPooled2D in mlutils.layers.readouts + + Returns: An initialized model which consists of model.core and model.readout + """ + + if "train" in dataloaders.keys(): + dataloaders = dataloaders["train"] + + # Obtain the named tuple fields from the first entry of the first dataloader in the dictionary + in_name, out_name = next(iter(list(dataloaders.values())[0]))._fields + + session_shape_dict = get_dims_for_loader_dict(dataloaders) + n_neurons_dict = {k: v[out_name][1] for k, v in session_shape_dict.items()} + in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} + input_channels = [v[in_name][1] for v in session_shape_dict.values()] + + class Encoder(nn.Module): + + def __init__(self, core, readout, elu_offset): + super().__init__() + self.core = core + self.readout = readout + self.offset = elu_offset + + def forward(self, x, data_key=None, **kwargs): + x = self.core(x) + x = self.readout(x, data_key=data_key, **kwargs) + return F.elu(x + self.offset) + 1 + + def regularizer(self, data_key): + return self.core.regularizer() + self.readout.regularizer(data_key=data_key) + + set_random_seed(seed) + + # get a stacked2D core from mlutils + core = DepthSeparableCore(input_channels=input_channels[0], + hidden_channels=hidden_channels, + input_kern=input_kern, + hidden_kern=hidden_kern, + layers=layers, + gamma_input=gamma_input, + skip=skip, + final_nonlinearity=final_nonlinearity, + bias=core_bias, + momentum=momentum, + pad_input=pad_input, + batch_norm=batch_norm, + hidden_dilation=hidden_dilation, + laplace_padding=laplace_padding, + input_regularizer=input_regularizer, + stack=stack) + + readout = MultiplePointPooled2d(core, in_shape_dict=in_shapes_dict, + n_neurons_dict=n_neurons_dict, + pool_steps=pool_steps, + pool_kern=pool_kern, + bias=readout_bias, + gamma_readout=gamma_readout, + init_range=init_range) + + if readout_bias: + for k in dataloaders: + readout[k].bias.data = dataloaders[k].dataset[:][1].mean(0) + + model = Encoder(core, readout, elu_offset) + + return model + + +def stacked2d_core_gaussian_readout(dataloaders, seed, hidden_channels=32, input_kern=13, # core args + hidden_kern=3, layers=3, gamma_hidden=0, gamma_input=0.1, + skip=0, final_nonlinearity=True, core_bias=False, momentum=0.9, + pad_input=False, batch_norm=True, hidden_dilation=1, + laplace_padding=None, input_regularizer='LaplaceL2norm', + readout_bias=True, init_mu_range=0.2, init_sigma_range=0.5, # readout args, + gamma_readout=0.1, elu_offset=0, stack=None, + ): + """ + Model class of a stacked2dCore (from mlutils) and a pointpooled (spatial transformer) readout + + Args: + dataloaders: a dictionary of dataloaders, one loader per session + in the format {'data_key': dataloader object, .. } + seed: random seed + elu_offset: Offset for the output non-linearity [F.elu(x + self.offset)] + + all other args: See Documentation of Stacked2dCore in mlutils.layers.cores and + PointPooled2D in mlutils.layers.readouts + + Returns: An initialized model which consists of model.core and model.readout + """ + + if "train" in dataloaders.keys(): + dataloaders = dataloaders["train"] + + in_name, out_name = next(iter(list(dataloaders.values())[0]))._fields + + session_shape_dict = get_dims_for_loader_dict(dataloaders) + n_neurons_dict = {k: v[out_name][1] for k, v in session_shape_dict.items()} + in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} + input_channels = [v[in_name][1] for v in session_shape_dict.values()] + assert np.unique(input_channels).size == 1, "all input channels must be of equal size" + + class Encoder(nn.Module): + + def __init__(self, core, readout, elu_offset): + super().__init__() + self.core = core + self.readout = readout + self.offset = elu_offset + + def forward(self, x, data_key=None, **kwargs): + x = self.core(x) + x = self.readout(x, data_key=data_key, **kwargs) + return F.elu(x + self.offset) + 1 + + def regularizer(self, data_key): + return self.core.regularizer() + self.readout.regularizer(data_key=data_key) + + set_random_seed(seed) + + # get a stacked2D core from mlutils + core = Stacked2dCore(input_channels=input_channels[0], + hidden_channels=hidden_channels, + input_kern=input_kern, + hidden_kern=hidden_kern, + layers=layers, + gamma_hidden=gamma_hidden, + gamma_input=gamma_input, + skip=skip, + final_nonlinearity=final_nonlinearity, + bias=core_bias, + momentum=momentum, + pad_input=pad_input, + batch_norm=batch_norm, + hidden_dilation=hidden_dilation, + laplace_padding=laplace_padding, + input_regularizer=input_regularizer, + stack=stack) + + readout = MultipleGaussian2d(core, in_shape_dict=in_shapes_dict, + n_neurons_dict=n_neurons_dict, + init_mu_range=init_mu_range, + init_sigma_range=init_sigma_range, + bias=readout_bias, + gamma_readout=gamma_readout) + + if readout_bias: + for k in dataloaders: + readout[k].bias.data = dataloaders[k].dataset[:][1].mean(0) + + model = Encoder(core, readout, elu_offset) + + return model + + + +def vgg_core_gauss_readout(dataloaders, seed, + input_channels=1, tr_model_fn='vgg16', # begin of core args + model_layer=11, momentum=0.1, final_batchnorm=True, + final_nonlinearity=True, bias=False, + init_mu_range=0.4, init_sigma_range=0.6, readout_bias=True, # begin or readout args + gamma_readout=0.002, elu_offset=-1): + """ + A Model class of a predefined core (using models from torchvision.models). Can be initialized pretrained or random. + Can also be set to be trainable or not, independent of initialization. + + Args: + dataloaders: a dictionary of train-dataloaders, one loader per session + in the format {'data_key': dataloader object, .. } + seed: .. + pool_steps: + pool_kern: + readout_bias: + init_range: + gamma_readout: + + Returns: + """ + + if "train" in dataloaders.keys(): + dataloaders = dataloaders["train"] + + in_name, out_name = next(iter(list(dataloaders.values())[0]))._fields + + session_shape_dict = get_dims_for_loader_dict(dataloaders) + n_neurons_dict = {k: v[out_name][1] for k, v in session_shape_dict.items()} + in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} + input_channels = [v[in_name][1] for v in session_shape_dict.values()] + assert np.unique(input_channels).size == 1, "all input channels must be of equal size" + + class Encoder(nn.Module): + """ + helper nn class that combines the core and readout into the final model + """ + def __init__(self, core, readout, elu_offset): + super().__init__() + self.core = core + self.readout = readout + self.offset = elu_offset + + def forward(self, x, data_key=None, **kwargs): + x = self.core(x) + x = self.readout(x, data_key=data_key) + return F.elu(x + self.offset) + 1 + + def regularizer(self, data_key): + return self.readout.regularizer(data_key=data_key) + self.core.regularizer() + + set_random_seed(seed) + + core = TransferLearningCore(input_channels=input_channels[0], + tr_model_fn=tr_model_fn, + model_layer=model_layer, + momentum=momentum, + final_batchnorm=final_batchnorm, + final_nonlinearity=final_nonlinearity, + bias=bias) + + readout = MultipleGaussian2d(core, in_shape_dict=in_shapes_dict, + n_neurons_dict=n_neurons_dict, + init_mu_range=init_mu_range, + bias=readout_bias, + init_sigma_range=init_sigma_range, + gamma_readout=gamma_readout) + + if readout_bias: + for k in dataloaders: + readout[k].bias.data = dataloaders[k].dataset[:][1].mean(0) + + model = Encoder(core, readout, elu_offset) + + return model diff --git a/nnfabrik/models/pretrained_models.py b/nnfabrik/models/pretrained_models.py new file mode 100644 index 00000000..fb35efe6 --- /dev/null +++ b/nnfabrik/models/pretrained_models.py @@ -0,0 +1,92 @@ +from mlutils.layers.readouts import PointPooled2d +from mlutils.layers.cores import Core2d, Core +from ..utility.nn_helpers import get_io_dims, get_module_output, set_random_seed, get_dims_for_loader_dict + +from itertools import count +import numpy as np + +from torch import nn +from torch.nn import functional as F +import torchvision +from torchvision.models import vgg16, alexnet, vgg19 + + +class TransferLearningCore(Core2d, nn.Module): + """ + A Class to create a Core based on a model class from torchvision.models. + """ + + def __init__( + self, + input_channels, + tr_model_fn, + model_layer, + pretrained=True, + final_batchnorm=True, + final_nonlinearity=True, + bias=False, + momentum=0.1, + fine_tune=False, + **kwargs + ): + """ + Args: + input_channels: number of input channgels + tr_model_fn: string to specify the pretrained model, as in torchvision.models, e.g. 'vgg16' + model_layer: up onto which layer should the pretrained model be built + pretrained: boolean, if pretrained weights should be used + final_batchnorm: adds a batch norm layer + final_nonlinearity: adds a nonlinearity + bias: Adds a bias term. currently unused. + momentum: batch norm momentum + fine_tune: boolean, sets all weights to trainable if True + **kwargs: + """ + print("Ignoring input {} when creating {}".format(repr(kwargs), self.__class__.__name__)) + super().__init__() + + # getattr(self, tr_model_fn) + tr_model_fn = globals()[tr_model_fn] + + self.input_channels = input_channels + self.tr_model_fn = tr_model_fn + + tr_model = tr_model_fn(pretrained=pretrained) + self.model_layer = model_layer + self.features = nn.Sequential() + + tr_features = nn.Sequential(*list(tr_model.features.children())[:model_layer]) + + # Fix pretrained parameters during training parameters + if not fine_tune: + for param in tr_features.parameters(): + param.requires_grad = False + + self.features.add_module("TransferLearning", tr_features) + print(self.features) + if final_batchnorm: + self.features.add_module("OutBatchNorm", nn.BatchNorm2d(self.outchannels, momentum=momentum)) + if final_nonlinearity: + self.features.add_module("OutNonlin", nn.ReLU(inplace=True)) + + def forward(self, x): + if self.input_channels == 1: + x = x.expand(-1, 3, -1, -1) + return self.features(x) + + def regularizer(self): + return 0 + + @property + def outchannels(self): + """ + Returns: dimensions of the output, after a forward pass through the model + """ + found_out_channels = False + i = 1 + while not found_out_channels: + if "out_channels" in self.features.TransferLearning[-i].__dict__: + found_out_channels = True + else: + i = i + 1 + return self.features.TransferLearning[-i].out_channels diff --git a/nnfabrik/models/v1_models.py b/nnfabrik/models/v1_models.py new file mode 100644 index 00000000..a95311cf --- /dev/null +++ b/nnfabrik/models/v1_models.py @@ -0,0 +1,241 @@ +import numpy as np +from torch import nn as nn +from torch.nn import functional as F + +from mlutils.layers.readouts import PointPooled2d +from mlutils.layers.cores import Stacked2dCore +from mlutils.training import eval_state + +from .pretrained_models import TransferLearningCore +from ..utility.nn_helpers import get_io_dims, get_module_output, set_random_seed, get_dims_for_loader_dict + +class MultiplePointPooled2d(nn.ModuleDict): + def __init__(self, core, in_shape_dict, n_neurons_dict, pool_steps, pool_kern, bias, init_range, gamma_readout, readout_reg_avg): + # super init to get the _module attribute + super(MultiplePointPooled2d, self).__init__() + for k in n_neurons_dict: + in_shape = get_module_output(core, in_shape_dict[k])[1:] + n_neurons = n_neurons_dict[k] + self.add_module(k, PointPooled2d( + in_shape, + n_neurons, + pool_steps=pool_steps, + pool_kern=pool_kern, + bias=bias, + init_range=init_range) + ) + + self.gamma_readout = gamma_readout + self.readout_reg_avg = readout_reg_avg + + def forward(self, *args, data_key=None, **kwargs): + if data_key is None and len(self) == 1: + data_key = list(self.keys())[0] + return self[data_key](*args, **kwargs) + + + def regularizer(self, data_key): + return self[data_key].feature_l1(average=self.readout_reg_avg) * self.gamma_readout + + +def stacked2d_core_point_readout(dataloaders, seed, hidden_channels=32, input_kern=13, # core args + hidden_kern=3, layers=3, gamma_hidden=0, gamma_input=0.1, + skip=0, final_nonlinearity=True, core_bias=False, momentum=0.9, + pad_input=False, batch_norm=True, hidden_dilation=1, + laplace_padding=None, input_regularizer='LaplaceL2norm', + pool_steps=2, pool_kern=7, readout_bias=True, init_range=0.1, # readout args, + gamma_readout=0.1, elu_offset=0, stack=None, readout_reg_avg=False, + use_avg_reg=False): + """ + Model class of a stacked2dCore (from mlutils) and a pointpooled (spatial transformer) readout + + Args: + dataloaders: a dictionary of dataloaders, one loader per session + in the format {'data_key': dataloader object, .. } + seed: random seed + elu_offset: Offset for the output non-linearity [F.elu(x + self.offset)] + + all other args: See Documentation of Stacked2dCore in mlutils.layers.cores and + PointPooled2D in mlutils.layers.readouts + + Returns: An initialized model which consists of model.core and model.readout + """ + + + # make sure trainloader is being used + dataloaders = dataloaders.get("train", dataloaders) + + # Obtain the named tuple fields from the first entry of the first dataloader in the dictionary + in_name, out_name = next(iter(list(dataloaders.values())[0]))._fields + + session_shape_dict = get_dims_for_loader_dict(dataloaders) + n_neurons_dict = {k: v[out_name][1] for k, v in session_shape_dict.items()} + in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} + input_channels = [v[in_name][1] for v in session_shape_dict.values()] + + assert np.unique(input_channels).size == 1, "all input channels must be of equal size" + + class Encoder(nn.Module): + + def __init__(self, core, readout, elu_offset): + super().__init__() + self.core = core + self.readout = readout + self.offset = elu_offset + + def forward(self, x, data_key=None, **kwargs): + x = self.core(x) + x = self.readout(x, data_key=data_key) + return F.elu(x + self.offset) + 1 + + def regularizer(self, data_key): + return self.core.regularizer() + self.readout.regularizer(data_key=data_key) + + def _readout_regularizer_val(self): + ret = 0 + with eval_state(model): + for data_key in model.readout: + ret += self.readout.regularizer(data_key).detach().cpu().numpy() + return ret + + def _core_regularizer_val(self): + with eval_state(model): + return self.core.regularizer().detach().cpu().numpy() if model.core.regularizer() else 0 + + @property + def tracked_values(self): + return dict(readout_l1=self._readout_regularizer_val, + core_reg=self._core_regularizer_val) + + set_random_seed(seed) + + # get a stacked2D core from mlutils + core = Stacked2dCore(input_channels=input_channels[0], + hidden_channels=hidden_channels, + input_kern=input_kern, + hidden_kern=hidden_kern, + layers=layers, + gamma_hidden=gamma_hidden, + gamma_input=gamma_input, + skip=skip, + final_nonlinearity=final_nonlinearity, + bias=core_bias, + momentum=momentum, + pad_input=pad_input, + batch_norm=batch_norm, + hidden_dilation=hidden_dilation, + laplace_padding=laplace_padding, + input_regularizer=input_regularizer, + stack=stack, + use_avg_reg=use_avg_reg) + + readout = MultiplePointPooled2d(core, + in_shape_dict=in_shapes_dict, + n_neurons_dict=n_neurons_dict, + pool_steps=pool_steps, + pool_kern=pool_kern, + bias=readout_bias, + init_range=init_range, + gamma_readout=gamma_readout, + readout_reg_avg=readout_reg_avg) + + # initializing readout bias to mean response + if readout_bias: + for k in dataloaders: + readout[k].bias.data = dataloaders[k].dataset[:][1].mean(0) + + model = Encoder(core, readout, elu_offset) + + return model + + +def vgg_core_point_readout(dataloaders, seed, + input_channels=1, tr_model_fn='vgg16', # begin of core args + model_layer=11, momentum=0.1, final_batchnorm=True, + final_nonlinearity=True, bias=False, + pool_steps=1, pool_kern=7, readout_bias=True, # begin or readout args + init_range=0.1, gamma_readout=0.002, elu_offset=-1, readout_reg_avg=False): + """ + A Model class of a predefined core (using models from torchvision.models). Can be initialized pretrained or random. + Can also be set to be trainable or not, independent of initialization. + + Args: + dataloaders: a dictionary of train-dataloaders, one loader per session + in the format {'data_key': dataloader object, .. } + seed: .. + pool_steps: + pool_kern: + readout_bias: + init_range: + gamma_readout: + + Returns: + """ + + if "train" in dataloaders.keys(): + dataloaders = dataloaders["train"] + + in_name, out_name = next(iter(list(dataloaders.values())[0]))._fields + + session_shape_dict = get_dims_for_loader_dict(dataloaders) + n_neurons_dict = {k: v[out_name][1] for k, v in session_shape_dict.items()} + in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} + input_channels = [v[in_name][1] for v in session_shape_dict.values()] + + class Encoder(nn.Module): + """ + helper nn class that combines the core and readout into the final model + """ + def __init__(self, core, readout, elu_offset): + super().__init__() + self.core = core + self.readout = readout + self.offset = elu_offset + + def forward(self, x, data_key=None, **kwargs): + x = self.core(x) + x = self.readout(x, data_key=data_key) + return F.elu(x + self.offset) + 1 + + def regularizer(self, data_key): + return self.readout.regularizer(data_key=data_key) + self.core.regularizer() + + def _readout_regularizer_val(self): + ret = 0 + with eval_state(model): + for data_key in model.readout: + ret += self.readout.regularizer(data_key).detach().cpu().numpy() + return ret + + @property + def tracked_values(self): + return dict(readout_l1=self._readout_regularizer_val) + + + set_random_seed(seed) + + core = TransferLearningCore(input_channels=input_channels[0], + tr_model_fn=tr_model_fn, + model_layer=model_layer, + momentum=momentum, + final_batchnorm=final_batchnorm, + final_nonlinearity=final_nonlinearity, + bias=bias) + + readout = MultiplePointPooled2d(core, in_shape_dict=in_shapes_dict, + n_neurons_dict=n_neurons_dict, + pool_steps=pool_steps, + pool_kern=pool_kern, + bias=readout_bias, + init_range=init_range, + gamma_readout=gamma_readout, + readout_reg_avg=readout_reg_avg) + + # initializing readout bias to mean response + if readout_bias: + for k in dataloaders: + readout[k].bias.data = dataloaders[k].dataset[:][1].mean(0) + + model = Encoder(core, readout, elu_offset) + + return model diff --git a/nnfabrik/template.py b/nnfabrik/template.py index 4bf38994..8d05b056 100644 --- a/nnfabrik/template.py +++ b/nnfabrik/template.py @@ -135,6 +135,7 @@ def get_full_config(self, key=None, include_state_dict=True, include_trainer=Tru model_fn, model_config = (self.model_table & key).fn_config dataset_fn, dataset_config = (self.dataset_table & key).fn_config + ret = dict(model_fn=model_fn, model_config=model_config, dataset_fn=dataset_fn, dataset_config=dataset_config) @@ -196,10 +197,6 @@ def load_model(self, key=None, include_dataloader=True, include_trainer=False, i print("Model could not be built without the dataloader. Dataloader will be built in order to create the model. " "Make sure to have an The 'model_fn' also has to be able to" "accept 'data_info' as an input arg, and use that over the dataloader to build the model.") - - ret = get_all_parts(**config_dict, seed=seed) - return ret[1:] if include_trainer else ret[1] - return get_all_parts(**config_dict, seed=seed) def call_back(self, uid=None, epoch=None, model=None, info=None): @@ -346,10 +343,9 @@ def get_model(self, key=None): model = self.model_cache.load(key=key, include_state_dict=True, include_dataloader=False) - model.eval() - model.to("cuda") return model + def get_dataloaders(self, key=None): if key is None: key = self.fetch1('KEY') @@ -357,7 +353,10 @@ def get_dataloaders(self, key=None): return dataloaders[self.measure_dataset] def get_repeats_dataloaders(self, key=None): - raise NotImplementedError("Function to return the repeats-dataloader has to be implemented") + if key is None: + key = self.fetch1('KEY') + dataloaders = self.dataset_table().get_dataloader(key=key) if self.data_cache is None else self.data_cache.load(key=key) + return dataloaders["test"] def get_avg_of_unit_dict(self, unit_scores_dict): return np.mean(np.hstack([v for v in unit_scores_dict.values()])) diff --git a/nnfabrik/training/trainers.py b/nnfabrik/training/trainers.py new file mode 100644 index 00000000..213597e8 --- /dev/null +++ b/nnfabrik/training/trainers.py @@ -0,0 +1,350 @@ +import warnings +from functools import partial + +import numpy as np +import torch +from scipy import stats +from tqdm import tqdm + +from mlutils import measures +from mlutils.measures import * +from mlutils.training import early_stopping, MultipleObjectiveTracker, eval_state, cycle_datasets, Exhauster, LongCycler +from ..utility.nn_helpers import set_random_seed + +from ..utility import metrics +from ..utility.metrics import corr_stop, poisson_stop + + +def early_stop_trainer(model, seed, stop_function='corr_stop', + loss_function='PoissonLoss', epoch=0, interval=1, patience=10, max_iter=75, + maximize=True, tolerance=0.001, device='cuda', restore_best=True, + lr_init=0.005, lr_decay_factor=0.3, min_lr=0.0001, optim_batch_step=True, + verbose=True, lr_decay_steps=3, dataloaders=None, **kwargs): + """" + Args: + model: PyTorch nn module + seed: random seed + trainer_config: + lr_schedule: list or ndarray that contains lr and lr decrements after early stopping kicks in + stop_function: stop condition in early stopping, has to be one string of the following: + 'corr_stop' + 'gamma stop' + 'exp_stop' + 'poisson_stop' + loss_function: has to be a string that gets evaluated with eval() + Loss functions that are built in at mlutils that can + be selected in the trainer config are: + 'PoissonLoss' + 'GammaLoss' + device: Device that the model resides on. Expects arguments such as torch.device('') + Examples: 'cpu', 'cuda:2' (0-indexed gpu) + + Pytorch Dataloaders are expanded into dictionary of individual loaders + train: PyTorch DtaLoader -- training data + val: validation data loader + test: test data loader -- not used during training + + Returns: + score: performance score of the model + output: user specified validation object based on the 'stop function' + model_state: the full state_dict() of the trained model + """ + + train = dataloaders["train"] if dataloaders else kwargs["train"] + val = dataloaders["val"] if dataloaders else kwargs["val"] + test = dataloaders["test"] if dataloaders else kwargs["test"] + + # --- begin of helper function definitions + def model_predictions(loader, model, data_key): + """ + computes model predictions for a given dataloader and a model + Returns: + target: ground truth, i.e. neuronal firing rates of the neurons + output: responses as predicted by the network + """ + target, output = torch.empty(0), torch.empty(0) + for images, responses in loader[data_key]: + output = torch.cat((output, (model(images.to(device), data_key=data_key).detach().cpu())), dim=0) + target = torch.cat((target, responses.detach().cpu()), dim=0) + + return target.numpy(), output.numpy() + + # all early stopping conditions + def corr_stop(model, loader=None, avg=True): + """ + Returns either the average correlation of all neurons or the the correlations per neuron. + Gets called by early stopping and the model performance evaluation + """ + loader = val if loader is None else loader + n_neurons, correlations_sum = 0, 0 + if not avg: + all_correlations = np.array([]) + + for data_key in loader: + with eval_state(model): + target, output = model_predictions(loader, model, data_key) + + ret = corr(target, output, axis=0) + + if np.any(np.isnan(ret)): + warnings.warn('{}% NaNs '.format(np.isnan(ret).mean() * 100)) + ret[np.isnan(ret)] = 0 + + if not avg: + all_correlations = np.append(all_correlations, ret) + else: + n_neurons += output.shape[1] + correlations_sum += ret.sum() + + corr_ret = correlations_sum / n_neurons if avg else all_correlations + return corr_ret + + def gamma_stop(model): + with eval_state(model): + target, output = model_predictions(val, model) + + ret = -stats.gamma.logpdf(target + 1e-7, output + 0.5).mean(axis=1) / np.log(2) + if np.any(np.isnan(ret)): + warnings.warn(' {}% NaNs '.format(np.isnan(ret).mean() * 100)) + ret[np.isnan(ret)] = 0 + return ret.mean() + + def exp_stop(model, bias=1e-12, target_bias=1e-7): + with eval_state(model): + target, output = model_predictions(val, model) + target = target + target_bias + output = output + bias + ret = (target / output + np.log(output)).mean(axis=1) / np.log(2) + if np.any(np.isnan(ret)): + warnings.warn(' {}% NaNs '.format(np.isnan(ret).mean() * 100)) + ret[np.isnan(ret)] = 0 + # -- average if requested + return ret.mean() + + def poisson_stop(model, loader=None, avg=False): + poisson_losses = np.array([]) + loader = val if loader is None else loader + n_neurons = 0 + for data_key in loader: + with eval_state(model): + target, output = model_predictions(loader, model, data_key) + + ret = output - target * np.log(output + 1e-12) + if np.any(np.isnan(ret)): + warnings.warn(' {}% NaNs '.format(np.isnan(ret).mean() * 100)) + + poisson_losses = np.append(poisson_losses, np.nansum(ret, 0)) + n_neurons += output.shape[1] + return poisson_losses.sum()/n_neurons if avg else poisson_losses.sum() + + def readout_regularizer_stop(model): + ret = 0 + with eval_state(model): + for data_key in val: + ret += model.readout.regularizer(data_key).detach().cpu().numpy() + return ret + + def core_regularizer_stop(model): + with eval_state(model): + if model.core.regularizer(): + return model.core.regularizer().detach().cpu().numpy() + else: + return 0 + + def full_objective(model, data_key, inputs, targets, **kwargs): + """ + Computes the training loss for the model and prespecified criterion. + Default: PoissonLoss, summed over Neurons and Batches, scaled by dataset + size and batch size to account for batch noise. + + Args: + inputs: i.e. images + targets: neuronal responses that the model should predict + + Returns: training loss summed over all neurons. Summed over batches and Neurons + + """ + m = len(train[data_key].dataset) + k = inputs.shape[0] + + return np.sqrt(m / k) * criterion(model(inputs.to(device), data_key=data_key, **kwargs), targets.to(device)).sum() \ + + model.regularizer(data_key) + + + def run(model, full_objective, optimizer, scheduler, stop_closure, train_loader, + epoch, interval, patience, max_iter, maximize, tolerance, + restore_best, tracker, optim_step_count, lr_decay_steps): + + for epoch, val_obj in early_stopping(model, stop_closure, + interval=interval, patience=patience, + start=epoch, max_iter=max_iter, maximize=maximize, + tolerance=tolerance, restore_best=restore_best, + tracker=tracker, scheduler=scheduler, lr_decay_steps=lr_decay_steps): + optimizer.zero_grad() + + # reports the entry of the current epoch for all tracked objectives + if verbose: + for key in tracker.log.keys(): + print(key, tracker.log[key][-1]) + + # Beginning of main training loop + for batch_no, (data_key, data) in tqdm(enumerate(LongCycler(train_loader)), + desc='Epoch {}'.format(epoch)): + + loss = full_objective(model, data_key, *data) + if (batch_no+1) % optim_step_count == 0: + optimizer.step() + optimizer.zero_grad() + loss.backward() + + # End of training + return model, epoch + + # model setup + set_random_seed(seed) + model.to(device) + model.train() + + # current criterium is supposed to be poisson loss. Only for that loss, the additional arguments are defined + criterion = eval(loss_function)(per_neuron=True, avg=False) + + # get stopping criterion from helper functions based on keyword + stop_closure = eval(stop_function) + + tracker = MultipleObjectiveTracker(correlation=partial(corr_stop, model), + poisson_loss=partial(poisson_stop, model), + poisson_loss_val=partial(poisson_stop, model, val), + readout_l1=partial(readout_regularizer_stop, model), + core_regularizer=partial(core_regularizer_stop, model)) + + trainable_params = [p for p in list(model.parameters()) if p.requires_grad] + optimizer = torch.optim.Adam(trainable_params, lr=lr_init) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, + mode='max' if maximize else 'min', + factor=lr_decay_factor, + patience=patience, + threshold=tolerance, + min_lr=min_lr, + verbose=verbose, + threshold_mode='abs', + ) + + optim_step_count = len(train.keys()) if optim_batch_step else 1 + + model, epoch = run(model=model, + full_objective=full_objective, + optimizer=optimizer, + scheduler=scheduler, + stop_closure=stop_closure, + train_loader=train, + epoch=epoch, + interval=interval, + patience=patience, + max_iter=max_iter, + lr_decay_steps=lr_decay_steps, + maximize=maximize, + tolerance=tolerance, + restore_best=restore_best, + tracker=tracker, + optim_step_count=optim_step_count) + + model.eval() + tracker.finalize() + + # compute average test correlations as the score + avg_corr = corr_stop(model, test, avg=True) + + # return the whole tracker output as a dict + output = {k: v for k, v in tracker.log.items()} + return avg_corr, output, model.state_dict() + + +def standard_early_stop_trainer(model, dataloaders, seed, avg_loss=True, scale_loss=True, # trainer args + loss_function='PoissonLoss', stop_function='corr_stop', + loss_accum_batch_n=None, device='cuda', verbose=True, + interval=1, patience=5, epoch=0, lr_init=0.005, # early stopping args + max_iter=100, maximize=True, tolerance=1e-6, + restore_best=True, lr_decay_steps=3, + lr_decay_factor=0.3, min_lr=0.0001, # lr scheduler args + cb=None, **kwargs): + + def full_objective(model, data_key, inputs, targets): + if scale_loss: + m = len(trainloaders[data_key].dataset) + k = inputs.shape[0] + loss_scale = np.sqrt(m / k) + else: + loss_scale = 1.0 + + return loss_scale * criterion(model(inputs.to(device), data_key), targets.to(device)) + model.regularizer(data_key) + + trainloaders = dataloaders["train"] + valloaders = dataloaders.get("validation", dataloaders["val"] if "val" in dataloaders.keys() else None) + testloaders = dataloaders["test"] + + ##### Model training #################################################################################################### + model.to(device) + set_random_seed(seed) + model.train() + + criterion = getattr(measures, loss_function)(avg=avg_loss) + stop_closure = partial(getattr(metrics, stop_function), model, valloaders, device=device) + + n_iterations = len(LongCycler(trainloaders)) + + optimizer = torch.optim.Adam(model.parameters(), lr=lr_init) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max' if maximize else 'min', + factor=lr_decay_factor, patience=patience, threshold=tolerance, + min_lr=min_lr, verbose=verbose, threshold_mode='abs') + + # set the number of iterations over which you would like to accummulate gradients + optim_step_count = len(trainloaders.keys()) if loss_accum_batch_n is None else loss_accum_batch_n + + # define some trackers + tracker_dict = dict(correlation=partial(corr_stop, model, valloaders, device=device), + poisson_loss=partial(poisson_stop, model, valloaders, device=device), + poisson_loss_val=partial(poisson_stop, model, valloaders, device=device)) + + if hasattr(model, 'tracked_values'): + tracker_dict.update(model.tracked_values) + + tracker = MultipleObjectiveTracker(**tracker_dict) + + # train over epochs + for epoch, val_obj in early_stopping(model, stop_closure, interval=interval, patience=patience, + start=epoch, max_iter=max_iter, maximize=maximize, + tolerance=tolerance, restore_best=restore_best, tracker=tracker, + scheduler=scheduler, lr_decay_steps=lr_decay_steps): + + # print the quantities from tracker + if verbose and tracker is not None: + print("=======================================") + for key in tracker.log.keys(): + print(key, tracker.log[key][-1], flush=True) + + # executes callback function if passed in keyword args + if cb is not None: + cb() + + # train over batches + optimizer.zero_grad() + for batch_no, (data_key, data) in tqdm(enumerate(LongCycler(trainloaders)), total=n_iterations, desc="Epoch {}".format(epoch)): + + loss = full_objective(model, data_key, *data) + loss.backward() + if (batch_no+1) % optim_step_count == 0: + optimizer.step() + optimizer.zero_grad() + + ##### Model evaluation #################################################################################################### + model.eval() + tracker.finalize() + + # Compute avg validation and test correlation + avg_val_corr = corr_stop(model, valloaders, avg=True, device=device) + avg_test_corr = corr_stop(model, testloaders, avg=True, device=device) + + # return the whole tracker output as a dict + output = {k: v for k, v in tracker.log.items()} + + return avg_test_corr, output, model.state_dict() diff --git a/notebooks/nnfabrik_monkey_demo.ipynb b/notebooks/nnfabrik_monkey_demo.ipynb index 323ebb02..dfaeb13c 100644 --- a/notebooks/nnfabrik_monkey_demo.ipynb +++ b/notebooks/nnfabrik_monkey_demo.ipynb @@ -20,10 +20,9 @@ "source": [ "- nnfabrik, mlutils, and nnvision from the sinzlab repository have to be installed/cloned\n", "\n", - "- we are using the pytorch image called `sinzlab/pytorch:v3.8-torch1.4.0-cuda10.1-dj0.12.4`\n", + "- we are using the pytorch image called `sinzlab/pytorch:1.3.1-cuda10.1-dj0.12.4`\n", " - docker image can be found here: https://github.com/sinzlab/pytorch-docker or https://hub.docker.com/r/sinzlab/pytorch/dockerfile\n", - " - there, the complete list of packages to be installed can be found.\n", - " - torch >= 1.4 is required.\n", + " - there, the complete list of packages to be installed can be found. \n", "
\n", "
\n", "- All individual pickle files and the image-pickle file have to be present. The can be found on the GPU server under /var/lib/nova/sinz-shared/data\n" @@ -40,7 +39,35 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "make sure to install all required packages. dependencies are listed in the dockerfile above. If necessary, install packages within the environment\n" + "make sure to install all required packages. dependencies are listed in the dockerfile above. If necessary, install packages within the environment" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Make sure that a dj-database is connected. Recommended dj version is 0.12.4" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Connecting kwilleke@sinzlab.chlkmukhxp6i.eu-central-1.rds.amazonaws.com:3306\n" + ] + } + ], + "source": [ + "import datajoint as dj\n", + "\n", + "dj.config['enable_python_native_blobs'] = True\n", + "dj.config['schema_name'] = \"nnfabrik_playground\"\n", + "schema = dj.schema(\"nnfabrik_playground\")\n" ] }, { @@ -52,8 +79,9 @@ "import torch\n", "\n", "import nnfabrik\n", + "from nnfabrik.main import *\n", "from nnfabrik import builder\n", - "\n", + "from nnfabrik.template import TrainedModelBase\n", "\n", "import numpy as np\n", "import pickle\n", @@ -244,7 +272,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Building a model" + "### A Model from nnfabrik.models" ] }, { @@ -253,7 +281,45 @@ "metadata": {}, "outputs": [], "source": [ + "model_fn='stacked2d_core_point_readout'\n", "\n", + "model_config = dict(\n", + " hidden_kern=3,\n", + " layers=2,\n", + " pad_input=False,\n", + " gamma_readout=.25,\n", + " gamma_input=11.2,\n", + " gamma_hidden=1e-6,\n", + " pool_kern=2,\n", + " pool_steps=1,\n", + " stack=None,\n", + " )\n", + " \n", + "model = builder.get_model(model_fn, model_config, dataloaders,seed=1000)\n", + "print(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Above is the model description. Each session has its own readout. The Readout learns an x,y position between -1 and 1, relative to image space, and reads out from that point in feature space. THat means that the effective receptive field size of a unit in the last hidden layer will also be the receptive field size of the neuron. The x/y coordinates can be accessed like this:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Loading a model from a different repo: \n", + "#### reponame.module.function_name\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "model_fn = 'nnvision.models.se_core_full_gauss_readout'\n", "model_config = {'pad_input': False,\n", " 'stack': -1,\n", @@ -265,15 +331,7 @@ " 'hidden_kern': 5,\n", " 'n_se_blocks': 0,\n", " 'hidden_channels': 32}\n", - "model = builder.get_model(model_fn, model_config, dataloaders=dataloaders,seed=1000)\n", - "print(model)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Above is the model description. Each session has its own readout. The Readout learns an x,y position between -1 and 1, relative to image space, and reads out from that point in feature space. THat means that the effective receptive field size of a unit in the last hidden layer will also be the receptive field size of the neuron. The x/y coordinates can be accessed like this:\n" + "se_model = builder.get_model(model_fn, model_config, dataloaders=dataloaders,seed=1000)\n" ] }, { @@ -364,38 +422,15 @@ "source": [ "Instead of using the builder to get the data/model/and trainer, we can use datajoint to manage that process for us.\n", "There are Model, Dataset, and Trainer Tables. And each combination in those tables should in principle lead to a fully trained model.\n", - "For completeness, there is also a Seed table that stores the random seed, and a Fabrikant table, that stores the name and contact details of the creator (=Fabrikant).\n" + "For completeness, there is also a Seed table that stores the random seed, and a Fabrikant table, that stores the name and contact details of the creator (=Fabrikant)." ] }, - { - "cell_type": "markdown", - "source": [ - "### Make sure that a dj-database is connected. Recommended dj version is 0.12.4" - ], - "metadata": { - "collapsed": false - } - }, { "cell_type": "code", "execution_count": null, + "metadata": {}, "outputs": [], - "source": [ - "import datajoint as dj\n", - "\n", - "dj.config['enable_python_native_blobs'] = True\n", - "dj.config['schema_name'] = \"nnfabrik_playground\"\n", - "schema = dj.schema(\"nnfabrik_playground\")\n", - "\n", - "from nnfabrik.template import TrainedModelBase\n", - "from nnfabrik.main import *\n" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + "source": [] }, { "cell_type": "code", @@ -742,17 +777,8 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.0" - }, - "pycharm": { - "stem_cell": { - "cell_type": "raw", - "source": [], - "metadata": { - "collapsed": false - } - } } }, "nbformat": 4, "nbformat_minor": 4 -} \ No newline at end of file +}