diff --git a/embed.py b/embed.py index 9368ec2..e785aae 100644 --- a/embed.py +++ b/embed.py @@ -94,8 +94,8 @@ def main(): path, filename = os.path.split(args.wm_config) copyfile(args.wm_config, os.path.join(output_dir, filename)) - source_model: torch.nn.Sequential = defense_config.source_model() - optimizer = defense_config.optimizer(source_model.parameters()) + source_model: torch.nn.Sequential = mlconfig.instantiate(defense_config.source_model) + optimizer = mlconfig.instantiate(defense_config.optimizer, source_model.parameters()) source_model: PyTorchClassifier = __load_model(source_model, optimizer, @@ -104,26 +104,27 @@ def main(): filename=args.filename, pretrained_dir=args.pretrained_dir) # Load the training and testing data. - train_loader = defense_config.dataset(train=True) - valid_loader = defense_config.dataset(train=False) + train_loader = mlconfig.instantiate(defense_config.dataset, train=True) + valid_loader = mlconfig.instantiate(defense_config.dataset, train=False) # Optionally load a dataset to load watermarking images from. wm_loader = None if "wm_dataset" in dict(defense_config).keys(): - wm_loader = defense_config.wm_dataset() + wm_loader = mlconfig.instantiate(defense_config.wm_dataset) print(f"Instantiated watermark loader (with {len(wm_loader)} batches): {wm_loader}") source_test_acc_before_attack = evaluate_test_accuracy(source_model, valid_loader) print(f"Source model test acc (before): {source_test_acc_before_attack}") # Create the defense instance with the pretrained source model. Note: The source model is copied here. - defense: Watermark = defense_config.wm_scheme(source_model, config=defense_config) + defense: Watermark = mlconfig.instantiate(defense_config.wm_scheme, source_model, config=defense_config) # Save this configuration. + from omegaconf import OmegaConf with open(os.path.join(output_dir, "config.json"), "w") as f: config = { "timestamp": str(datetime.now()), - "defense_config": defense_config, + "defense_config": OmegaConf.to_container(defense_config, resolve=True), "args": vars(args) } json.dump(config, f) @@ -131,7 +132,8 @@ def main(): # Embed the watermark. Note that all inputs are copied here. # We assume the defense stores the model and all auxiliary information in the output directory. start_time = time.time() - (x_wm, y_wm), defense = defense_config.embed(defense=defense, + (x_wm, y_wm), defense = mlconfig.instantiate(defense_config.embed, + defense=defense, train_loader=train_loader, valid_loader=valid_loader, wm_loader=wm_loader, diff --git a/steal.py b/steal.py index a58f4db..7390962 100644 --- a/steal.py +++ b/steal.py @@ -125,15 +125,15 @@ def main(): model_basedir, model_filename = os.path.split(pth_file) - source_model = defense_config.source_model() + source_model = mlconfig.instantiate(defense_config.source_model) source_model = source_model.to(device) - optimizer = defense_config.optimizer(source_model.parameters()) + optimizer = mlconfig.instantiate(defense_config.optimizer, source_model.parameters()) source_model = __load_model(source_model, optimizer, image_size=defense_config.source_model.image_size, num_classes=defense_config.source_model.num_classes, defense_filename=pth_file) - defense = defense_config.wm_scheme(classifier=source_model, optimizer=optimizer, config=defense_config) + defense = mlconfig.instantiate(defense_config.wm_scheme, classifier=source_model, optimizer=optimizer, config=defense_config) x_wm, y_wm = defense.load(filename=model_filename, path=model_basedir) print(y_wm) @@ -142,12 +142,12 @@ def main(): print(f"Using ground truth labels? {use_gt}") if use_gt: print("Using ground-truth labels ..") - train_loader = attack_config.dataset(train=True) - valid_loader = attack_config.dataset(train=False) + train_loader = mlconfig.instantiate(attack_config.dataset, train=True) + valid_loader = mlconfig.instantiate(attack_config.dataset, train=False) else: print("Using predicted labels ..") - train_loader = attack_config.dataset(source_model=source_model, train=True) - valid_loader = attack_config.dataset(source_model=source_model, train=False) + train_loader = mlconfig.instantiate(attack_config.dataset, source_model=source_model, train=True) + valid_loader = mlconfig.instantiate(attack_config.dataset, source_model=source_model, train=False) source_test_acc_before_attack = evaluate_test_accuracy(source_model, valid_loader) print(f"Source model test acc: {source_test_acc_before_attack:.4f}") @@ -155,8 +155,8 @@ def main(): print(f"Source model wm acc: {source_wm_acc:.4f}") if "surrogate_model" in attack_config.keys(): - surrogate_model = attack_config.surrogate_model() - optimizer = attack_config.optimizer(surrogate_model.parameters()) + surrogate_model = mlconfig.instantiate(attack_config.surrogate_model) + optimizer = mlconfig.instantiate(attack_config.optimizer, surrogate_model.parameters()) surrogate_model = __load_model(surrogate_model, optimizer, image_size=attack_config.surrogate_model.image_size, num_classes=attack_config.surrogate_model.num_classes) @@ -182,11 +182,12 @@ def main(): print(f"[ERROR] {e}") print("Could not extract watermark accuracy from the surrogate model ... Continuing ..") - attack: RemovalAttack = attack_config.create(classifier=surrogate_model, config=attack_config) + attack: RemovalAttack = mlconfig.instantiate(attack_config.create, classifier=surrogate_model, config=attack_config) # Run the removal. We still need wrappers to conform to the old interface. start = time.time() - attack, train_metric = attack_config.remove(attack=attack, + attack, train_metric = mlconfig.instantiate(attack_config.remove, + attack=attack, source_model=source_model, train_loader=train_loader, valid_loader=valid_loader, diff --git a/train.py b/train.py index c819e8c..1965a7b 100644 --- a/train.py +++ b/train.py @@ -69,22 +69,29 @@ def main(): device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu') - model: torch.nn.Sequential = config.model() + model: torch.nn.Sequential = mlconfig.instantiate(config.model) model = model.to(device) - optimizer = config.optimizer(model.parameters()) - scheduler = config.scheduler(optimizer) + optimizer = mlconfig.instantiate(config.optimizer, model.parameters()) + scheduler = mlconfig.instantiate(config.scheduler, optimizer=optimizer) model: PyTorchClassifier = __load_model(model, optimizer=optimizer, image_size=config.model.image_size, num_classes=config.model.num_classes) - train_loader = config.dataset(train=True) - valid_loader = config.dataset(train=False) + train_loader = mlconfig.instantiate(config.dataset, train=True) + valid_loader = mlconfig.instantiate(config.dataset, train=False) - trainer = config.trainer(model=model, train_loader=train_loader, valid_loader=valid_loader, - scheduler=scheduler, device=device, output_dir=output_dir) + trainer = mlconfig.instantiate( + config.trainer, + model=model, + train_loader=train_loader, + valid_loader=valid_loader, + scheduler=scheduler, + device=device, + output_dir=output_dir + ) if args.resume is not None: trainer.resume(args.resume) diff --git a/wrt/classifiers/pytorch.py b/wrt/classifiers/pytorch.py index 882104f..558d7ba 100644 --- a/wrt/classifiers/pytorch.py +++ b/wrt/classifiers/pytorch.py @@ -975,6 +975,7 @@ def reduce_labels(self): return isinstance(self.loss, (torch.nn.CrossEntropyLoss, torch.nn.NLLLoss, torch.nn.MultiMarginLoss)) def compute_loss(self, pred, true, x=None): + true = true.to(torch.int64) return self.loss(pred, true) def __call__(self, *args, **kwargs): diff --git a/wrt/defenses/__init__.py b/wrt/defenses/__init__.py index 6707ad0..329341a 100644 --- a/wrt/defenses/__init__.py +++ b/wrt/defenses/__init__.py @@ -2,4 +2,4 @@ Module implementing defenses for neural networks. """ from .watermark import * -from backdoor import * \ No newline at end of file +# from backdoor import * \ No newline at end of file diff --git a/wrt/training/datasets/trigger_datasets.py b/wrt/training/datasets/trigger_datasets.py index 005531d..bffef19 100644 --- a/wrt/training/datasets/trigger_datasets.py +++ b/wrt/training/datasets/trigger_datasets.py @@ -31,7 +31,7 @@ def __getitem__(self, idx): class AdiTrigger(Trigger): - url = "https://www.dropbox.com/s/z11ds7jvewkgv18/adi.zip?dl=1" + url = "https://www.dropbox.com/scl/fi/5fbrlbkxwlse8zotgih3z/adi.zip?rlkey=trg2s2fm9tx57uhn2c8tdzc46&st=ae39w6mn&dl=0" filename = "adi.zip" folder_name = "adi"