diff --git a/GANDLF/compute/inference_loop.py b/GANDLF/compute/inference_loop.py index c09b44cf7..2c9c0230a 100644 --- a/GANDLF/compute/inference_loop.py +++ b/GANDLF/compute/inference_loop.py @@ -89,7 +89,16 @@ def inference_loop( assert file_to_load != None, "The 'best_file' was not found" main_dict = torch.load(file_to_load, map_location=parameters["device"]) - model.load_state_dict(main_dict["model_state_dict"]) + state_dict = main_dict["model_state_dict"] + if parameters.get("differential_privacy"): + # this is required for torch==1.11 and for DP inference + new_state_dict = {} + for key, val in state_dict.items(): + new_key = key.replace("_module.", "") + new_state_dict[new_key] = val # remove `module.` + state_dict = new_state_dict + + model.load_state_dict(state_dict) parameters["previous_parameters"] = main_dict.get("parameters", None) model.eval() elif parameters["model"]["type"].lower() == "openvino": diff --git a/GANDLF/compute/training_loop.py b/GANDLF/compute/training_loop.py index 32b52f188..9c56a232d 100644 --- a/GANDLF/compute/training_loop.py +++ b/GANDLF/compute/training_loop.py @@ -31,6 +31,10 @@ from .forward_pass import validate_network from .generic import create_pytorch_objects +from GANDLF.privacy.opacus.model_handling import empty_collate +from GANDLF.privacy.opacus import handle_dynamic_batch_size, prep_for_opacus_training +from opacus.utils.batch_memory_manager import wrap_data_loader + # hides torchio citation request, see https://github.com/fepegar/torchio/issues/235 os.environ["TORCHIO_HIDE_CITATION_PROMPT"] = "1" @@ -91,6 +95,14 @@ def train_network( for batch_idx, (subject) in enumerate( tqdm(train_dataloader, desc="Looping over training data") ): + if params.get("differential_privacy"): + subject, params["batch_size"] = handle_dynamic_batch_size( + subject=subject, params=params + ) + assert not isinstance( + model, torch.nn.DataParallel + ), "Differential privacy is not supported with DataParallel or DistributedDataParallel. Please use a single GPU or DDP with Opacus." + optimizer.zero_grad() image = ( # 5D tensor: (B, C, H, W, D) torch.cat( @@ -212,6 +224,23 @@ def train_network( return average_epoch_train_loss, average_epoch_train_metric +def train_network_wrapper(model, train_dataloader, optimizer, params): + """ + Wrapper Function to handle train_dataloader for benign and DP cases and pass on to train a network for a single epoch + """ + + if params.get("differential_privacy"): + with train_dataloader as memory_safe_data_loader: + epoch_train_loss, epoch_train_metric = train_network( + model, memory_safe_data_loader, optimizer, params + ) + else: + epoch_train_loss, epoch_train_metric = train_network( + model, train_dataloader, optimizer, params + ) + return epoch_train_loss, epoch_train_metric + + def training_loop( training_data: pd.DataFrame, validation_data: pd.DataFrame, @@ -368,6 +397,7 @@ def training_loop( logger_csv_filename=os.path.join(output_dir, "logs_validation.csv"), metrics=metrics_log, mode="valid", + add_epsilon=bool(params.get("differential_privacy")), ) if testingDataDefined: test_logger = Logger( @@ -392,6 +422,36 @@ def training_loop( print("Using device:", device, flush=True) + if params.get("differential_privacy"): + print( + "Using Opacus to make training differentially private with respect to the training data." + ) + + model, optimizer, train_dataloader, privacy_engine = prep_for_opacus_training( + model=model, + optimizer=optimizer, + train_dataloader=train_dataloader, + params=params, + ) + + train_dataloader.collate_fn = empty_collate(train_dataloader.dataset[0]) + + # train_dataloader = BatchMemoryManager( + # data_loader=train_dataloader, + # max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE, + # optimizer=optimizer, + # ) + batch_size = params["batch_size"] + max_physical_batch_size = params["differential_privacy"].get( + "physical_batch_size" + ) + if max_physical_batch_size and max_physical_batch_size != batch_size: + train_dataloader = wrap_data_loader( + data_loader=train_dataloader, + max_batch_size=max_physical_batch_size, + optimizer=optimizer, + ) + # Iterate for number of epochs for epoch in range(start_epoch, epochs): if params["track_memory_usage"]: @@ -453,6 +513,14 @@ def training_loop( patience += 1 + # if training with differential privacy, print privacy epsilon + if params.get("differential_privacy"): + delta = params["differential_privacy"]["delta"] + this_epsilon = privacy_engine.get_epsilon(delta) + print(f" Epoch Final Privacy: (ε = {this_epsilon:.2f}, δ = {delta})") + # save for logging + epoch_valid_metric["epsilon"] = this_epsilon + # Write the losses to a logger train_logger.write(epoch, epoch_train_loss, epoch_train_metric) valid_logger.write(epoch, epoch_valid_loss, epoch_valid_metric) diff --git a/GANDLF/config_manager.py b/GANDLF/config_manager.py index 99497fbb1..ed26ec8e1 100644 --- a/GANDLF/config_manager.py +++ b/GANDLF/config_manager.py @@ -7,6 +7,7 @@ from .utils import version_check from GANDLF.data.post_process import postprocessing_after_reverse_one_hot_encoding +from GANDLF.privacy.opacus import parse_opacus_params from GANDLF.metrics import surface_distance_ids from importlib.metadata import version @@ -710,6 +711,10 @@ def _parseConfig( temp_dict["type"] = params["optimizer"] params["optimizer"] = temp_dict + # initialize defaults for DP + if params.get("differential_privacy"): + params = parse_opacus_params(params, initialize_key) + # initialize defaults for inference mechanism inference_mechanism = {"grid_aggregator_overlap": "crop", "patch_overlap": 0} initialize_inference_mechanism = False diff --git a/GANDLF/logger.py b/GANDLF/logger.py index 2562eb17d..4f1d76e03 100755 --- a/GANDLF/logger.py +++ b/GANDLF/logger.py @@ -12,14 +12,21 @@ class Logger: - def __init__(self, logger_csv_filename: str, metrics: List[str], mode: str) -> None: + def __init__( + self, + logger_csv_filename: str, + metrics: List[str], + mode: str, + add_epsilon: bool = False, + ) -> None: """ - Logger class to log the training and validation metrics to a csv file. - May append to existing file if headers match; elsewise raises an error. + Logger class to log the training and validation metrics to a csv file. May append to existing file if headers match; elsewise raises an error. Args: logger_csv_filename (str): Path to a filename where the csv has to be stored. metrics (Dict[str, float]): The metrics to be logged. + mode (str): The mode of the logger, used as suffix to metric names. Normally may be `train` / `val` / `test` + add_epsilon (bool): Whether to log epsilon values or not (differential privacy measurement) """ self.filename = logger_csv_filename mode = mode.lower() @@ -28,6 +35,8 @@ def __init__(self, logger_csv_filename: str, metrics: List[str], mode: str) -> N new_header = ["epoch_no", f"{mode}_loss"] + [ f"{mode}_{metric}" for metric in metrics ] + if add_epsilon: + new_header.append(f"{self.mode}_epsilon") # TODO: do we really need to support appending to existing files? if os.path.exists(self.filename): diff --git a/GANDLF/models/imagenet_unet.py b/GANDLF/models/imagenet_unet.py index 940987e1f..f1a203d4a 100644 --- a/GANDLF/models/imagenet_unet.py +++ b/GANDLF/models/imagenet_unet.py @@ -252,6 +252,10 @@ def __init__(self, parameters) -> None: aux_params=classifier_head_parameters, ) + # all BatchNorm should be replaced with InstanceNorm for DP experiments + if "differential_privacy" in parameters: + self.replace_batchnorm(self.model) + if self.n_dimensions == 3: self.model = self.converter(self.model).model diff --git a/GANDLF/privacy/__init__.py b/GANDLF/privacy/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/GANDLF/privacy/opacus/__init__.py b/GANDLF/privacy/opacus/__init__.py new file mode 100644 index 000000000..3a7aa2672 --- /dev/null +++ b/GANDLF/privacy/opacus/__init__.py @@ -0,0 +1,3 @@ +from .config_parsing import parse_opacus_params +from .model_handling import opacus_model_fix, prep_for_opacus_training +from .training_utils import handle_dynamic_batch_size diff --git a/GANDLF/privacy/opacus/config_parsing.py b/GANDLF/privacy/opacus/config_parsing.py new file mode 100644 index 000000000..9ae2fd12b --- /dev/null +++ b/GANDLF/privacy/opacus/config_parsing.py @@ -0,0 +1,59 @@ +from typing import Callable + + +def parse_opacus_params(params: dict, initialize_key: Callable) -> dict: + """ + Function to set defaults and augment the parameters related to making a trained model differentially + private with respect to the training data. + + Args: + params (dict): Training parameters. + initialize_key (Callable): Function to fill in value for a missing key. + + Returns: + dict: Updated training parameters. + """ + + if not isinstance(params["differential_privacy"], dict): + print( + "WARNING: Non dictionary value for the key: 'differential_privacy' was used, replacing with default valued dictionary." + ) + params["differential_privacy"] = {} + # these are some defaults + params["differential_privacy"] = initialize_key( + params["differential_privacy"], "noise_multiplier", 10.0 + ) + params["differential_privacy"] = initialize_key( + params["differential_privacy"], "max_grad_norm", 1.0 + ) + params["differential_privacy"] = initialize_key( + params["differential_privacy"], "accountant", "rdp" + ) + params["differential_privacy"] = initialize_key( + params["differential_privacy"], "secure_mode", False + ) + params["differential_privacy"] = initialize_key( + params["differential_privacy"], "allow_opacus_model_fix", True + ) + params["differential_privacy"] = initialize_key( + params["differential_privacy"], "delta", 1e-5 + ) + params["differential_privacy"] = initialize_key( + params["differential_privacy"], "physical_batch_size", params["batch_size"] + ) + + if params["differential_privacy"]["physical_batch_size"] > params["batch_size"]: + print( + f"WARNING: The physical batch size {params['differential_privacy']['physical_batch_size']} is greater" + f"than the batch size {params['batch_size']}, setting the physical batch size to the batch size." + ) + params["differential_privacy"]["physical_batch_size"] = params["batch_size"] + + # these keys need to be parsed as floats, not strings + for key in ["noise_multiplier", "max_grad_norm", "delta", "epsilon"]: + if key in params["differential_privacy"]: + params["differential_privacy"][key] = float( + params["differential_privacy"][key] + ) + + return params diff --git a/GANDLF/privacy/opacus/model_handling.py b/GANDLF/privacy/opacus/model_handling.py new file mode 100644 index 000000000..92d9b55bc --- /dev/null +++ b/GANDLF/privacy/opacus/model_handling.py @@ -0,0 +1,149 @@ +import collections.abc as abc +from functools import partial +from torch.utils.data._utils.collate import default_collate +from torch.utils.data import DataLoader +from typing import Union, Callable, Tuple +import copy + +import numpy as np +import torch +from opacus import PrivacyEngine +from opacus.validators import ModuleValidator + + +def opacus_model_fix(model: torch.nn.Module, params: dict) -> torch.nn.Module: + """ + Function to detect components of the model that are not compatible with Opacus differentially private training, and replacing with compatible components + or raising an exception when a fix cannot be handled by Opacus. + + Args: + model (torch.nn.Module): The model to be trained. + params (dict): Training parameters. + + Returns: + torch.nn.Module: Model, with potentially some components replaced with ones compatible with Opacus. + """ + # use opacus to detect issues with model + opacus_errors_detected = ModuleValidator.validate(model, strict=False) + + if not params["differential_privacy"]["allow_opacus_model_fix"]: + assert ( + opacus_errors_detected == [] + ), f"Training parameters are set to not allow Opacus to try to fix incompatible model components, and the following issues were detected: {opacus_errors_detected}" + elif opacus_errors_detected != []: + print( + f"Allowing Opacus to try and patch the model due to the following issues: {opacus_errors_detected}" + ) + print() + model = ModuleValidator.fix(model) + # If the fix did not work, raise an exception + ModuleValidator.validate(model, strict=True) + return model + + +def prep_for_opacus_training( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + train_dataloader: DataLoader, + params: dict, +) -> Tuple[torch.nn.Module, torch.optim.Optimizer, DataLoader, PrivacyEngine]: + """ + Function to prepare the model, optimizer, and dataloader for differentially private training using Opacus. + + Args: + model (torch.nn.Module): The model to be trained. + optimizer (torch.optim.Optimizer): The optimizer to be used for training. + train_dataloader (DataLoader): The dataloader for the training data. + params (dict): Training parameters. + + Returns: + Tuple[torch.nn.Module, torch.optim.Optimizer, DataLoader, PrivacyEngine]: Model, optimizer, dataloader, and privacy engine. + """ + + privacy_engine = PrivacyEngine( + accountant=params["differential_privacy"]["accountant"], + secure_mode=params["differential_privacy"]["secure_mode"], + ) + + if not "epsilon" in params["differential_privacy"]: + model, optimizer, train_dataloader = privacy_engine.make_private( + module=model, + optimizer=optimizer, + data_loader=train_dataloader, + noise_multiplier=params["differential_privacy"]["noise_multiplier"], + max_grad_norm=params["differential_privacy"]["max_grad_norm"], + ) + else: + (model, optimizer, train_dataloader) = privacy_engine.make_private_with_epsilon( + module=model, + optimizer=optimizer, + data_loader=train_dataloader, + max_grad_norm=params["differential_privacy"]["max_grad_norm"], + epochs=params["num_epochs"], + target_epsilon=params["differential_privacy"]["epsilon"], + target_delta=params["differential_privacy"]["delta"], + ) + return model, optimizer, train_dataloader, privacy_engine + + +def build_empty_batch_value( + sample: Union[torch.Tensor, np.ndarray, abc.Mapping, abc.Sequence, int, float, str] +): + """ + Build an empty batch value from a sample. This function is used to create a placeholder for empty batches in an iteration. Inspired from https://github.com/pytorch/pytorch/blob/main/torch/utils/data/_utils/collate.py#L108. The key difference is that pytorch `collate` has to traverse batch of objects AND unite its fields to lists, while this function traverse a single item AND creates an "empty" version of the batch. + + Args: + sample (Union[torch.Tensor, np.ndarray, abc.Mapping, abc.Sequence, int, float, str]): A sample from the dataset. + + Raises: + TypeError: If the data type is not supported. + + Returns: + Union[torch.Tensor, np.ndarray, abc.Mapping, abc.Sequence, int, float, str]: An empty batch value. + """ + if isinstance(sample, torch.Tensor): + # Create an empty tensor with the same shape except for the zeroed batch dimension. + return torch.empty((0,) + sample.shape) + elif isinstance(sample, np.ndarray): + # Create an empty tensor from a numpy array, also with the zeroed batch dimension. + return torch.empty((0,) + sample.shape, dtype=torch.from_numpy(sample).dtype) + elif isinstance(sample, abc.Mapping): + # Recursively handle dictionary-like objects. + return {key: build_empty_batch_value(value) for key, value in sample.items()} + elif isinstance(sample, tuple) and hasattr(sample, "_fields"): # namedtuple + return type(sample)(*(build_empty_batch_value(item) for item in sample)) + elif isinstance(sample, abc.Sequence) and not isinstance(sample, str): + # Handle lists and tuples, but exclude strings. + return [build_empty_batch_value(item) for item in sample] + elif isinstance(sample, (int, float, str)): + # Return an empty list for basic data types. + return [] + else: + raise TypeError(f"Unsupported data type: {type(sample)}") + + +def empty_collate( + item_example: Union[ + torch.Tensor, np.ndarray, abc.Mapping, abc.Sequence, int, float, str + ] +) -> Callable: + """ + Creates a new collate function that behave same as default pytorch one, + but can process the empty batches. + + Args: + item_example (Union[torch.Tensor, np.ndarray, abc.Mapping, abc.Sequence, int, float, str]): An example item from the dataset. + + Returns: + Callable: function that should replace dataloader collate: `dataloader.collate_fn = empty_collate(...)` + """ + + def custom_collate(batch, _empty_batch_value): + if len(batch) > 0: + return default_collate(batch) # default behavior + else: + return copy.copy(_empty_batch_value) + + empty_batch_value = build_empty_batch_value(item_example) + + return partial(custom_collate, _empty_batch_value=empty_batch_value) diff --git a/GANDLF/privacy/opacus/training_utils.py b/GANDLF/privacy/opacus/training_utils.py new file mode 100644 index 000000000..0664ebc45 --- /dev/null +++ b/GANDLF/privacy/opacus/training_utils.py @@ -0,0 +1,106 @@ +from typing import Tuple +import torch +import torchio + + +def handle_nonempty_batch(subject: dict, params: dict) -> Tuple[dict, int]: + """ + Function to detect batch size from the subject an Opacus loader provides in the case of a non-empty batch, and make any changes to the subject dictionary that are needed for GaNDLF to use it. + + Args: + subject (dict): Training data subject dictionary. + params (dict): Training parameters. + + Returns: + Tuple[dict, int]: Modified subject dictionary and batch size. + """ + batch_size = len(subject[params["channel_keys"][0]][torchio.DATA]) + return subject, batch_size + + +def handle_empty_batch(subject: dict, params: dict, feature_shape: list) -> dict: + """ + Function to replace the list of empty arrays an Opacus loader provides in the case of an empty batch with a subject dictionary GANDLF can consume. + + Args: + subject (dict): Training data subject dictionary. + params (dict): Training parameters. + feature_shape (list): Shape of the features. + + Returns: + dict: Modified subject dictionary. + """ + + print("\nConstructing empty batch dictionary.\n") + + subject = { + "subject_id": "empty_batch", + "spacing": None, + "path_to_metadata": None, + "location": None, + } + subject.update( + { + key: {torchio.DATA: torch.zeros(tuple([0] + feature_shape))} + for key in params["channel_keys"] + } + ) + if params["problem_type"] != "segmentation": + subject.update( + { + key: torch.zeros((0, params["model"]["num_classes"])).to(torch.int64) + for key in params["value_keys"] + } + ) + else: + subject.update( + { + "label": { + torchio.DATA: torch.zeros(tuple([0] + feature_shape)).to( + torch.int64 + ) + } + } + ) + + return subject + + +def handle_dynamic_batch_size(subject: dict, params: dict) -> Tuple[dict, int]: + """ + Function to process the subject Opacus loaders provide and prepare to handle their dynamic batch size (including possible empty batches). + + Args: + subject (dict): Training data subject dictionary. + params (dict): Training parameters. + + Raises: + RuntimeError: If the subject is a list object that is not an empty batch. + + Returns: + Tuple[dict, int]: Modified subject dictionary and batch size. + """ + + # The handling performed here is currently to be able to comprehend what + # batch size we are currently working with (which we may later see as not needed) + # and also to handle the previously observed case where Opacus produces + # a subject that is not a dictionary but rather a list of empty arrays + # (due to the empty batch result). The latter case is detected as a subject that + # is a list object. + if isinstance(subject, list): + are_empty = torch.Tensor( + [torch.equal(tensor, torch.Tensor([])) for tensor in subject] + ) + assert torch.all( + are_empty + ), "Detected a list subject that is not an empty batch, which is not expected behavior." + # feature_shape = [params["model"]["num_channels"]]+params["patch_size"] + feature_shape = [params["model"]["num_channels"]] + params["patch_size"] + subject = handle_empty_batch( + subject=subject, params=params, feature_shape=feature_shape + ) + batch_size = 0 + else: + subject, batch_size = handle_nonempty_batch(subject=subject, params=params) + + return subject, batch_size diff --git a/setup.py b/setup.py index 293dd1ffe..07f747476 100644 --- a/setup.py +++ b/setup.py @@ -81,6 +81,7 @@ "packaging==24.0", "typer==0.9.0", "colorlog", + "opacus==1.5.2", ] if __name__ == "__main__": diff --git a/testing/test_full.py b/testing/test_full.py index b36a8ab64..28e1b9437 100644 --- a/testing/test_full.py +++ b/testing/test_full.py @@ -3245,3 +3245,91 @@ def test_generic_debug_info(): print("53: Starting test for logging") _debug_info(True) print("passed") + + +def test_differential_privacy_epsilon_classification_rad_2d(device): + print("54: Testing complex DP training for 2D classification") + # overwrite previous results + sanitize_outputDir() + # read and initialize parameters for specific data dimension + parameters = parseConfig( + testingDir + "/config_classification.yaml", version_check_flag=False + ) + parameters["modality"] = "rad" + parameters["opt"] = "adam" + parameters["patch_size"] = patch_size["2D"] + parameters["batch_size"] = 32 # needs to be revised + parameters["model"]["dimension"] = 2 + parameters["model"]["amp"] = True + # read and parse csv + training_data, parameters["headers"] = parseTrainingCSV( + inputDir + "/train_2d_rad_classification.csv" + ) + parameters = populate_header_in_parameters(parameters, parameters["headers"]) + parameters["model"]["num_channels"] = 3 + parameters["model"]["norm_type"] = "instance" + parameters["differential_privacy"] = {"epsilon": 25.0, "physical_batch_size": 4} + file_config_temp = os.path.join(outputDir, "config_classification_temp.yaml") + # if found in previous run, discard. + if os.path.exists(file_config_temp): + os.remove(file_config_temp) + + with open(file_config_temp, "w") as file: + yaml.dump(parameters, file) + parameters = parseConfig(file_config_temp, version_check_flag=True) + + TrainingManager( + dataframe=training_data, + outputDir=outputDir, + parameters=parameters, + device=device, + resume=False, + reset=True, + ) + sanitize_outputDir() + + print("passed") + + +def test_differential_privacy_simple_classification_rad_2d(device): + print("55: Testing simple DP") + # overwrite previous results + sanitize_outputDir() + # read and initialize parameters for specific data dimension + parameters = parseConfig( + testingDir + "/config_classification.yaml", version_check_flag=False + ) + parameters["modality"] = "rad" + parameters["opt"] = "adam" + parameters["patch_size"] = patch_size["2D"] + parameters["batch_size"] = 32 # needs to be revised + parameters["model"]["dimension"] = 2 + parameters["model"]["amp"] = False + # read and parse csv + training_data, parameters["headers"] = parseTrainingCSV( + inputDir + "/train_2d_rad_classification.csv" + ) + parameters = populate_header_in_parameters(parameters, parameters["headers"]) + parameters["model"]["num_channels"] = 3 + parameters["model"]["norm_type"] = "instance" + parameters["differential_privacy"] = True + file_config_temp = os.path.join(outputDir, "config_classification_temp.yaml") + # if found in previous run, discard. + if os.path.exists(file_config_temp): + os.remove(file_config_temp) + + with open(file_config_temp, "w") as file: + yaml.dump(parameters, file) + parameters = parseConfig(file_config_temp, version_check_flag=True) + + TrainingManager( + dataframe=training_data, + outputDir=outputDir, + parameters=parameters, + device=device, + resume=False, + reset=True, + ) + sanitize_outputDir() + + print("passed")