-
Notifications
You must be signed in to change notification settings - Fork 6
/
franka_param_inference.py
133 lines (100 loc) · 4.71 KB
/
franka_param_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#%%
import os
import sys
cur_root_dir = os.getcwd()
print("Current root dir is {}".format(cur_root_dir))
if cur_root_dir not in sys.path:
sys.path.append(cur_root_dir)
#%%
from src.utils.param_inference import *
import numpy as np
#%%
data_file = os.path.join(os.path.join(cur_root_dir, "assets/data/data_stiffness_5k.pkl"))
#%%
g = FrankaDataGenerator(data_file=data_file, load_from_disk=True, params_dim=1, data_dim=198)
params, stats = g.gen(1)
shapes = {"params": params.shape[1], "data": stats.shape[1]}
#%%
log_mdn, inf_mdn = train(epochs=1000, batch_size=150, generator=g, model="MDN", stats_dim=198, params_dim=1,
num_sampled_points=5000)
#%%
log_rff, inf_rff = train(epochs=1000, batch_size=150, generator=g, model="MDRFF", stats_dim=198, params_dim=1,
num_sampled_points=5000)
#%%
true_obs = np.array([[6000]])
get_results_from_true_obs(env_params=["stiffness"], true_obs=true_obs, generator=g, inf=inf_mdn, shapes=shapes,
p_lower=[0.], p_upper=[1.0])
#%%
true_obs = np.array([[40000]])
get_results_from_true_obs(env_params=["stiffness"], true_obs=true_obs, generator=g, inf=inf_mdn, shapes=shapes,
p_lower=[0.], p_upper=[1.0])
#%%
true_obs = np.array([[6000]])
get_results_from_true_obs(env_params=["stiffness"], true_obs=true_obs, generator=g, inf=inf_rff, shapes=shapes,
p_lower=[0.], p_upper=[1.0])
#%%
true_obs = np.array([[40000]])
get_results_from_true_obs(env_params=["stiffness"], true_obs=true_obs, generator=g, inf=inf_rff, shapes=shapes,
p_lower=[0.], p_upper=[1.0])
#%%
data_file = os.path.join(os.path.join(cur_root_dir, "assets/data/data_friction_stiffness_5k.pkl"))
#%%
g = FrankaDataGenerator(data_file=data_file, load_from_disk=True, params_dim=2, data_dim=198)
params, stats = g.gen(1)
shapes = {"params": params.shape[1], "data": stats.shape[1]}
#%%
log_mdn, inf_mdn = train(epochs=3000, batch_size=100, generator=g, model="MDN", stats_dim=198, params_dim=2,
num_sampled_points=5000)
#%%
log_rff, inf_rff = train(epochs=3000, batch_size=100, generator=g, model="MDRFF",stats_dim=198, params_dim=2,
num_sampled_points=5000)
#%%
true_obs = np.array([[0.6, 6000]])
#%%
get_results_from_true_obs(env_params=["friction", "stiffness"], true_obs=true_obs, generator=g, inf=inf_mdn, shapes=shapes,
p_lower=[0., 0.], p_upper=[1.0, 1.0])
#%%
get_results_from_true_obs(env_params=["friction", "stiffness"], true_obs=true_obs, generator=g, inf=inf_rff, shapes=shapes,
p_lower=[0., 0.], p_upper=[1.0, 1.0])
####### Mass, Friction and Stiffness
#%%
data_file = os.path.join(os.path.join(cur_root_dir, "assets/data/data_friction_mass_5k.pkl"))
#%%
g = FrankaDataGenerator(data_file=data_file, load_from_disk=True, params_dim=2, data_dim=198)
params, stats = g.gen(1)
shapes = {"params": params.shape[1], "data": stats.shape[1]}
#%%
log_mdn, inf_mdn = train(epochs=1000, batch_size=150, generator=g, model="MDN", stats_dim=198, params_dim=2,
num_sampled_points=5000)
#%%
log_rff, inf_rff = train(epochs=1000, batch_size=150, generator=g, model="MDRFF", stats_dim=198, params_dim=2,
num_sampled_points=5000)
#%%
true_obs = np.array([[0.1, 0.2]])
get_results_from_true_obs(env_params=["friction", "mass"], true_obs=true_obs, generator=g, inf=inf_mdn, shapes=shapes,
p_lower=[0., 0.], p_upper=[1.0, 1.0])
#%%
true_obs = np.array([[1.0, 1.0]])
get_results_from_true_obs(env_params=["friction", "mass"], true_obs=true_obs, generator=g, inf=inf_rff, shapes=shapes,
p_lower=[0., 0.], p_upper=[1.0, 1.0])
####### Mass
#%%
data_file = os.path.join(os.path.join(cur_root_dir, "assets/data/data_mass_5k.pkl"))
#%%
g = FrankaDataGenerator(data_file=data_file, load_from_disk=True, params_dim=1, data_dim=198, scale_params=True)
params, stats = g.gen(1)
shapes = {"params": params.shape[1], "data": stats.shape[1]}
#%%
log_mdn, inf_mdn = train(epochs=500, batch_size=150, generator=g, model="MDN", stats_dim=198, params_dim=1,
num_sampled_points=5000)
#%%
log_rff, inf_rff = train(epochs=1000, batch_size=150, generator=g, model="MDRFF", stats_dim=198, params_dim=1,
num_sampled_points=5000)
#%%
true_obs = np.array([[0.3]])
get_results_from_true_obs(env_params=["mass"], true_obs=true_obs, generator=g, inf=inf_mdn, shapes=shapes,
p_lower=[0.], p_upper=[1.0])
#%%
true_obs = np.array([[0.5]])
get_results_from_true_obs(env_params=["mass"], true_obs=true_obs, generator=g, inf=inf_rff, shapes=shapes,
p_lower=[0.], p_upper=[1.0])