From 7bbcf47751364357bd99411b377c75fca62effeb Mon Sep 17 00:00:00 2001 From: Mohammad Bashiri Date: Sun, 13 Feb 2022 17:05:52 +0100 Subject: [PATCH] Update imports. --- nnsysident/models/models.py | 98 +++++++++++++++++++++++++++++-------- 1 file changed, 77 insertions(+), 21 deletions(-) diff --git a/nnsysident/models/models.py b/nnsysident/models/models.py index ed1323f..833c91a 100644 --- a/nnsysident/models/models.py +++ b/nnsysident/models/models.py @@ -3,7 +3,7 @@ import copy from nnfabrik.utility.nn_helpers import set_random_seed, get_dims_for_loader_dict -from neuralpredictors.layers import ( +from neuralpredictors.layers.readouts import ( MultipleFullGaussian2d, MultiplePointPooled2d, MultipleSpatialXFeatureLinear, @@ -126,7 +126,11 @@ def se2d_fullgaussian2d( in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} input_channels = [v[in_name][1] for v in session_shape_dict.values()] - core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0] + core_input_channels = ( + list(input_channels.values())[0] + if isinstance(input_channels, dict) + else input_channels[0] + ) source_grids = None grid_mean_predictor_type = None @@ -140,7 +144,9 @@ def se2d_fullgaussian2d( # real data try: if v.dataset.neurons.animal_ids[0] != 0: - source_grids[k] = v.dataset.neurons.cell_motor_coordinates[:, :input_dim] + source_grids[k] = v.dataset.neurons.cell_motor_coordinates[ + :, :input_dim + ] # simulated data -> get random linear non-degenerate transform of true positions else: source_grid_true = v.dataset.neurons.center[:, :input_dim] @@ -152,18 +158,30 @@ def se2d_fullgaussian2d( det = np.linalg.det(matrix) loops += 1 assert det > 5.0, "Did not find a non-degenerate matrix" - source_grids[k] = np.add((matrix @ source_grid_true.T).T, grid_bias) + source_grids[k] = np.add( + (matrix @ source_grid_true.T).T, grid_bias + ) except FileNotFoundError: - print("Dataset type is not recognized to be from Baylor College of Medicine.") - source_grids[k] = v.dataset.neurons.cell_motor_coordinates[:, :input_dim] + print( + "Dataset type is not recognized to be from Baylor College of Medicine." + ) + source_grids[k] = v.dataset.neurons.cell_motor_coordinates[ + :, :input_dim + ] elif grid_mean_predictor_type == "shared": pass else: - raise ValueError("Grid mean predictor type {} not understood.".format(grid_mean_predictor_type)) + raise ValueError( + "Grid mean predictor type {} not understood.".format( + grid_mean_predictor_type + ) + ) shared_match_ids = None if share_features or share_grid: - shared_match_ids = {k: v.dataset.neurons.multi_match_id for k, v in dataloaders.items()} + shared_match_ids = { + k: v.dataset.neurons.multi_match_id for k, v in dataloaders.items() + } all_multi_unit_ids = set(np.hstack(shared_match_ids.values())) for match_id in shared_match_ids.values(): @@ -285,7 +303,11 @@ def se2d_pointpooled( in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} input_channels = [v[in_name][1] for v in session_shape_dict.values()] - core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0] + core_input_channels = ( + list(input_channels.values())[0] + if isinstance(input_channels, dict) + else input_channels[0] + ) set_random_seed(seed) @@ -384,7 +406,11 @@ def se2d_spatialxfeaturelinear( in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} input_channels = [v[in_name][1] for v in session_shape_dict.values()] - core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0] + core_input_channels = ( + list(input_channels.values())[0] + if isinstance(input_channels, dict) + else input_channels[0] + ) set_random_seed(seed) @@ -494,11 +520,17 @@ def se2d_fullSXF( in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} input_channels = [v[in_name][1] for v in session_shape_dict.values()] - core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0] + core_input_channels = ( + list(input_channels.values())[0] + if isinstance(input_channels, dict) + else input_channels[0] + ) shared_match_ids = None if share_features: - shared_match_ids = {k: v.dataset.neurons.multi_match_id for k, v in dataloaders.items()} + shared_match_ids = { + k: v.dataset.neurons.multi_match_id for k, v in dataloaders.items() + } all_multi_unit_ids = set(np.hstack(shared_match_ids.values())) for match_id in shared_match_ids.values(): @@ -632,7 +664,11 @@ def taskdriven_fullgaussian2d( in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} input_channels = [v[in_name][1] for v in session_shape_dict.values()] - core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0] + core_input_channels = ( + list(input_channels.values())[0] + if isinstance(input_channels, dict) + else input_channels[0] + ) source_grids = None grid_mean_predictor_type = None @@ -646,7 +682,9 @@ def taskdriven_fullgaussian2d( # real data try: if v.dataset.neurons.animal_ids[0] != 0: - source_grids[k] = v.dataset.neurons.cell_motor_coordinates[:, :input_dim] + source_grids[k] = v.dataset.neurons.cell_motor_coordinates[ + :, :input_dim + ] # simulated data -> get random linear non-degenerate transform of true positions else: source_grid_true = v.dataset.neurons.center[:, :input_dim] @@ -658,18 +696,30 @@ def taskdriven_fullgaussian2d( det = np.linalg.det(matrix) loops += 1 assert det > 5.0, "Did not find a non-degenerate matrix" - source_grids[k] = np.add((matrix @ source_grid_true.T).T, grid_bias) + source_grids[k] = np.add( + (matrix @ source_grid_true.T).T, grid_bias + ) except FileNotFoundError: - print("Dataset type is not recognized to be from Baylor College of Medicine.") - source_grids[k] = v.dataset.neurons.cell_motor_coordinates[:, :input_dim] + print( + "Dataset type is not recognized to be from Baylor College of Medicine." + ) + source_grids[k] = v.dataset.neurons.cell_motor_coordinates[ + :, :input_dim + ] elif grid_mean_predictor_type == "shared": pass else: - raise ValueError("Grid mean predictor type {} not understood.".format(grid_mean_predictor_type)) + raise ValueError( + "Grid mean predictor type {} not understood.".format( + grid_mean_predictor_type + ) + ) shared_match_ids = None if share_features or share_grid: - shared_match_ids = {k: v.dataset.neurons.multi_match_id for k, v in dataloaders.items()} + shared_match_ids = { + k: v.dataset.neurons.multi_match_id for k, v in dataloaders.items() + } all_multi_unit_ids = set(np.hstack(shared_match_ids.values())) for match_id in shared_match_ids.values(): @@ -767,11 +817,17 @@ def taskdriven_fullSXF( in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()} input_channels = [v[in_name][1] for v in session_shape_dict.values()] - core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0] + core_input_channels = ( + list(input_channels.values())[0] + if isinstance(input_channels, dict) + else input_channels[0] + ) shared_match_ids = None if share_features: - shared_match_ids = {k: v.dataset.neurons.multi_match_id for k, v in dataloaders.items()} + shared_match_ids = { + k: v.dataset.neurons.multi_match_id for k, v in dataloaders.items() + } all_multi_unit_ids = set(np.hstack(shared_match_ids.values())) for match_id in shared_match_ids.values():