diff --git a/mala/common/parameters.py b/mala/common/parameters.py index 1d2ba9d9..9dd1a0fa 100644 --- a/mala/common/parameters.py +++ b/mala/common/parameters.py @@ -1709,7 +1709,7 @@ def load_from_file( Parameters ---------- - file : string or ZipExtFile + file : string or ZipExtFile or dict File to which the parameters will be saved to. save_format : string @@ -1738,56 +1738,63 @@ def load_from_file( json_dict = json.load(open(file, encoding="utf-8")) else: json_dict = json.load(file) + loaded_parameters = cls._process_loaded_dict( + json_dict, no_snapshots, force_no_ddp + ) - loaded_parameters = cls() - for key in json_dict: - if ( - isinstance(json_dict[key], dict) - and key != "openpmd_configuration" - ): - # These are the other parameter classes. - sub_parameters = globals()[ - json_dict[key]["_parameters_type"] - ].from_json(json_dict[key]) - setattr(loaded_parameters, key, sub_parameters) - - # Backwards compatability: - if key == "descriptors": - if ( - "use_atomic_density_energy_formula" - in json_dict[key] - ): - loaded_parameters.use_atomic_density_formula = ( - json_dict[key][ - "use_atomic_density_energy_formula" - ] - ) - - # We iterate a second time, to set global values, so that they - # are properly forwarded. - for key in json_dict: - if ( - not isinstance(json_dict[key], dict) - or key == "openpmd_configuration" - ): - if key == "use_ddp" and force_no_ddp is True: - setattr(loaded_parameters, key, False) - else: - setattr(loaded_parameters, key, json_dict[key]) - if no_snapshots is True: - loaded_parameters.data.snapshot_directories_list = [] - # Backwards compatability: since the transfer of old property - # to new property happens _before_ all children descriptor classes - # are instantiated, it is not properly propagated. Thus, we - # simply have to set it to its own value again. - loaded_parameters.use_atomic_density_formula = ( - loaded_parameters.use_atomic_density_formula + elif save_format == "dict": + loaded_parameters = cls._process_loaded_dict( + file, no_snapshots, force_no_ddp ) + else: raise Exception("Unsupported parameter save format.") return loaded_parameters + @classmethod + def _process_loaded_dict(cls, json_dict, no_snapshots, force_no_ddp): + loaded_parameters = cls() + for key in json_dict: + if ( + isinstance(json_dict[key], dict) + and key != "openpmd_configuration" + ): + # These are the other parameter classes. + sub_parameters = globals()[ + json_dict[key]["_parameters_type"] + ].from_json(json_dict[key]) + setattr(loaded_parameters, key, sub_parameters) + + # Backwards compatability: + if key == "descriptors": + if "use_atomic_density_energy_formula" in json_dict[key]: + loaded_parameters.use_atomic_density_formula = ( + json_dict[key]["use_atomic_density_energy_formula"] + ) + + # We iterate a second time, to set global values, so that they + # are properly forwarded. + for key in json_dict: + if ( + not isinstance(json_dict[key], dict) + or key == "openpmd_configuration" + ): + if key == "use_ddp" and force_no_ddp is True: + setattr(loaded_parameters, key, False) + else: + setattr(loaded_parameters, key, json_dict[key]) + if no_snapshots is True: + loaded_parameters.data.snapshot_directories_list = [] + # Backwards compatability: since the transfer of old property + # to new property happens _before_ all children descriptor classes + # are instantiated, it is not properly propagated. Thus, we + # simply have to set it to its own value again. + loaded_parameters.use_atomic_density_formula = ( + loaded_parameters.use_atomic_density_formula + ) + return loaded_parameters + @classmethod def load_from_pickle(cls, file, no_snapshots=False): """ @@ -1838,3 +1845,32 @@ def load_from_json(cls, file, no_snapshots=False, force_no_ddp=False): no_snapshots=no_snapshots, force_no_ddp=force_no_ddp, ) + + @classmethod + def load_from_dict( + cls, param_dict, no_snapshots=False, force_no_ddp=False + ): + """ + Load a Parameters object from a file. + + Parameters + ---------- + param_dict : dictionary + Dictionary containing parameters to be loaded + + no_snapshots : bool + If True, than the snapshot list will be emptied. Useful when + performing inference/testing after training a network. + + Returns + ------- + loaded_parameters : Parameters + The loaded Parameters object. + + """ + return Parameters.load_from_file( + param_dict, + save_format="dict", + no_snapshots=no_snapshots, + force_no_ddp=force_no_ddp, + ) diff --git a/mala/common/physical_data.py b/mala/common/physical_data.py index c7dd08f4..d4bf4938 100644 --- a/mala/common/physical_data.py +++ b/mala/common/physical_data.py @@ -5,7 +5,7 @@ import json import numpy as np -from mala.common.parallelizer import get_comm, get_rank +from mala.common.parallelizer import get_comm, get_rank, printout from mala.version import __version__ as mala_version @@ -86,7 +86,7 @@ def si_unit_conversion(self): ############################## def read_from_numpy_file( - self, path, units=None, array=None, reshape=False + self, path, units=None, array=None, reshape=False, selection_mask=None ): """ Read the data from a numpy file. @@ -103,6 +103,11 @@ def read_from_numpy_file( If not None, the array to save the data into. The array has to be 4-dimensional. + selection_mask : None or [boolean] + If None, entire snapshot is loaded, else it is used as a + mask to select which examples are loaded + + reshape : bool If True, the loaded 4D array will be reshaped into a 2D array. @@ -117,17 +122,42 @@ def read_from_numpy_file( if array is None: loaded_array = np.load(path)[:, :, :, self._feature_mask() :] self._process_loaded_array(loaded_array, units=units) - return loaded_array + + # Select portion of array if mask provided + if selection_mask is not None: + original_dims = loaded_array.shape + + # Pseudo-flatten to apply mask without causing dimensionality mismatch later on + loaded_array = loaded_array.reshape( + (-1, 1, 1, original_dims[-1]) + )[selection_mask] + return loaded_array + else: + return loaded_array else: if reshape: array_dims = np.shape(array) - array[:, :] = np.load(path)[ - :, :, :, self._feature_mask() : - ].reshape(array_dims) + if selection_mask is not None: + array[:, :] = np.load(path)[ + :, :, :, self._feature_mask() : + ].reshape((len(selection_mask), -1))[selection_mask] + else: + array[:, :] = np.load(path)[ + :, :, :, self._feature_mask() : + ].reshape(array_dims) else: + array_dims = np.shape(array) array[:, :, :, :] = np.load(path)[ :, :, :, self._feature_mask() : ] + + # Select portion of array if mask provided + if selection_mask is not None: + # Pseudo-flatten to apply mask without causing + # dimensionality mismatch later on + array = array.reshape((-1, 1, 1, array_dims[-1]))[ + selection_mask + ] self._process_loaded_array(array, units=units) def read_from_openpmd_file(self, path, units=None, array=None): @@ -272,7 +302,9 @@ def read_from_openpmd_file(self, path, units=None, array=None): else: self._process_loaded_array(array, units=units) - def read_dimensions_from_numpy_file(self, path, read_dtype=False): + def read_dimensions_from_numpy_file( + self, path, read_dtype=False, selection_mask=None + ): """ Read only the dimensions from a numpy file. @@ -293,6 +325,11 @@ def read_dimensions_from_numpy_file(self, path, read_dtype=False): be returned. """ loaded_array = np.load(path, mmap_mode="r") + if selection_mask is not None: + original_dims = loaded_array.shape + loaded_array = loaded_array.reshape((-1, 1, 1, original_dims[-1]))[ + selection_mask + ] if read_dtype: return ( self._process_loaded_dimensions(np.shape(loaded_array)), diff --git a/mala/datahandling/data_handler.py b/mala/datahandling/data_handler.py index 3b9521e4..b903d157 100644 --- a/mala/datahandling/data_handler.py +++ b/mala/datahandling/data_handler.py @@ -172,7 +172,7 @@ def clear_data(self): # Preparing data ###################### - def prepare_data(self, reparametrize_scaler=True): + def prepare_data(self, reparametrize_scaler=True, from_arrays_dict=None): """ Prepare the data to be used in a training process. @@ -188,6 +188,15 @@ def prepare_data(self, reparametrize_scaler=True): If True (default), the DataScalers are parametrized based on the training data. + from_arrays_dict : dict or None + (Allows user to provide data directly from memory) + Dictionary which assigns an array (values) to each snapshot, e.g., + {(0,'inputs') : fp_array, (0, 'outputs') : ldos_array, ...} where 0 + is the index of the snapshot (absolute, not relative to data + partition) and inputs/outputs indicates the nature of the array. + None value indicates the data should be pulled from disk according + to the snapshot objects. + """ # During data loading, there is no need to save target data to # calculators. @@ -203,7 +212,7 @@ def prepare_data(self, reparametrize_scaler=True): "Checking the snapshots and your inputs for consistency.", min_verbosity=1, ) - self._check_snapshots() + self._check_snapshots(from_arrays_dict=from_arrays_dict) printout("Consistency check successful.", min_verbosity=0) # If the DataHandler is used for inference, i.e. no training or @@ -225,7 +234,7 @@ def prepare_data(self, reparametrize_scaler=True): # Parametrize the scalers, if needed. if reparametrize_scaler: printout("Initializing the data scalers.", min_verbosity=1) - self.__parametrize_scalers() + self.__parametrize_scalers(from_arrays_dict=from_arrays_dict) printout("Data scalers initialized.", min_verbosity=0) elif ( self.parameters.use_lazy_loading is False @@ -235,12 +244,16 @@ def prepare_data(self, reparametrize_scaler=True): "Data scalers already initilized, loading data to RAM.", min_verbosity=0, ) - self.__load_data("training", "inputs") - self.__load_data("training", "outputs") + self.__load_data( + "training", "inputs", from_arrays_dict=from_arrays_dict + ) + self.__load_data( + "training", "outputs", from_arrays_dict=from_arrays_dict + ) # Build Datasets. printout("Build datasets.", min_verbosity=1) - self.__build_datasets() + self.__build_datasets(from_arrays_dict=from_arrays_dict) printout("Build dataset: Done.", min_verbosity=0) # After the loading is done, target data can safely be saved again. @@ -253,6 +266,150 @@ def prepare_data(self, reparametrize_scaler=True): # allows for parallel I/O. barrier() + def refresh_data( + self, from_arrays_dict=None, partitions=["tr", "va", "te"] + ): + """ + Replace tr, va, te data for next generation of active learning. + + Internally replicates prepare_data function. + + Parameters + ---------- + from_arrays_dict : dict or None + (Allows user to provide data directly from memory) + Dictionary which assigns an array (values) to each snapshot, e.g., + {(0,'inputs') : fp_array, (0, 'outputs') : ldos_array, ...} where 0 + is the index of the snapshot (absolute, not relative to data + partition) and inputs/outputs indicates the nature of the array. + None value indicates the data should be pulled from disk according + to the snapshot objects. + + partitions: list + Specifies the partitions for which to reload data + """ + # During data loading, there is no need to save target data to + # calculators. + # Technically, this would be no issue, but due to technical reasons + # (i.e. float64 to float32 conversion) saving the data this way + # may create copies in memory. + self.target_calculator.save_target_data = False + + printout( + "Checking the snapshots and your inputs for consistency.", + min_verbosity=1, + ) + self._check_snapshots(from_arrays_dict=from_arrays_dict) + printout("Consistency check successful.", min_verbosity=0) + + # Reallocate arrays for data storage + if self.parameters.data_splitting_type == "by_snapshot": + ( + self.nr_training_snapshots, + self.nr_training_data, + self.nr_test_snapshots, + self.nr_test_data, + self.nr_validation_snapshots, + self.nr_validation_data, + ) = (0, 0, 0, 0, 0, 0) + # pprint(vars(self)) + # pprint(vars(self.parameters)) + snapshot: Snapshot + # As we are not actually interested in the number of snapshots, + # but in the number of datasets, we also need to multiply by that. + + for i, snapshot in enumerate( + self.parameters.snapshot_directories_list + ): + # if snapshot._selection_mask: + # snapshot.grid_size = sum(snapshot._selection_mask) + printout( + f"Snapshot {i}: {snapshot.grid_size}", min_verbosity=3 + ) + if snapshot.snapshot_function == "tr": + self.nr_training_snapshots += 1 + self.nr_training_data += snapshot.grid_size + elif snapshot.snapshot_function == "te": + self.nr_test_snapshots += 1 + self.nr_test_data += snapshot.grid_size + elif snapshot.snapshot_function == "va": + self.nr_validation_snapshots += 1 + self.nr_validation_data += snapshot.grid_size + else: + raise Exception( + "Unknown option for snapshot splitting " "selected." + ) + + # Now we need to check whether or not this input is believable. + nr_of_snapshots = len(self.parameters.snapshot_directories_list) + if nr_of_snapshots != ( + self.nr_training_snapshots + + self.nr_test_snapshots + + self.nr_validation_snapshots + ): + raise Exception( + "Cannot split snapshots with specified " + "splitting scheme, " + "too few or too many options selected: " + f"[{nr_of_snapshots} != {self.nr_training_snapshots} + {self.nr_test_snapshots} + {self.nr_validation_snapshots}]" + ) + + # MALA can either be run in training or test-only mode. + # But it has to be run in either of those! + # So either training AND validation snapshots can be provided + # OR only test snapshots. + if self.nr_test_snapshots != 0: + if self.nr_training_snapshots == 0: + printout( + "DataHandler prepared for inference. No training " + "possible with this setup. If this is not what " + "you wanted, please revise the input script. " + "Validation snapshots you may have entered will" + "be ignored.", + min_verbosity=0, + ) + else: + if self.nr_training_snapshots == 0: + raise Exception("No training snapshots provided.") + if self.nr_validation_snapshots == 0: + raise Exception("No validation snapshots provided.") + else: + raise Exception("Wrong parameter for data splitting provided.") + + self.__allocate_arrays() + + ### Load updated data + expand_partition_name = { + "tr": "training", + "va": "validation", + "te": "test", + } + for partition in partitions: + self.__load_data( + expand_partition_name[partition], + "inputs", + from_arrays_dict=from_arrays_dict, + ) + self.__load_data( + expand_partition_name[partition], + "outputs", + from_arrays_dict=from_arrays_dict, + ) + + # After the loading is done, target data can safely be saved again. + self.target_calculator.save_target_data = True + + printout("Build datasets.", min_verbosity=1) + self.__build_datasets(from_arrays_dict=from_arrays_dict) + printout("Build dataset: Done.", min_verbosity=0) + + # Wait until all ranks are finished with data preparation. + # It is not uncommon that ranks might be asynchronous in their + # data preparation by a small amount of minutes. If you notice + # an elongated wait time at this barrier, check that your file system + # allows for parallel I/O. + barrier() + def prepare_for_testing(self): """ Prepare DataHandler for usage within Tester class. @@ -448,9 +605,9 @@ def resize_snapshots_for_debugging( # Loading data ###################### - def _check_snapshots(self): + def _check_snapshots(self, from_arrays_dict=None): """Check the snapshots for consistency.""" - super(DataHandler, self)._check_snapshots() + super(DataHandler, self)._check_snapshots(from_arrays_dict) # Now we need to confirm that the snapshot list has some inner # consistency. @@ -483,7 +640,8 @@ def _check_snapshots(self): raise Exception( "Cannot split snapshots with specified " "splitting scheme, " - "too few or too many options selected" + "too few or too many options selected: " + f"[{nr_of_snapshots} != {self.nr_training_snapshots} + {self.nr_test_snapshots} + {self.nr_validation_snapshots}]" ) # MALA can either be run in training or test-only mode. # But it has to be run in either of those! @@ -547,7 +705,7 @@ def __allocate_arrays(self): dtype=DEFAULT_NP_DATA_DTYPE, ) - def __load_data(self, function, data_type): + def __load_data(self, function, data_type, from_arrays_dict=None): """ Load data into the appropriate arrays. @@ -584,8 +742,9 @@ def __load_data(self, function, data_type): snapshot_counter = 0 gs_old = 0 - for snapshot in self.parameters.snapshot_directories_list: - # get the snapshot grid size + for i, snapshot in enumerate( + self.parameters.snapshot_directories_list + ): # get the snapshot grid size gs_new = snapshot.grid_size # Data scaling is only performed on the training data sets. @@ -602,7 +761,56 @@ def __load_data(self, function, data_type): ) units = snapshot.output_units - if snapshot.snapshot_type == "numpy": + # Pull from existing array rather than file + if from_arrays_dict is not None: + if snapshot._selection_mask is not None: + gs_new = sum(snapshot._selection_mask) + # TODO streamline this + if snapshot._selection_mask is not None: + # Update data already in tensor form + if torch.is_tensor(getattr(self, array)): + getattr(self, array)[ + gs_old : gs_old + gs_new, : + ] = torch.from_numpy( + from_arrays_dict[(i, data_type)][ + :, calculator._feature_mask() : + ][snapshot._selection_mask] + ) + + # Update a fresh numpy array + else: + getattr(self, array)[ + gs_old : gs_old + gs_new, : + ] = from_arrays_dict[(i, data_type)][ + :, calculator._feature_mask() : + ][ + snapshot._selection_mask + ] + else: + # Update data already in tensor form + if torch.is_tensor(getattr(self, array)): + getattr(self, array)[ + gs_old : gs_old + gs_new, : + ] = torch.from_numpy( + from_arrays_dict[(i, data_type)][ + :, calculator._feature_mask() : + ] + ) + # Update a fresh numpy array + else: + getattr(self, array)[ + gs_old : gs_old + gs_new, : + ] = from_arrays_dict[(i, data_type)][ + :, calculator._feature_mask() : + ] + + calculator._process_loaded_array( + getattr(self, array)[gs_old : gs_old + gs_new, :], + units=units, + ) + + # Pull directly from file + elif snapshot.snapshot_type == "numpy": calculator.read_from_numpy_file( file, units=units, @@ -610,8 +818,13 @@ def __load_data(self, function, data_type): gs_old : gs_old + gs_new, : ], reshape=True, + selection_mask=snapshot._selection_mask, ) elif snapshot.snapshot_type == "openpmd": + if snapshot._selection_mask is not None: + raise NotImplementedError( + "Selection mask is not implemented for openpmd" + ) getattr(self, array)[gs_old : gs_old + gs_new] = ( calculator.read_from_openpmd_file( file, units=units @@ -662,7 +875,7 @@ def __load_data(self, function, data_type): self._test_data_outputs ).float() - def __build_datasets(self): + def __build_datasets(self, from_arrays_dict=None): """Build the DataSets that are used during training.""" if ( self.parameters.use_lazy_loading @@ -803,10 +1016,14 @@ def __build_datasets(self): ) if self.nr_validation_data != 0: - self.__load_data("validation", "inputs") + self.__load_data( + "validation", "inputs", from_arrays_dict=from_arrays_dict + ) self.input_data_scaler.transform(self._validation_data_inputs) - self.__load_data("validation", "outputs") + self.__load_data( + "validation", "outputs", from_arrays_dict=from_arrays_dict + ) self.output_data_scaler.transform( self._validation_data_outputs ) @@ -843,7 +1060,7 @@ def __build_datasets(self): # Scaling ###################### - def __parametrize_scalers(self): + def __parametrize_scalers(self, from_arrays_dict=None): """Use the training data to parametrize the DataScalers.""" ################## # Inputs. @@ -868,6 +1085,7 @@ def __parametrize_scalers(self): snapshot.input_npy_file, ), units=snapshot.input_units, + selection_mask=snapshot._selection_mask, ) elif snapshot.snapshot_type == "openpmd": tmp = ( @@ -896,7 +1114,9 @@ def __parametrize_scalers(self): self.input_data_scaler.partial_fit(tmp) else: - self.__load_data("training", "inputs") + self.__load_data( + "training", "inputs", from_arrays_dict=from_arrays_dict + ) self.input_data_scaler.fit(self._training_data_inputs) printout("Input scaler parametrized.", min_verbosity=1) @@ -917,6 +1137,10 @@ def __parametrize_scalers(self): # We need to perform the data scaling over the entirety of the # training data. for snapshot in self.parameters.snapshot_directories_list: + if snapshot._selection_mask is not None: + raise NotImplementedError( + "Example selection hasn't been implemented for lazy loading yet." + ) # Data scaling is only performed on the training data sets. if snapshot.snapshot_function == "tr": if snapshot.snapshot_type == "numpy": @@ -953,7 +1177,9 @@ def __parametrize_scalers(self): i += 1 else: - self.__load_data("training", "outputs") + self.__load_data( + "training", "outputs", from_arrays_dict=from_arrays_dict + ) self.output_data_scaler.fit(self._training_data_outputs) printout("Output scaler parametrized.", min_verbosity=1) diff --git a/mala/datahandling/data_handler_base.py b/mala/datahandling/data_handler_base.py index c141551f..ee074533 100644 --- a/mala/datahandling/data_handler_base.py +++ b/mala/datahandling/data_handler_base.py @@ -106,6 +106,7 @@ def add_snapshot( input_units="None", calculation_output_file="", snapshot_type="numpy", + selection_mask=None, ): """ Add a snapshot to the data pipeline. @@ -144,7 +145,16 @@ def add_snapshot( snapshot_type : string Either "numpy" or "openpmd" based on what kind of files you want to operate on. + + selection_mask : None or [boolean] + If None, entire snapshot is loaded, if [boolean], it is used as a + mask to select which examples are loaded """ + if selection_mask is not None and self.parameters.use_lazy_loading: + raise NotImplementedError( + "Example selection hasn't been " + "implemented for lazy loading yet." + ) snapshot = Snapshot( input_file, input_directory, @@ -173,13 +183,15 @@ def clear_data(self): # Loading data ###################### - def _check_snapshots(self, comm=None): + def _check_snapshots(self, from_arrays_dict=None, comm=None): """Check the snapshots for consistency.""" self.nr_snapshots = len(self.parameters.snapshot_directories_list) # Read the snapshots using a memorymap to see if there is consistency. firstsnapshot = True - for snapshot in self.parameters.snapshot_directories_list: + for i, snapshot in enumerate( + self.parameters.snapshot_directories_list + ): #################### # Descriptors. #################### @@ -191,7 +203,28 @@ def _check_snapshots(self, comm=None): snapshot.input_npy_directory, min_verbosity=1, ) - if snapshot.snapshot_type == "numpy": + if from_arrays_dict is not None: + printout( + f'arrdim: {from_arrays_dict[(i, "inputs")].shape}', + min_verbosity=2, + ) + printout( + f"featmask: {self.descriptor_calculator._feature_mask()}", + min_verbosity=2, + ) + tmp_dimension = from_arrays_dict[(i, "inputs")][ + :, self.descriptor_calculator._feature_mask() : + ].shape + # We don't need any reference to full grid dim at this point + # so this is just for compatibility w other code + if len(tmp_dimension) > 2: + raise ValueError("Flatten the data pool arrays.") + tmp_dimension = (tmp_dimension[0], 1, 1, tmp_dimension[-1]) + printout( + f"from_arrays_dict dim {i}: {from_arrays_dict[(i, 'inputs')].shape}", + min_verbosity=2, + ) + elif snapshot.snapshot_type == "numpy": tmp_dimension = ( self.descriptor_calculator.read_dimensions_from_numpy_file( os.path.join( @@ -214,6 +247,11 @@ def _check_snapshots(self, comm=None): # for flexible grid sizes only this need be consistent tmp_input_dimension = tmp_dimension[-1] tmp_grid_dim = tmp_dimension[0:3] + + # If using selection_mask, apply to dimensions + if snapshot._selection_mask is not None: + tmp_grid_dim = (sum(snapshot._selection_mask), 1, 1) + snapshot.grid_dimension = tmp_grid_dim snapshot.grid_size = int(np.prod(snapshot.grid_dimension)) if firstsnapshot: @@ -235,7 +273,16 @@ def _check_snapshots(self, comm=None): snapshot.output_npy_directory, min_verbosity=1, ) - if snapshot.snapshot_type == "numpy": + if from_arrays_dict is not None: + tmp_dimension = from_arrays_dict[(i, "outputs")][ + :, self.target_calculator._feature_mask() : + ].shape + # We don't need any reference to full grid dim at this point + # so this is just for compatibility w other code + if len(tmp_dimension) > 2: + raise ValueError("Flatten the data pool arrays.") + tmp_dimension = (tmp_dimension[0], 1, 1, tmp_dimension[-1]) + elif snapshot.snapshot_type == "numpy": tmp_dimension = ( self.target_calculator.read_dimensions_from_numpy_file( os.path.join( diff --git a/mala/datahandling/snapshot.py b/mala/datahandling/snapshot.py index 1bac8488..6b6a708f 100644 --- a/mala/datahandling/snapshot.py +++ b/mala/datahandling/snapshot.py @@ -1,5 +1,7 @@ """Represents an entire atomic snapshot (including descriptor/target data).""" +import numpy as np + from mala.common.json_serializable import JSONSerializable @@ -43,6 +45,10 @@ class Snapshot(JSONSerializable): - tr: This snapshot will be a training snapshot. - va: This snapshot will be a validation snapshot. + selection_mask : None or [boolean] + If None, entire snapshot is loaded, if [boolean], it is used as a + mask to select which examples are loaded + Attributes ---------- grid_dimensions : list @@ -104,6 +110,7 @@ def __init__( output_units="", calculation_output="", snapshot_type="openpmd", + selection_mask=None, ): super(Snapshot, self).__init__() @@ -133,6 +140,22 @@ def __init__( self.input_dimension = None self.output_dimension = None + # Mask determining which examples from the snapshot to use + if isinstance(selection_mask, np.ndarray): + self._selection_mask = selection_mask.tolist() + else: + self._selection_mask = selection_mask + + def set_selection_mask(self, selection_mask): + """Set the selection mask for snapshot loading.""" + if isinstance(selection_mask, np.ndarray): + self._selection_mask = selection_mask.tolist() + else: + self._selection_mask = selection_mask + if selection_mask is not None: + self.grid_size = sum(self._selection_mask) + # TODO also adjust other dimensinot params + @classmethod def from_json(cls, json_dict): """ @@ -150,14 +173,27 @@ def from_json(cls, json_dict): The object as read from the JSON file. """ - deserialized_object = cls( - json_dict["input_npy_file"], - json_dict["input_npy_directory"], - json_dict["output_npy_file"], - json_dict["output_npy_directory"], - json_dict["snapshot_function"], - json_dict["snapshot_type"], - ) + # Temporary try,except for compatibility with + # pre-selection_mask parameter dicts TODO-remove + try: + deserialized_object = cls( + json_dict["input_npy_file"], + json_dict["input_npy_directory"], + json_dict["output_npy_file"], + json_dict["output_npy_directory"], + json_dict["snapshot_function"], + json_dict["snapshot_type"], + json_dict["selection_mask"], + ) + except: + deserialized_object = cls( + json_dict["input_npy_file"], + json_dict["input_npy_directory"], + json_dict["output_npy_file"], + json_dict["output_npy_directory"], + json_dict["snapshot_function"], + json_dict["snapshot_type"], + ) for key in json_dict: setattr(deserialized_object, key, json_dict[key]) return deserialized_object