diff --git a/code/nnv/examples/NN/FairNNV/adult_exact_verify.m b/code/nnv/examples/NN/FairNNV/adult_exact_verify.m index c7838bbb2..15590f10c 100644 --- a/code/nnv/examples/NN/FairNNV/adult_exact_verify.m +++ b/code/nnv/examples/NN/FairNNV/adult_exact_verify.m @@ -12,6 +12,7 @@ clear; clc; modelDir = './adult_onnx'; % Directory containing ONNX models onnxFiles = dir(fullfile(modelDir, '*.onnx')); % List all .onnx files +onnxFiles = onnxFiles(1); % simplify for debugging load("adult_data.mat", 'X', 'y'); % Load data once @@ -58,12 +59,12 @@ % First, we define the reachability options reachOptions = struct; % initialize reachOptions.reachMethod = 'exact-star'; - reachOptions.relaxFactor = 0.5; nR = 50; % ---> just chosen arbitrarily % ADJUST epsilons value here - epsilon = [0.0,0.001,0.01]; + % epsilon = [0.001,0.01]; + epsilon = 0.01; % Set up results nE = 3; @@ -87,19 +88,23 @@ start(verificationTimer); % Start the timer - for i=1:numObs + % for i=1:numObs + for i=57 idx = rand_indices(i); IS = perturbationIF(X_test_loaded(:, idx), epsilon(e), min_values, max_values); - + + unsafeRegion = net.robustness_set(y_test_loaded(idx), 'min'); t = tic; % Start timing the verification for each sample - temp = net.verify_robustness(IS, reachOptions, y_test_loaded(idx)); + temp = net.verify_robustness(IS, reachOptions, unsafeRegion); met(i,e) = 'exact'; - res(i,e) = temp; % robust result - % end - + res(i,e) = temp; % robust result time(i,e) = toc(t); % store computation time + + if ~(temp == 1) + counterExs = getCounterRegion(IS,unsafeRegion,net.reachSet{end}); + end % Check for timeout flag if evalin('base', 'timeoutOccurred') @@ -164,6 +169,7 @@ % Calculate disturbed lower and upper bounds considering min and max values lb = max(x - disturbance, min_values); ub = min(x + disturbance, max_values); - IS = ImageStar(single(lb), single(ub)); % default: single (assume onnx input models) + IS = Star(single(lb), single(ub)); % default: single (assume onnx input models) + end diff --git a/code/nnv/examples/NN/FairNNV/adult_verifiy.m b/code/nnv/examples/NN/FairNNV/adult_verifiy.m index d093f7b7a..845fcfdf9 100644 --- a/code/nnv/examples/NN/FairNNV/adult_verifiy.m +++ b/code/nnv/examples/NN/FairNNV/adult_verifiy.m @@ -64,7 +64,7 @@ nR = 50; % ---> just chosen arbitrarily % ADJUST epsilon values here - epsilon = [0.0,0.001,0.01]; + epsilon = [0.001,0.01]; % Set up results @@ -90,7 +90,7 @@ start(verificationTimer); % Start the timer % Iterate through observations - for i=1:numObs + for i=38 idx = rand_indices(i); [IS, xRand] = perturbationIF(X_test_loaded(:, idx), epsilon(e), nR, min_values, max_values); @@ -109,6 +109,7 @@ time(i,e) = toc(t); met(i,e) = "counterexample"; skipTryCatch = true; % Set the flag to skip try-catch block + disp('Counter example found'); continue; end end @@ -198,7 +199,7 @@ % Calculate disturbed lower and upper bounds considering min and max values lb = max(x - disturbance, min_values); ub = min(x + disturbance, max_values); - IS = ImageStar(single(lb), single(ub)); % default: single (assume onnx input models) + IS = Star(single(lb), single(ub)); % default: single (assume onnx input models) % Create random samples from initial set % Adjusted reshaping according to specific needs @@ -208,7 +209,7 @@ xRand = xB.sample(nR); xRand = reshape(xRand,[13,nR]); xRand(:,nR+1) = x; % add original image - xRand(:,nR+2) = IS.im_lb; % add lower bound image - xRand(:,nR+3) = IS.im_ub; % add upper bound image + xRand(:,nR+2) = xB.lb; % add lower bound image + xRand(:,nR+3) = xB.ub; % add upper bound image end diff --git a/code/nnv/examples/NN/FairNNV/getCounterRegion.m b/code/nnv/examples/NN/FairNNV/getCounterRegion.m new file mode 100644 index 000000000..40b8206a4 --- /dev/null +++ b/code/nnv/examples/NN/FairNNV/getCounterRegion.m @@ -0,0 +1,38 @@ +function counterExamples = getCounterRegion(inputSet, unsafeRegion, reachSet) + % counterExamples = getCounterRegion(inputSet, unsafeRegion, reachSet) + % NOTE: This is only to be used with exact-star method + % unsafeRegion = HalfSpace (unsafe/undesired region) + % inputSet = ImageStar/Star + % reachSet = Star + % + % check the "safety" of the reachSet + % Then, generate counterexamples + + % Initialize variables + counterExamples = []; + + % Get halfspace variables + G = unsafeRegion.G; + g = unsafeRegion.g; + + % Check for valid inputs + if ~isa(inputSet, "Star") + error("Must be a Star"); + end + if ~isa(reachSet, "Star") + error("Must be Star or ImageStar"); + end + + % Begin counterexample computation + n = length(reachSet); % number of stars in the output set + V = inputSet.V; + for i=1:n + % Check for safety, if unsafe, add to counter + if ~isempty(reachSet(i).intersectHalfSpace(G, g)) + counterExamples = [counterExamples Star(V, reachSet(i).C, reachSet(i).d,... + reachSet(i).predicate_lb, reachSet(i).predicate_ub)]; + end + end + +end +