forked from FederatedAI/FATE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ftl_param.py
161 lines (143 loc) · 7.63 KB
/
ftl_param.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import collections
import copy
from pipeline.param.intersect_param import IntersectParam
from types import SimpleNamespace
from pipeline.param.base_param import BaseParam
from pipeline.param import consts
from pipeline.param.encrypt_param import EncryptParam
from pipeline.param.encrypted_mode_calculation_param import EncryptedModeCalculatorParam
from pipeline.param.predict_param import PredictParam
from pipeline.param.callback_param import CallbackParam
class FTLParam(BaseParam):
def __init__(self, alpha=1, tol=0.000001,
n_iter_no_change=False, validation_freqs=None, optimizer={'optimizer': 'Adam', 'learning_rate': 0.01},
nn_define={}, epochs=1, intersect_param=IntersectParam(consts.RSA), config_type='keras', batch_size=-1,
encrypte_param=EncryptParam(),
encrypted_mode_calculator_param=EncryptedModeCalculatorParam(mode="confusion_opt"),
predict_param=PredictParam(), mode='plain', communication_efficient=False,
local_round=5, callback_param=CallbackParam()):
"""
Args:
alpha: float, a loss coefficient defined in paper, it defines the importance of alignment loss
tol: float, loss tolerance
n_iter_no_change: bool, check loss convergence or not
validation_freqs: None or positive integer or container object in python. Do validation in training process or Not.
if equals None, will not do validation in train process;
if equals positive integer, will validate data every validation_freqs epochs passes;
if container object in python, will validate data if epochs belong to this container.
e.g. validation_freqs = [10, 15], will validate data when epoch equals to 10 and 15.
Default: None
The default value is None, 1 is suggested. You can set it to a number larger than 1 in order to
speed up training by skipping validation rounds. When it is larger than 1, a number which is
divisible by "epochs" is recommended, otherwise, you will miss the validation scores
of last training epoch.
optimizer: optimizer method, accept following types:
1. a string, one of "Adadelta", "Adagrad", "Adam", "Adamax", "Nadam", "RMSprop", "SGD"
2. a dict, with a required key-value pair keyed by "optimizer",
with optional key-value pairs such as learning rate.
defaults to "SGD"
nn_define: dict, a dict represents the structure of neural network, it can be output by tf-keras
epochs: int, epochs num
intersect_param: define the intersect method
config_type: now only 'tf-keras' is supported
batch_size: batch size when computing transformed feature embedding, -1 use full data.
encrypte_param: encrypted param
encrypted_mode_calculator_param:
predict_param: predict param
mode:
plain: will not use any encrypt algorithms, data exchanged in plaintext
encrypted: use paillier to encrypt gradients
communication_efficient:
bool, will use communication efficient or not. when communication efficient is enabled, FTL model will
update gradients by several local rounds using intermediate data
local_round: local update round when using communication efficient
"""
super(FTLParam, self).__init__()
self.alpha = alpha
self.tol = tol
self.n_iter_no_change = n_iter_no_change
self.validation_freqs = validation_freqs
self.optimizer = optimizer
self.nn_define = nn_define
self.epochs = epochs
self.intersect_param = copy.deepcopy(intersect_param)
self.config_type = config_type
self.batch_size = batch_size
self.encrypted_mode_calculator_param = copy.deepcopy(encrypted_mode_calculator_param)
self.encrypt_param = copy.deepcopy(encrypte_param)
self.predict_param = copy.deepcopy(predict_param)
self.mode = mode
self.communication_efficient = communication_efficient
self.local_round = local_round
self.callback_param = copy.deepcopy(callback_param)
def check(self):
self.intersect_param.check()
self.encrypt_param.check()
self.encrypted_mode_calculator_param.check()
self.optimizer = self._parse_optimizer(self.optimizer)
supported_config_type = ["keras"]
if self.config_type not in supported_config_type:
raise ValueError(f"config_type should be one of {supported_config_type}")
if not isinstance(self.tol, (int, float)):
raise ValueError("tol should be numeric")
if not isinstance(self.epochs, int) or self.epochs <= 0:
raise ValueError("epochs should be a positive integer")
if self.nn_define and not isinstance(self.nn_define, dict):
raise ValueError("bottom_nn_define should be a dict defining the structure of neural network")
if self.batch_size != -1:
if not isinstance(self.batch_size, int) \
or self.batch_size < consts.MIN_BATCH_SIZE:
raise ValueError(
" {} not supported, should be larger than 10 or -1 represent for all data".format(self.batch_size))
if self.validation_freqs is None:
pass
elif isinstance(self.validation_freqs, int):
if self.validation_freqs < 1:
raise ValueError("validation_freqs should be larger than 0 when it's integer")
elif not isinstance(self.validation_freqs, collections.Container):
raise ValueError("validation_freqs should be None or positive integer or container")
assert isinstance(self.communication_efficient, bool), 'communication efficient must be a boolean'
assert self.mode in [
'encrypted', 'plain'], 'mode options: encrpyted or plain, but {} is offered'.format(
self.mode)
self.check_positive_integer(self.epochs, 'epochs')
self.check_positive_number(self.alpha, 'alpha')
self.check_positive_integer(self.local_round, 'local round')
@staticmethod
def _parse_optimizer(opt):
"""
Examples:
1. "optimize": "SGD"
2. "optimize": {
"optimizer": "SGD",
"learning_rate": 0.05
}
"""
kwargs = {}
if isinstance(opt, str):
return SimpleNamespace(optimizer=opt, kwargs=kwargs)
elif isinstance(opt, dict):
optimizer = opt.get("optimizer", kwargs)
if not optimizer:
raise ValueError(f"optimizer config: {opt} invalid")
kwargs = {k: v for k, v in opt.items() if k != "optimizer"}
return SimpleNamespace(optimizer=optimizer, kwargs=kwargs)
else:
raise ValueError(f"invalid type for optimize: {type(opt)}")