forked from FFmgll/shapiq
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_marginal_utility.py
76 lines (60 loc) · 2.37 KB
/
test_marginal_utility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import copy
import numpy as np
from games.all import MarginalUtility
from approximators import SHAPIQEstimator
from approximators.regression import RegressionEstimator
if __name__ == "__main__":
# setup the game function (here we use a
game_list = [
MarginalUtility(n=11, p=0.2, example=1),
MarginalUtility(n=10, p=0.2, example=1),
MarginalUtility(n=12, p=0.2, example=1)
]
for game in game_list:
game_name = game.game_name
game_fun = game.set_call
n = game.n
N = set(range(n))
interaction_order = 2
# SHAP-IQ to approximate the Shapley Interaction Index
shapley_extractor_sii = SHAPIQEstimator(
N=N,
order=interaction_order,
interaction_type="SII",
top_order=False
)
# SHAP-IQ to approximate the Shapley Taylor Index
shapley_extractor_sti = SHAPIQEstimator(
N=N,
order=interaction_order,
interaction_type="STI",
top_order=False
)
approximators = {
"SII": shapley_extractor_sii,
"STI": shapley_extractor_sti,
}
# print("Starting exact computations")
shapx_exact_values = {}
for interaction_type, approximator in approximators.items():
# print("Exact values are calculated via brute force.")
shapx_exact_values[interaction_type] = copy.deepcopy(
approximator.compute_interactions_complete(game_fun)
)
# FSI values
shapley_extractor_FSI_regression = RegressionEstimator(
N, interaction_order)
shapx_exact_values["FSI"] = shapley_extractor_FSI_regression.compute_exact_values(game_fun)
# n-Shapley
shapx_exact_values["n_shapley"] = approximators["SII"].transform_interactions_in_n_shapley(
shapx_exact_values["SII"])
print("--------------------")
print(game.n, " features and p=", game.p, " example: ", game.example)
for vals in shapx_exact_values:
print("computed by ", vals)
results = {}
for key in shapx_exact_values[vals]:
results[key] = np.unique(np.round(
shapx_exact_values[vals][key][np.nonzero(shapx_exact_values[vals][key])], 6))
print(results)
# print(shapx_exact_values[vals])