forked from tf-encrypted/tf-encrypted
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
50 lines (38 loc) · 1.55 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""Example of a simple average using TF Encrypted."""
import logging
import sys
import tensorflow as tf
import tf_encrypted as tfe
# use configuration from file if specified
# otherwise, fall back to default LocalConfig
if len(sys.argv) >= 2:
# config file was specified
config_file = sys.argv[1]
config = tfe.RemoteConfig.load(config_file)
tfe.set_config(config)
tfe.set_protocol(tfe.protocol.Pond())
@tfe.local_computation(name_scope='provide_input')
def provide_input() -> tf.Tensor:
# pick random tensor to be averaged
return tf.random_normal(shape=(10,))
@tfe.local_computation('result-receiver', name_scope='receive_output')
def receive_output(average: tf.Tensor) -> tf.Operation:
# simply print average
return tf.print("Average:", average)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
# get input from inputters as private values
inputs = [
provide_input(player_name='inputter-0'), # pylint: disable=unexpected-keyword-arg
provide_input(player_name='inputter-1'), # pylint: disable=unexpected-keyword-arg
provide_input(player_name='inputter-2'), # pylint: disable=unexpected-keyword-arg
provide_input(player_name='inputter-3'), # pylint: disable=unexpected-keyword-arg
provide_input(player_name='inputter-4'), # pylint: disable=unexpected-keyword-arg
]
# sum all inputs and divide by count
result = tfe.add_n(inputs) / len(inputs)
# send result to receiver
result_op = receive_output(result)
# run a few times
with tfe.Session() as sess:
sess.run(result_op, tag='average')