-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpolicy_eval.py
77 lines (65 loc) · 1.98 KB
/
policy_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from functools import partial
from typing import Optional, Tuple, Sequence
import trio
from absl import app
from absl import flags
from absl import logging
import gin
import tensorflow as tf
from gym import register
from camera.image import ImageShape
from controller.policy_executor import PolicyExecutor
from env.robot_arm_env import RobotArmEnv
@gin.configurable
async def policy_eval(
serial_port_name: str,
saved_model_path: str,
checkpoint_path: str,
env_name='',
sequence_length=2,
target_update_delta_time=0.1,
command_delta_time=0.01,
observations: Optional[Sequence] = None,
):
register(
id='ScalaArm-v0',
entry_point=RobotArmEnv,
)
executor = PolicyExecutor(saved_model_path, checkpoint_path)
try:
image_shape = ImageShape(*gin.query_parameter('apply_crop_and_reshape.image_shape'))
except ValueError:
image_shape = None
await executor.run(
env_name,
sequence_length,
target_update_delta_time,
command_delta_time,
observations,
image_shape,
serial_port_name
)
FLAGS = flags.FLAGS
flags.DEFINE_multi_string('gin_file', None, 'Paths to the gin-config files.')
flags.DEFINE_multi_string('gin_bindings', None, 'Gin binding parameters.')
flags.DEFINE_string(
'serial_port', '',
'name of serial port e.g.: /dev/tty.usbserial-001')
flags.DEFINE_string(
'saved_model_path', '',
'path of saved policy')
flags.DEFINE_string(
'checkpoint_path', '',
'path of checkpoint of policy')
def main(_):
logging.set_verbosity(logging.INFO)
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings, skip_unknown=False)
tf.config.run_functions_eagerly(False)
trio.run(partial(
policy_eval,
serial_port_name=FLAGS.serial_port,
saved_model_path=FLAGS.saved_model_path,
checkpoint_path=FLAGS.checkpoint_path,
))
if __name__ == '__main__':
app.run(main)