forked from opensbt/opensbt-fmnist
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_mnist.py
101 lines (85 loc) · 3.55 KB
/
run_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from opensbt.algorithm.nsga2_optimizer import NsgaIIOptimizer
from opensbt.algorithm.nsga2d_optimizer import NSGAIIDOptimizer
from opensbt.evaluation.fitness import *
from opensbt.experiment.search_configuration import DefaultSearchConfiguration
from opensbt.experiment.experiment import *
from opensbt.algorithm.algorithm import *
from opensbt.evaluation.critical import *
from mnist.mnist_problem import *
from mnist.fitness_mnist import *
from mnist.utils_mnist import get_number_verts
from mnist.operator import MnistSamplingValid
from opensbt.config import *
from mnist import mnist_simulation
from mnist.mnist_simulation import MnistSimulator
from opensbt.config import RESULTS_FOLDER as results_folder
""" MNIST Testing with single seed mutation
"""
import random
random.seed(42)
import numpy as np
np.random.seed(42)
config = DefaultSearchConfiguration()
config.population_size = 20
config.n_generations = 20
### pass here custom operators ####
# config.operators["mut"] = MnistMutation
# config.operators["cx"] = MyNoCrossover
# config.operators["dup"] = MnistDuplicateElimination
config.operators["init"] = MnistSamplingValid
seed = 120 #127 #52# 132 #129
#other possible seeds: 8, 15, 23, 45, 52, 53, 102, 120, 127, 129, 132, 152
lb = -8 # displacement bounds
ub = +8
digit = mnist_simulation.generate_digit(seed)
vertex_num = get_number_verts(digit)
ub_vert = vertex_num -1
config.operators["init"] = MnistSamplingValid
mnistproblem = MNISTProblem(
problem_name=f"MNIST_3D",
xl=[lb, lb, 0],
xu=[ub, ub, ub_vert],
simulation_variables=[
"mut_extent_1",
"mut_extent_2",
"vertex_control"
],
simulate_function=MnistSimulator.simulate,
fitness_function=FitnessMNIST(),
critical_function=CriticalMNISTConf_05(),
expected_label=5,
min_saturation=0.1,
seed=seed
)
##############
# 6 D Problem
##############
# mnistproblem = MNISTProblem(
# problem_name=f"MNIST_6D",
# xl=[lb, lb, lb, lb, 0, 0],
# xu=[ub, ub, ub, ub, ub_vert, ub_vert],
# simulation_variables=[
# "mut_extent_1",
# "mut_extent_2",
# "mut_extent_3",
# "mut_extent_4",
# "vertex_control",
# "vertex_start"
# ],
# simulate_function=MnistSimulator.simulate,
# fitness_function=FitnessMNIST(),
# critical_function=CriticalMNISTConf_05(),
# expected_label=EXPECTED_LABEL,
# min_saturation=0.1,
# max_seed_distance=4,
# seed=seed
# )
mnistproblem.set_fitness_function(FitnessMNIST(diversify=True))
mnistproblem.critical_function=CriticalMNISTConf_05()
mnistproblem.problem_name = mnistproblem.problem_name+ "_NSGA-II-DJ" + f"_D{seed}"
optimizer = NSGAIIDOptimizer(
problem=mnistproblem,
config=config)
res = optimizer.run()
res.write_results(results_folder=results_folder, params = optimizer.parameters)
log.info("====== Algorithm search time: " + str("%.2f" % res.exec_time) + " sec")