Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sif develop #133

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
5b5703d
Add SIF algorithm
zhenglecheng Aug 17, 2021
a7160fb
Update test_SIF.py
zhenglecheng Aug 17, 2021
5d1276a
Update SIF.ipynb
zhenglecheng Aug 17, 2021
48c69fc
Update README.md
zhenglecheng Aug 26, 2021
b2dc3b9
Update .travis.yml
zhenglecheng Sep 16, 2021
dc0efab
Update SIF Algorithm
zhenglecheng Sep 22, 2021
0d8e096
Update SIF.ipynb
zhenglecheng Sep 22, 2021
6c36abd
Update SIF.ipynb
zhenglecheng Sep 22, 2021
5467993
Update SIF.ipynb
zhenglecheng Sep 22, 2021
9a7f177
Update .travis.yml
zhenglecheng Sep 22, 2021
f2607d4
Update .travis.yml
zhenglecheng Sep 22, 2021
0b75812
Create requirement.txt
zhenglecheng Sep 22, 2021
4c28bfd
Merge branch 'SIF_develop' of https://github.com/Leo02016/AIX360 into…
zhenglecheng Sep 22, 2021
cca9cb0
Update .travis.yml
zhenglecheng Sep 22, 2021
e2b8027
Update .travis.yml
zhenglecheng Sep 22, 2021
84140e6
Update .travis.yml
zhenglecheng Sep 22, 2021
b28a9f2
Update .travis.yml
zhenglecheng Sep 22, 2021
022fea8
Update .travis.yml
zhenglecheng Sep 22, 2021
31494ae
Update .travis.yml
zhenglecheng Sep 22, 2021
5b11aaa
Update .travis.yml
zhenglecheng Sep 22, 2021
8c1b7d3
Update .travis.yml
zhenglecheng Sep 22, 2021
dbdd574
Update .travis.yml
zhenglecheng Sep 22, 2021
dbd758f
Update .travis.yml
zhenglecheng Sep 22, 2021
53ad047
Update .travis.yml
zhenglecheng Sep 28, 2021
52c60d1
Merge branch 'SIF_develop' of https://github.com/Leo02016/AIX360 into…
zhenglecheng Sep 28, 2021
35fd6ac
Merge branch 'master' into SIF_develop
zhenglecheng Dec 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ python:
install:
- pip3 install --upgrade setuptools==41.0.0
- pip3 install .
#- pip3 install -r requirements.txt
# - pip3 install -r requirements.txt

# commands to run tes
# before_script: redis-cli ping
Expand All @@ -22,7 +22,8 @@ script:
- python3.6 ./tests/rbm/test_Logistic_Rule_Regression.py
- python3.6 ./tests/lime/test_lime.py
- python3.6 ./tests/shap/test_shap.py

