forked from FederatedAI/FATE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluation_param.py
155 lines (122 loc) · 5.92 KB
/
evaluation_param.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from federatedml.util import consts, LOGGER
from federatedml.param.base_param import BaseParam
class EvaluateParam(BaseParam):
"""
Define the evaluation method of binary/multiple classification and regression
Parameters
----------
eval_type : {'binary', 'regression', 'multi'}
support 'binary' for HomoLR, HeteroLR and Secureboosting,
support 'regression' for Secureboosting,
'multi' is not support these version
unfold_multi_result : bool
unfold multi result and get several one-vs-rest binary classification results
pos_label : int or float or str
specify positive label type, depend on the data's label. this parameter effective only for 'binary'
need_run: bool, default True
Indicate if this module needed to be run
"""
def __init__(self, eval_type="binary", pos_label=1, need_run=True, metrics=None,
run_clustering_arbiter_metric=False, unfold_multi_result=False):
super().__init__()
self.eval_type = eval_type
self.pos_label = pos_label
self.need_run = need_run
self.metrics = metrics
self.unfold_multi_result = unfold_multi_result
self.run_clustering_arbiter_metric = run_clustering_arbiter_metric
self.default_metrics = {
consts.BINARY: consts.ALL_BINARY_METRICS,
consts.MULTY: consts.ALL_MULTI_METRICS,
consts.REGRESSION: consts.ALL_REGRESSION_METRICS,
consts.CLUSTERING: consts.ALL_CLUSTER_METRICS
}
self.allowed_metrics = {
consts.BINARY: consts.ALL_BINARY_METRICS,
consts.MULTY: consts.ALL_MULTI_METRICS,
consts.REGRESSION: consts.ALL_REGRESSION_METRICS,
consts.CLUSTERING: consts.ALL_CLUSTER_METRICS
}
def _use_single_value_default_metrics(self):
self.default_metrics = {
consts.BINARY: consts.DEFAULT_BINARY_METRIC,
consts.MULTY: consts.DEFAULT_MULTI_METRIC,
consts.REGRESSION: consts.DEFAULT_REGRESSION_METRIC,
consts.CLUSTERING: consts.DEFAULT_CLUSTER_METRIC
}
def _check_valid_metric(self, metrics_list):
metric_list = consts.ALL_METRIC_NAME
alias_name: dict = consts.ALIAS
full_name_list = []
metrics_list = [str.lower(i) for i in metrics_list]
for metric in metrics_list:
if metric in metric_list:
if metric not in full_name_list:
full_name_list.append(metric)
continue
valid_flag = False
for alias, full_name in alias_name.items():
if metric in alias:
if full_name not in full_name_list:
full_name_list.append(full_name)
valid_flag = True
break
if not valid_flag:
raise ValueError('metric {} is not supported'.format(metric))
allowed_metrics = self.allowed_metrics[self.eval_type]
for m in full_name_list:
if m not in allowed_metrics:
raise ValueError('metric {} is not used for {} task'.format(m, self.eval_type))
if consts.RECALL in full_name_list and consts.PRECISION not in full_name_list:
full_name_list.append(consts.PRECISION)
if consts.RECALL not in full_name_list and consts.PRECISION in full_name_list:
full_name_list.append(consts.RECALL)
return full_name_list
def check(self):
descr = "evaluate param's "
self.eval_type = self.check_and_change_lower(self.eval_type,
[consts.BINARY, consts.MULTY, consts.REGRESSION,
consts.CLUSTERING],
descr)
if type(self.pos_label).__name__ not in ["str", "float", "int"]:
raise ValueError(
"evaluate param's pos_label {} not supported, should be str or float or int type".format(
self.pos_label))
if type(self.need_run).__name__ != "bool":
raise ValueError(
"evaluate param's need_run {} not supported, should be bool".format(
self.need_run))
if self.metrics is None or len(self.metrics) == 0:
self.metrics = self.default_metrics[self.eval_type]
LOGGER.warning('use default metric {} for eval type {}'.format(self.metrics, self.eval_type))
self.check_boolean(self.unfold_multi_result, 'multi_result_unfold')
self.metrics = self._check_valid_metric(self.metrics)
return True
def check_single_value_default_metric(self):
self._use_single_value_default_metrics()
# in validation strategy, psi f1-score and confusion-mat pr-quantile are not supported in cur version
if self.metrics is None or len(self.metrics) == 0:
self.metrics = self.default_metrics[self.eval_type]
LOGGER.warning('use default metric {} for eval type {}'.format(self.metrics, self.eval_type))
ban_metric = [consts.PSI, consts.F1_SCORE, consts.CONFUSION_MAT, consts.QUANTILE_PR]
for metric in self.metrics:
if metric in ban_metric:
self.metrics.remove(metric)
self.check()