-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
517 changed files
with
352,996 additions
and
7 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,308 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ebd66700", | ||
"metadata": {}, | ||
"source": [ | ||
"## Demo_Quality\n", | ||
"This is a demo for visualizing the Image Quality\n", | ||
"\n", | ||
"To run this demo from scratch, you need first generate a BadNet attack result by using the following cell" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b950f4fc", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"! python ../../attack/badnet.py --save_folder_name badnet_demo" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8f81f973", | ||
"metadata": {}, | ||
"source": [ | ||
"or run the following command in your terminal\n", | ||
"\n", | ||
"```python attack/badnet.py --save_folder_name badnet_demo```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "87bd9f5a", | ||
"metadata": {}, | ||
"source": [ | ||
"### Step 1: Import modules and set arguments" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "71b7087b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import sys, os\n", | ||
"import yaml\n", | ||
"import torch\n", | ||
"import shap\n", | ||
"import numpy as np\n", | ||
"import torchvision.transforms as transforms\n", | ||
"\n", | ||
"sys.path.append(\"../\")\n", | ||
"sys.path.append(\"../../\")\n", | ||
"sys.path.append(os.getcwd())\n", | ||
"from visual_utils import *\n", | ||
"from utils.aggregate_block.dataset_and_transform_generate import (\n", | ||
" get_transform,\n", | ||
" get_dataset_denormalization,\n", | ||
")\n", | ||
"from utils.aggregate_block.fix_random import fix_random\n", | ||
"from utils.aggregate_block.model_trainer_generate import generate_cls_model\n", | ||
"from utils.save_load_attack import load_attack_result\n", | ||
"from utils.defense_utils.dbd.model.utils import (\n", | ||
" get_network_dbd,\n", | ||
" load_state,\n", | ||
" get_criterion,\n", | ||
" get_optimizer,\n", | ||
" get_scheduler,\n", | ||
")\n", | ||
"from utils.defense_utils.dbd.model.model import SelfModel, LinearModel\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "2fb719c7", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"### Basic setting: args\n", | ||
"args = get_args(True)\n", | ||
"\n", | ||
"########## For Demo Only ##########\n", | ||
"args.yaml_path = \"../../\"+args.yaml_path\n", | ||
"args.result_file_attack = \"badnet_demo\"\n", | ||
"######## End For Demo Only ##########\n", | ||
"\n", | ||
"with open(args.yaml_path, \"r\") as stream:\n", | ||
" config = yaml.safe_load(stream)\n", | ||
"config.update({k: v for k, v in args.__dict__.items() if v is not None})\n", | ||
"args.__dict__ = config\n", | ||
"args = preprocess_args(args)\n", | ||
"fix_random(int(args.random_seed))\n", | ||
"\n", | ||
"save_path_attack = \"../..//record/\" + args.result_file_attack\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f959b510", | ||
"metadata": {}, | ||
"source": [ | ||
"### Step 2: Load data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "b8b67ac9", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"WARNING:root:save_path MUST have 'record' in its abspath, and data_path in attack result MUST have 'data' in its path\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Files already downloaded and verified\n", | ||
"Files already downloaded and verified\n", | ||
"loading...\n", | ||
"max_num_samples is given, use sample number limit now.\n", | ||
"subset bd dataset with length: 5000\n", | ||
"Create visualization dataset with \n", | ||
" \t Dataset: bd_train \n", | ||
" \t Number of samples: 5000 \n", | ||
" \t Selected classes: [0 1 2 3 4 5 6 7 8 9]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Load result\n", | ||
"result_attack = load_attack_result(save_path_attack + \"/attack_result.pt\")\n", | ||
"selected_classes = np.arange(args.num_classes)\n", | ||
"\n", | ||
"# Select classes to visualize\n", | ||
"if args.num_classes>args.c_sub:\n", | ||
" selected_classes = np.delete(selected_classes, args.target_class)\n", | ||
" selected_classes = np.random.choice(selected_classes, args.c_sub-1, replace=False)\n", | ||
" selected_classes = np.append(selected_classes, args.target_class)\n", | ||
"\n", | ||
"# keep the same transforms for train and test dataset for better visualization\n", | ||
"result_attack[\"clean_train\"].wrap_img_transform = result_attack[\"clean_test\"].wrap_img_transform \n", | ||
"result_attack[\"bd_train\"].wrap_img_transform = result_attack[\"bd_test\"].wrap_img_transform \n", | ||
"\n", | ||
"# Create dataset\n", | ||
"args.visual_dataset = 'bd_train'\n", | ||
"if args.visual_dataset == 'mixed':\n", | ||
" bd_test_with_trans = result_attack[\"bd_test\"]\n", | ||
" visual_dataset = generate_mix_dataset(bd_test_with_trans, args.target_class, args.pratio, selected_classes, max_num_samples=args.n_sub)\n", | ||
"elif args.visual_dataset == 'clean_train':\n", | ||
" clean_train_with_trans = result_attack[\"clean_train\"]\n", | ||
" visual_dataset = generate_clean_dataset(clean_train_with_trans, selected_classes, max_num_samples=args.n_sub)\n", | ||
"elif args.visual_dataset == 'clean_test':\n", | ||
" clean_test_with_trans = result_attack[\"clean_test\"]\n", | ||
" visual_dataset = generate_clean_dataset(clean_test_with_trans, selected_classes, max_num_samples=args.n_sub)\n", | ||
"elif args.visual_dataset == 'bd_train': \n", | ||
" bd_train_with_trans = result_attack[\"bd_train\"]\n", | ||
" visual_dataset = generate_bd_dataset(bd_train_with_trans, args.target_class, selected_classes, max_num_samples=args.n_sub)\n", | ||
"elif args.visual_dataset == 'bd_test':\n", | ||
" bd_test_with_trans = result_attack[\"bd_test\"]\n", | ||
" visual_dataset = generate_bd_dataset(bd_test_with_trans, args.target_class, selected_classes, max_num_samples=args.n_sub)\n", | ||
"else:\n", | ||
" assert False, \"Illegal vis_class\"\n", | ||
"\n", | ||
"print(f'Create visualization dataset with \\n \\t Dataset: {args.visual_dataset} \\n \\t Number of samples: {len(visual_dataset)} \\n \\t Selected classes: {selected_classes}')\n", | ||
"\n", | ||
"# Create data loader\n", | ||
"data_loader = torch.utils.data.DataLoader(\n", | ||
" visual_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False\n", | ||
")\n", | ||
"\n", | ||
"# Create denormalization function\n", | ||
"for trans_t in data_loader.dataset.wrap_img_transform.transforms:\n", | ||
" if isinstance(trans_t, transforms.Normalize):\n", | ||
" denormalizer = get_dataset_denormalization(trans_t)\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "67cbfec4", | ||
"metadata": {}, | ||
"source": [ | ||
"### Step 3: SSIM" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "39104beb", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Number Poisoned samples: 489\n", | ||
"Average SSIM: 0.9929845929145813\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"visual_poison_indicator = np.array(get_poison_indicator_from_bd_dataset(visual_dataset))\n", | ||
"bd_idx = np.where(visual_poison_indicator == 1)[0]\n", | ||
"\n", | ||
"from torchmetrics import StructuralSimilarityIndexMeasure\n", | ||
"ssim = StructuralSimilarityIndexMeasure()\n", | ||
"ssim_list = []\n", | ||
"if visual_poison_indicator.sum() > 0:\n", | ||
" print(f'Number Poisoned samples: {visual_poison_indicator.sum()}')\n", | ||
" # random choose two poisoned samples\n", | ||
" start_idx = 0\n", | ||
" for i in range(bd_idx.shape[0]):\n", | ||
" bd_sample = denormalizer(visual_dataset[i][0]).unsqueeze(0)\n", | ||
" with temporary_all_clean(visual_dataset):\n", | ||
" clean_sample = denormalizer(visual_dataset[i][0]).unsqueeze(0)\n", | ||
" ssim_list.append(ssim(bd_sample, clean_sample)) \n", | ||
"print(f'Average SSIM: {np.mean(ssim_list)}')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "2c2b0104", | ||
"metadata": {}, | ||
"source": [ | ||
"### Step 4: FID" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "57497927", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Number Poisoned samples: 489\n", | ||
"FID: 0.00030133521067909896\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"visual_poison_indicator = np.array(get_poison_indicator_from_bd_dataset(visual_dataset))\n", | ||
"bd_idx = np.where(visual_poison_indicator == 1)[0]\n", | ||
"\n", | ||
"from torchmetrics.image.fid import FrechetInceptionDistance\n", | ||
"fid = FrechetInceptionDistance(feature=64, normalize = True)\n", | ||
"if visual_poison_indicator.sum() > 0:\n", | ||
" print(f'Number Poisoned samples: {visual_poison_indicator.sum()}')\n", | ||
" # random choose two poisoned samples\n", | ||
" start_idx = 0\n", | ||
" for i in range(bd_idx.shape[0]):\n", | ||
" bd_sample = denormalizer(visual_dataset[i][0]).unsqueeze(0)\n", | ||
" with temporary_all_clean(visual_dataset):\n", | ||
" clean_sample = denormalizer(visual_dataset[i][0]).unsqueeze(0)\n", | ||
" fid.update(clean_sample, real=True)\n", | ||
" fid.update(bd_sample, real=False)\n", | ||
" fid_value = fid.compute().numpy() \n", | ||
"print(f'FID: {fid_value}')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "870cf186", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.12" | ||
}, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "6869619afde5ccaa692f7f4d174735a0f86b1f7ceee086952855511b0b6edec0" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.