- python3.6 ./tests/sif/test_SIF.py
- python3.6 ./tests/sif/test_SIF.py
after_success:
# - codecov

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ We have developed the package with extensibility in mind. This library is still
- Contrastive Explanations Method with Monotonic Attribute Functions ([Luss et al., 2019](https://arxiv.org/abs/1905.12698))
- LIME ([Ribeiro et al. 2016](https://arxiv.org/abs/1602.04938), [Github](https://github.com/marcotcr/lime))
- SHAP ([Lundberg, et al. 2017](http://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions), [Github](https://github.com/slundberg/shap))

- SIF ([Jianbo, et al. 2021](https://ojs.aaai.org/index.php/AAAI/article/view/17379/17186))
### Local direct explanation

- Teaching AI to Explain its Decisions ([Hind et al., 2019](https://doi.org/10.1145/3306618.3314273))
Expand Down
1,186 changes: 1,186 additions & 0 deletions aix360/algorithms/sif/SIF.py

Large diffs are not rendered by default.

201 changes: 201 additions & 0 deletions aix360/algorithms/sif/SIF_NN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import tensorflow as tf
from aix360.algorithms.sif.SIF import SIFExplainer
from aix360.datasets.SIF_dataset import DataSet


class AllAR(SIFExplainer):
def __init__(self, x_dim, y_dim, time_steps, share_param, **kwargs):
self.time_steps = time_steps
self.x_dim = x_dim
self.cells = None
self.y_dim = y_dim
self.share_param = share_param
if share_param:
self.out_weights = tf.Variable(tf.random_normal([self.time_steps, 1]))
else:
self.out_weights = tf.Variable(tf.random_normal([self.time_steps, self.y_dim]))
super().__init__(**kwargs)

def get_all_params(self):
all_params = []
all_params.append(self.out_weights)
return all_params

def retrain(self, num_steps, feed_dict):
retrain_dataset = DataSet(feed_dict[self.input_index_placeholder], feed_dict[self.labels_index_placeholder])
for step in range(num_steps):
iter_feed_dict = self.fill_feed_dict_with_batch(retrain_dataset)
self.sess.run(self.train_op, feed_dict=iter_feed_dict)

def placeholder_inputs(self):
input_index_placeholder = tuple([tf.placeholder(
tf.int32,
shape=(None, 1),
name='input_index_placeholder_{}'.format(i)) for i in range(self.time_steps)])
labels_index_placeholder = tf.placeholder(
tf.int32,
shape=(None, 1),
name='labels_index_placeholder')
ts_placeholder = tf.placeholder(
tf.float32,
shape=[None, self.y_dim],
name='input_ts')
return input_index_placeholder, labels_index_placeholder, ts_placeholder

def inference(self, input_x, labels_placeholder=None, keep_probs_placeholder=None):
if self.share_param:
weight = tf.tile(self.out_weights, [1, self.y_dim], name='Weight')
else:
weight = self.out_weights
x = tf.stack([x[:, 0] for x in input_x], axis=1, name='x')
y_hat = tf.einsum('ijk,jk->ik', x, weight, name='y_hat')
return y_hat

def predictions(self, logits):
preds = logits
return preds


class AllLSTM(SIFExplainer):
def __init__(self, x_dim, y_dim, time_steps, num_units, share_param, **kwargs):
self.time_steps = time_steps
self.x_dim = x_dim
self.num_units = num_units
self.cells = None
self.y_dim = y_dim
self.share_param = share_param
if share_param:
self.out_weights = tf.Variable(tf.random_normal([self.num_units, 1]))
self.out_bias = tf.Variable(tf.random_normal([1, 1]))
else:
self.out_weights = tf.Variable(tf.random_normal([self.num_units, self.y_dim]))
self.out_bias = tf.Variable(tf.random_normal([1, self.y_dim]))
super().__init__(**kwargs)

def get_all_params(self):
all_params = []
lstm_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="LSTM")
all_params += lstm_variables
all_params.append(self.out_weights)
all_params.append(self.out_bias)
return all_params

def retrain(self, num_steps, feed_dict):
retrain_dataset = DataSet(feed_dict[self.input_index_placeholder], feed_dict[self.labels_index_placeholder])
for step in range(num_steps):
iter_feed_dict = self.fill_feed_dict_with_batch(retrain_dataset)
self.sess.run(self.train_op, feed_dict=iter_feed_dict)


def placeholder_inputs(self):
input_index_placeholder = tuple([tf.placeholder(
tf.int32,
shape=(None, 1),
name='input_index_placeholder_{}'.format(i)) for i in range(self.time_steps)])
labels_index_placeholder = tf.placeholder(
tf.int32,
shape=(None, 1),
name='labels_index_placeholder')
ts_placeholder = tf.placeholder(
tf.float32,
shape=[None, self.y_dim],
name='input_ts')
return input_index_placeholder, labels_index_placeholder, ts_placeholder

def inference(self, input_x, labels_placeholder=None, keep_probs_placeholder=None):
if isinstance(input_x, list) | isinstance(input_x, tuple):
n = input_x[0].shape[2]
x = [tuple(x0[:, :, i] for x0 in input_x) for i in range(n)]
else:
n = input_x.shape[2]
x = [input_x[:, :, i] for i in range(n)]
with tf.variable_scope("LSTM") as vs:
cell = tf.nn.rnn_cell.BasicLSTMCell(self.num_units, name='LSTM_Layer')
def run_lstm(x_n):
output, _ = tf.nn.static_rnn(cell, x_n, dtype=tf.float32)
return tf.matmul(output[-1], self.out_weights) + self.out_bias
res = tf.stack(list(map(run_lstm, x)), axis=1)[:, :, 0]
return res

def predictions(self, logits):
preds = logits
return preds


class AllRNN(SIFExplainer):
def __init__(self, x_dim, y_dim, time_steps, num_units, share_param, **kwargs):
self.time_steps = time_steps
self.x_dim = x_dim
self.num_units = num_units
self.cells = None
self.y_dim = y_dim
self.share_param = share_param
if share_param:
self.out_weights = tf.Variable(tf.random_normal([self.num_units, 1]))
self.out_bias = tf.Variable(tf.random_normal([1, 1]))
else:
self.out_weights = tf.Variable(tf.random_normal([self.num_units, self.y_dim]))
self.out_bias = tf.Variable(tf.random_normal([1, self.y_dim]))
super().__init__(**kwargs)

def get_all_params(self):
all_params = []
rnn_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="RNN")
all_params += rnn_variables
all_params.append(self.out_weights)
all_params.append(self.out_bias)
return all_params

def retrain(self, num_steps, feed_dict):
retrain_dataset = DataSet(feed_dict[self.input_index_placeholder], feed_dict[self.labels_index_placeholder])
for step in range(num_steps):
iter_feed_dict = self.fill_feed_dict_with_batch(retrain_dataset)
self.sess.run(self.train_op, feed_dict=iter_feed_dict)

def placeholder_inputs(self):
input_index_placeholder = tuple([tf.placeholder(
tf.int32,
shape=(None, 1),
name='input_index_placeholder_{}'.format(i)) for i in range(self.time_steps)])
labels_index_placeholder = tf.placeholder(
tf.int32,
shape=(None, 1),
name='labels_index_placeholder')
ts_placeholder = tf.placeholder(
tf.float32,
shape=[None, self.y_dim],
name='input_ts')
return input_index_placeholder, labels_index_placeholder, ts_placeholder

def inference(self, input_x, labels_placeholder=None, keep_probs_placeholder=None):
from tensorflow.keras import layers
model = tf.keras.Sequential()
model.add(layers.Embedding(input_dim=1000, output_dim=64))
# The output of GRU will be a 3D tensor of shape (batch_size, timesteps, 256)
model.add(layers.GRU(256, return_sequences=True))
# The output of SimpleRNN will be a 2D tensor of shape (batch_size, 128)
model.add(layers.SimpleRNN(128))
model.add(layers.Dense(10, activation='softmax'))
model.summary()
if isinstance(input_x, list) | isinstance(input_x, tuple):
n = input_x[0].shape[2]
x = [tuple(x0[:, :, i] for x0 in input_x) for i in range(n)]
else:
n = input_x.shape[2]
x = [input_x[:, :, i] for i in range(n)]
with tf.variable_scope("RNN") as vs:
cell = tf.nn.rnn_cell.BasicLSTMCell(self.num_units, name='RNN_Layer')
def run_rnn(x_n):
output, _ = tf.nn.static_rnn(cell, x_n, dtype=tf.float32)
return tf.matmul(output[-1], self.out_weights) + self.out_bias

res = tf.stack(list(map(run_rnn, x)), axis=1)[:, :, 0]
return res

def predictions(self, logits):
preds = logits
return preds
Loading