From 4df319c1fba272c4b6ad2febe0cf4df1c93f50e3 Mon Sep 17 00:00:00 2001 From: GaspTO Date: Sat, 23 Oct 2021 16:03:16 +0100 Subject: [PATCH] First Commit --- README.md | 8 + .../__pycache__/cart_pole.cpython-38.pyc | Bin 0 -> 2102 bytes .../__pycache__/environment.cpython-38.pyc | Bin 0 -> 2418 bytes .../__pycache__/minigrid.cpython-38.pyc | Bin 0 -> 2563 bytes .../__pycache__/tictactoe.cpython-38.pyc | Bin 0 -> 5420 bytes environments/cart_pole.py | 52 ++++ environments/environment.py | 51 ++++ environments/minigrid.py | 71 ++++++ environments/tictactoe.py | 191 +++++++++++++++ game.py | 57 +++++ .../__pycache__/_functional.cpython-38.pyc | Bin 0 -> 1552 bytes ...ling_mask_value_reward_loss.cpython-38.pyc | Bin 0 -> 9397 bytes .../abstract_unrolling_mvr.cpython-38.pyc | Bin 0 -> 9097 bytes ...unrolling_value_reward_loss.cpython-38.pyc | Bin 0 -> 6974 bytes .../__pycache__/functional.cpython-38.pyc | Bin 0 -> 3173 bytes .../illegal_value_reward_loss.cpython-38.pyc | Bin 0 -> 3709 bytes loss_module/__pycache__/loss.cpython-38.pyc | Bin 0 -> 785 bytes .../loss_module_interfaces.cpython-38.pyc | Bin 0 -> 3043 bytes .../monte_carlo_loss.cpython-38.pyc | Bin 0 -> 3361 bytes ...arlo_mask_value_reward_loss.cpython-38.pyc | Bin 0 -> 2342 bytes .../monte_carlo_mvr.cpython-38.pyc | Bin 0 -> 2208 bytes ...nte_carlo_value_reward_loss.cpython-38.pyc | Bin 0 -> 2219 bytes .../node_bootstrap_loss.cpython-38.pyc | Bin 0 -> 2075 bytes ...e_td_mask_value_reward_loss.cpython-38.pyc | Bin 0 -> 2668 bytes .../__pycache__/offline_td_mvr.cpython-38.pyc | Bin 0 -> 2534 bytes ...ffline_td_value_reward_loss.cpython-38.pyc | Bin 0 -> 2353 bytes .../__pycache__/online_td_loss.cpython-38.pyc | Bin 0 -> 2847 bytes ...e_td_loss_constrained_state.cpython-38.pyc | Bin 0 -> 3510 bytes ...e_td_mask_value_reward_loss.cpython-38.pyc | Bin 0 -> 3433 bytes .../__pycache__/online_td_mvr.cpython-38.pyc | Bin 0 -> 3258 bytes .../online_td_value_rewardloss.cpython-38.pyc | Bin 0 -> 3085 bytes .../__pycache__/td_loss.cpython-38.pyc | Bin 0 -> 1997 bytes loss_module/abstract_unrolling_mvr.py | 226 ++++++++++++++++++ loss_module/loss.py | 21 ++ loss_module/monte_carlo_mvr.py | 66 +++++ loss_module/offline_td_mvr.py | 76 ++++++ loss_module/online_td_mvr.py | 108 +++++++++ .../__pycache__/disjoint_mlp.cpython-38.pyc | Bin 0 -> 6930 bytes model_module/__pycache__/model.cpython-38.pyc | Bin 0 -> 2905 bytes model_module/__pycache__/nets.cpython-38.pyc | Bin 0 -> 6446 bytes .../__pycache__/operation.cpython-38.pyc | Bin 0 -> 327 bytes model_module/disjoint_mlp.py | 203 ++++++++++++++++ model_module/operation.py | 2 + .../__pycache__/mask_op.cpython-38.pyc | Bin 0 -> 653 bytes .../__pycache__/next_state_op.cpython-38.pyc | Bin 0 -> 747 bytes .../representation_op.cpython-38.pyc | Bin 0 -> 711 bytes .../__pycache__/reward_op.cpython-38.pyc | Bin 0 -> 732 bytes .../__pycache__/state_value_op.cpython-38.pyc | Bin 0 -> 672 bytes .../__pycache__/observation_op.cpython-38.pyc | Bin 0 -> 666 bytes .../state_action_op.cpython-38.pyc | Bin 0 -> 897 bytes .../__pycache__/state_op.cpython-38.pyc | Bin 0 -> 631 bytes .../input_operations/observation_op.py | 7 + .../input_operations/state_action_op.py | 11 + .../input_operations/state_op.py | 7 + model_module/query_operations/mask_op.py | 10 + .../query_operations/next_state_op.py | 11 + .../query_operations/representation_op.py | 8 + model_module/query_operations/reward_op.py | 12 + .../query_operations/state_value_op.py | 8 + .../best_first_node.cpython-38.pyc | Bin 0 -> 1100 bytes .../__pycache__/mcts_node.cpython-38.pyc | Bin 0 -> 1064 bytes node_module/__pycache__/node.cpython-38.pyc | Bin 0 -> 5784 bytes node_module/best_first_node.py | 21 ++ node_module/mcts_node.py | 17 ++ node_module/node.py | 156 ++++++++++++ .../abstract_best_first_search.cpython-38.pyc | Bin 0 -> 1887 bytes ...abstract_depth_first_search.cpython-38.pyc | Bin 0 -> 619 bytes .../average_minimax.cpython-38.pyc | Bin 0 -> 1295 bytes .../best_first_minimax.cpython-38.pyc | Bin 0 -> 4372 bytes .../__pycache__/expectimax.cpython-38.pyc | Bin 0 -> 1147 bytes .../__pycache__/minimax.cpython-38.pyc | Bin 0 -> 5119 bytes .../monte_carlo_tree_search.cpython-38.pyc | Bin 0 -> 4693 bytes .../__pycache__/planning.cpython-38.pyc | Bin 0 -> 825 bytes .../ucb_best_first_minimax.cpython-38.pyc | Bin 0 -> 3610 bytes ...ucb_monte_carlo_tree_search.cpython-38.pyc | Bin 0 -> 4391 bytes .../uct_best_first_minimax.cpython-38.pyc | Bin 0 -> 4365 bytes ...uct_monte_carlo_tree_search.cpython-38.pyc | Bin 0 -> 4704 bytes planning_module/abstract_best_first_search.py | 38 +++ .../abstract_depth_first_search.py | 5 + planning_module/average_minimax.py | 31 +++ planning_module/minimax.py | 143 +++++++++++ planning_module/planning.py | 18 ++ planning_module/ucb_best_first_minimax.py | 95 ++++++++ .../ucb_monte_carlo_tree_search.py | 114 +++++++++ ...sitional_adversarial_policy.cpython-38.pyc | Bin 0 -> 1629 bytes .../epsilon_greedy_value.cpython-38.pyc | Bin 0 -> 2163 bytes .../epsilon_greedy_visits.cpython-38.pyc | Bin 0 -> 2417 bytes .../exponentiated_visit_count.cpython-38.pyc | Bin 0 -> 2256 bytes .../__pycache__/policy.cpython-38.pyc | Bin 0 -> 1480 bytes .../__pycache__/simple_policy.cpython-38.pyc | Bin 0 -> 2889 bytes .../__pycache__/visit_ratio.cpython-38.pyc | Bin 0 -> 2153 bytes .../compositional_adversarial_policy.py | 33 +++ policy_module/epsilon_greedy_value.py | 39 +++ policy_module/epsilon_greedy_visits.py | 46 ++++ policy_module/policy.py | 23 ++ policy_module/simple_policy.py | 76 ++++++ policy_module/visit_ratio.py | 41 ++++ test.py | 188 +++++++++++++++ .../Simple_Optimizer.cpython-38.pyc | Bin 0 -> 1123 bytes .../simple_optimizer.cpython-38.pyc | Bin 0 -> 1198 bytes utils/optimization/simple_optimizer.py | 29 +++ .../priority_replay_buffer.cpython-38.pyc | Bin 0 -> 1942 bytes ...roportional_priority_buffer.cpython-38.pyc | Bin 0 -> 3088 bytes ...onal_priority_replay_buffer.cpython-38.pyc | Bin 0 -> 2504 bytes .../simple_replay_buffer.cpython-38.pyc | Bin 0 -> 1388 bytes .../storage_module_interface.cpython-38.pyc | Bin 0 -> 1217 bytes .../__pycache__/uniform_buffer.cpython-38.pyc | Bin 0 -> 1422 bytes .../uniform_game_replay_buffer.cpython-38.pyc | Bin 0 -> 1599 bytes utils/storage/proportional_priority_buffer.py | 82 +++++++ utils/storage/uniform_buffer.py | 35 +++ 110 files changed, 2436 insertions(+) create mode 100644 README.md create mode 100644 environments/__pycache__/cart_pole.cpython-38.pyc create mode 100644 environments/__pycache__/environment.cpython-38.pyc create mode 100644 environments/__pycache__/minigrid.cpython-38.pyc create mode 100644 environments/__pycache__/tictactoe.cpython-38.pyc create mode 100644 environments/cart_pole.py create mode 100644 environments/environment.py create mode 100644 environments/minigrid.py create mode 100644 environments/tictactoe.py create mode 100644 game.py create mode 100644 loss_module/__pycache__/_functional.cpython-38.pyc create mode 100644 loss_module/__pycache__/abstract_unrolling_mask_value_reward_loss.cpython-38.pyc create mode 100644 loss_module/__pycache__/abstract_unrolling_mvr.cpython-38.pyc create mode 100644 loss_module/__pycache__/abstract_unrolling_value_reward_loss.cpython-38.pyc create mode 100644 loss_module/__pycache__/functional.cpython-38.pyc create mode 100644 loss_module/__pycache__/illegal_value_reward_loss.cpython-38.pyc create mode 100644 loss_module/__pycache__/loss.cpython-38.pyc create mode 100644 loss_module/__pycache__/loss_module_interfaces.cpython-38.pyc create mode 100644 loss_module/__pycache__/monte_carlo_loss.cpython-38.pyc create mode 100644 loss_module/__pycache__/monte_carlo_mask_value_reward_loss.cpython-38.pyc create mode 100644 loss_module/__pycache__/monte_carlo_mvr.cpython-38.pyc create mode 100644 loss_module/__pycache__/monte_carlo_value_reward_loss.cpython-38.pyc create mode 100644 loss_module/__pycache__/node_bootstrap_loss.cpython-38.pyc create mode 100644 loss_module/__pycache__/offline_td_mask_value_reward_loss.cpython-38.pyc create mode 100644 loss_module/__pycache__/offline_td_mvr.cpython-38.pyc create mode 100644 loss_module/__pycache__/offline_td_value_reward_loss.cpython-38.pyc create mode 100644 loss_module/__pycache__/online_td_loss.cpython-38.pyc create mode 100644 loss_module/__pycache__/online_td_loss_constrained_state.cpython-38.pyc create mode 100644 loss_module/__pycache__/online_td_mask_value_reward_loss.cpython-38.pyc create mode 100644 loss_module/__pycache__/online_td_mvr.cpython-38.pyc create mode 100644 loss_module/__pycache__/online_td_value_rewardloss.cpython-38.pyc create mode 100644 loss_module/__pycache__/td_loss.cpython-38.pyc create mode 100644 loss_module/abstract_unrolling_mvr.py create mode 100644 loss_module/loss.py create mode 100644 loss_module/monte_carlo_mvr.py create mode 100644 loss_module/offline_td_mvr.py create mode 100644 loss_module/online_td_mvr.py create mode 100644 model_module/__pycache__/disjoint_mlp.cpython-38.pyc create mode 100644 model_module/__pycache__/model.cpython-38.pyc create mode 100644 model_module/__pycache__/nets.cpython-38.pyc create mode 100644 model_module/__pycache__/operation.cpython-38.pyc create mode 100644 model_module/disjoint_mlp.py create mode 100644 model_module/operation.py create mode 100644 model_module/query_operations/__pycache__/mask_op.cpython-38.pyc create mode 100644 model_module/query_operations/__pycache__/next_state_op.cpython-38.pyc create mode 100644 model_module/query_operations/__pycache__/representation_op.cpython-38.pyc create mode 100644 model_module/query_operations/__pycache__/reward_op.cpython-38.pyc create mode 100644 model_module/query_operations/__pycache__/state_value_op.cpython-38.pyc create mode 100644 model_module/query_operations/input_operations/__pycache__/observation_op.cpython-38.pyc create mode 100644 model_module/query_operations/input_operations/__pycache__/state_action_op.cpython-38.pyc create mode 100644 model_module/query_operations/input_operations/__pycache__/state_op.cpython-38.pyc create mode 100644 model_module/query_operations/input_operations/observation_op.py create mode 100644 model_module/query_operations/input_operations/state_action_op.py create mode 100644 model_module/query_operations/input_operations/state_op.py create mode 100644 model_module/query_operations/mask_op.py create mode 100644 model_module/query_operations/next_state_op.py create mode 100644 model_module/query_operations/representation_op.py create mode 100644 model_module/query_operations/reward_op.py create mode 100644 model_module/query_operations/state_value_op.py create mode 100644 node_module/__pycache__/best_first_node.cpython-38.pyc create mode 100644 node_module/__pycache__/mcts_node.cpython-38.pyc create mode 100644 node_module/__pycache__/node.cpython-38.pyc create mode 100644 node_module/best_first_node.py create mode 100644 node_module/mcts_node.py create mode 100644 node_module/node.py create mode 100644 planning_module/__pycache__/abstract_best_first_search.cpython-38.pyc create mode 100644 planning_module/__pycache__/abstract_depth_first_search.cpython-38.pyc create mode 100644 planning_module/__pycache__/average_minimax.cpython-38.pyc create mode 100644 planning_module/__pycache__/best_first_minimax.cpython-38.pyc create mode 100644 planning_module/__pycache__/expectimax.cpython-38.pyc create mode 100644 planning_module/__pycache__/minimax.cpython-38.pyc create mode 100644 planning_module/__pycache__/monte_carlo_tree_search.cpython-38.pyc create mode 100644 planning_module/__pycache__/planning.cpython-38.pyc create mode 100644 planning_module/__pycache__/ucb_best_first_minimax.cpython-38.pyc create mode 100644 planning_module/__pycache__/ucb_monte_carlo_tree_search.cpython-38.pyc create mode 100644 planning_module/__pycache__/uct_best_first_minimax.cpython-38.pyc create mode 100644 planning_module/__pycache__/uct_monte_carlo_tree_search.cpython-38.pyc create mode 100644 planning_module/abstract_best_first_search.py create mode 100644 planning_module/abstract_depth_first_search.py create mode 100644 planning_module/average_minimax.py create mode 100644 planning_module/minimax.py create mode 100644 planning_module/planning.py create mode 100644 planning_module/ucb_best_first_minimax.py create mode 100644 planning_module/ucb_monte_carlo_tree_search.py create mode 100644 policy_module/__pycache__/compositional_adversarial_policy.cpython-38.pyc create mode 100644 policy_module/__pycache__/epsilon_greedy_value.cpython-38.pyc create mode 100644 policy_module/__pycache__/epsilon_greedy_visits.cpython-38.pyc create mode 100644 policy_module/__pycache__/exponentiated_visit_count.cpython-38.pyc create mode 100644 policy_module/__pycache__/policy.cpython-38.pyc create mode 100644 policy_module/__pycache__/simple_policy.cpython-38.pyc create mode 100644 policy_module/__pycache__/visit_ratio.cpython-38.pyc create mode 100644 policy_module/compositional_adversarial_policy.py create mode 100644 policy_module/epsilon_greedy_value.py create mode 100644 policy_module/epsilon_greedy_visits.py create mode 100644 policy_module/policy.py create mode 100644 policy_module/simple_policy.py create mode 100644 policy_module/visit_ratio.py create mode 100644 test.py create mode 100644 utils/optimization/__pycache__/Simple_Optimizer.cpython-38.pyc create mode 100644 utils/optimization/__pycache__/simple_optimizer.cpython-38.pyc create mode 100644 utils/optimization/simple_optimizer.py create mode 100644 utils/storage/__pycache__/priority_replay_buffer.cpython-38.pyc create mode 100644 utils/storage/__pycache__/proportional_priority_buffer.cpython-38.pyc create mode 100644 utils/storage/__pycache__/proportional_priority_replay_buffer.cpython-38.pyc create mode 100644 utils/storage/__pycache__/simple_replay_buffer.cpython-38.pyc create mode 100644 utils/storage/__pycache__/storage_module_interface.cpython-38.pyc create mode 100644 utils/storage/__pycache__/uniform_buffer.cpython-38.pyc create mode 100644 utils/storage/__pycache__/uniform_game_replay_buffer.cpython-38.pyc create mode 100644 utils/storage/proportional_priority_buffer.py create mode 100644 utils/storage/uniform_buffer.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..460f800 --- /dev/null +++ b/README.md @@ -0,0 +1,8 @@ +This is a modular architecture for model based reinforcement learning using search. + +The components are separated, facilitating the creation of the agents and +the extension of the existing components + + +run test.py for an example: it will ask you to choose from different components +and then run. diff --git a/environments/__pycache__/cart_pole.cpython-38.pyc b/environments/__pycache__/cart_pole.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0beb93aed506295c227db8b797091a3442fcda8 GIT binary patch literal 2102 zcmb7F&u`;I6rQo2#7)w)-7cjpVsV9p6saUaAaOxxS7kvQh>%)1L}*9GGfmt&c9?P6 zZIoWv{vqwL|B|m<=)b^;?@gLu`a{52^Jd2L#@~D2d-MF`>S~>#9q;7fsKVIqv?wKPuPg}Ij6akIDH4Bo46yd?{W4mlb)pJH#o`KylA{i zPhy>=BbDYoTJMEA{~=2h-bbWUZY~xt(9KT}iuJiYJ?G&C- z$)P6hubFOQ`15sln2l67ji*zUbiY>SIM2r2gE$;y-Mwryp5#i4{U}yxq`C*nsP6Zf zoFt)ssiR?>t0P!kqU;P2@c_v87g_h zVckKyR(LW?RndBBjM90WrLVNkbkVS9MUJqfmUq1ZeZ&C)DK?{t)<{c{{enX$AsuU1 zbb~aiVMNQmi@knTtJ6@+!i&?x%*JglR-vgkA=c3i3C3G|9eoQVSMf|&mqZg}MAj}E z%_-k3$F>d*5g}0}}~KKOte=sc0lduAkz?HMjQPHS%3- zJwrD>1hb|g3tCI#><1pVu7yJsq)UFqs0iQ&@Lt=wMURpU{B1+$E-XO?=*Hf$GN3*4 ze*pD+_|5_n0ZFCO4`RL#y^l(mYb7V?&^=J%)W(1^UUD$)^lr%Dxi=5^RjdtEF6J2` zOgvQ|m+-cicz2Zk6L6c~kQx619j7P^!VJT)YCj5HQo0Ad%sZ-a6j{@fTsi}k?g64J zUAl_XZtz_iEZo@*@}Uay(=3=fZh}y&AWTlfGZQ3g5GH|5PrZguqSDWk^da^(h!}g~ zmP4~kPE;Y7Q{}!heYODY+*7^h%DwPu`q1M9)&$D27UDC)yv@t?CEo2O*-u5V-J}Yd zqUs1y)P)#P@uhiFh+ijRGT+g3KWrh;bkQ`Wy{0_WACUNvL_lJP1XbdautkJoPIbr7 zrS^H<^Xq=qcW~MYa_Fmvb+_B4%w=n%w72|1bTo!Q}t| literal 0 HcmV?d00001 diff --git a/environments/__pycache__/environment.cpython-38.pyc b/environments/__pycache__/environment.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6bec35bb1ab2387b0e514d98f6b8ecc782cf8d9 GIT binary patch literal 2418 zcmb_e&2Aev5GMCWl7A%IvYo_j<7^KFD!{RNE{wuOViYY}1*jba&}Fe+j;yWsM7x`Yrkv1$gZ#uh3Ixc4c|x*lkgS0=b->A-|d7$5}nE*B2SKXJ2)KKU~KC zB%$1F6!x*JuW=Ae@R%jM%M+_>C3e^5^v;T%u8U7Qt|ZlNm9y`daD@Ao3HLXvTNBnH zYgOLkde*Xy^JAp6agTx5Y3U#N&O(Qr~S_Yv`K=(5dl5x3JP}s+=e#Ze=mkZXl zzy=50Un)?A8Z7jjEp+dlB`jgTb-IgFolO=!8+Fb*E?iMT-IA#0QcWzNv^#1E&QXAFq|+a&6~sBnKA?M3F1Ay;o;5%9AWm!Z=f~hVrc_x%O_5*`dyWC?I(o z5E3L6AU8@#&G~pp{ZXJ+PC)wwFMSn_;eHA8(=z4_yN?GRrN}a0^@9O?@=xF;ySu^j zlPqq{m!5!LFpRa=%cQ5G^dtsv5C6BRuOBG9^g-AM;psl04sSmN7-p%$ z6?joc#Z|eLS!)zb3@1T6J#_VXDVB#fW0|v@3x&QIlcfJ1f^5!f(vVlO>lPZUNE(%D{s&S>zv(@_i!Icr_rj)2ly~nIxy}S zWTx`VU|h_H1Y-wrY#QGxEQC}Bi}GgYm7!ao^Vv8_piXa=i!$j+{{>>H{0qrzyn%h2 zkMANs+0SKy{H2~}qQt++f8^ipqBm_<44AR>3=jW*F~G~S_%ECehZ5_Io|1Zpa*)?* z`$JC8o3=Jh&G*wF0pB-^zMo`b7}I>&_g@Z!c&f3ZR7>KC+`_?BQxQlRjEvihvp^f? zILl&536oT~CDm2gq+y)~Dt>a41}aICJj*#DAJNRCA+OBW1J!6reR&$Hif~n}y0Bb_AaqcC#VqCnwgHhqJ9UO-e|1K4~-UerCdqwcy5Car} KZylfL{U)U1SGOS#8F@rrI~EE*JIE04Bg`- zUTq{q^2$HhkNyw7qF;IPFCZSM>hUh?OGIYW)t9NBQ(dR3`cb>xAn-i>^GWRa>%q zrAUfyB6S(3X?L`pu`H3vNG1FVXg5JSQMt^7EQk1u`NxOJEok~F2tgvsN#yYNEIf@| zUgO~#5_!DN8_@f_$y?9|yv;k%*Z2Zogg)e__-W|t{0u(}eS@Fl=b>-%3w#Os7I)zQ z7fpEoz^?rO;Pxqq!JeaVV#N?g(+9^MwA;}1_aG+DgfKF3LDJIMa5uaSe?mFEKqek{ z{=lyRtk-tJ3E2pCDb6Q8C!2NkEw1d;C;l!)FM~1eVBdj%_8jBsWGu`=JOckPaG@$$ zk?UT=bd-1zE6(&VP7^sY^(=k~3Zl@aHSc8O3He+bUt~qO9fd2{fzM47mZy$4;R_WP zg-|+bf0vCOJkyV>Ilv3Sdja<8nxO;L1hZ5N1;(N}4PuX|bI|?0+F#Ez(U-}1EYkiR zp*PFC=&vU6DDU6Svtp|hiai`ALJmcLRcO&4tUPA-A4MxaKIqR`qx;hUSuVCs$QaO4 zGIka)G$Iokz`H~jX-LPXrlc)T^@ixV>4On8qvlt2syJIsKX4R@31_e-2p z?raA7njOQ&4Trm55{EF<@^*mGUHZg(<-9_g$t!QR-%-DCm|*~Lf74SbcR>et1rztI zMV)wS9#9($#@C0j?8?0CikD&tUg=uex-Zs+?2h70bk}3uU4v8Wb#T6`FcPaU#0o_0S#S`j@vUT5N3#0Paho5Xu!X4jDM(1 zAR$|Jz@6ycU9yX{#2wtNju)sU&{Y>f7)KVykClqIl|75EAcD$0YR_@h*ekH(8Z>QF z*2c@YGc(WbJZ2w;jm_o`48R_;LwBJrVm|KCLC-mW9i}X}iZtGa0{QhE_N7A(IIMRJ9{p^ck0Fh0tHuaIi$lfX?g(xmTll@f zk!%P^537b<>#lobkm&79LVPX5dND5!pNO9^BP$50M&!YezA5 zXU;Q=Gu(rZP;amlTYnD}5V)*P>zLU#y%gCjbBd literal 0 HcmV?d00001 diff --git a/environments/__pycache__/tictactoe.cpython-38.pyc b/environments/__pycache__/tictactoe.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56167cde3e10d4c58cbad0fa5ad8420c8ae98139 GIT binary patch literal 5420 zcmb_g&2JmW6`$EJE-8wlWJQ*2*JbRaH3QcQl76^FP&iGj6m6;ovD^Uc$AaOEq_xOh zdUj<;ECm#mduxH9haQR?q#la;&_AI-4n6nKLl1jwPx=RfB8R5GH?t&n6&omwlGxe# zo;PoP@Auy9-TC<%!|%$U-|K$=0%QN8LhcE{#aXZM|f z+i^L2iwQ@#kC0`w=bGv5$~#466f%) z$vLt75f{&il}ApeF3yWp)Xs}FaRKj!xF{~+y&%3IF5}%4Ulc987sd193f@cNDmL?C zR^1AFTixDPEVnjJEy&j+6-|DWr^+|*B-fCP`IrT4Y$1i(kXmC~@I5cNoZ7$SV;eP& zV3^t3Y}<-s#3?1pMbdV&Y9jmF!LZ*w$m;Tgp;T#*q+OL}?)!09iLCK@lAtn-quWZw zDqGAuiEw*6?Ctl{gSL}7G>NLvv_2(sNFFmakJqPn*LUNAT#v%hNcPv?l*wKi57)QC z?oPb^Mm!kqr&0xXdZCPZa(zoCa((mmd%+Lxc5c7>!*w}3!(=@Tdug|q#`4DSAoGJD zjKVYsXpuxijC=ZD=cA>m3U5rh^+;)xpN9-7QY-NuGGlg+sYc2LKjKgtyJrcu=c`s~ z3(M3TJ7a#rf8rckhxVa!=%SCC{A}#tVaMJHJH)87eS|%>m)W6>!?DLs$hI9EgR?oh z-s?uKC{A1Q1KGp=TS+R1N$cURj9NS0fo$z|lh!u2m+VUMY4eS4zu$Tnnn_WqlJxX5 z{Mwbw5pg6HZCzE7WVIi4`}@)ylZ#W7LwzJ!BgoGT%ELLfJyoafo++BKo{|r7+-iwh zc%YBYBn_%@XtrhK66(NsgV&hPmw1g+>hn=!rg7@pBZy6Yr1`0~-$6lW$fg`xwuCL5 z-w=3h0i6bp*6RjRPkjp-t>l{ZN&B462{yk+{k^n`l6|BZU~}WKtr2}nmaQ%#$=onX zUws8k;EMDBz$TH}HM2u28fNZDsyH#*Z}j$+0#Rr~kg5RX=V(6KxuR{fEv-AXLKTF( zf{t88M|2sLxsp8I%pvJRQGn&bOi8+LNlPT&R7*Ed!0y0(K;hB^mnrs*I&00e*=uZ- z9YHJFyq!(eB!RSX6To5Z1PgTzN#+g})bW@b@y|$(!;geiZFEd%B-5umY8J(%4xf{k z*3vdaf@?0eHc$1UdWEZ~J*%aK9hn9Oc|j76N{@9*AnGOFIy82V`Z4?M?R6?;ZkFATA5!rm0Os z4A3IV+6L5wy=SRwy4LwzEkAZ?Bxm=prQR6i2?e-05MP8C1vFM>Xf!tHs&2F+!A?0E z(l{8#3213J1RRaIxf-t>w=Z|PeQkA_t>TB7-HZEYk+nz))gAL0lW&qf$#xB(J?3x+ zcG%$8$r>+~wObkq)__r@`zk7_nea|n)tJADj8;p@MoG=WBHtnGJJb$-gEPL7I@+(G zPNO7O(uQFrtDq;hpeOB7TUf`i{8M36W-l!rQM}o94Nz+EvtTL@DS1Z%6FqvKzagV{ z3D(t2H1UU)_J&LdEw1wG04sy_-Rf+snoIjh1s&xE&vwxF8UXs#nZrWTfa z;L7IKsQwoD*)VCv4?y=*&KzBN2a!-1?X<$E1;oj&IDpgb_oZk7oFVy76FJRk5NmQ( zFF{A<)QJ5vx2qI|Q^9Wcfeapo5xjI}_hqDr)>VrV?Mw{1jWfAUwM2&MJSAm@)(+@x z)Fz{7lm-dKc~dOJ_>tFhwq{HAbwn=c*>&)*ermQwDrOuNN2y z0q>POZrKK}A!gAkI@Q4)IoCGVh-^CrcR1U@8R1DiKgN^LaK@J*JR*kGnfa9&pl4IJ z{8y5j{2R%=45XbW`Cv^5Yqh)A8k@*76Xi*cxfNl)&bTMl}>o8)4qFAMFB_beI zPEF+7fKLkdcf`S`DFuS}pwuy+u;;OL!jhZ1)-Km73>63p%?Mm{*so|DOo;%l$2~CP zp46s%Gy1--`?xt*|4Em;QmJA}$=`LGpSS%(m#U>wZBm-bHD_u*W816?R|(hQK<2-fo`G2`J`JAMWIB%Xv2 z%B(tU*yA_3Z@)BJ-gf%*Ubt{FunqIbkPU)3^&~WvPFr9+5QAe0=DDbgeONkc3S_G{~3w{5Ivs zez^ri0>Y-1=%FjpaTUbto{)>MEyfXh&}=!Db##%>cFW9GY&kO=3c*83#yr$e-g$c9Aex{3>rcbR%5lSCIOadDp-Pb$IhqolTyeV%8PJ zt`xmeFh_Y`=bzqE@?@nC042^K2BSxF26@5!8mB04Hga{E)BDmN1S$by3>5@_LEGf| zwpgGj(~2)nTN{RG&(Rl~2tij6WVIj|#A3fs`FarC-|zNKi&~(i)+l+Nk{2lX5+z@z z=U=26(ieK~Re8;!(_xxI; z>M!8!_$&T*bYFaFadGt;9w;5%IrM!|$8Xtc@$=M;LP1doQ_$c{46Qu>HM4yR9L$A} I4zphUFAefrbN~PV literal 0 HcmV?d00001 diff --git a/environments/cart_pole.py b/environments/cart_pole.py new file mode 100644 index 0000000..f371bd9 --- /dev/null +++ b/environments/cart_pole.py @@ -0,0 +1,52 @@ +from environments.environment import Environment +import gym +from gym.envs.classic_control.cartpole import CartPoleEnv + + + +''' +Adapted from https://github.com/werner-duvaud/muzero-general +''' + +class CartPole(Environment): + def __init__(self,max_steps=1000): + self.environment = gym.make("CartPole-v1") + self.max_steps = max_steps + + def step(self,action): + assert not self.done, "can not execute steps when game has finished" + assert self.steps_taken < self.max_steps + self.steps_taken +=1 + obs, reward, self.done, info = self.environment.step(action) + if self.steps_taken == self.max_steps: + self.done = True + self.current_observation = obs + return obs, reward, self.done, info + + def reset(self): + self.done = False + self.steps_taken = 0 + self.current_observation = self.environment.reset() + return self.current_observation + + def close(self): + self.environment.close() + + def render(self): + self.environment.render() + + def get_action_size(self): + return 2 + + def get_input_shape(self): + return (4,) + + def get_legal_actions(self): + """ In CartPole, the two actions are always legal """ + if not self.done: + return [0,1] + else: + return [] + + def __str__(self): + return "CartPole-v1" \ No newline at end of file diff --git a/environments/environment.py b/environments/environment.py new file mode 100644 index 0000000..1a2da60 --- /dev/null +++ b/environments/environment.py @@ -0,0 +1,51 @@ +from typing import List, Tuple, Dict +import numpy as np +from abc import abstractmethod + +""" +Environments work with numpy arrays, so don't forget to convert them to torch tensors when appropriate +""" +class Environment: + + def step(self,action:int) -> Tuple[np.ndarray,float,bool,Dict]: + """ return next_observation, reward, done, info.""" + raise NotImplementedError + + def reset(self) -> Tuple[np.ndarray,int,np.ndarray]: + raise NotImplementedError + + def close(self) -> None: + return None + + def render(self) -> None: + raise NotImplementedError + + def get_action_size(self) -> int: + raise NotImplementedError + + def get_input_shape(self) -> Tuple[int]: + raise NotImplementedError + + def get_num_of_players(self) -> int: + """ default for single player environments """ + return 1 + + def get_legal_actions(self) -> List[int]: + "return an empty list when environment has reached the end, for consistency" + raise NotImplementedError + + def get_action_mask(self) -> np.ndarray: + legal_actions = self.get_legal_actions() + mask = np.zeros(self.get_action_size()) + mask[legal_actions] = 1 + assert (np.where(mask == 1)[0] == legal_actions).all() + return mask + + def get_current_player(self) -> int: + """ default for single player environments. + return a player even when environment has reached the end, for consistency """ + return 0 + + + + \ No newline at end of file diff --git a/environments/minigrid.py b/environments/minigrid.py new file mode 100644 index 0000000..cb75f4c --- /dev/null +++ b/environments/minigrid.py @@ -0,0 +1,71 @@ +import gym +from gym.core import Env +try: + import gym_minigrid +except ModuleNotFoundError: + raise ModuleNotFoundError('Please run "pip install gym_minigrid"') +import numpy as np +from environments.environment import Environment +import random + + +''' +Original gym environment: https://github.com/maximecb/gym-minigrid + +set agent_start_pos to None for it to be a random every time you reset +''' + + +class Minigrid(Environment): + def __init__(self,N=6,reward_scaling=1, max_steps=None,agent_start_pos=(1,1),seed=None): + self.reward_scaling = reward_scaling + self.max_steps = max_steps + self.environment = gym_minigrid.envs.empty.EmptyEnv(size=N+2,agent_start_pos=agent_start_pos) + self.environment = gym_minigrid.wrappers.ImgObsWrapper(self.environment) + if seed is not None: + self.environment.seed(seed) + + def step(self, action): + assert not self.done, "can not execute steps when game has finished" + assert self.steps_taken < self.max_steps + assert action in [0,1,2] + self.steps_taken +=1 + obs, reward, self.done, info = self.environment.step(action) + if self.steps_taken == self.max_steps: + self.done = True + if reward > 0: + #Ths minigrid gives a reward according to how many steps it took before, + #which goes against the markov property + reward = 1 + return obs, self.reward_scaling*reward, self.done, info + + def reset(self): + self.done = False + self.steps_taken = 0 + return np.array(self.environment.reset()) + + def close(self): + self.environment.close() + + def render(self): + return self.environment.render() + + def get_action_size(self): + return 3 + + def get_input_shape(self): + return (7,7,3) + + def get_num_of_players(self): + return 1 + + def get_legal_actions(self): + if self.done: + return [] + else: + return [0,1,2] + + def __str__(self): + return "MiniGrid" + + diff --git a/environments/tictactoe.py b/environments/tictactoe.py new file mode 100644 index 0000000..0639484 --- /dev/null +++ b/environments/tictactoe.py @@ -0,0 +1,191 @@ +import numpy as np +from copy import deepcopy +from environments.environment import Environment + + +""" This agent, when playing against expert has 3 difficulties: +0 - only random plays +1 - defends when obvious and attacks randomly +2 - defends and attacks when obvious, random otherwise +""" + +class TicTacToe(Environment): + def __init__(self,self_play=True,expert_start=False,expert_difficulty=2): + self.self_play = self_play + self.expert_start = expert_start + self.board = None + assert expert_difficulty in [0,1,2] + self.expert_difficulty = expert_difficulty + + def step(self, action): + assert not self.done, "can not execute steps when game has finished" + if self.board is None: raise ValueError("Call Reset first") + board, reward, done, _ = self._step(action) + if not self.self_play and not done: + action = self._expert_action() + board, reward, done, _ = self._step(action) + reward = -1 * reward + return board, reward, done, {} + + def reset(self): + self.done = False + self.board = np.zeros((2, 3, 3), dtype="int32") + self.current_player = 0 + if not self.self_play: + if self.expert_start: + action = self._expert_action() + self._step(action) + self.expert_start = (self.expert_start == False) #alternate + return deepcopy(self.board) + + def render(self): + if self.board is None: raise ValueError("Call Reset first") + print(self.board[0] - self.board[1]) + + def get_action_size(self): + return 9 + + def get_input_shape(self): + return (2,3,3) + + def get_num_of_players(self): + if self.self_play: + return 2 + else: + return 1 + + def get_legal_actions(self): + if self.board is None: raise ValueError("Call Reset first") + if self.done: return [] + legal = [] + for action in range(9): + row, col = self._action_to_pos(action) + if self.board[0][row, col] == 0 and self.board[1][row, col] == 0: + legal.append(action) + return deepcopy(legal) + + def get_current_player(self) -> int: + if self.board is None: raise ValueError("Call Reset first") + if self.self_play is False: + return 0 + else: + assert self.current_player in [0,1] + return self.current_player + + + def _step(self,action): + if self.done == True: + raise ValueError("Game is over") + row,col = self._action_to_pos(action) + if self.board[0][row,col] != 0 or self.board[1][row,col] != 0: + raise ValueError("Playing in already filled position") + + self.board[0,row, col] = 1 + self.board = np.array([self.board[1],self.board[0]]) #switch + self.done = self._have_winner() or len(self.get_legal_actions()) == 0 + reward = 1 if self._have_winner() else 0 + self.current_player = (self.current_player + 1) % 2 + + return deepcopy(self.board), reward, self.done, {} + + def _action_to_pos(self,action): + assert action >= 0 and action <= 8 + row = action // 3 + col = action % 3 + return (row,col) + + def _pos_to_action(self,row,col): + action = row * 3 + col + return action + + def _have_winner(self): + # Horizontal and vertical checks + for i in range(3): + if (self.board[0,i] == 1).all() or (self.board[1,i] == 1).all(): + return True #horizontal + if (self.board[0,:,i] == 1).all() or (self.board[1,:,i] == 1).all(): + return True #verticals + + #diagonals + if (self.board[0,0,0] == 1 and self.board[0,1,1] == 1 and self.board[0,2,2] == 1 or \ + self.board[1,0,0] == 1 and self.board[1,1,1] == 1 and self.board[1,2,2] == 1 + ): + return True + + + if (self.board[0,0,2] == 1 and self.board[0,1,1] == 1 and self.board[0,2,0] == 1 or \ + self.board[1,0,2] == 1 and self.board[1,1,1] == 1 and self.board[1,2,0] == 1 + ): + return True + + return False + + def _expert_action(self): + board = self.board + summed_board = 1*board[0] + -1*board[1] + action = np.random.choice(self.get_legal_actions()) + + + # Horizontal and vertical checks + if self.expert_difficulty == 2: + for i in range(3): + if sum(summed_board[i,:]) == 2: #attacking row position + col = np.where(summed_board[i, :] == 0)[0][0] + action = self._pos_to_action(i,col) + return action + + if sum(summed_board[:,i]) == 2: #attacking col position + row = np.where(summed_board[:,i] == 0)[0][0] + action = self._pos_to_action(row,i) + return action + + if self.expert_difficulty >= 1: + for i in range(3): + if sum(summed_board[i,:]) == -2: #defending row position + col = np.where(summed_board[i, :] == 0)[0][0] + action = self._pos_to_action(i,col) + return action + + if sum(summed_board[:,i]) == -2: #defending col position + row = np.where(summed_board[:,i] == 0)[0][0] + action = self._pos_to_action(row,i) + return action + + # Diagonal checks + diag = summed_board.diagonal() #left_up-right_dow + anti_diag = np.fliplr(summed_board).diagonal() #left_down-right_up + if self.expert_difficulty == 2: + if sum(diag) == 2: #attacking diag + ind = np.where(diag == 0)[0][0] + row = ind + col = ind + action = self._pos_to_action(row,col) + return action + + if sum(anti_diag) == 2: #attacking anti-diag + ind = np.where(anti_diag == 0)[0][0] + row = ind + col = 2-ind + action = self._pos_to_action(row,col) + return action + + if self.expert_difficulty >= 1: + if sum(diag) == -2: #defending diag + ind = np.where(diag == 0)[0][0] + row = ind + col = ind + action = self._pos_to_action(row,col) + return action + + if sum(anti_diag) == -2: #defending anti-diag + ind = np.where(anti_diag == 0)[0][0] + row = ind + col = 2-ind + action = self._pos_to_action(row,col) + return action + + return action + + def __str__(self): + return "TicTacToe" + diff --git a/game.py b/game.py new file mode 100644 index 0000000..15c263a --- /dev/null +++ b/game.py @@ -0,0 +1,57 @@ +from copy import deepcopy +from typing import List, Tuple, Dict + + +''''''''''''''''''''''''''''''''''''''''''''''''''''''''' +* * +* PUBLIC INTERFACES OF SEARCH MODULE * +* * +''''''''''''''''''''''''''''''''''''''''''''''''''''''''' +class Game: + """ A game organizes a sequence of transitions in an environment. Each + state in the environment has a node and, in between two nodes there is an action + executed an a reward received this class stores """ + def __init__(self,observation_shape=None,action_size=None,num_players=None): + self.observation_shape = observation_shape + self.action_size = action_size + self.num_players = num_players + self.observations = [] + self.nodes = [] + self.actions = [] + self.masks = [] + self.players = []### + self.rewards = [] + self.dones = [] + self.infos = [] + self.info = {} # don't confuse this attribute with the above. The above is a list + # of the infos returned by the environment. This one is for + # any extra information we want to associate with this class instance + + def __len__(self): #num of transitions + return len(self.actions) + + def get_observations(self) -> List: #deprecated + assert False + return [n.get_observation() for n in self.nodes] + + def get_players(self) -> List: #deprecated + return [n.get_player() for n in self.nodes] + + def get_dones(self) -> List: #this is quite the useless method + return [n.is_terminal() for n in self.nodes] + + def get_player_rewards(self): + rewards = {} + assert len(self.players)-1 == len(self.rewards) + for p,r in zip(self.players,self.rewards): + if p not in rewards: + rewards[p] = r + else: + rewards[p] += r + return rewards + + + + + + diff --git a/loss_module/__pycache__/_functional.cpython-38.pyc b/loss_module/__pycache__/_functional.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11262fce2742de2962690e0957d140a17fb845a9 GIT binary patch literal 1552 zcmbUhO>g5wbjGpWG+o8o=cb^Zc5wAaPj$v9;)atwd7Hes&= zJOhx|fE1cyimCM)71l1g!ZTbtU~a*eF144qm7+Ga3TK9oYh}G$EgMB6#^qX32kgbD zY^qveU*TPJ@}_J7zpm_JrEssn_R06hR@p8$77{yUt7sRk50G+-jZ^muzeM7Oi5~$sY~>W*rO@>BCPdL6Uf=`~er|L{NVq_{dj7=8wh`Cb&<8N}`lK z^@mF3@@Wv9pP%;uSqfMQgj6R*Sto}wB0)jG5H(OU#vvmP@Q@}DIWuZT)`_SVu?udi*>Ou{r{ zg|2@LhJ3_=ESXGL8tgLpQt>=ENJz|sXMB{86%*lpl&~yf!2y#j*n9pW{PytZ`48U( zDVH)FaXL;}5Dvy!WE_#SmtX3}cA7{9rObCeHL;qYU1XujC*dMzqsr3WcFH4?${iqX zRsMUT=@n9yR%(~zIm@VaNVQp6u3ZqOe5BoI z$P-ZB6f7DGFr@A*b~m*xSvt^bb3>s_U<^z%X{UMw?TweZQ+rE|Bei4dDYa)duu0hk zaPBLTj$y~4Zr%EWAy9DCzI)++p4_Ny+{p-LMpWumXa+?>sHy0f$Z9}o+c?#m8M{y+ z)CL-c!UP#N@4T&UzuQ>8DZFVukNyU>8J}eU literal 0 HcmV?d00001 diff --git a/loss_module/__pycache__/abstract_unrolling_mask_value_reward_loss.cpython-38.pyc b/loss_module/__pycache__/abstract_unrolling_mask_value_reward_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7acac07aa7674d69897590ba3661b6cf9ff673f GIT binary patch literal 9397 zcmb_iO>7)TcJBY_`Q>m(iQ?MX^m>(zXO|`=@1J&~Xl1XYjj^)i=qH;@?__g~BgL51Z%{PJVHb;>K;R+{C{up*_YW{X#LLw68#+PD0N zuJGD7chFcLH{MR%#FuSL?+M=vI!O?P?LTn)!@RjM)BLU*Z_CarxBMO9$9{dPc*OQEudcW>fYiLrL~avYIsX%t>?96?-Xi_dF`~ff?6Z5o$*#tTk^gFA-s~E zed~Ul2)C2m4Mo)N2jM31kl8EcLCpG0#g?z#>D2Hu{!m%R^S1ub6C`4A=%3nZ((7013H zZMZ5#B(lnoxzdgke<#jro9|3J9W39&3g{ek18gNoc4w#JJ2PvM^@{tz7w)Fdk89R4 zBli1UK|&LYL>fevh%6I1MdUP*6(VPdtP=SOkynVk3i7EUzKWm!N;bBlfxi(3dwYI= zAx-0 zQjh!Yy>Sv%UmrgMM#v1o`nW@iGXB^okh-wecTq^7_lXMat*7dyI)e5eXgCSlp?1$4 z=_6xgj;xUeP1aYG)Wl`2C?jZbPxsVqJHDP;p0=&U*AoM}uk=ju^R$pysrp!bh58jz zt*b$k^&8g&9l~ZoEZ`?I`edof?M**n_V}?BR~{??wEd>ox9y z`$noBst4ww^2|mJyv}PDYK2M7M9rSm3Q6&x^i+LnJhi&Ir+dZ|J1Hj>lG9`Lit-q9 zj12ChqmPX~=COTD?xUcOi9Xh`eXx$6@=H~GmHTSwTR`8!v3(1ngI-m+6089l)2pRc zuePoIQjISq3(zVnHFDi6K=*9)ucyYLil;0tE2ARUO|CakFLB-C`V#78t{1qzjCv(C zh~ufr)M9EgEix@JEiQgZtRs-?PTf3k>kQiMG(!;WiR6Ix$% zKZN)l;U_~8#?E&fQbs@iAzBD3tT}8(*Bmk#vEzoG!%DmE+}ZMFkw-W|>_oYbUvr|c zzgu$hPtbK9`p!c)Oq>K`wgS)d!+dDeb)ungU_|L|nZ1o-G2P5uSs0%h$%%Dyt=~Aj2?jB#$H5?vyT)+3Uc;TzZ$DPTt#6G-dC~$xQp$ z-1zbO#eURr`|%|d##4SdyJVkAz6m?5Pj@KgAvajfH<-6g_9xC`*p1nZeK|{F=bkTV zVY(z%4Y{I9z9QK)zo466{(=}!JkKL}Lgsmtx6Lo<*6fm+i`g>!E~zi=9Y7Cq+}RSh zGLxr`4Q`};PTM3*8;=v;!0h6;iM&pvN#so;-zKs_g#4cP4w3H?`5uw)gJcCDX`sir zd5ZmUw#*)ShUbh$mvguS&ChR99AbHX;&!&O29KH3k<9G&BR9#GW^d)^BU_wpWC_ks zAz7H2NhC14iE3v%Gld&&`r<4t+T8O+6ywg83nq{|&-8dWkWi38A>%=Ug3T(6hY^MN zEi{OWL@p5d0g)eqv?S;Gxf3q8mXauO`x5@Xkk=&ZW>K=2xy|fl2{4*rma_zimu{==t$`APQh$LtWMyYE7@H{MFB@_-eNL4Xvg&v^vlVfA#0#q6QQ*zEJvitFd?P zWo^tHsG;zJE(VT|P@pi0LZ%*FT8cRF7sfqZLdH}FJk(aH55S?goEp=*r>!Uej8_;i zQoD|R)E~d-+Z&e?lfVgjhTKyJ1O?=PQ#Eb?+Mbfa*%YqV?nWZFwSu6R#FQ4H?depRLbmOPgI+Y%42_NF=N&a2xf-5Fat*Ec6udClw_kM74HXk`w zO--sL#Wx)S>EjQXWb#-HC~v@*Z7L%Z(Yf{n5+LyQO!dI;Ne!U>Kz*t_Rl6GEbq!EY zOY}LYXQdioq?I7<2he2L1@LcU-eRJqIzhts|8i79y_^*Gm3;&VX!*T++Nd(B0^XEX z6wFxd*^rPT{yHrm6sacftSXy;ldAH;18Pr7z4Ac?cdLg8F!FXI{)e=1P&=d$K|5GL zm{3mYJu6wndJH_JHd;t+%(+1Q(;`YWE}1BmxwJ4UeBfc94O~lU@lXX|wGQ=BebQ!6 z+7|ORH3dw}`!&*fx|lXtqYCaC!L(aBO#8Nz_+cCg;Is)UnTRR}Ce9m|95y%3>xffc z=Rm8uHWO_--*V2coo_xr!59{N;S9%*md5Z42lL}|I4IA-IxqLQ1(C`9qV23X(lNB- zV9#$l&-ryiGPiNX6pb^WsvwpGpgL6MK$Hx)1h%T5ypCFI%5 zap>w1=2d1PJInKcXj9y*z0uC!8d-vLcdsVZ0L|G`WV3&7~QeJk9=)b#7IqL(UP4_2!@&(ya=s!x@b2+WJ zRn65+GOtyUI+)k0`POsJviBhP#{q%3juJW3nBag?Fxbz(2G4G(2C@(a5I{pcqn*(V z)mG1_CH`94vP$2*i~pZ9(IK9XKjb$_AKAPA6AI+r5zeYTBrH~y5xfUV@aQ8XAK+{B zkKyk1=UGjxeh;>?0bEoft_&Cs!Rp$!inyju8|SBM6TD zH77_=n2ef)vpgMHTL+NE@qv{Q$NCQ9baaV*CtpP#X(J?0Ci6dZF*l-i1e5L9v!;bAvzc@)Vyh`S(}@gVRY za&kG-J8qIg=d3o*ES@?unO_9Xk!VIze>UqIpKmSG^NyUG^RuT=hLXq+3!Kg)!{0o1 z_?xYj=^@UC$9^XYJ&C*-Zy!Nym$0o*aK*2JD27GH35Tf$7+vGv(|q6>Cq{9S^K)U) z+_@rTSj06!%|KKdrCb?dwV>HZu_7U*0tuP;)$nVjW=}`b3oi_MNPK;a6XpRqQ4N^* zxqFY*Gs?2UNm&iaKqR_=yedfn$v+K=!6MFzN(tD#3zIyQn*$u@VUTRS7{&mEIP51r zBXiL#Wko+6!sS5`#k*(`6slx}+=$B5{8?)M84=PoL9iez4#ODwX>+_T5Zgw_@xaQymRflqeZbBK$ z&ENs@eD`;Gupd4ML=+C-nx{r{lE^!ChkpDSNt*cWOZghQk(*kR#_qt8W11RSiH95)}Lc^X7H3pj1k#{zrz(% zazZi5!_(J*XP0q*1;2aWduc`H=RT*#6hE>RdMPzVtp2%bd@r%!e+@)^tl-F(LTSVL zxRv7hT#bt}YFxzq(u^8YO{%f6cX?u=UMS_yjcuf*l2X5q{Wi-lmimm42|Rv|kQ-9S zv(Wp9tQmR=mb8v;NA1)V%>9qJq@WqfO$B`#V2=g;{g)QBJR@_GCL>K4fF(`NCs95E zf`k3n_TF-wtB>IG0U#Ku^xc>f^An`P2bDc30U$b)58Yj7HyXM&w7r+#ru6OfCLHfi@lOlC-OftwwtNALa-|LA z;`Zeo8OLa0n3eEen{%1?UxYn~<%{HO`F;?(Vz=Ev!Y~FPNkp`>D;z9Myhh|(M7~ah zvOAd`pa+3;LC`_GO@w$WNT%X4k?E=C5{Ft9+#pCDe-p$uwESpmNXf~st!n>KdaY4v zT&P*~v-YW4xwceW6f~}RI@{nk{`n5q8J4vpq#W~CHRI4J4?ChA{y#vrezH+AI>8_v z56jNKc+eE<=5MqyjsRkI?!@EB9AWhmM}GIk!;S=-lckam&gwGr&F|gUsVvB*Nr9UH zC6MDEQa>x%-NCyhLFzBagA4kLk)Xf7WJVGIybzm@lUrtA#QN}B{}P=cPGwYZILBW) Jv}aUX|3AUT#J>Ol literal 0 HcmV?d00001 diff --git a/loss_module/__pycache__/abstract_unrolling_mvr.cpython-38.pyc b/loss_module/__pycache__/abstract_unrolling_mvr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2fac07cfa387bbb621cf7ee2734f5d52da52777 GIT binary patch literal 9097 zcmb_iO^h7Zaqj==`Q6#s;Vx;(PFt2Gj1Y*DmaxtL8Zcfo9K<+u`6d*bGDM(Jb1;`}`k~m-0>-k$Qmo}go zbj^FO-(OeNtNN;XZda=%1)pF1op7x;73JTlGXF79xr8hJ3lLB7)PXWo+p0=sZJ-VH zw$5dJU<}Q+$z@|;4GZl8m(77aEVheWwg#nPxn1URVNhvTFh^lfZP)N?50-}Yc73?q zULH2ujp0grCGXd4H-D&jMX&T&@k)OAk={Pz6>ci6@|W0=(o!?~hX247-t}FS%Uk}# zelcx`A}KM^#fe-UxB1*s&teDw67%EBTc-Ym-V@_F;_O{%GM(-zLFF^RZ|VK z*r^U$ObyUdYJ!$i3$&6JK&z<@T1$(dOKAzTo|ZwE(+X%Kt%9z23LZz@YBe)+7@>L{=Fw7dmm`@5Whe+Z_(wju*t; zXcQ({Z2-mT^hROlKB-!^G_4C(&8#|Y7(&^yrEcW+I@~4))MpL(h}x`$hphTx7u)x+ z11NW=7vNEXWN&^dzB9KMS+BVFec^8Vyj-i68L>a;2@;xECek3XLZnIL43SkLYedcx zIY;DcL|!8DGRPN-_y&Id+uhuWhW=(4?C<-7%@6$eUJ~tY-U{69X!8fraCemWqI0tw zVE6v!Eu6f~t;;{@T)W-AeB=5iGXO#v4g5_v*S3yiG}Ns#ye~F(_cFWF3Bn-hbZF1< zI*3xUw2E5#rKXksPt!~Ps~M&Lu}n)n>AU~+lk|9FGAB%i8G^5Im$YE=u~8s3V3qEn zkU*&u6{=ZJ)opbQ)jrgq=Gu{V#~kZpV{DGCu?98O*Ob)6WvwY=sAymJ)O&V(HMKnL zo)%wC3@EnJH^ooWLSm)rBlRWfS4g#<1~t{Ut_V6|t%6v>Pi73r?36p(e!}AOV=0R| zZn79Mec*>JTYQ}wzR5#^uoq=zW^JxYr8z|_N}iJv%4Vq~1~s!pCHW1S%APC>zo5jl zrlOVXCYA9w|?4)WF-kW}#M?)=bpwX|0eH4@*zf zC&m-2r+d0*Jhqc^QX#oKQa?~0VUDrEeRTA((Z@WokI8)$^fA%LIrK=vsX-i1O{Nx8n`x10iD{W>h3J9K)IL!8*E&%6*OZbw zC{QicJ^S-zoQooia~Vci>zUBthC%lp_Y=N5aH6{~{P!6i#2e21Op!C%^2wZs)<7$%M6J_OsD z(z|5ly6_ekuhHt{iY-M`_TN6)bWhBSA8))nh`R0|zKFtPim&FW*dday!4uVIPbB#- zAE=r?P~J9u9&rJ~w&ow?tJxBJ=6p*_vn{bI$Q@Pk9m%G}4PF20H^gAzSsuv~vB;ym zZE;IC=C{;Z&YJ9-q<(aEfg8v%XDi&wOrAA1xRIW7)+RC5WSn>nvy0b>oF~#E@)nWr z5!oa{eoeehR#5}4gYwRG5bNaU3H$26KetT2&SgBL?wpG>CVJyhG%7hd(SI4Jc)Dq4XbCW8Zm^O<4df6cEtGK=2_76w^?A)2B;I5ybt>xT8x5 znCgIg+Bxb2uqQ61#;opXYYG712MhqIT}MCak6-ldPs)i&-~&BF?x_P(0#d+f8e4#> zXQVK;h3j>Ck;rYSAZR6Vqm`}Y0YZL^rf7N&@qb9&Nk3R`TGpH{{D@j-(qgUPLpwXU zzeing#RQTS^=0+E`W{yY51Cx@KnxJK;G4FUv58Pwdkl^f z7<;OE;P12sP=2UBQJ$zh4Z*Vp=%yw50(7%d4Y1Hk5QGCbGK>Ovw=r)q(Ndit-ba5k zE}>pd3J1yo06tp&;Epz~jH`evr8NaJR{J(2qKLmu%ZEj($vfwiZNNiS`S?DyC#8P* zu!6hQBLoO}yAl6WS~#p75yaOHmjLt2Nxg3+%h->BRchm<)W)1k)ITkvRO6C~QkhFj z}lI_-lnF2eR;n|T2GhL25V5kT_@OeBZpnza}qy{ zBLUnsMHy33nX{u;V5j+PNSIbJC9g ze16Bzu!v4Ho#dD-2D50DFeD+zevTv09b;By9&*Dx5X0JLNSpaWYgcG9$N+f$LqBet zRLdh(_7()TKBe_0JY+-2Cb88NmvJYvq@s%-q9R>oQ!JzbJ0FjHu_qHh`j0OENv7`y zyX+D(%iZ1eLod_9-OPfs^`c>Bb$6nm>u2RTWf8wmD^)WnM2b&`h@ki8F>yCs#ydrfxZl9SR)V!rb@VC7>YtB&xw`WD< ziV~OjDIZElDKl8eWb3n)5J}U{LV_u7S95if%xhPq4(7FL{^&Vp**g%NV?ZyiqD0Oz zCikus4EEcvz*AeQfvkZ6kKRzvYG*Y=wbip~iNBWCRO!3F{_C89F7bHsAwNj^$KLv1 zP#|xO@Kfz0!EjC)!#kh^4?RY*0KP>34F22GK11>aZDxX`L&_cp7WM?+fiwdAkEh`- z5(@cvjb`ET@IM%D(~Rg*WZzRH-;>y|_d4u-X_ga`nW?hc302$Q;_ z0ixH~iIK5$a#bOeyFNAF7(p>Cg4WnycY*|k>8NS=$3RY?UO2Vfexa zyU@b}K?SD~8lEOUdSng6ZIH~kANUVAvzzH%H_4%IR$F8iPo0^}F9O%d@*;yipK^^a zw41cvv2$`^{?y4(5Xo7A(|Bz7TPF^GtGzZmwE6JZ??#~~ku&4vV+iab9_w>l@yj5J zVbSry;i&-z*Z6mp4_M>5s7!KtA#7PVM}+K(xTdHV2uY(9C?lpeG#M#RB#cxb8xy}8 zey!B(>qtr=JKslQ>NA`u4@ih=z_icYd!(LKnhGaLH6-hhhys$TBn2erG$ihdI2$S@ zMEBpB=84=Kus9EbWao5P0#M|?b$O^fUlzH%1sr@HJ zNVf#J`K&k!WB3LiU|=rCiWq1imBB_J*4V!^wLqsz^|`T# zv_w+Mck#Tf^665Z6Donf=LoeSg*p$NkI9yyqhQJE=yu!=eSo?D8J83+L$R4)R{_>o zu&=*Ru<{HnNS2H)`EZsjIg>>B6bKCVT-$%oaV|fE&jtWsq%m+~PJ~a91|LNBgalyd z%sp`ToV{oy^Z7wSiIA18DEV-Rm+^E|eUIPKB666%20Q++m#~R{#TC=m6m9>-k0pIE zy*9@CM}T_2)5%Jmj(l5*a-{>L;SS^-8E0r=n3eDnm~)5t(||pU`JBw=s_S_kgFH(6Cqv-l9{+fWOhoq#Gz0HHwXg9 z-vO}=EkCLnQeg6HtJ;5-UTKsX@6@dNtM-{%xwcYU7BsH4n!V1i$MYxL;E<>jA;p)! z3dzHPXqR7_XK$R|nK`1WRj#jO<)Fy>5r+OmF(@})r=r*7UXaR{cT3jUq3P<34kq#izkUC bv){#j@QV8)9SP1HRB*N?UpjMVRa^gmv9(9J literal 0 HcmV?d00001 diff --git a/loss_module/__pycache__/abstract_unrolling_value_reward_loss.cpython-38.pyc b/loss_module/__pycache__/abstract_unrolling_value_reward_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd4487d3766e18587b334cc07c37c5dc89dbddc7 GIT binary patch literal 6974 zcmb_hOOxBi5ys%fH^O-TuXiOsV#iSwTbBHYZONraY@_5tQ3Q5HiUg?v zRgKOm>M=HzP*IpmP!>jp?-SBfR4OcWY3jhR7D zcTayk!^^c=S;Ozf*$bVwp47CzQzid1QF#@2vXgUk)m_4Y%bNE^6({SCB+&3$=LWR@fQ#y55E0Hy!DF=f_|+rp$|RCk|5H z()%*-d)>G<3cX)<20IzMnPb1#iLR%OS1tr&8AL&d{?s8ij`{8=iq5r7WnB*ON)^wK zsYz#`%y)ZHtgK5r<3Z3R@$$bSDzD;>z61$0SNNK%`yw#=@TG4ReZx2R zgj?)sTOug=))&IJ{lcE*mUEnEz}nMV^Hz3a$dU{#j?|lFNbn881%wzR$QuSQo~mQleyHbfKGo@HgLa= zJNgtP5sBt$F=$_l_1n5Uot2GT+02!#T-m;@M{mT1--<-XD8h3l3%O41�<{nD1?T0zv#&lFHcq@c?qb^x<`YWJ7hGj?m6$&w2F&1#B zMk2Pw1P-~c!#VW>{fap;Cg#MN=x{A#MN6y|Z34&Y8}f8w$7UjKi^qs#C;FCtQIj-} zwj~>QD04uozT#~MF;6&%SlP-L1Yx@%A441YeQwwbw??YM3d{^$&bdJ;)hxM)Nkz6* zO1?|oi_;JM7`2Eb&_r3YMN3qnx~g~tg-u5|q^BwdqKT01qMu?x-!b-~sRMDvx~tz6 z6XUKnF~JkWyKAEat)XP0WN^uh&3)@L@tN_NwPhe|>3hZ4jtiu<+v1FN8?7f6x6#qY zMjPXSZJ=q4qoIw7Hs%A{6tL7`oy3NPifCUt)o-TrY6Ff_f#f z_Nxcz%e26>$h5??%(O!Ej?UD+Blu}020t~XW?~`?FTh($2s0Jh z%PL}@QuVka$)Kq#W$wbBQa(v8*~Lme^)h{(R)MYQ+U~0lUs6*S38IaYgHg9Lh)$s} z9q!wiAxEkW+E-jMWX1NNmR`w zN?C^mdL=JVdqK5wYup&GA#Rx$s=-qSlt^oW7ltt~DYH^UtO^-3L$=pK`=VI(3DQnnwU-L-ag9u?&eNYIyvkgv+;iM&AMMG$u>9>twO zikAONA!G{#TAnehgf+x*mlVJnN)^T)^aOa+BBJkUqGgl?F5@XdS54=wvW1Qi zOE*#dA9iDR{hKVME(L?>kAec_8jMR{q5vG?B&km+l73)Z0S;w&WG6-+5p@l1`ucTU zUP{b)-4`nw^Yo0Dy}PbQAI2t9CaUY`OJm@Pwtt+K6Y4YNW#!nSh>hN#b8j8VmWd?G z-m(zgjB~OEStW(Iwyl=g&89d%ga3I>0++bffyEZ76FEd43R53uTe0QYbZmZJSM$Y1 zEBzwwh_Yu*JT6X(Q)2hUN7MZgKX2;{!<;d7DI`vRhld09>Kn_?-B0tzX|?z}vismD#ovjllejJ!B$vXo7dkR)X(mnMY- z;FvV2X99psmbe6vOcuFRMF|j#wzz2cOe>^E}? zCxsU0*;5XCs`E7RwWm3aY_I1Me&-qI#QKT${R2#a_n{*^xOW<-AKIBe?_BHoeh^{< zg4n*n;yy?q3ds+f&bpJrffx05gSK-YdeXw+6-OY=DN7ZZve?%Sv_ol1nvP;C6x`G0 zPtaDGq`MMT33(PmK1lehdExKV z7R7N+@V)NY~Z{<8Y;VcZPe=qs*>|5Rr9k0nRG0IKHg=k5{?qW zXgtF7Pl4H9%Vhzxy{ubH5kE^>~sOU$7V}9d+kxWnrEjC zQ}*(_b(Yqo{MMY$nGJJQ31@b90YAFmpi9PI-;rjkmI!Z3wDu2{V zjp5ePZ%wP!)Wx%x75d=xRV`CxW<`d=w=rOv{hmhYBizwR5UpT}I<6;?vn+GQVhR%* zK~pqwt?IbMGUYbeW$P_$7Sn6@rT=po9!g~hl}oremH7k(AeK(sY(_8_sDf=Za6ues z6AMorPdl;t2GEDL*XD}$J0u}KkcnzYK<;~YTdZm=4V%3J(8*4b=0-IxVB@DBi2afe znx-d|nS4dw)-J135bj`6BU;Kc;K(;YRAIVP;xm&ciGQ95EgMM(1gf+XM)2AIk+C>; zP9&j+c3SKj@)_#&1`#rd{E!Ic;zxtX&%xWno$c67q?bipFVl|wg@+m;pK)j6AF=>x zpU$k<0)Ik*EI?Vy$Dd9NY+4bvVF7I3$pUz1$2fC@u8~x(Xq;8N0nJ)T0T~4>)wejK zaH*7X*R*{beT!j}bnff0lu4UyZ0V7|?3a@wA^wfyuud5*a2|J%W*J4UW!BmKn==SH zgvs%>z#)HkwnmaGw1S)*-bM~Dj>ze3!Fp*EtgekYgufj-{wRo?a1^Jb-alv-FPvPg zU3YNu*A39N6Fa@g+3wv8Lg(r&?i_?SdvX*G5o6Y8NPl%Fc1GbKZIKf*@G}u}8F6*) zXKCb}`xu}7Fy3TKo2EZ^I!S1oIwD8KQqJ)zX)NS`bi~(ipow83`2l(>lZQN#J5z}G zJ?>}~L^DlxyCR}P3(=vaH{pOyvHRRNazXy)AqS)dM=OsVkm5nC{T4^t(=#%mmd=gh zw}&_pC*=k{pQr-%X2)Vj*9+~X4{%2`vL<$)|3(6)q5fRENoS7uUm);2Rrb8}yb|TA z2Z-wo(l;q!>S3tLIKJgwGEVUegA=^pRIRJMup@7IT}cPLkY0|)wf@~q)gjVr__EV7qQZ)6O+LfUCbPlt>W#3+N^^;gDeh)iGaB;=x7v8LUMVD@lYx(^6JOsc153&nQk+9S z%DvR_NA5?i1nK$MUA_4A-Hs$MGf^q>sYZH&!Y9%jRF<@vOKocb24J)Ishu6)8smgo pl9Nh`o3c!V{;yyKKgdq1coLG~sQVQ8ctnf3igh!+D3Gt3{{`-f6;1#E literal 0 HcmV?d00001 diff --git a/loss_module/__pycache__/functional.cpython-38.pyc b/loss_module/__pycache__/functional.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2816b5e2108554e8a1be250340862631755f4ee GIT binary patch literal 3173 zcmbVOPjB1E73c7uD2b9C+p!%WCeL5n0s3v9E%CfNdN0)i#S5^aiP zW=3&rkUiAt2k5ofgAMeuUt~W7d+4F3e1W3qrG0Nkl9g?W-7TpZ&AgfSXWo1BetgmC zv2~}ruw&rHglL8TVI;79Xs#r zVIyt<#XR)l2GG2?`OXeY`!x}!{&JlR@)b3K_};a*(|#k=1aS}0#>Y5>tzdMtRr99 zhS3!@@?o^5Miz{^YP4Q8U%!M?cc67kkQzI&dtp@@3!1;6U0u*x3)<$4MoMcz+bTE8 ztL0|7bznni8sT5Sx%6`;r#ut%HkB+BIT!b7i^5ZI7AKKpF?|!JQzq!mhg6Kig3%i( z%WkMB_I8(|NI#+byZgNlnFG#_Lmq!S_kVK=8OrafmX4D+W*KA?@bs0Ly_T}%`8W|2 zenRpv8I2`9$T?a>4|qPIa?I#tddhfCLoSmjW%ub=%0k>9433YF`&kIiVf`qd3})+M zAnxqn{pr1%d$;y(?%m$IyMObJQnW*V6~9GIC==`eaoLC#d6G#QN;(QBtgk)0LtiZA z{+#8~-(cw>i)7AEKn~mv8bO>Ej|n5Zr_Z2?4d_O2C0_Xzfs{8cj6a#LA-G@Rm0KAx z8RWkcjK2wGl4oC;wb|dRZ44cEfKzqe#bVmhgmKcFfAaGUxL7GQlDjvdM&iMTgu8ss}_welKbB?9t^0XWaAa6r= z723k6|6h@EauUA{!f(;HxvY<4MBiiATd5rgbu_MBegEn!NtVbYOoJplU>v<(J3P!r zthP>*qPF1n)s0YjTGVb>6fBEtGb?HrcE$OmcB65g09RXY3Pf_s>UGYdDF<!zu^9UJ2XbSEvRqv-Y7z{) zbDb*YVt93t<@&0M0AsyO*1Vd$dAWO5Ml`(%#>p@MuXBuIZNaFv(S~~Ul4WMal}pM@ zNWGr1w=w|F!YSjKAwys74|EWQ9ldux^>xn97iW&u({ChKz4`I{#>5l1pfj2_*&tog zHCIiSxWpztv^99Q&24jov}UghWBu|U0~ZItHhQKGkjquo{s1?9Sa+c*-1H&05H}Gm zU;VSPOB;(r<-k9Ie>Wyo1J*rFKQsX$okOc?NgEKac?L+kZImwdv2|vFj4!Ho*(%%D zjC1$Gz)@^&8G3eP@;7CpG&hW@k*&qH_`PiWfs~E5Aw8tQS`#^cf@ANz8N(W2qzTyh zDQNS~pmd?MBZq%odaAUwec5%O0HTg&d-VQoW5d{j)epE$@eiff0vz@bLtlHc z4G0YQ2WaPpe+HN60<4`UXh<>p=JI}jo;U1xy6`lCK5(PrWLZ}08i-MzDW zd-vAv{wq3RA>gwd1BxWs99|@&$B7)%gM`BowB5%^no`iJw&AT+LKWJ9(0l*e7L7Q$ zG{v;*KvL)oP& zqRR32|AyxmPkSAX8O=M;)s2}?Rrd2Wq$8-*EiCgz8lEs-cYiJb>JWgQaGvwpixUy$ z;QQK5GF;(Ug82$|tI+j)Ee)$&-34nVVX$CX52I*0nWpeL5zoJ2coSKFgxwF3$Jdon zn>M`u5ssh$=)`mK9bg!?a2V+CUy^*YZv0@b95jAJ)?JVwPz=@&PrWwOjTTgp>+rTr z55_)OwN}ZRc^xW;Pxx)v-6Ohkv;bgE3IJM;Ul7U-@GU-_6eoNO2(^vzgi@MGU2f1t eO2Cqx>Z1V*S?xX0p#KmtOnA`wXi{VKlm7q!v2hXr literal 0 HcmV?d00001 diff --git a/loss_module/__pycache__/illegal_value_reward_loss.cpython-38.pyc b/loss_module/__pycache__/illegal_value_reward_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..732649f17317e985f6a8804421b0544febcf6f83 GIT binary patch literal 3709 zcmai1&2J<}6|d^<>FN0%e`IYZVK)I17zx&V1Oh08m3EgT2%Km?0&0PpW~w~fvpqfI zsh8}7+txf#UpVY(pBnX2^BW!T? z0TZ66J!hgO`KcQ=h4YXN{4cQ^8*t>5jc|ZR}UI;P^TFHQ@^HIS=bH5Vg;_@J0Q(7dAv78qXQ@ zX^Ix=)^fHjI+$%^kB;ohV26qB=Z@%!K1O?DO{`;WUu@tEo94zxX(~r?`e~dV$_MhZ zSP44LmmEqx8a5!ShuRLxT@?KeM8OLdvBH7mg$wBv9;92;AictetQB=gzX%}ff<0tk z@V!A`S~5FI)HIvO4Ca|kg-lKRFjLbsjdU(&+H_MW5$zvl(Ge+Tdgpy*bu_)G^XY8&Q4)`)yLYFP*2*g?jOX?vxse0(3 z6lm6Ie@7!>b7#cnE{d1Kg*iWU^y>v5@j1_(g)8_OpV!V?e!=gtKfFJO8y8+tKjXi5 z=E1xH7YFaMc@0i(p!N%QfsuuO#?An%!b7i7^$y;_YJZE(n*|%!)nAKd9u)kPzeXd? z!r3PH1R5W`s)ykd-m64K5ya|@Gc6;KBUjv{Muls z_PP`U0k$Lxv@!4i4fqnCx4v}dZq64@;f!7Q-Lq}iwgF;z9AG|k;9mxh0^VpN62^W} z-^4ECdf{&}@}DAd0I<8~KCz%ybS6_fk!q?9X^WsctU2Qsftmy34uV9a5x0-H|eV48oI2l0bKT@W@%|1SVmbV7Wc-*zsP`f&aTjebFzv7Q+ErZQCN<)`0ZZ$i6l26}`EC@T+{isGlM| zJ$2Q;@{PqNvBSKcHPL=WP&C!C0=m+Qzi}1VH37zhl{EnpL|_@f+f$TtD4)j1iB?tf|Fu6z7p0=+nUTPg<&@m=dnGv~JdjNfTfeP?w5$**Mt@Ip=s5XCe~G z1hJ^%Y$VGJV0mF<)sOMeZfEMthI$*im_R(s;z=^pYKOZ0WqjhpW!R~&VL05ZRz;aS z&XLXHTxxX-QzaPOVVtYi&>i$Ig5G%cKyD(AO?&l3hDTv%c&KoU}}0I(ZzjDsdY zn9c%WPEa~{`yeRFbf-v#K|>MUsPB^a0fcF-+=9xeq3~2cv~#6%BAtLxDcCCBv=Q7X z{a%GEajE(ViJy{qkHpVO(5+@A{q&OGGDwJ*RDq{pDc6F{Si4Ayp`YX<@M?jm3q8T>T z_FBkKZeFKyxpZ%^Mz>Kgos1&WjH2@Tigr7S4i4kAoGHo7$uiT#CEfBR{_oV&@7t@4 zdIcY}G^xHpEy}BkAg=C_Ao^1TKlK5Lk4Ri-Zd(+9iY~)sY`qN;)F>1k`u(@L(OnO^ z!Od>xM$4`mv`pWcrquQ}wG|zBc>~D~C$nejdlXC>gE168nM2^al9G7ECt_Y7gyCs}`jmnhMYmz6wF`mqG=986bwesnrb9mOmv)8S4 zK3VpifF^hyQ-aC`vVDpINJ%*bBfBJMX}Wq3f1oSAc~h@KRVqN%eLO=;!?}5cVMdP28(N;Sb~LO zK6_WaS(UT5iz0MXNzHo!!~MM5`E&!W3Wi5vHuy6LN%qq)((`-EJU{ViRjFD+RryJE z2Z8s~s`}K|;?BYZ?$?Ax`ow9|Z2icAvf6kS=42kCW$5BaOvBpJw(DB8F`-BkR%%!fA>j~m1RG{$0=7?|9~yoU_HRR`l&19YgTLwW!@d9j literal 0 HcmV?d00001 diff --git a/loss_module/__pycache__/loss_module_interfaces.cpython-38.pyc b/loss_module/__pycache__/loss_module_interfaces.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b427d7c535782b5de5154d97d67ed1339ccbdb63 GIT binary patch literal 3043 zcmaJ@OOG4J5uP5-47nutA<3|0E0NniIH9~+uAUZzATVMHfdP39L@NhiY!61$O>%bU zrS9fFh~Y(2j`;^U?B0@VkYCe4PC4l>1PG9?dq^%V2pMu}`cYj~U0q*QKVM&OJ9z&4 zY$y6#*Kz)BI;&3uov-l9Mh{w@ z%5MucNqNuJ{wT`0)Pdk?BJ$XUmc?V*1n<{)WeY?(uZZ1@JDt)tU3u!&l14Oza5^jm zjnb4_*)+--CgLJL;(5YzMOhNtt&3B7oT%~7Bt`it6+Gof5g3yklCsD!U+CNG5q)K^ zP~T+mR|a{UJrSZ1*o-H6IuuQq`gcck@cB>F9~T+#=gD-+)Ba;F4^>h2_mgN;^uI2$ za-z5hpT#gM=KXyxd4KQ8AH#2-51#ztyS~{=m|0J)w!;KE#W0Gw+$m3WYs6J(ST~KI ztONb&l?x~DEaruAW*mSlo=a$$zV+}w(Wn|{jv@zcUUig@wZUt0=!l<8s~-KxnXTtrklZ|LRpe~# z+&L$+b!cl>t!jNlM&#W6!JBPB$L4IST084hTjzv0RU54h+v=dTXASx}#QjI+~1f3V1B(m4lmy$eD zNt7ZEhFqA)(QaPq4nHXo>LE)qsk_As$;FWkU8#K$2nbgzT7sx7dCq#5gvnV@m>(sg$PEUxXRxMyL@Q5qce1im-ALGp z_K%`;!Zis8>kGl~nBsbC=>a$f19El5E|U7v1&`q7I3@B6PgN8e(LrtaVDo~@t{bd~ zu9poi=ln#4a2d>GwLXi}njrr=dMmsTV|Vi*M&&O+oPd%5KaX_DWwPaNk!z%Fe*u|Z zUjf{^65d+i`9{S2S;N=#E%2sjFf#K2r~r@w2$}n55%agBvQnMEDLlH`9P8lbh;97HsdBrQhN+bfJ`!443S*z z0aD3%z6|k6u9DQk2=QgO3^Ytct>B06jCP=`HqY?<>B08>?R(pIx9@D<{*h)p$_)|$ zyd+;BQgNH%=`a!4N^Spfg318xmaR}xYK7V)0ih56zfCnr3nfl>+Dp!xbRXu{(to{6 zzo2<|K)N0ehp7<4_=dH>5)9cU1 z#Uy1kiw=>FkjUVLIvW;;(|wl6xPYG#_fZ^AvPlZCv2bx5V9X0;(oq>9CDkRA=9BCN z7iKeM8l7?>>5F7E!U+7F(>uLDH|GL$OYkEwaHfMKNA6L21FOSg7|xgK_V*?Q*jkB; z!rYwRnoTC+I(8ro%ETuidIS5-BIL5#=s{PRlbTzrue@gmqG#3-@h+Y}Sqgdo&bP5l zehlKQd4To?z}>>z1+Z<)ZIkKERd6qCdcU~1VMNUza}!wB*N5mR=Zpa2Zbk5O4jNYh z{Hu-uJL+KrBi@-`HEg~0W1_G9A~#jF*EIdC3^x*g8%I%s#p@Z}^*=QExGk1+XC+%e=Xi1tj`6zit{xgoa= zxo3#Ec&!kV)rGNj1FyYDw#diL_J{3jb`$|WiBc{Mtb{pGVXnco d1xC_}bgRr3j@EDLtJ-f(wztk{A(F0j{s&qhR-ym^ literal 0 HcmV?d00001 diff --git a/loss_module/__pycache__/monte_carlo_loss.cpython-38.pyc b/loss_module/__pycache__/monte_carlo_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4671f64146c77dc3453cc057dad8cd35409d2cd1 GIT binary patch literal 3361 zcmb_eO^@5g8J-z_h|+4cvUlV4CWYI4^n*#UN^V7AxGo%{mu3^xP6H@52tpjmk||Q2 zq1;rUF5NA>z7%NBL7~0Xe`&5g`4<$(sn0uR-?!~+)kudxdYxp9A=Ykd+5b3+$#AtNRK6sshfW-3q9VKy+glA-S9)2xt56!&`5 z@ifJ{WVulHXrIl?$KWoSE+HiA@`OP*E}fA(@<#sZnQ+{q?;p7Cx?28s7APyD2I z=637LQIOOzYAi>MWCx?2<*1qLV$@uYTFD+pyGc8_cE-A`;Zel zPr|y{k=gU1%CoV|3geCQM5d-$p&98yPPAza;_*0+Re|c=~SA%%W=hp z+3V%9AKBSzMSDH5T+)o6OBD~KT@^Nrt7Y0(0qjwA5@e>@A<-nUO9D>D)E&7N$24(9^{>S`FJudq>3K&hBE8P&O@nX=j8rZ(Pv+F z@BiWRj@1!1PgB_$6ZBD!u(w(rPyTFzC>mzNB8qDGm8So(hR?U}+2J;6$E$DPtX>8@ zb&m*JUF2>KnkDd$=oFx^;HN^jOFrOp(76yLzu=FYxtKe1kU4Tnalq!Fc;uE2dJc#d zhindtN1i$U=(Q_)=XolhZ!OT)*Ck+>cZ?q16f!LrLN$#z!Ij$ zNo5xL>Kgj$Et=S|2rWr#Y*6YBvwn^p;R=?AO6=1&!D{~>VN`*}8_efz-oPCIz9#=3 zXsd6wup%!YpCG*SzaWM{;an_$81Qjl@|VD3?qY`1%uDK7>|Q!A*bDK(={pp(XF=f< zE+KTvKV_$U?!Dl6>*B5RhPQs?*hBdYnNl+jssxqpR;r4oS(a;o4DWB<|6Ut zb%a?^ioQeX>ZC`pySm5$h!Da@{{&*dO7{5XoS*XrJBK+(>>P?8u?z8c@yrE%m^A@ME= z>l9vPi8aCosv6&lnVl2Mx#hd3LO(3rzi_)& z;9J@QPptjhs$bGKtGQ*r@F=kH^*`*_0;$n0y*^?nbWRkhT#@T5$|CFP0#csD6KPt% z)mkd@g8NG4%96L40xNH^&hrUvU>UMbx%SKt@!~gVnvxacH~DRThtJ-B6RcN!YYSic znyR10_ZxHuNOy!WrLD)^M^!m+7yN<}y}J;2`lW~R;8UF+xrgkpKqo;Np%IYz-n&zN zlN~UWj#?=eep$P~dF)xUs4bAVF8HFJKvu|6M7*ZFh3Z_F*>sG0UC1iA<3ND2Px)e) zd=te*WG8B?`X1fQBOtD18lrf@WxdIt0gDl9-WFZct8}HMif=2C-j?9 z^lTcZ)ticOk!7ZVgQ88cacXMmC-&f>4)7y&Na7X=N?D46ND+$`D|MH8YdzP&=HUf} zwWj14c*4?bv)X?eZ?zh&qm~cpDq0r?-L|a-m6pf$xIzqA{kk_c->vnmG6ivKd}Q<_ ZAKT5|q>Tw~S6km9hHU&dp%8pe{1?G?9zp;B literal 0 HcmV?d00001 diff --git a/loss_module/__pycache__/monte_carlo_mask_value_reward_loss.cpython-38.pyc b/loss_module/__pycache__/monte_carlo_mask_value_reward_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..348520608f2d81bcbfbdd43bac2ef91d372cff50 GIT binary patch literal 2342 zcmai0&2Jnv6u0MVXFn54Xc7p$02Rxj%K^ltf>v$e(gvv?AdNuQbnM;EW_D&{djp9a z36Y?l_ydr7NRItCIPxd>%Bg2KRY>rjXOpZ-2p;+Q+2fDr-+S*j^QhZx5h%Y5ze#@u z?RV@fmH;|8psJ5RaKdRpCN!ZGb+7Oye&RdbFM=fag7AQchlGd1Kk$=~`wvJzdI2j* zpW5JVr4+R9-h8aIOf&tclvPpW<>+3jw!TV>nYb^$OC`s7n;EWJEbVzm?SLXkLODr1 zPK7@X#^E@6rEPHPo`o6QAGV*8A$2x3lUOu)^o;Tbj}L>S#e1R&d$f7;na5kaeHbPk z-r?Otl60SvfFwP*%aVQjUR7#wE0slcj@2(H*!KCuthZn|>Pt|4sOkd{njDZZHPm>< zH-QOFWEv(mP17=M(=lC6f9@itaGtl>@7i!uaZ%XL%)!N!7E@(=qjWM!8P8Q#%}Q;1 z1%Sqevy$x~kap>`FCC_~dpa-ywCz$>i6L`i^@8QIf%%Hw*@#5h=;NVwzro(cnDEC&0(}M%L<*6Gw{E_PD?Clyc4eqTG4%7rY-raA(y1=y7N!{Z5 zb+YBjA0Qe16|x_fEfcp%t-~0cg0nw?H>P<6=bLy$k3CAZ6F8-L011tIY=!bS)8wS? z{p78KjpYW61S2|{&#C9HkrlF9-{5Sur(UAvCOCpqj)7>mW8FMznc$d$U&Bp_6eH39 zFaCxy0>zbn+r|hs@B(B%YB~R&rSk#55T$L@@v3ZNTBN%|0*l;KN=Thovn|GI?gG8T*!CNoOzm}|+X`6cK+boDV1O-{&|8frY_o4|x7G7S@(rfHeB>6k93 zzje`Ppx@gWbZt1PxF~FA=G?)Q7E@(=qjWM!8P8Q#%}Q;11vroG&PsNW7BgX&F6UDG zkVJiZTOBI9QNs!FF@7D%XdLs@%$gE%m&U z9tfF^gu8Cw*-*`SFji4RWf{9O~kOwkB&q!yd%_}uBO9Bc{-|w+tp+`(?YU`Sq={q z!$(4i;m+Of*!{=J-3Q+coo~SPvqB6f7-W`Vl-cA!ZcPtmAMQg_)g}deK@2{iP z*HUEb^3~LaCHPZi9%`63+r^CvSir`3{oxOR@o+U11}knNB4RE)UcPAXm;Jdi35k%m1?(bfa1uNug)5 z%n%JO8paF*>Rz%8g)t(>fuVQE_kSOUCOjkO0O-zUd$*)nmP)I7-G z7To<=YWd)v2LEVOmI#ij^U0(ypS00m;XQ;0pr@BZDVvc6giYKaykEz zrE`H_h|)IZfwqllksb;OkZ@ZmA$4ArccrYP4P{!6gpKkN#WqT(Q&Dmo>Plu$ZKOr1 zDmiE`WXEOcC2zkWrEGxdY6EAK4+bS2UFo!#r(Cl6hBFKby?zDFuO**+7oPedbcKY_ zfX40%fZL=u={jwBE!v~Y-ZEYBHff)B=+P&yXXO9aK!)DEXviZ3yL3bV;Z#F$jp+$p ztD!bobIO?m^BNj!$OXvq;g94>YCvm$T`Qyp15}iZeF=&$s^`Y*w$68e0*p8OIspGl zh%aNd#aInAz;_tipQS~;B0qwy)+HNpeYQMst9nn-`SlmmUt_6dflO0gSpVtoh0>h}#v<{?0 G?EecNaXH`s literal 0 HcmV?d00001 diff --git a/loss_module/__pycache__/monte_carlo_value_reward_loss.cpython-38.pyc b/loss_module/__pycache__/monte_carlo_value_reward_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbb19e459a98026d937ea2dd73267ede147a0fe1 GIT binary patch literal 2219 zcmaJ?&2QX96rUM?tk?UIkcK8fp(+RvD^*#LkU$6_N?N$kAhiid7S+i0Ot#*5?d^=W zkjS|-L6wk@`Umuo9Q(I$;KFIHh(rGYkhs8m<4v+kn~pVao}cHX~i5$y2nH#xZ6X6Q)kO)sY2TtS(=ML%l z&q0&)sByn3w1)NS^?SNhaZ+yOs>rf5AAA>Q6M0+y7%PE$HWJjB9l`WP4}p;+qJl(B zQ0WZap*Qr;tqno#!?uIRWSa^{xQERskWJw|rNS4DLpKUVM>avPC4#3+G(~vmMQza% z?L!iE9ut>DU9i1tR&Et}Dc56_73ZvcPQjMVq%~tF@@<$7O#Kd!k{pmBt!Tw6r*bQ= z@~cJ_RLv?B^w$sr2bZvoUTC~gA!KIS6B`q*OF7o2JBUZ4n2S^=#Uw9HH-qT#?Mcob zV#Lhc>ASMQGo4exDAqe>E-B0*tbar5ow6ABH`91f^w*2gcv4Em?~6+#tIjkQ~f5$wY7-uU3-NWFDSx2c!lwmEGH^Yu+>inCJ{4iy;_AzACP^wqBqDB zde`6gD({FKlYJjn3M(|+egHd~C9C|Qb3g@KB9?a#=nm6gmfj;;`e$F4N!1W+nFxT1 z8|3cheh7MQ6|RtaMg0}B!_*__7H5&{2l;|<^)Gl%6|_h#gKuyfYJChls^$^cQ3V5f z%qV&A1Kid;fX>6Sc0Bb*)fA+D`*Jvq~6PpJ0o-R=?G za}4nJASjzhVdWlEgt!fbV4$!Lp8p5qs{wF!K)(bSdFW*5SL9Oh6OX3?CzH8pOe?@N z##y{874-det)(i{BEP9rp^T^Ed?1aV<|sCPJRZwj7_Tf;a^Lu+%ypr9t(mepZOrJ> zOKQ!ySjWb-Mmbh8< zF4K?&AnVe3HcuDWD!okGbnlaki}ls5l3=h;2J(ovwo(yn1TCQlhxC9h)#w^6%_4+F zjv8eF-W>Q7+0N_HI(p~&t-%3`73ZJA!6!}j*2Q*Dcf(;=3cWeR%^Wh%xd}P1@e24h z=MN@vR)10N!dAob+=PJa31&UcOe6cy;+I*xm*%nB<%vS3@3=V!m)`_27+zenKt^tYY- z=Y)_yQCNRm5I%&~+=5OJK}#~D8KpR8C7U^!W5-VEW-jQQ(#yP0h;W5>M1&`uLnrfv zvqy&h3-CaO)VQD3O#|QbZe6uSn%x5D*B^f+G#Lye$*3S16I41=cj`_3)32cL zvVH#<*{8x0?opHlG8En~RQMt|aXixh((4mfW?(NdZ*zqc_GX$PQ@C1fVRz-M z8^Fk1whk_sqeWxmpa<9gC1!MdVZnuTS;g(g!+umL$_VE6kn1kjPGJ0QzQM=ii`>K{@aY(af)k%vV6fEpcblj8t9 z^fau!Nnn>Atgu6P+bk;hmdER*M7qP1<8uXbSRW=CZfwCQ2*;O~QKg;Qjx=@BdUxf$2GYaC}yV*FX@og;Jc- zLplI74uCL0Vkp_%W?Qdq$~(f#0$z3O=okm+;ykZI+NI!?>I^t=q~iPw*!8p|w_o!P zG-9oWeI()vfCNO`Y>kfL;p(MP8bqeAiN!gvI3SG(t?_iPJ@)L(x_dUn z*6alb<%EQS8 zi&J55JKOHI_tbxdQ}fJlaC^{rKnBz(R|~45#=YMtukh-9C#X+nHC~5VgEx5dJ_(u+ zh(m%FsNd4()&_$lPR07_Gt_-dL8&j-qq!N45y#htSG^A)CkJGk7PPPmr*I3es1((r zR@93|(JWd;o710LloWPBn8gVkYW3kdp&2Tsj8IM(zjniDW z6R;T@j8k?O?WpID=hB!|w~h%zFt(nHGBIF=Rtn6Y5T+9C6N(X85@`gQaL^#k1slYm zRh;k5Ze>IrZC5KZ0)M?NJw2}EKcK`F09vDkj5WIq* zilBx7gH>W6NC;LUTL{_+(9Uum!5IV#03K0YV=PKSFbsP{{!4mW*--S-cz<6cy;Y%h z@@&*wkHgKZcQqT1#<`H}RuscEMQ>dw(d%FPmfg4=T)Vl}GnN8Jj1$qz$~a&-XK0wx z;;giqxnp;%~ z737c<^oT+L(LTsrV4EUmD9ARlELXVmcnvVUU^Y0# z{w~m!(Sq*xUaUNxaZ2`-_UK3Sw9)>1$f8X$u?o6M9@0BsOzgrwBu8Z8z(0k57goI- za=IYfHn*0^fn_+GlL{mg*>;x6qzb83T_(9(INRO<-68T9#7gQ4IjkIkg|@56eU>be zNexzOuv%FrD`bj;{0iBz<*!Ap0UW8r>7EW*bV#yI~$@ zsnQ;%OUiZoiUKOqu1T3m?dF+`wzQXvRKc`1PSx&Mh<(v%p5#(+HqY`fIh_I8L2GHp z+`@(%=+>zW30l!uN-z#oNG|RLbH{mG60}4s0U)+R>$D1w zLze(spl{G~v}!r_+tjCv)}r+)ZCY*GglGTC%d`0O zN_FAIcPj-tZEQ^`1Q@iYI%8$k2fWGH?l?@!6^Yi>R+{QMblnJ}B2093sLD2{7dPTG zlzS|acn5H$%%)vIJ4pF8rL$m*k ZlE(I?S5z;T_WBGRudo9`R9O&v*8f>=*Tnz; literal 0 HcmV?d00001 diff --git a/loss_module/__pycache__/offline_td_mvr.cpython-38.pyc b/loss_module/__pycache__/offline_td_mvr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6242be2081217f8254432d425f1a8c24dfdb75a7 GIT binary patch literal 2534 zcmahLO>f*pbjDxn{Y;vs3894=ghUrryMV-pP=$b|aG>2tZ3@Doj2zEw;*Hnd%y>f+ zd4)uZdP1s%!~tyQpAne7@EFW;AU+37SPPt_U%syx48X)a9fy1rtfg` z4rx12UfeaG|9yEK*s5<&u;fZWY3QJbfQLp2;x&te1lVA4y=JaaGv>A zI4zzT32yf54@sXE%2a(9I|R^>IA)p?ya9uU9rkXXcTg1SxJ*y#7; zC=r|MpxzS-N_@N-F3f0%IKC#l>U97aIU)m^)7;3d+|Hf6lDl~|ujTc;kvH=er{6az z$<3TFgA+K!=)pPmFy$iFjd3E=IA$snBc)s0!EhKb9;q-LCz)==U^LbrCu~26$3ich z&ZYP|kX$!U3By3`>V+^BeOA!Q1&e2d>5A4FrC72ek`OfEph1uc){j7|C_9+ls^HA% zQqCK}zL3GTD2{8J+E(L{klJM|N}`N0T}jGYwHLIdMBJAS7)w?VTtMI=s3O2%l<0T~ z!AN8iK??!eSuP?thhPc7W2&o+g>e9eVUNjQUvziUq39;j;h~7T>q6~j>8QII1>0%& zdO949Ga=cXFoJ7}?xs+p+q>~OyLH#UaeJd%SPC35jzu>u{mwGZhWoNJI*`k794e_i z0A#_T#(zD3I<3PCvw_ojY6XaRI|2Ku(7;Wm196O>>+mYPeNK)^PEROA2kpU{SJ<}5 z7%pU6X_~1_2BQ%;DGMYr3g@b}Ay;6nyog}!6g&@@UMe=26Mq%xs?dTSUY=JL_necR z(-xhAZY#9?3Q}U5OpKhalSlO4M-wwQkI4y{Sn#FrwSntokdQeUnA}(;M@Avpm{cG$ z$iP}96BqKtT_u^FTLb5a?h^SUVkLEr99K@jA_Ett&y!^`sRFkO+{!9hBU2pKuaR9- z{*+hiAdy&{PO4cAxb;%9-0eGKjU0il_j+F%liH;IFFV1md9B#ZkiSFb){C=ic=9{2 z^WLxXw0BEdrO>{S+uSVA*adsYE4g#r%&R9*EBD^atDNo{kEnV(Ykf4_lR@zTDR?Fg?d6JAj!u6p1vPmPDjEYn&Q3#JL@ z1qsJRNg;=;ONNK>)H#ixEA5&$9h)2L8}GrANO|xJysaqnq9hUh2oopkc)c|9vNZLE z<8bGyw>8c>osK7xbiBReNi34m2GDzg;niN!E*=843@R345x>uZDZ!ROvMtWafHq^1 z6d|mur>YgTL03j`a3G}8RgsK`P;4@xE7SYxM!Kbh+z+xSO_X*pJrb^4A1a_C?S_=8 z)OME2a7R0tNEA$~<3#O^g*X)L##!$8=d&~m;`t2F7FtVN#Vt&@fo{%ah~EsyQi5@y z+;MT=UpURvlA!If0@Fx4rOwZ!sJsNWdJSH66#y|UTB9yJ7F_{siM~iLQP;4{m#IgW zjb-CG+AvzQ0ngzZ|IFV1DHS35=;=WGwECbHoMxDB^oXuPe}}xpnSnE?M`q7onV!cs z*V#?%NM$to6_ykM3Dw_(5_RU(HZ86DOi`*Z=m?UrPvOj`m1ueXE~P-bg?%J}0D?YJ zW2`LEfHxT18wYX8k!U?_B#Ev;iwx20L9E@ODtnS%-iney9J(-=|C}7Lw>guX?b=6nZe%NZ&30yaR zzZ3ry67mNQRxcL@pTeWB0}({fl#FRaDe^33V<&Pf@1$E08WV?wwp~cyHd#vqIjCRhoYpr&D=fej6)+liLZZu3o`npan3JL{yN7 z2`ZhDJMu>UnRiW4d$WGjo5R$WGN zX-PJKGZ*Xvr_`^4Tw5>b?7i1Jj_aSYJLNfM3G-+@CxGz*nX{7KBH#QucS`4o9Fw^V zZwhY@=KKOGup}c#unls+K<}4%4RB6I?gp6$fcjvA6kh3$`~$j2)C9Rox=D^|#}JuO zfZ7+yI+=%=E9g8dLeQ#%R&9e+y5F-+a)4f3`I^n^^9EE=!&=tS?jx|=|FNvu?=(tJ zIF&wp@7K~lYL?+KG^_m&%23cfrhhA1-%*qc$cU9Q(D5xF6z!w6k{#2dbHcUx@@9ih zKUWm+p>;LU!R&7>v;9uD zqp+nY56Jcl%wzYw32+`70F*g$L-B;i3xTKNY#>jH&N!*eELKdoT-8>h31v1NLq!+T z)D|9^X8uS^^*Anud8Unz@XLg0f2Ki-vfE1K%6LVtl3n8$GSe^&r1k9M7x!me>%0Swo7O4RMXhA2l+0jLnUGJSwPhVt0^?QPiVlrm`Fy4X)Fy=J zEqL@bAcVOzu(wOsX`No69p`i){~>uQUvyLBrbF-w`>9Qg`Vd!YExLOm;W*~ev}7Fk?0-GRlD?|endXD bORsolgKb-EaJ>rBCzxxE#|40h0m#_D)eC@a literal 0 HcmV?d00001 diff --git a/loss_module/__pycache__/online_td_loss.cpython-38.pyc b/loss_module/__pycache__/online_td_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ca788dd0a7f3cbd740491755bd71517fadf9643 GIT binary patch literal 2847 zcmai0&5zs06`$d!L{Z$efmU?POl;Vcp66 z*f-x+-2NNkF87`g?g{(Cjsxx-kiP#MyOKUt&WA-=;(Kks;NpX#s17QbOcs+{WtCKd zN%2`CIh*FP$g)v>w4Z9>)yG>5%3~0M#FUfR;#Alt&WU>h#m$#fyw=eBk_;)gx$~qQ z`=Z6&r<8l#f8xaL#i+&G7zK+_z&jXqc*wg?NF07i91?eNh)9JW3*RYo~T=ulDO!-L8YnK)djOJfr)4OS$EABBW|DHp)j8 zWBrbDN|6m!N92!2vdG6GuarB6(=rt{o?&GrCZ+Q81rZ%3<8i|Hs7#9~xT6e4vf(sm zkCJRE)W+4gaj@D*3o&G7w&B=ZPApi$f z=OP>26Xj`DOa_OeN(?xvTsQ(D0g+Yebw!NQrC^_8vCy>-&jh|Z>?LZ z_gv#$GgT{b(Tu6zfH)#E3(cd?~&jr!yZ2A}7{1nY9dMbI_+m zhIOj~!cN=^T6;SP+MQ+=GeBKyo9vRghdtW0HTUcGB|^Kl=RUU#h3}sT_69e%cZvSK zkUF)?r&a4ebgB1j?<}k>M1*r5_wFJ{MOyWi-bBduiif3;kCJLsOx^Pbf4C3UDj{|QgOty7oRsI?su^sBr zmU$f-&^z=x`T*nCX++QW)}Z%`+pZz;8g|#gVUTzT&%X$o29Fgo*afobsUzRf=$qS> zGk0rX=aN3KXZFmQIl!$8e0m1E-_LxEfzSo!{J9V81yzev$hS3SFDzjEf3r^A!At-= zhjj=&?UN1=99$vzYh+Gum(04gXE5s3UGRrm<`r^-%%a?Z)~M=2GOB~RvrW*4+vE#t zmt5#Ly7|DGZOk_7sNOV^8(I!%kIufWJO8ioE2?8@zEMXZ#v)V)) zV2#tOH5Z&M!}q@NiF-*AuO9ZbFB++jjg*UcvGN^0JUpCmN}dpj?^cN5UY!!M;gV8+xfN^63~^yAsN8(s(J~(1N-k zLF+QJ{RR_pAd}A)X-`q6x@*}=MaxwkZ}N`Tu(ig+P`Q|n1o>O7<#)RKz3%>?dD~YT zD7UFHbfmn-u9reV-h{xrXyqT#5gGvQA>CPkz5)IX8dzH<0Ipk(b(`L>c4%ZpfPZKO z2#52x)`H*{@_edO%!^0UHkt{6*FkVvAq-CF1>HtE-9|CT2w|{|vW=^Z7e$F3-eVu? zQrnbe<+I#W=Dw|Ll$Yfl^=ae#mkNr{P{1W)51{fz6}ho~iiQC{AoxMZk^dM|0b@<| z1RXMVHchf-MZSf7l$GZyK~kI-WennI}f^iq*!|rfj#~Yuxp|zUETLTL^H4 I#`u>1GhO!aod5s; literal 0 HcmV?d00001 diff --git a/loss_module/__pycache__/online_td_loss_constrained_state.cpython-38.pyc b/loss_module/__pycache__/online_td_loss_constrained_state.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6b4609072df819a13f8171c24487822f173d326 GIT binary patch literal 3510 zcmaJ@&2Jn@6|d^=>G`mo3?660E+HE}W+iNpfL);wK${2&C0>bm1sVl4y;C)l9`|%l zs@qA{>b~HhJq7Uxa2Rvj8yEfs4hV7Lfa=OA2d*IP1%9u3Y|jSdZdKK*>Q_~-tA4-t z-iO_8$nbQ3zZ?H;i?M%DXZ>-|`5{Vi7nNXwr)6 z+oB`7Xgi-UhmE>0ucxnnkfli`8++c*v!YZn`r=_+%2yowAB@{;wvpUHu~CYfs7iLi zmb~JXRoRtOxs_M>RZz96@HC{e!X|uwXlb`NT1cgXC`z)VjG|#%JB3Uqx-GM(iORFN z%u4Ofb0JgRIm%R?rcqJKMWOv{CG@88d>%(4DaQE`w0kMeFq#}?(bG6Rl6vF3ujI2> z3B57SqrjmGB1z)GbPKb0z;N~85D*R3cTr!kg|K4qUB?mf=uaxY7ck7c@dR~Cn5 zzSuiR;%UD3em-9ul~P3y$BE3wa_>MEa&Q0M&!Pt(j_&>RgFUk?{s%X{mzzV5O3@fI zzBp(68cuq5@v#bUnq4|p3l(c0-@4>xcF&KIYZs>?i?NwSGx2Po5QyGKDR2!eVYM}7 z%{y@(IkjCowOe~7t9Z#ztm5kxpYocQ*3uUIl-K^b6r4+~ir;0w_+=edZ0X>B`6Ih- z*PXhH8*jhMYP$ChTDtu%+E!&R-KxFxPFdBVdcaP>rOLr*$4D-HH!bD1255b>EwlmJUS&-z zps}?-#i^J_cg-CEml2(z?wf7{jOcvy>IHsYA2fG&G1y+4ufvOJE7@odTRKdN1U!oY z8ttS>0oa$h8b8)vDYGJ1x))7lX-qP3(O$e*$V_NY$TA*3R##!Ywo{o6H&hF<+Mxj_ z$tJnFM*Ztly+PF%>C0A`mvP!aZw33!wIM@|(fWw7RCkJazDT7omZ&#rxwokL60Pab z@2PF-n`%IV0ohF-Xz${?s2F#6pNGujp?L$o$piCln|{aYaF2IZW6hU?V86DmAxuoo9O(bgw-dg@`T#zt7B1VB4Qj!HWe4*) zEb`5$gHZ=kxAIQ;vL_s2;hP7&jYjX0RR?v;=px);pVkSdoiQ)0Lcy;A=2Y(>1Rb*F z<|$ilowDEZA~Y7WyWdf(CH=~7;A{m!}mGSk8BC?{eqA0GYuc6Yt9~FgEG~=F9AcWtbwonvtv^|bX z{T89%qdYGuV1XVYDq}@s2F-3xLiF~F9hE5cA60iU)k%}C4q3t+LL7P&})h6vt zZc|%XrbEP(F>N2G+E2e_Ag!*D3bF=7O~X3vkke4RG->nVubTb!*2T!m_;n>iutN7% zHlJy9KlcxsN4j@rb~9tt6CY>sJQ+7(qj-&#uklgcf^F}j6vSHwBHrM)_PAo)?zt#>Q z$Y$u*ZsnHn;zUDX6O2#zAyW%{b(dc0&)~TYMcn}-atY_9Wnzt2`-m~11qg_hu7X%~ zkU4_aK^bCB8?<&1Ie3Mc&AQn64Une=ngkW%7G}0+9->q5xaBKn%6>`AAxcziX;|nGBket_GK2Y z^pRpC=LKhF6zMP`r-y_Pawm$O9K~t#rMaHlSX+IcsykFsI8$Gy>KjxMnH3$VIzXkn zt3-7yktj!N2j3?Z3dhPtMNYM#e+fLxL|KdD{a3Hsdo$P$c8tz2&^JxC(ipe9Cgb99 zE|VDUQc;Gfn`~&MxXNc0`2#&zo3=|?`O1Sf|MpF&n9bHKT}vEpH!I$uiz_H35GYJK G;Qtp``?1IX literal 0 HcmV?d00001 diff --git a/loss_module/__pycache__/online_td_mask_value_reward_loss.cpython-38.pyc b/loss_module/__pycache__/online_td_mask_value_reward_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea9eae575593c1d88e654d23989443c75b7027b6 GIT binary patch literal 3433 zcmai0&2JmW6`$EJE-7*)%d%uUX~MWh3;G~6Mct%mVZ^ZA1ZX0+vfQ+bxPujUD3e_7 zQnO1Z5|%wgF49ZU`~f{kKrZ!p%%(bT+i{6SJ`g^mKM5&Fki+wxu=IhP-_`UbE z)oL(2U;IZ#=`v&gpw9f`pz{vC{5_DE#XMmn9&%1?E3rm)Xq&d3I3qW7P1{Ml(8IWs z)I$GV7Q3{ z7Fscv_P`mq1MkAW7IX8?&4}&Z!e^|axtvGne!q#WZVPSyVFX~qwr?{)!e)xZG`wfSLzufLF&WTTL`fPmpk3b4`$OgRR zrBynmTY6=!^vinLC>P3R*(!sW|E0xQX_rh`F~gbGW~ZgyQ5MTYHz!8fA}{1P*THTy z8bu=R=iO|Q7CK0va?zWl;su#OFP`?5v50P+GDfhKUhHPFCk(Gzuyn?ltq9IIX345d zyO1P?1Z1yXA5!&;{rRoT$eb)yzZt!dD%zFixQ?yed@`0w`$F{7ej$XerPZyVV<(qM zPt~9-d^zc;WJj#z#~kj%H?ih*u`3I)lVt_P?05_V7Lg*JF?`;!)EbtlSBcy_ zC9j~Rm(AL9)7QbzjR5>${Zg^%h;vKN$xXH*yBfK_M)aJ>J+*e()GiPthY07PqqYj` z$UbI&y@SwT1Lw#st->4Bj;)XQ)G-{tbc*_nV}Ho*vfr>R_B)1s+?Y?jl0Rgh^Jni& zYvA~0?Wl3grgi)m@NX2&m}8e#Y^||F%gk&{7Z8$c5UjCjvur|^`e5;pmyPReh@I3b zLe#s#rY%Uegt03K;pQ5Gjabs&FGtYUnvnp>UpC+h|~#I2rfePm4+rzFG0y|8b&NP7=0 z?5E10Tr%e`(&>;VNSDv^-Ujb;zW0;TKe}E9yYL{+_WoSD==rAo2i*2X=^pbbpbI*& z4>|3fe`Rjz!r$yoMKGv5kj_ngOLUPPWcPyiWjnc4FAVhO+D>Gu-6B)n=YWAs?5Zd> z;jdTJ`7E(Fk!oL2X4Ces9{pY~A%~nKmuif{K>Or&07pO|wfgZVBAPK=6{Wj!4p+$K zb)%=Ukw_<_@xE@Jmn}*X9o)|$F6F65N@dDmk*SyHz&eeeG5Exc~|3d(hdN`0{%oj5lnD zH+h3E^HU#Ptic2TNL)s(y3kAM+|}P;Lq`|Ai+a0AmFDpWMMY2Dfov-ImpVjHC8w!>I=9x zah!oyaS`x-X&oVFQ#hQ4_32J~clP|kKdP6O0blzdY(J>{zkPNM)JpAnA87Z}B5OzO zov7$OZ?tDmFH`MNres_8dVQ0T_A#+tJeOc2?Y8IQtf78@3#%UzG1=1O%{Fxnq~Dw-6)h}AJtT}Zs+#Q0@LF}a*nTM9r^2YJ zopep4rhGyXVj~-#K|zOC8o7}FY669c5@LuWPmZTz?RXCvC z6k=}@CDn*Ah_=!c7&J*unH(kBALUi0q*r$OX{7c=SJ4#^zbczm7TQVr5bQOGWdc%> zNz{E1-I|p=U6~|8%$bgir6{W_7ldjkQ=9mnHM8GYh2nGTpTQe}9dv>#W_G8c@6J^~ z>(Q+KpXn7&{Vk=CqDn!(X`FVZg>?hgA_GuWIbeMb$PM~nOp&$q-OJZqNDx(7inW9U lm2TY%)~$4Gt}&+#V`x+3_`j&EbB~QpW0zPcAl_q?JG!QU9LCZJ;)0c^^2#e%93Q59=`Yf9{8;%MO4W5 zqhux@PtjlBlT#&gnSwpc(#O-?K1o=9nu|{U2Q(xLdCWp9=F%QHBX{Io`PX7@-lZ9_ zJ!pQ$20V6R_hmiwWj*%(%Hvw>zjVUJqF0X_=rtF;X52!r6}RKyB@5f1F^7c#&bOwU zk5imR?mxtNzT(j67yJFSC3UGypN%j7AxOba*@&0Cv`VLRORubzepxRYWwUIR?J|h@ zpWB?3cFBYlGwfmQcH7#WWU)+iYi7(R@F!2%c9K4 zk}g%h6}^xuI+SL+uC3jCHkC^ILJZSkA%w1_)u~`pCzr`U)nF{;6WJnCCqgl-C;(K8 zNSjDNgiNf~iCiOs2xsa#$QNAKh3F>{3?sf^|NgM|Je$Z~Iy^a%N$;V|$3-^n?GK~F ztoI0Ax>mIlM+xxW_oNO+I&CV+uApYGH*3J7C$yhwprs#QqFDC<(?2MKCoFg{)F4nv* z4rL(@vaFyKn@-`KB2vUNUd_9f+Qu^V8j;%<v3|>-abEZxpSVLrOcgw%Ms= zW;W(cuyjU)4E05$Qaw$;dU7`F(-APe^6{;e7l-k z?es_M7CXhMp6z{P&DZ9n!&O<3TdtApu?2am49a!0{~GNMeS&mjmG?GypYx-il>XVx zGB`xUaCY>k(nZfV?cd|Hw@UY%&w)d*iG9jRcK&yBN*D2E?<&GH<$-i>>${>4@Rt26 z(Uu+LQoS&MoohRhsdkG@^`9fBGjXV**yOU_R4dpnj%HFFE4p~J{p&}+)9XO~c-SZ0 zUyM=MYoEdnNC!Nl)-e7=L<@$iqV!NMfdqxPZVXg55$S9)J=U#N-J&GX!Mi!srClB= zl_>)(x;9OsW2tgomskj-Dx|Je7?^91k}8e0SI9KaRQK9tKIxmw8;BwkGB1V`EDswh z`goCn+A|=S>-I|egu&$$I+u1R`@^l3Ndhn_h_4-E1e+$d`!l6*Ndi1$`H5b;NdAfu z)-MEWw}N&))LwO-D-fW*4V(Q4U;YjV;|<&4E#BZ8{57;3v>fi^-Qo>v6aVkvTjzJU zZ#mXoewS}rTRd2-IQjlJ@Zv(VZ@|(Dco1+5F;|_r^g7fvkbRF@R63|Ph>)+k^~(@cKPK**M1D%-Eh29d z`5B0|q3l(czm9Q+a)C4fCqQjHjcYb#BGDYMidQJn z-aHtlkvbNAMJgj7Rko-Ew37^>$Q`I*5=@cF)qN1%UeqIfnIu9ixrUsf=vG!P2o*S{ zmhe4m;h)Qj!sph%f;R#?=myu!>~2HfT`Es@qD8sAup@$&>J&xSy8eNQz{Qbw3<`@3 zC{qFWPF2|%)BszN7j^G{FS!EjSVg&9S4rtPF&n?BHR;n-;SExxkGx5r?497QQiG;y(ndzly@_kr|BTu{4!1_;o%N4*UfSttPnk?Bmy+LPA53h;kBHoCfhkhyw-;QJ=vh#=FX#L!p!Q~+AStyA+^ugARCm7=|)y5 z*K28~5b1_&z-c8qOP)p2M1s*N=a8ULWE)9{1cH^w1t3qTZZejnF&v0JA^(W}Rz4E_ zY_PW{(*AW(Y?t}Czcz?B^ZvDbG@g_~vOCE@WQpjnK>+nvZ+yi*yBpoOeXDOy2nL*_ zqMsX|v65FdV%0fOVr9H57h!AErnm@%v^{D)`~L0C?RB5y`jtaFKq$KzxU5JlT&aGr zjUjLiz54ZatWr%3rIRg7wjKGtB17xI zR#piJd~AI~r;g!x$|)N)$A&pBr&CYS>*U^V(+2c?)j04UlW7zFP51|8z$wg}<5q|4 zTOjwtbWY6yr_~|TmTJMC&0%|=s^C1?Ch}JRi+6!c+sfnCI|SBlgKUWMXUP(ohM5C9 zhUGlWT0n?(2vAt-JL>}3M?EfnV@>C$3o29#X68K3gO$RaKa_uPMunT;BT9Cpa-rv& z_9x}y!CB=#rc;0~Xk_oxZR;mmyjCuKOO*>AwO1u(0_g$iozu%Kfn+F>Q~u4@3nA~D zY*%PI6`4VdUXmvh;m$X@f6BoX$Qo{P7D5E|GHZP7)#2!DtA=JV+i z4Xj1z`E=28tYzqjAQgbz-aG#ZvLnr&!SxBmp-a%w265j)12AG4ykt{b*#`tL;=rH6 zzrvCSQ43g^HlUD@p^dl!T=-B}8XZzP3h)Tgi5zF>Rb0S}udD+oI~dJJ$T?bR=kktU z`Ug#Ane1fmMcr#v_;!zv0hPTw>+|j)EAwvLU60FTE9lnGMlQRfT#ByP*cg};y$gcf za!UXks#UgSmvvnfNrqNTn3;P+fXx(;oYN^HBMED z27~%68M_Bt9+rkzo_nTBwl|LjK7*_R1wLS`Dy7hFF}5>_(`v*xLR(p;1Gs{in_Zl0 ze^gYL1--OB$YQz65{c>qU}aucTOwIN^tM*%6x&I0;hIY6do-uN8le#i<8(4|!i)w96i02d60&~leSU~kv6RP#m zS4_rbIi#a1aj(o|W0PZGOrm@B+y6c9RF*fX5$7WQD}%a_n#a3f*ReUlj0ucqP3iaF Ut6cIvMwPi(v>+q3&=QS*0a?65-2eap literal 0 HcmV?d00001 diff --git a/loss_module/__pycache__/td_loss.cpython-38.pyc b/loss_module/__pycache__/td_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77ea38c5e82e179567df528a9959b3c3ba262f49 GIT binary patch literal 1997 zcmah~OK&4Z5bmD0->W4XlO+cWA;5rQapD39ltf&1jkL)MGy=6|r`wL5@r=_w_Ab`= zvI*M1;T&^ENZj}xeS}l~00%Bm)#GHH6%sv~p6=>;b=6ljPaBOYf$`hRUi5`e$e*~` z91d*0f>#~D#0jSfnb43@Tw94Xu|wOe?Zi1onv=MpYrbCSeM7j*y$ixUVV~Qf&z&RE z^{+t_=~C@H%9Mg{^Kq8uVn32ewx-aJ`Y|$X4sX3teOL$*QcgmPQ(=#tu{#C><(E@< ztbzB63@Ep`b5Q}a#)&ZCJ`XNPSm703g>Us0aY$GLmulKOK0w>o6qFywwGG3##@B{d zy$@4P&dG`vw6F@ha0{>Si=e0!)yoPcg;Nk_adJeiXuoS~SIwqE>VUCPI?5T-e#%rX zrV30tN+bha6X|>;vveZTT)UHui$vFFsmzk3vk{2^GI^tiZd}_hFWsmdrerg)T6nO(NGw!E2=Au#YqHe*%kTwN$(_^h+aBc zEJV^f5b89~roH1)G|YPY*vH~{=FpLN)Qro4ARvGraoA|^S<(q+P)nO{gieTI?ts^1yM(en+bTRxPc8L(UVBNA&S^={>#Ig#U(!{RI~;Pb zvEof5wtgY-{25VMY&P4KWpo-zXKO6Wu>N1@d4p>%oB z!JO{h=J3OAM`CeGjBo$?9~j-1c4d?fg>2%1sS`U9sjg?wl#uf%A7!c1E-F&Ga;r+V zfL&KaIs+U;xzPUFE$tz8Qm&g{D^P&qhf-!zyLl$#6Yb?9Rj_sDBk{ajF_@AM(e^v2 zv2h0&VI!VN2~`cH&&7+dbqmjupxv@A=}3E}{WmZvcR(q=Y8NJAIdqFw={D_Hj`c3x zvbJg6s?$2XPZu9=!t;M>A;u`)3|wQG8Ni|$PICksJ*REx?`>#y2s_+Cs_lN*JU(EL zu=SKF4tLlZ^QsDjVSczlVeQ;2!Q0J<4i{{j*e6*d3> literal 0 HcmV?d00001 diff --git a/loss_module/abstract_unrolling_mvr.py b/loss_module/abstract_unrolling_mvr.py new file mode 100644 index 0000000..a2dae3f --- /dev/null +++ b/loss_module/abstract_unrolling_mvr.py @@ -0,0 +1,226 @@ +from model_module.query_operations.reward_op import RewardOp +from model_module.query_operations.next_state_op import NextStateOp +from model_module.query_operations.state_value_op import StateValueOp +from model_module.query_operations.representation_op import RepresentationOp +from model_module.query_operations.mask_op import MaskOp +from loss_module.loss import Loss +import torch +import numpy as np +from typing import Union, Optional, List, Tuple +import time + + +class AbstractUnrollingMVR(Loss): + def __init__(self, + model, + unroll_steps, + gamma_discount=1, + loss_fun_value=torch.nn.functional.mse_loss, + loss_fun_reward=torch.nn.functional.mse_loss, + loss_fun_mask=torch.nn.functional.binary_cross_entropy, + coef_loss_value = 1, + coef_loss_reward = 1, + coef_loss_mask = 1, + encoded_state_fidelity = False, #SHOULD OUR STATES MIMICK THE REAL OBSERVATIONS? + coef_loss_state = 1, + loss_fun_state=torch.nn.functional.mse_loss, + average_loss=True): + self.model = model + ''' check if model has all the necessary operations ''' + assert isinstance(self.model,RewardOp) and \ + isinstance(self.model,NextStateOp) and \ + isinstance(self.model,StateValueOp) and \ + isinstance(self.model,MaskOp) and \ + isinstance(self.model,RepresentationOp) + self.unroll_steps = unroll_steps + self.gamma_discount = gamma_discount + self.loss_fun_value = loss_fun_value + self.loss_fun_reward = loss_fun_reward + self.loss_fun_mask = loss_fun_mask + self.coef_loss_value = coef_loss_value + self.coef_loss_reward = coef_loss_reward + self.coef_loss_mask = coef_loss_mask + self.encoded_state_fidelity = encoded_state_fidelity + self.coef_loss_state = coef_loss_state + self.loss_fun_state = loss_fun_state + self.average_loss = average_loss + + def get_loss(self,nodes:list,info={}): + if not isinstance(nodes,list): nodes = [nodes] + loss_reward, loss_value, loss_mask, loss_state, info = self._get_losses(nodes,info) + loss = loss_value + loss_reward + loss_mask + loss_state + assert self.encoded_state_fidelity is False or loss_state == 0 + if self.average_loss: + loss = loss/len(nodes) + return loss, info + + #TODO, we can pass the loss of rewards and states to the super class. + def _get_losses(self,nodes:list,info={}): + """ returns 4 losses: reward, value, masks and state. The state loss is optional, only + if we want the hidden state of our agent to mimick the real observations. if we want to update + the states like this, we have to be careful not to make updates to the state when doing backwards + on the values, masks and rewards; but only when we do backwards on the state loss """ + if self.encoded_state_fidelity: + predicted_rewards, predicted_states, actions = self._get_predicted_model(nodes,False) + predicted_values, predicted_masks = self._unrolled_prediction(predicted_states.detach()) + else: + predicted_rewards, predicted_states, actions = self._get_predicted_model(nodes,True) + predicted_values, predicted_masks = self._unrolled_prediction(predicted_states) + + #targets + target_values = self._get_target_values(nodes).float() + target_rewards = self._get_target_rewards(nodes).float() + target_masks = self._get_target_masks(nodes).float() + target_states = self._get_target_states(nodes).float() + + #losses + loss_reward_per_node = torch.stack([self.loss_fun_reward(predicted_rewards[i],target_rewards[i]) for i in range(len(nodes))]) + loss_value_per_node = torch.stack([self.loss_fun_value(predicted_values[i],target_values[i]) for i in range(len(nodes))]) + loss_mask_per_node = torch.stack([self.loss_fun_mask(predicted_masks[i],target_masks[i]) for i in range(len(nodes))]) + + if self.encoded_state_fidelity: + loss_state_per_node = torch.stack([self.loss_fun_state(predicted_states[i],target_states[i]) for i in range(len(nodes))]) #! + else: + loss_state_per_node = torch.zeros((len(nodes))) + assert loss_state_per_node.shape == loss_mask_per_node.shape + + total_loss_per_node = loss_mask_per_node + loss_reward_per_node + loss_value_per_node + loss_state_per_node + loss_reward = torch.sum(loss_reward_per_node) * self.coef_loss_reward + loss_value = torch.sum(loss_value_per_node) * self.coef_loss_value + loss_mask = torch.sum(loss_mask_per_node) * self.coef_loss_mask + loss_state = torch.sum(loss_state_per_node) * self.coef_loss_state + + #debug info + info = {"loss_reward_per_node":loss_reward_per_node.detach(), + "loss_value_per_node":loss_value_per_node.detach(), + "loss_mask_per_node": loss_mask_per_node.detach(), + "loss_per_node":total_loss_per_node.detach(), + "loss_reward":loss_reward.detach(), + "loss_value":loss_value.detach(), + "loss_mask":loss_mask.detach(), + "loss_state":loss_state.detach(), + "predicted_values":predicted_values.detach(), + "target_values":target_values, + "predicted_rewards":predicted_rewards.detach(), + "target_rewards":target_rewards, + "predicted_masks":predicted_masks.detach(), + "target_masks":target_masks, + "predicted_states":predicted_states.detach(), + "target_states":target_states.detach(), + "actions":actions} + + return loss_reward, loss_value, loss_mask , loss_state, info + + def _get_predicted_model(self,nodes:list,state_grad_for_reward): + predicted_rewards, predicted_states, actions = self._state_reward_unrolling(nodes,self.model,self.unroll_steps,state_grad_for_reward=state_grad_for_reward) + assert predicted_rewards.shape[1] == self.unroll_steps and predicted_states.shape[1] == self.unroll_steps + 1 + assert predicted_states.shape[0] == predicted_rewards.shape[0] and predicted_rewards.shape[0] == len(nodes) + return predicted_rewards, predicted_states, actions + + def _state_reward_unrolling(self,nodes:list,model,unroll_steps:int,state_grad_for_reward): + """ Returns 3 tensors: + predicted values -> shape (len(nodes),unroll_step + 1,1) + predicted_rewards -> shape (len(nodes),unroll_step,1) + predicted_states -> shape (len(nodes),unroll_step + 1, hidden_state.shape) + predicted_mask-> shape (len(nodes),unroll_step, action_size) """ + + total_actions = [] + games = [node.get_game() for node in nodes] + game_indexes = [node.get_idx_at_game() for node in nodes] + observations = torch.tensor([games[n_idx].observations[game_indexes[n_idx]] for n_idx in range(len(nodes))]) + + current_states, = model.representation_query(observations,RepresentationOp.KEY) + predicted_states_list = [current_states] + predicted_rewards_list = [] + for delta_idx in range(unroll_steps): + actions = [] + for idx,game in zip(game_indexes,games): #collect actions per game + if idx + delta_idx < len(game.actions): + actions.append([game.actions[idx+delta_idx]]) + else: + actions.append([np.random.choice(game.action_size)]) + + if not state_grad_for_reward: + current_states = current_states.detach() + predicted_rewards, next_encoded_states = model.dynamic_query(current_states,actions,RewardOp.KEY,NextStateOp.KEY) + predicted_states_list.append(next_encoded_states) + predicted_rewards_list.append(predicted_rewards) + total_actions.append(actions) + current_states = next_encoded_states + #swap the two first dimensions, so that the first refers to each node and the second the each unrolling step + reward_tensor = torch.transpose(torch.stack(predicted_rewards_list),0,1) + state_tensor = torch.transpose(torch.stack(predicted_states_list),0,1) + action_tensor = torch.transpose(torch.tensor(total_actions),0,1) + return reward_tensor, state_tensor, action_tensor + + def _unrolled_prediction(self,encoded_states): + """ this method takes all the encoded_states in batch and calculates some predictions, + in this case, it calculates the value and mask for each one of these encoded states. it + was created to be used after unrolling""" + model = self.model + batch = encoded_states.shape[0] + steps = encoded_states.shape[1] + ''' first values - we do not want the maks for the first state of each node, only the value''' + encoded_states1 = encoded_states[:,0:1,] + flat_encoded_states1 = torch.flatten(encoded_states1,0,1) + predicted_values_first, = model.prediction_query(flat_encoded_states1,StateValueOp.KEY) + predicted_values_first = predicted_values_first.view(batch,1,1) #! change last to -1 + ''' second values and masks''' + encoded_states2 = encoded_states[:,1:] + flat_encoded_states2 = torch.flatten(encoded_states2,0,1) + predicted_values_second, predicted_masks = model.prediction_query(flat_encoded_states2,StateValueOp.KEY,MaskOp.KEY) + predicted_values_second = predicted_values_second.view(batch,steps-1,1) + predicted_masks = predicted_masks.view(batch,steps-1,-1) + predicted_values = torch.cat((predicted_values_first,predicted_values_second),dim=1) #! change last to -1 + return predicted_values, predicted_masks + + def _get_target_rewards(self,nodes:list): + """ returns a tensor with shape (len(nodes),unroll_step,1) """ + target_rewards_list = [] + for i,node in enumerate(nodes): + game = node.get_game() + idx = node.get_idx_at_game() + target_rewards = game.rewards[idx:idx+self.unroll_steps] + [0] * (self.unroll_steps-len(game.rewards[idx:idx+self.unroll_steps])) #fill rest with 0s + target_rewards_list.append(target_rewards) + return torch.tensor(target_rewards_list).unsqueeze(2) + + ''' masks ''' + def _get_target_masks(self,nodes:list): + """ learns the mask for the next unroll_steps states after the node. + It does not learn the mask for the current observation, since that is given by the environment, + but only the next unrolled states. + returns shape (len(nodes),unroll_step, action_size) """ + total_masks = [] + for node in nodes: + game = node.get_game() + idx = node.get_idx_at_game() + masks = [] + for delta_idx in range(1,self.unroll_steps+1): #starts at one because we do not need to learn the mask for the real observation + if (idx + delta_idx) < len(game.nodes): + mask = game.nodes[idx + delta_idx].get_action_mask() + else: + mask = torch.tensor([0]*game.action_size) #redundant + masks.append(mask) + total_masks.append(torch.stack(masks)) + return torch.stack(total_masks) + + def _get_target_states(self,nodes:list): + """ returns shape (len(nodes),unroll_step + 1, hidden_state.shape) """ + total_observations = [] + for node in nodes: + game = node.get_game() + idx = node.get_idx_at_game() + observations = [] + for delta_idx in range(0,self.unroll_steps+1): #starts at one because we do not need to learn the mask for the real observation + if (idx + delta_idx) < len(game.nodes): + obs = torch.tensor(game.observations[idx + delta_idx]) + else: + obs = torch.tensor(game.observations[-1]) #! JUST THE LAST ONE WITH VALUE 0 + observations.append(obs) + total_observations.append(torch.stack(observations)) + return torch.stack(total_observations) + + + def _get_target_values(self,nodes:list): + """ Extend this class to calculate the value in the way you want it""" + raise NotImplementedError diff --git a/loss_module/loss.py b/loss_module/loss.py new file mode 100644 index 0000000..35dd606 --- /dev/null +++ b/loss_module/loss.py @@ -0,0 +1,21 @@ +import torch +import numpy as np +from typing import Union, Optional, List, Tuple + + +''' +This is a simple common interface. +It has some useful functions that might come in handy. Extend them if necessary. + +The main method is get_loss that returns a loss tensor and a dictionary with any specific +values that might be specific to each class +''' + +class Loss: + def get_loss(self,nodes:list,info={}) -> Tuple[torch.tensor,dict]: + ''' + returns a loss for all nodes and a convenient dictionary with + any relevant information + ''' + raise NotImplementedError + diff --git a/loss_module/monte_carlo_mvr.py b/loss_module/monte_carlo_mvr.py new file mode 100644 index 0000000..b455e6c --- /dev/null +++ b/loss_module/monte_carlo_mvr.py @@ -0,0 +1,66 @@ +from loss_module.loss import Loss +from loss_module.abstract_unrolling_mvr import AbstractUnrollingMVR +import torch + + +class MonteCarloMVR(AbstractUnrollingMVR): + def __init__(self, + model, + unroll_steps, + gamma_discount=1, + loss_fun_value=torch.nn.functional.mse_loss, + loss_fun_reward=torch.nn.functional.mse_loss, + loss_fun_mask=torch.nn.functional.binary_cross_entropy, + coef_loss_value = 1, + coef_loss_reward = 1, + coef_loss_mask = 1, + encoded_state_fidelity = False, + coef_loss_state = 1, + loss_fun_state=torch.nn.functional.mse_loss, + average_loss=True): + super().__init__(model=model, + unroll_steps=unroll_steps, + gamma_discount=gamma_discount, + loss_fun_value=loss_fun_value, + loss_fun_reward=loss_fun_reward, + loss_fun_mask=loss_fun_mask, + coef_loss_value=coef_loss_value, + coef_loss_reward=coef_loss_reward, + coef_loss_mask=coef_loss_mask, + encoded_state_fidelity=encoded_state_fidelity, + coef_loss_state=coef_loss_state, + loss_fun_state=loss_fun_state, + average_loss=average_loss) + + def _get_target_values(self,nodes:list): + return self.discounted_accumulated_rewards(nodes,self.unroll_steps,self.gamma_discount) + + def discounted_accumulated_rewards(self,nodes:list,unroll_steps,gamma_discount=1)->torch.tensor: + accumulated_rewards = [] + for node in nodes: + game = node.get_game() + idx = node.get_idx_at_game() + total_reward = 0 + accumulated_rewards_per_game = [] + assert len(game.rewards) + 1 == len(game.players) + for i in range(len(game.rewards)-1,idx-1,-1): + if game.players[i] == game.players[i+1]: + total_reward = total_reward * gamma_discount + else: + total_reward = -total_reward * gamma_discount + total_reward += game.rewards[i] + accumulated_rewards_per_game.insert(0,[total_reward]) + + if len(accumulated_rewards_per_game) >= unroll_steps+1: + accumulated_rewards_per_game = accumulated_rewards_per_game[:unroll_steps+1] + else: + accumulated_rewards_per_game = accumulated_rewards_per_game + [[0.0]]*(unroll_steps+1-len(accumulated_rewards_per_game)) + assert len(accumulated_rewards_per_game) == unroll_steps+1 + accumulated_rewards.append(accumulated_rewards_per_game) + return torch.tensor(accumulated_rewards) + + + def __str__(self): + return "Monte_Carlo_loss" + "_unroll" + str(self.unroll_steps) + + diff --git a/loss_module/offline_td_mvr.py b/loss_module/offline_td_mvr.py new file mode 100644 index 0000000..6db3382 --- /dev/null +++ b/loss_module/offline_td_mvr.py @@ -0,0 +1,76 @@ +from loss_module.loss import Loss +from loss_module.monte_carlo_mvr import MonteCarloMVR + +import torch +import random + +class OfflineTDMVR(MonteCarloMVR): + def __init__(self, + model, + unroll_steps, + n_steps, + gamma_discount=1, + loss_fun_value=torch.nn.functional.mse_loss, + loss_fun_reward=torch.nn.functional.mse_loss, + loss_fun_mask=torch.nn.functional.binary_cross_entropy, + coef_loss_value = 1, + coef_loss_reward = 1, + coef_loss_mask = 1, + encoded_state_fidelity = False, #SHOULD OUR STATES MIMICK THE REAL OBSERVATIONS? + coef_loss_state = 1, + loss_fun_state=torch.nn.functional.mse_loss, + average_loss=True): + super().__init__( + model=model, + unroll_steps=unroll_steps, + gamma_discount=gamma_discount, + loss_fun_value=loss_fun_value, + loss_fun_reward=loss_fun_reward, + loss_fun_mask=loss_fun_mask, + coef_loss_value=coef_loss_value, + coef_loss_reward=coef_loss_reward, + coef_loss_mask=coef_loss_mask, + encoded_state_fidelity=encoded_state_fidelity, + coef_loss_state=coef_loss_state, + loss_fun_state=loss_fun_state, + average_loss=average_loss) + self.n_steps = n_steps + + def _get_target_values(self,nodes:list): + return self._get_node_bootstrapped_target_values(nodes) + + def _get_node_bootstrapped_target_values(self,nodes:list): + """ for each of the unroll steps, get sum of rewards for n_steps and then get the n_step node.get_value() + TODO: this is a bit inneficient. Not too much, but... enough """ + total_target_values = [] + for node in nodes: + game = node.get_game() + idx = node.get_idx_at_game() + target_values = [] + for current_index in range(0,self.unroll_steps+1): + bootstrap_index = current_index + self.n_steps + if idx + bootstrap_index < len(game.nodes): #get bootstrapped value + value = game.nodes[idx+bootstrap_index].get_value() * self.gamma_discount**self.n_steps + if game.players[idx+current_index] != game.players[idx+bootstrap_index]: + value = -value + else: + value = 0 + + #collect reward sum until bootstrapped value + for i, reward in enumerate(game.rewards[idx+current_index:idx+bootstrap_index]): + if game.players[idx+current_index] == game.players[idx+current_index+i]: + value += reward * self.gamma_discount**i + else: + value -= reward * self.gamma_discount**i + + if current_index < len(game.observations[idx:]): #is current index imaginary? + target_values.append(value) + else: + assert value == 0 + target_values.append(0) + total_target_values.append(target_values) + return torch.tensor(total_target_values).unsqueeze(2) + + + def __str__(self): + return "Offline_TD_Loss" + "_unroll"+str(self.unroll_steps)+"_nsteps"+str(self.n_steps) \ No newline at end of file diff --git a/loss_module/online_td_mvr.py b/loss_module/online_td_mvr.py new file mode 100644 index 0000000..bee258b --- /dev/null +++ b/loss_module/online_td_mvr.py @@ -0,0 +1,108 @@ +from loss_module.abstract_unrolling_mvr import AbstractUnrollingMVR +from model_module.query_operations.state_value_op import StateValueOp +from model_module.query_operations.representation_op import RepresentationOp +import torch +import numpy as np + + +''' + unroll_steps = number of states to predict + n_steps = the amount of steps to bootstrap in each unrolled step +''' +class OnlineTDMVR(AbstractUnrollingMVR): + def __init__( + self, + model, + unroll_steps, + n_steps, + gamma_discount=1, + loss_fun_value=torch.nn.functional.mse_loss, + loss_fun_reward=torch.nn.functional.mse_loss, + loss_fun_mask=torch.nn.functional.binary_cross_entropy, + coef_loss_value=1, + coef_loss_reward=1, + coef_loss_mask=1, + encoded_state_fidelity = False, #SHOULD OUR STATES MIMICK THE REAL OBSERVATIONS? + coef_loss_state = 1, + loss_fun_state=torch.nn.functional.mse_loss, + average_loss=True): + super().__init__( + model=model, + unroll_steps=unroll_steps, + gamma_discount=gamma_discount, + loss_fun_value=loss_fun_value, + loss_fun_reward=loss_fun_reward, + loss_fun_mask=loss_fun_mask, + coef_loss_value=coef_loss_value, + coef_loss_reward=coef_loss_reward, + coef_loss_mask=coef_loss_mask, + encoded_state_fidelity=encoded_state_fidelity, + coef_loss_state=coef_loss_state, + loss_fun_state=loss_fun_state, + average_loss=average_loss) + self.n_steps = n_steps + + + def _get_target_values(self,nodes:list): + return self._get_bootstrapped_target_values(nodes) + + def _get_bootstrapped_target_values(self,nodes:list): + observations_to_estimate = self._collect_bootstrapped_observations(nodes) + + ''' Estimate the bootstrapping observations collected ''' + raw_values = [] + if len(observations_to_estimate) > 0: + with torch.no_grad(): + states, = self.model.representation_query(observations_to_estimate,RepresentationOp.KEY) + raw_values, = self.model.prediction_query(states,StateValueOp.KEY) + + ''' Create target values based on rewards and bootstrapping steps ''' + target_values = [] + raw_values_idx = 0 + for node in nodes: + target_values_per_node = [] + game = node.get_game() + idx = node.get_idx_at_game() + for current_index in range(0,self.unroll_steps+1): + bootstrap_index = current_index + self.n_steps + + if bootstrap_index < (len(game.nodes)-idx): #bootstrapped value + raw_value = raw_values[raw_values_idx] + assert (torch.from_numpy(game.observations[idx+bootstrap_index]) == observations_to_estimate[raw_values_idx]).all() + raw_values_idx += 1 + value = raw_value * self.gamma_discount**self.n_steps + if game.players[idx+current_index] != game.players[idx+bootstrap_index]: + value = -value + else: + value = 0 + + for i, reward in enumerate(game.rewards[idx+current_index:idx+bootstrap_index]): + if game.players[idx+current_index] == game.players[idx+current_index+i]: + value += reward * self.gamma_discount**i + else: + value -= reward * self.gamma_discount**i + target_values_per_node.append([value]) + target_values.append(target_values_per_node) + target_values = torch.tensor(target_values) + assert raw_values_idx == len(raw_values) + return target_values + + def _collect_bootstrapped_observations(self,nodes:list): + """ Collect bootstrapping observations into a batch + for more efficiency in the model """ + observations_to_estimate = [] + for node in nodes: + game = node.get_game() + idx = node.get_idx_at_game() + current_index = 0 + bootstrap_index = current_index + self.n_steps + while (current_index < self.unroll_steps + 1) and (idx+bootstrap_index < len(game.nodes)): + obs = game.observations[idx+bootstrap_index] + observations_to_estimate.append(obs) + current_index += 1 + bootstrap_index = current_index + self.n_steps + return torch.tensor(observations_to_estimate) + + + def __str__(self): + return "Online_TD_Loss" + "_unroll"+str(self.unroll_steps)+"_nsteps"+str(self.n_steps) diff --git a/model_module/__pycache__/disjoint_mlp.cpython-38.pyc b/model_module/__pycache__/disjoint_mlp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a38e2eab951cf339b13d666e8b33f587cb50d98d GIT binary patch literal 6930 zcma)B%WoUU8Q&L|Psx;Q%aURzjN`;+9+F!$X_C5$9otQl*mCMe9!!J9nlqx7TJF*_ zD?1WaKoPkJdT7vI3!{*LUUF#vhhBQ_d5&!_Mz1XjB!~3(&5|pUk=?E^A7{S#c6R1_ z&!<~cQxyYG>yN_!_@rU{od%;%8H4j^@;3m?U}k9a%$8|V-wLgs-LiGx4xL`1RnUDW zbbG~CQTGdBsaI~5b>9sutqR_A!)mLB-(on`n{G{;2CsAvb?X|=tXizZ%J)km;G~Gh#ere#K_R2#OUw&u-G+k8?e$?U04tb;sNi5oH%8DZ8 zUI`-Z3st$n?`+_Jfgd)@suD!~jl`3|7H?HUf0GN{m#P|XB>BwDVBoit;I8J3s*Lmt z^tx28eqSmxXxW(k$`Etd(bpd>uf;vS90gljJY2rS<$4nLmv07sCtki7_wW&X^hP`2 zQJXK{1oM|yE`R7f2qHzZ2Xdj${K)CtecK` z#GE&uGsQf{Z9!XEO9oqS507wk(=7|%V<e!@YBc9xg#kzWG zU$A3e^o9iU`vgLJ^?YW*gXd;JrdOHZ_g@;B(kd{RX-+9e@^CfV8XZb+Ts%d+27#vu zAodyJ83LyXoB>egXRVnePW%uYBR?6zQOU1NRrb7g=u7E&+WU!P^u}`lw32)tz?iC- zF8x^*vtm)NV%G8Nm?g7f*R6SL&Z=7_vt(@@9kr{ap>Pn0Y9l5k+vNC%PZ1p#?QJx9 z7a%cGGcmgs^mQ(wz++XzwJ0$iW*OAA59=(_8cTuO?c1HW6+bHFou&)O~T8QY-S zwh_pj)PmL?--b5sR-lt+YVX|IYgfba=;vQ*bif5;`({i>Emyn%E9v zQC}4!<_qC(3Ub`aO?V_@0VO2gpk<+^FG%?iJL*6WkvshdFfIi@?^G&Iy)&j;je2!LORTez)G_A5Rugtrdz$j*Esi`9q z?gwF0Y+l2(JOW^tH3)9roP*#3)qH_(BojIOzEiBqW! z11^iX)Q;c;cHw{^HaoR%7>Uiy#38yJWCqMShMeMah)8DwU$5O<;|-jr(T5Up)?hK0 zjVMkUJwIu$HT-C^5rfG(h&N#M{a7R*yS>>+@M3#I2;>G09d?&4uEnwBjW|ME)$^-y z7{>SDq+f3+v(hXG%7v8u-sO+g(Mh3bKC6mOt}-FoG7(KD6AT?$OGjLwp5}Vadhui6 zK0%Xp00Vpmmo0GlTg$Y*-D-?+dGfnZhB1pEpH59P`BiklPB0qmw2*>XunZ(-(EsC! z4gI&e(DJUup#S9CyM^CdUmDwBzhvI!^Y->Ya zgh3m-HiC}1XThj-)~U_7K_QOJSe(K_II2D%9ZR259UAqu;I|;%6>ny=+U~9p;yi)3 z2oUcDZE8Z$sy^bGwSBFFD9h)-4zjBn+0}_LwoJa$Ch-m*>K|-VM}~HEgY3-Kc6J6h z2z(D`P2dm{|Ii{#%k(yNO zF$U|Tb{=_c9eJ%m$i6*ovA-g= z)Bucu#p+17IHnI{FFZN~(gyMnIG33#at!$1++bhK3^v1#n=FMw7%&W52S()0xe+59 zAr5ajjaW1Uxk%^;*~-wK^ZP`WRZ2lRikOz;Bh(Fs+3&tlj|Q`S7EW09EiCn0G`R>s zmQ9I6Ee~hbFsv<%v26Hhl2_9jT#xoMaAMV%I+5?sEq@h*M_E3M#XeaFDrPD;FvA zhKRs=GKy>R0YK)(7O3Qf%nB;Bm>`=XO9TcfiB_zFJPR!RRhEtV^`>& zkvafCg<{#9uVo||idY;<@c8-63Zfi&&~2~y-R)VXe!VMUjN%1_RI+Q2sak);M7$Y=5o z5B4?O0OjkW-?gj-d4n1YJ{X}Px zsA93ThrOEhgk2e#fkVHtv+D`WK4fTlY}=4 zi~8nPhcBoj%bZ2!5Y{k}ejnt!nx^;^@3+b64Ik>~zV{3UqX|tHwOR^Vdz6h^zotTx zp6?(7L+!Tn(#7Nd{M##wZ_>5?9uGQeiN|=`--MmhQUlZ1wHD=EV17?k`o8dc9CcGv z95NP-pH)pmuf4oHBmazbBxR`CRch;ru}Cvc`<;VPn7l{}kPqr$>)H0Pd6R1571>_n zRGt-5a8OLJJ=@AY&0lCG;+27t#{nDJr)(}#Jql$$j(SwZ8x(=juX#v3{4kr*w+`_< z)>97YvaYpcJacdAx`McjNqwUfbmtN)1Ss1Ul*Q`IPLKx^9}u8mEhu0J@|l9}iParwf8sYaUgnM~mT0I`@K%2X}e7G^mHgeVy2-7amW&A1<** zmhHUw_%|}E%PMPlg=DD4>hlL}30JlJzbW(~%2L#`kN=QP!$C_5JJ9}Uqu1XQuh7eG zve`!gA`=NZg9%;q0=76@gH8b7Q{dG@wkKWDCw={Sj{`Ir}c1F|w z!O7%f!sHqJBTIheQ9zF9uI zswvL=nZep@0XTCNXMTdyVT-_N-@v)U?gD4w2F?<@2b>ODhIp(L*6Wn_&lH%@Wn*A0-G-FW5*-If( z;pl}ad9*7~{8AJ1@cQrfUO$a_FA4LUN4*zZe#p{6Z#(q&(%$nl9t<-sysaSQNx*yC zfUdXs@(=IzJNM<=H@z5vdq9T~FDKo>anZi+M*|Tm=p7> zJ~WwCEh8iSlzh|yi+Lk50-oct0u1QX(vWzl+vfX4c3aS&ubXZ;lAcjqF|hf}Au)#`(;sao%F&0)9k0cy;hh zNg2R3Mh4`ao&VdS)wA6XCEeqhq(04fB2z&>hFPEbQAD#or)3^aUG0-ov%zD!s0> z^!-EBv*Tw$pRy3TmISxIphagE%BAlQcwt05DcZkCXo(mU)zk_Lg0#mA3(7s{7go+i zDvQO^Jx>n&fO|5`L0~sZeYkD{Ahtgk@PrkuVImKPoafwgniD-@1Ll)r zI!G^MN3dIY>9#8)X~MmJnib?|x?JvgF&BH>W5&q+#49UuX=F&^&L=lM!Zb|9v(s}l_^*@pi5)Q&1uNZF;)kZkGo;J zL^NjI7KDEv(A90{rs_!eQC8vWQb?>yI864&#nkszUWaiQ`B*Aq7(m14by5NwQ+S-t z<3z-_?9$iVPe9oQh0Iy6WS}G*`s4}Z-ceSyxvt7DtQD-&S~mhx3K7H(=Ly|~;sif9 z+n7#OtUo!7Ya`)%XG0ZjNW+Ax93AtV+CX_G0z*B04ol-;dlNVx)@NenYuC?YQtTGz z10MN1)Z*>o0*rsqMjbI&VUUiE1J<2lbe82VAg zX?Pw>1J4r=Kxbi>nIZ7D6Bd#?NS2UbnF#D1VjW0f;&F)YasC664J21Lx*FlNF7Prz zRT5P-vBt~?X20FtCMgEb}aJF`S$!>G`cq_82}@Rg%B-ej?8 RK8}^dQ>;jd8k*$3@jvKo?==7b literal 0 HcmV?d00001 diff --git a/model_module/__pycache__/nets.cpython-38.pyc b/model_module/__pycache__/nets.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..796868434045f67715dbbcf0a38649d4ac79d372 GIT binary patch literal 6446 zcmbVQ&2QYs73c6{m-{74vTVz-<1mTicoSKQgS2T_ZZsNI4No)r#IsIm+TW-%{+Q_e0B`&o0IqQsXm0+`T&DWEYv5lG%&z5H-I7~k zG;ajuuI<|0id*Sc-DO+7O{{OD3P<*mkv z^x8>RBO_PXbX zXqfRGnb5$sh>M_Y(RYfcf`1$DCA_f&NVN`2*)c8EJAm6(N82tXTEbH60po05d%!Mh zw|=l)PRqrZKCj)@_0-t5(^8^u7&~TarRC!qK@K{{wOu{6A7~x@UC_I8ns7F( zwDe02j2p{O7{2^O2WXZuOL}2PWJVyu%t|8JUeEL}%t}}NP+ZHt!2SxK#rGv;RoW*=+vBzE-g z&z9Dsu2>5Ftt}BOT^8}}BWhv(~j6?#G{h zy3{3!w*dPANQfj}>}_U7H|W(sI<5m~R>J zKP?Pl;%5zh+<0VvsjV??a_eEkE%6dBKh)f^GCIw*bCYu`qPkJz^+$|1_|!wwty5n# zgmqj~Glw1-e45W-W@kj=#cR#o8lZYNqBMg+-jnPF(A|LPl5mL9lr#3PaIO1YZ`( zy-40(UTNCdOtcycnJcpuuX{a_&5%l?uod^bwoqn}HK3G;`(Y>V&abswQuHL06eh|J zT7}MMF~c`Quj{wt!PvCIzUN6k7&(NIyI#;2g8@S6MS|!Y*%AiT1mB*p6Y-`S_L#Y+peYLT=ngGT1)HiY1=h8inUSNDJM2Y%+z>b>@wh1 zhNDKRZ5VJ?YdBX;OI*8Mkzeh@DREY;uZ6$?of@2zaDO+33wJHB>IGIkwH4N1#<1*E zFR&T~R)do(>&uduZcb%Z-0umQ*{zlz`bn#my|@}hK`V@8*9+ioEu|U3TQL+NV)-(b zWt9k#tczJGXV90kG+4jvCEjPA^nySH%0sead$kD9BFo5L8Yyvtertd*KSGc~sdg}P zyal`92&-g5xsnG{FN{wW>MrJ*4MbNFtVxIVOH9aD2)s(*G=bL$oFVW!fj0=e3E<8o zQQ`#<4~3}^aRQO&{$rL^TCH~A#c`{p;*WfhRx|;Kdi+BGtx;tq`WYs(S=A^pJ^%ez z|E5%-o=tt#vw=7Ljct|;YE>s%Tc;23vc-Ys+7y=u4;`pv&X@4U6dqG;M@!i*+lDFc z;N(EdfLR-B`gQFwqtk49uUxa4v5{PyO;f%NV%gM$))11FBJiK1k?-Nk$40)38Vr(66Ovf@YGpB={g>T1B zZn*n2He~7}kQ_2~W}GREnmD5YPiSCne#o4WFr3r|*SYaE*iMO6bDUL{8QkwSOY%I3 z=Ny|+2$%kV`pU1%akQ3nz59ilbLPK~X}2+mo>9k7;=Qu?9#A5@u{r`Fi<6MW6x%xY zEu~Qy7t88nr$sy=o(Y}*7~e`@&R|)8pv$8PB{F2vFzg)*`bP}S^}kcz)iA2<*aPF= zsK_hmS5sX%!@sUD zh`dCe!>&}^$mX9FKL=Jt`YjHXC;CrRMr zWX)MQKY!t^QY3vDwn?^whmy{r89M0X1_o*A2pB4KsJ2Px90d>gT_7QR%ngs_P~1nR zr;dwkdN`Vg@wjP?`PDTn%?>O4JprX$pzoE6J)rjsEG^L=F};_1vWCyU*sD zLjOLlaiVXVa64U|!JLtrs&5YXH?t$Hu8mLb@?gQe}%VjGO_|?0dE~-0Rzr)unv{-Y|crQlbmlLiA7vJo9N)#*g*D!D*_~; zqAE*ssXz*pQBYI`gDd!xO9WJIaFJRRV)pSaZ(!7|tDQYv-SRVF{0i?F&$G`BvqPQ> zq7!k1?3?UeogJ^Bk@Kr!0a-Vq!WXb~=o&gs^)&;$v8_KlRcGuu_klAWs2o3bJSUF} zPPA6EV&|TptUKy7aV|P%!+w`0#}_V)_nP}Yp>9S*oM^5KC-%Ak$z~vmPn^9u;Wp_z{wJ3K+7xOA58ZKxp@()kGDWDw+{Wo9Po)+m=G&zZH`nbFF!Z#fR942Fww&5f znzcp-$-IBSAot32LO^jL+y7gnWD>B2BA(`E5a9Ap3Zs9ext#h<=jH%yEt1e8vgz}; z7M?Py)OKs{#!3JA?f9vi9=Z4Vg;6`VfCXhd#bO?b*a@S=X?s#Q-YSMCK6P*1Z?8K& zOooXgJhX{izF%im1X0+Dog`XxR@S}b_1Gb1ioL}?hhav}n&0VzHYhJRe&|G)mJZT> z9=|{FsoXQjC5DE6R+)M3D5xuR)^ic6eZ>7$C0mdzB)O1pYcWsQhO}|cU5E8#B`y*$ zkQ;nu&#{BsteTy?-^kCgQ2qqKJ*mVvVJ|H(n`dr$M))*$IV?rHFG+eMhxsk&$trxRYTRBIRHc!z`~o<8!|>NYb2>PONEB`-5S@cI*L0SgXYFH1GS7e~sg8Yqgx7%j zsYL7D_v1mPL-|c+#Qm&={;BIlU1U%n`PcMEwq=*?%5>ue>uC9C)vnA>PnD)DNi>?r^QdKZy<|PJkXQ9K z<;R#nIe4`#NwCtOZYkO9p{6JmBb4|_a$;E~K)3eHqzb&Gyd$gl2zUufU$LyxJZaRO z9WM(ns*a?hcE^^5S^_8|5FHO|z4@ Ozo=gcrK0~-kNgkJ{i>z_ literal 0 HcmV?d00001 diff --git a/model_module/__pycache__/operation.cpython-38.pyc b/model_module/__pycache__/operation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d296d180a194d814ee5b6d5878241021984636bb GIT binary patch literal 327 zcmY*T!A`?43{6sjx@o&`;}d$pKQN(5Fma-j*eax6tf;Y|qHRjj9@v%r417|rd;&L4 zIO`B#={@_!j{UOvd@4Bh?@pg{{s`i4L@bsJf5vnpj4*P~cvZPDY9o@U;P-zaSuGiU z#zZ2QM&!y!o>jwS-ICK?X`3{Oh8VHYH6adD#M+v^3iS-}x!2X$c@yr0*Bt08gGXBv zH05qb6TC torch.tensor: + ''' + This gets a tensor with all the states and the list of actions per state and calculates the transitions for them. + e.g. states[0] is a state and actions[0] is a list of all the actions the transitions will be calculates + ''' + assert states.shape[0] == len(actions), "There needs to be a list of actions per encoded state" + assert len(states.shape) >= 2, "enconded_states needs to have a batch dimension" + input_vector = [] + for state_idx in range(len(actions)): + for action in actions[state_idx]: + action_one_hot = torch.zeros(self.action_space_size).float() + action_one_hot[action] = 1 + x = torch.cat((states[state_idx],action_one_hot)) + input_vector.append(x.unsqueeze_(0)) + input_vector = torch.cat(input_vector) + return input_vector + + + ''' instance specific ''' + def get_optimizers(self)->list: + if self.optimizer is None: + self.optimizer = torch.optim.Adam(self.parameters(),weight_decay=1e-04) + return [self.optimizer] + + def get_schedulers(self)->list: + #self.schedule = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.9) + self.schedulers = [] + return self.schedulers + + diff --git a/model_module/operation.py b/model_module/operation.py new file mode 100644 index 0000000..3cccae6 --- /dev/null +++ b/model_module/operation.py @@ -0,0 +1,2 @@ +class Operation: + pass \ No newline at end of file diff --git a/model_module/query_operations/__pycache__/mask_op.cpython-38.pyc b/model_module/query_operations/__pycache__/mask_op.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..390ad0617db145743a4f312b3c55ef3fc7758f08 GIT binary patch literal 653 zcmZ8e!D`$v5S480#tjK6?u?b9$(cK{i%`|UV%L~rw z95qo9oiZJ3amTa(ej69DjutGP{KsyV^02k=rH6T+S5z2VjL9*&`-A~l!8I!)^cckN ziRSlQM>@U}#Z(Id4Ab+{E$2OAP90&oD#C|vMrgP`J4JW2vuF5<=h(iV8-LdJ4YbgC z&@*d|P2(UOG(Ro12C;;-dtf%f6~nl0p?l14j+39x~{0MhQ1H9 z_`ry74qm8c?|breDo6)ad{VfWs!NApNuOG88Q85 H9iBwb8uz92 literal 0 HcmV?d00001 diff --git a/model_module/query_operations/__pycache__/next_state_op.cpython-38.pyc b/model_module/query_operations/__pycache__/next_state_op.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b19efe881815ea3cffcf0f7839610231d6f61deb GIT binary patch literal 747 zcmZ8e!EV$r5Vd0`P21fA;?y5N$|b&VL{$;2Dg{G2GR;(MZEVBc98E4Cq7@(2GU8TI@_V;z2%8&^4aUAsp2>!GN__1qXH zXL#%jbU;d~NEs`m`3ER?tfH z+ZM#vEtpa20ZbJnvZ6<3a-xFOd>H*wjeXTMwVa+FBD{^z5xP5>sgOT;Rb+=gl@k90 zDgBXrK*M+;u}{@L+_3knD<3;;>KmUp=&d#;H21v4^_c~R6XBUV4z-|QP-T)br97G{ zxR}I^L(^aHo1D&>=KnLzn;vzVQP20R_v^#@v{>`(5pOm_*ntUy`1I9Y$J4M}8%A2t H1x?vM`#Zq! literal 0 HcmV?d00001 diff --git a/model_module/query_operations/__pycache__/representation_op.cpython-38.pyc b/model_module/query_operations/__pycache__/representation_op.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93c5a30c8c295f94b8d3e8d2223e9cd9755e32f2 GIT binary patch literal 711 zcmZuv&2H2%5Vqs&b}743B~Z`3AdyS_0)Q&PszQjY)E^;ou|ho1&?HXB4$|(GJ_av< zc!55Pubg;=o*0L%2rEXKk!S3g&);WXEEfBW?Dhv1hn%sWAsJgt$uYToN^6 z&GLv`K>T^4`8Td3EpFm`s^c)$ADlvmV4-!}k%Tcta6PtVL?iAk%q2SPgxm%Wj!{3) z==h_|xYtcrLqk(QFA)tU!p9cskA80c-q(6@1?%q&A5i~YD=$q0vaL27sO2lLSI%_u zyed8!`O-99?*P?VSwUNZJf|zlZ23;Ty~vlRE7=$gwW8S9Ag_DCwK5&xz`;uVhZF6j z-P%9isSaIxfzH5_Z=UQ*k^ZkEOZNOsskUf9DZj6_m!S1rsq4O|w>#p^^1YupgXN{4 zI%usyd`M^m`B4bJ1%tXl)Ir%9kE5ZI#QF0v(R7z+TD2t8hNxF6fstlBGqz{uy}4K{4hYK8FH^msg!~S{R2T*)sP-iqNhDRIr6r}o zk5tsgWjy+^VkJ8zlF8(j$OPi=ahb??L-O_cN6o{NpE9K5bwqq}uq)Zuz6 zF$O27_Bk3LC6%O%P-9?U5=de849EBp$w^EbNfw+@Wh)D3hs@N=-hi`6Ib@uFOvt$C{@D1mmX zfPdYB8HDb^RDl}H&4>sy%YN_+IdoOq)M9)VTH6Q>p}VJhT91m=&Uv-RQ62azRVom5$VJ{4!JfjN5VaHUZm6~bi8B(%<`N&dAXEuf6+)yE+Ygb86~zM$>%_tKBJE!3&)^Rr z{*kYoxN_t|XVO-RI?{|i%Mt0KHd};4#m)QQawBcMKN92nXH(wA!_G2a#7wEa?m(${pzfyPeyI?vHg_lKyVb`y3)VGPaUP)e);aq4 zk?kZrq8iCal8h2&TqI*97uAa?v*IDMqU}+a3Dx43svR70agXn@C-X>bM?+)w@>Z}{ L^nv5#CC`!@8T+oE literal 0 HcmV?d00001 diff --git a/model_module/query_operations/input_operations/__pycache__/observation_op.cpython-38.pyc b/model_module/query_operations/input_operations/__pycache__/observation_op.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..807a1d06aa18c46259ba0d25f658170d5585d961 GIT binary patch literal 666 zcmZuv!EV$r5ValeYFl>20mP{%PR(Wg0|-@uRmFi_iMB$@#R~Dz1`<1ry+~bgqo2VK zz@P97zH;IRdf>#^Z6QdEG~>)<=H<+f{_NXghFB}e4`kb+@`cdSaM z=y)$G!Pzl0g^^dxND!ZksxbV7m6M;;&DJH4dz%5U4z06IDvhn6$2W8vc4m$N5{QKYrb=ubFPCaw| zunmCvq-nu6piaSqTCZN~muJ=L)kgKsK&L5gJ5V3CfS20cso_;?X?k23v7qvDxQq`G z1|(1$^5QyUOw)FGZN7W_k1WgowXe%*oM~-ltN0G~RgmXID&YbK?js$%_dw&Ys*z^fcXFzPA~l Fi{JE`u<`%^ literal 0 HcmV?d00001 diff --git a/model_module/query_operations/input_operations/__pycache__/state_action_op.cpython-38.pyc b/model_module/query_operations/input_operations/__pycache__/state_action_op.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..552d0f5df65154cfea7bc9fcad7503229fccb3ad GIT binary patch literal 897 zcmY*XUyIW~5Z|Pmv{zdY1VMaRd@>LcJr8^lIngSDwqSb*DJ6vLc+Fjsjk_C<7JTx) zg?<732!52j`s7#e$=S4(IIy!bvpc`}do~^q3DU10GWlxZZImuiR8H}z6J(rlD<_e; z?c6Ht$hM5!%X)Kq8Bi$cEkp1qa+a82+2hZQi@4QJ+%;cbf5BCmNom25we)cH=5zP)$E+ zZ7MaG2QuYgnV2cYFp4IKVr_r2#TuBXbyaJo)}qciTPL5;ROTC|uLLWS3n3ZL3MdtZ zQUQC3dRAhzG6jFgkd6|z3yyXJbd~8fq9qlu`^SC|1RhID&i0A1!H?cU(RE2HvT|?WHcvm*sgKE=^`KC;YER9?5K4zqBWQ1<-&btE* z{`)ue2F)Oj(G%jh8OHk*P#(pIl;rytM+=#*8&~77gp>xkp*e1!9HH1Ypu=t$n^^|a zsy2;*-RhnzZ=wltHdR)hw>Z)5Zt=mtvZYzG)~kbTIxP)+GkCJ=)>E^6%8X-RjjWOV E7m>8`1^@s6 literal 0 HcmV?d00001 diff --git a/model_module/query_operations/input_operations/__pycache__/state_op.cpython-38.pyc b/model_module/query_operations/input_operations/__pycache__/state_op.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e14ff6890c6420e316884f952ad3afac4af34067 GIT binary patch literal 631 zcmZ8e!A{&T5ValeHXuMqm8u^40nKIo0|+6}72?3I#DXg2Vny))fh10CFVgN6KBGU- zUilkeIdSO+^wgQKpr|9w*fXAao;P#3ySvR${`_{u6Q0lH3`L@IgzCSdG0gCq<#dXg zHzF6DoiLLa`N)g}@sQ++;R}|HUND=@G9I_)9k?P?wk;_%lmwk4RR0AHu$&u~3rKFI z;Xk<%CV7y}GdtsJ2?Kp;q6(-v!&Li>kP`>ex2jGJw;IvVyh* zbpamKZ2DcFUFOsCxoWI|TBBXopnfdDt+aik?^V@e^)S#A!N8Up68nTA~wb4rx98YEr%<{sFy_pZ)*< literal 0 HcmV?d00001 diff --git a/model_module/query_operations/input_operations/observation_op.py b/model_module/query_operations/input_operations/observation_op.py new file mode 100644 index 0000000..191709d --- /dev/null +++ b/model_module/query_operations/input_operations/observation_op.py @@ -0,0 +1,7 @@ +import torch +from model_module.operation import Operation + +class ObservationOp(Operation): + + def representation_query(self,observations,*args) -> torch.tensor: + return NotImplementedError \ No newline at end of file diff --git a/model_module/query_operations/input_operations/state_action_op.py b/model_module/query_operations/input_operations/state_action_op.py new file mode 100644 index 0000000..281aec2 --- /dev/null +++ b/model_module/query_operations/input_operations/state_action_op.py @@ -0,0 +1,11 @@ +import torch +from typing import List +from model_module.operation import Operation + +class StateActionOp(Operation): + + def dynamic_query(self,states:torch.tensor,actions:List[list],*keys): + """ The output should have only two major dimensions. For instance, + a states input with shape (2,...) and actions [[0,1,2,3,4],[1,2]], should return an output + with shape (7,...). """ + return NotImplementedError diff --git a/model_module/query_operations/input_operations/state_op.py b/model_module/query_operations/input_operations/state_op.py new file mode 100644 index 0000000..9973140 --- /dev/null +++ b/model_module/query_operations/input_operations/state_op.py @@ -0,0 +1,7 @@ +import torch +from model_module.operation import Operation + +class StateOp(Operation): + + def prediction_query(self,states:torch.tensor,*keys): + return NotImplementedError diff --git a/model_module/query_operations/mask_op.py b/model_module/query_operations/mask_op.py new file mode 100644 index 0000000..9e493d4 --- /dev/null +++ b/model_module/query_operations/mask_op.py @@ -0,0 +1,10 @@ +from model_module.query_operations.input_operations.state_op import StateOp +import torch + +class MaskOp(StateOp): + KEY = "MaskOp" + """ Use an activation function with this - Important to keep this + consistent so that loss functions now how to update the neural networks """ + + def prediction_query(self,states:torch.tensor,*keys): + return NotImplementedError diff --git a/model_module/query_operations/next_state_op.py b/model_module/query_operations/next_state_op.py new file mode 100644 index 0000000..e3e1164 --- /dev/null +++ b/model_module/query_operations/next_state_op.py @@ -0,0 +1,11 @@ +from typing import List +from model_module.query_operations.input_operations.state_action_op import StateActionOp +import torch + + +class NextStateOp(StateActionOp): + KEY = "StateActionOp" + + def dynamic_query(self,states:torch.tensor,actions:List[list],*keys): + return NotImplementedError + diff --git a/model_module/query_operations/representation_op.py b/model_module/query_operations/representation_op.py new file mode 100644 index 0000000..a3c1a17 --- /dev/null +++ b/model_module/query_operations/representation_op.py @@ -0,0 +1,8 @@ +from model_module.query_operations.input_operations.observation_op import ObservationOp +import torch + +class RepresentationOp(ObservationOp): + KEY = "RepresentationOp" + + def representation_query(self, observations, *key) -> torch.tensor: + return NotImplementedError diff --git a/model_module/query_operations/reward_op.py b/model_module/query_operations/reward_op.py new file mode 100644 index 0000000..cbe698f --- /dev/null +++ b/model_module/query_operations/reward_op.py @@ -0,0 +1,12 @@ +from typing import List +from model_module.query_operations.input_operations.state_action_op import StateActionOp +import torch + + + +class RewardOp(StateActionOp): + KEY = "RewardOp" + + def dynamic_query(self,states:torch.tensor,actions:List[list],*keys): + return NotImplementedError + diff --git a/model_module/query_operations/state_value_op.py b/model_module/query_operations/state_value_op.py new file mode 100644 index 0000000..adc1e62 --- /dev/null +++ b/model_module/query_operations/state_value_op.py @@ -0,0 +1,8 @@ +from model_module.query_operations.input_operations.state_op import StateOp +import torch + +class StateValueOp(StateOp): + KEY = "StateValueOp" + + def prediction_query(self,states:torch.tensor,*keys): + return NotImplementedError diff --git a/node_module/__pycache__/best_first_node.cpython-38.pyc b/node_module/__pycache__/best_first_node.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00807a62450992c2c5ae7dd29bf211c9bbad77a5 GIT binary patch literal 1100 zcmZ`%OK%e~5VpOKY(fGNFLC0&m+TEL6+)EK3j~R@Lh{85Yp30^FYFCcqud&)KY?pG zhM(jsr~U;_%xuzhld6{NneongzHjC`?)5qZ*43Yn<6o4JzZf)^3&C61W)BWYBuxpY zwPl=vWvP?7+@<)${6qC zpZLU_oB_!t9WJSEL9Fpj-Q65L@NDNqRcopk8*Cb(vlrV1ZSFF-hhz2%ntH#l2R0Gb@;2ZRYJm<%BV#V?5)K427A3O2L)M|{Rd1c^9%q0 literal 0 HcmV?d00001 diff --git a/node_module/__pycache__/mcts_node.cpython-38.pyc b/node_module/__pycache__/mcts_node.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e61a336e42f679fb515d55f264479b6f43e0c385 GIT binary patch literal 1064 zcmZuw&5qMB5FY2JO-j4F>!T)TL8we#){knr*?IdLdEMt_QN|hn8fTTPPx!pbc`Sq!jb+1XN6T!6 ztIv6o#aeT|fnU?`Q<2MXo-UU%3tvh7xhxjpG>xx|a8%@rs+5XPl2pzU8BV2^;rQYM zzr2br-oFdySSg+tqRM2LC#B|O7%sjk7n7tu=ne{CJ=S5%hr8DgcR44AWgvaDyS9QG zeDa2k1Ewg9@<{aeVt~8)_*$0jhx=&MR83&ZM|;2nr+;wig2AQego~4%A-Xz3Ltr)d zUEKdS37_Jdy3_p~{OMot<~T8#p%0B+&kgW78dg}c8|?X0JmQVUR*i;iGu^*8-M5rU zg^Gf?(UYSGLS|*$Bym_HA8F8>9HXnIz5qL7%RdS39RiOSHIH+NjOy`iZwN<`czml|(Tia=8*qLHro+}a-2&fle zY-!k;Oyxyqy6K|JE~KUYPgmG>mA}wczjH4Lf)BZ4aCn~w_niBk$2}K6EH2gzJU@Nc z@*W>EjQ?<__>|DOjep!kK?a$D(K1zSwJg+DQ0mz&+cX{;R3iJ8LH3^2DwFxdXq11( zJH|%COznqWPqevLeDbyHreO%RF4i>3dS$dqdbO5KWvZZGra7vjuFyQyP|wi<)lpYz zk(N-;(=x4~uF=a=kx{Y^Ylx)f%*ddif*F5NMF(|)DC?`w^3iBJ3#JUYHvg& z+Q21$xNYHo7ytM-D99Wbt}ad026h-JsnS*@^nj6yD(6&LRpq>IV7#Wv1*I#?`!^Uk z_oDD>;`C$TBwgWnNh1Ac{Y1nL`BJo#V8^*8Ube+{;>04@Z2IA5)L3(3KWqyp?mue> zUL1qE=XbjKGQM=87op>Yghl<>@xw&OO%HFN16JgSABEg`(G^nEoWhfIWz_F7k=1rdJ(!YL~bX}qes7Wwx!oj{I-ajW6M5>q$8T0 z<{BhX3-=;O^nTluLF6P-U=87cm*WMUNcu^)x26SGyJMlHAb-;jAvvT-x?bXRMB@JF z1$}X?fmuix#$!A^>>r!-I>J$83XL$zVAQWRov~NOI(LFNT5~AZ&NDB8faLc?82eE; z+)09NBx7eM>O0$jjUnA=)YG!tfdi&>*Qb}RmkfIhPZYX6FWyS$-0i^I5t1986k(Ds zWer!8jrvl&@Fb-bmtY88)M`tZ7hxM#OKzNaiAd*N=qn(smD+45hlm4gHHU^g4)gu* zqxEjo6YHVByDNhAdm`RSqV4skzSoJ???yeyFQogV?PFE3{#3+bedGQg-0y#A-GBUO zJp?1S7g0aJ3TQXCchais`k|k=ZiQXs3<_heV$S;QF0y+yhxBdstKq|Kj!eqW9-xsJ z12Zw7TYHu~mz4JHfq7u=mk*35dZzTef|=$Z8yg@xZdK)t0U}@mpW8G4iV^sa42QbDbY6FgFLlgQQYpF=-hLSdM7a|!<{ezvu{tg*@I zj*dbrcvr;`oD7I3YM+PZ2DtOGvdRV1ObW3=BJtW?cT+|^mp$$Zh+-Rs0d?)3QUXlU zOo_n(Q7MVlJusdFWO)J2hADZenVhTasO$?3P{c~@W@Ze-nK$86K8nvWuU2rjwF0}7 zYI_p|Z8o{#!|goOGk^mX>IW8sJjnNH>3y;gjgnTlt^B+M&Xbr|<80^aAdeX`#u{jA zR~XJt?Ve>(DA&;#82gMg93aE`(30mdE_uC~fGZ#hapqGH4sjm=Q6MYNKnaQ^n3TsVbu*|zG%zwecXtS%EyIPzHlAqbirexxG6ey?BNQ`~Rvu_FB(V?{=`Q5{k#Qm|`Y8SW1X`y9bOM)D?{IoJNbob~bg?dv#ek}1x4%;3<@ zzIuSbjKgS!Q#|(DC+Nd-S>N!N4_KN7E@pu-&)}*Po|DZiV$}J7SjrWyrrXMcK*JBW zb1~T1$3j;Lev2UX+}JNcG;GNV4013R3Z;+JE~aw9W=Quju|I4uPtJeslc6PJAF@j$gzbJNfNSAO6|Z$qRua;6_q{4Gi~sA5h<57 z9(%dG*6!*QUYZ8_s^$nPZ1EC5+#W+Pv(Dj>DnA23c9O_jXiu6@osWm%(gFmdtVJuY zI6oDO$63P-2I{C#S&L#lJQ*tbkecMKTvBoOnYG%PS#U=+HI+$>>56j*=6h2|0?wE8 zrgZ`#?_eOc8R;qJ9+BR9FvUEdI(o~jYt8yB&P&dob+VS_iXX8cx4D=&3RiLyaJ5%Q zAR2C6XNr>)Q`VZ}KC;$5oPAqX77HaO0y7P8&W~{pdz-Srw63f!n;6?LBb}qo=A-f@ zoU?xc>b&WJ#pPB>UPp@)28#pX8GupjucFTpyJ5kT^mQ@X#Pzz3t1PZ}xbaFyN2d)0 z#jn99)u*Joz)6NORi>ie%~C_H+4+2xNt~DS8GW65iB)1gl?|(Au9z#jqXS0mU<_l{Wv)*Sf|D-%Ym8+soR-!5lUtojA8DMvWqd2slDey@X0c?ukmzv z12aPe7ROebMB*54c8uchcz2W%YoVryri6N(LU0n!6$Gm2&kU22)R>8$`o~BFC`BiF zquD8~9A$9wG1i%-2>B2b|70c^^o3kfc@Vx>>SV-JDRtvAe_04|X>?uOpR=zHXxCl$ z&wVe@Gx8*}Q$WlHYUjh>PEG@stbRYys@!+~qrlWJ4+IMylvDb-8M*l^T__jHF?ENAi_OU3uhXZZXD+VwMnT T^``!YbBnd8a!mdRQMdjFje*{f literal 0 HcmV?d00001 diff --git a/node_module/best_first_node.py b/node_module/best_first_node.py new file mode 100644 index 0000000..aa48f1d --- /dev/null +++ b/node_module/best_first_node.py @@ -0,0 +1,21 @@ +import numpy as np +import torch +from game import Game + +from node_module.node import Node + +class BestFirstNode(Node): + def __init__(self): + super().__init__() + self._visits = 0 + + ''' visits ''' + def get_visits(self): + return self._visits + + def set_visits(self,visits): + self._visits = visits + + def increment_visits(self,n=1): + self._visits += n + \ No newline at end of file diff --git a/node_module/mcts_node.py b/node_module/mcts_node.py new file mode 100644 index 0000000..2130294 --- /dev/null +++ b/node_module/mcts_node.py @@ -0,0 +1,17 @@ +from node_module.best_first_node import BestFirstNode + +class MCTSNode(BestFirstNode): + def __init__(self): + super().__init__() + self._total_value = 0 + self._num_added_value = 0 #this is more robust than using the number of visits + + def get_total_value(self): + return self._total_value + + def get_value(self): + return (self._total_value)/(self._num_added_value) + + def add_value(self,delta_value): + self._num_added_value += 1 + self._total_value += delta_value \ No newline at end of file diff --git a/node_module/node.py b/node_module/node.py new file mode 100644 index 0000000..e017678 --- /dev/null +++ b/node_module/node.py @@ -0,0 +1,156 @@ +import numpy as np +import torch +from game import Game + +class Node: + def __init__(self): + """ Don't use the attributes directly (except self.info), since subclasses might use their own and this interface is used everywhere + throughout the architecture, so we're following more strict OOP practices. + For e.g., in the monte-carlo tree search algorithm, this class is redifined so that get_value() returns + an average value and not self._value. + + Also, don't use batch dimension in the tensors you put here""" + self._game:Game = None + self._idx_at_game:int = None + self._action_mask:torch.Tensor = None + self._player:int = None + self._parent:Node = None #! make sure this triple is inserted + self._parent_action:int = None + self._parent_reward:float = None + self._depth = None + self._value:float = None #updatable value + self._encoded_state:torch.Tensor = None + self._children:dict = {} #key is action:int, value is node + self.info:dict = {} + + def _to_tensor(self,x): + if isinstance(x,np.ndarray): + return torch.tensor(x).float() + elif isinstance(x,torch.Tensor): + return x.float() + elif x is None: + return x + else: + raise ValueError("Can't convert input to torch Tensor") + + def detach_from_tree(self): + self._children = {} + self._parent = None + + ''' Game ''' + def get_game(self): + return self._game + + def get_idx_at_game(self): + return self._idx_at_game + + def set_game(self,game,idx): + self._game = game + self._idx_at_game = idx + return self + + ''' Actions ''' + def get_action_mask(self): + return self._to_tensor(self._action_mask) + + def set_action_mask(self,action_mask): + self._action_mask = self._to_tensor(action_mask) + return self + + def get_legal_actions(self,threshold=1): + return torch.where(self.get_action_mask() >= threshold)[0] + + def get_illegal_actions(self,threshold=0): + return torch.where(self.get_action_mask() <= threshold)[0] + + ''' Parent ''' + def get_parent(self): + return self._parent + + def get_parent_action(self): + return self._parent_action + + def get_parent_reward(self): + return self._parent_reward + + def get_depth(self): + if self.get_parent() is None: + return 0 + else: + return self._depth + + def set_parent_info(self,parent_node,parent_action,parent_reward): + self._parent = parent_node + self._parent_action = parent_action + self._parent_reward = parent_reward + self._depth = self._parent.get_depth() + 1 + return self + + + ''' Successors ''' + def get_children(self) -> dict: + return self._children + + def get_num_of_children(self): + return len(self._children) + + def get_children_nodes(self): + ''' this does not guarantee order''' + return list(self._children.values()) + + def set_children(self,children:dict): + self._children = children + return self + + def get_child(self,action): + return self._children[action] + + def add_child(self,action,node): + self._children[action] = node + + def get_child_reward(self,action): + return self._children[action].get_parent_reward() + + def is_leaf(self): + return len(self._children) == 0 + + + ''' Node Properties ''' + #evaluation + def successor_value(self,action): + ''' value of successor for parent''' + successor = self.get_child(action) + if self.get_player() != successor.get_player(): + return (successor.get_parent_reward() - successor.get_value()) + else: + return (successor.get_parent_reward() + successor.get_value()) + + def get_value(self): + return self._value + + def set_value(self,value): + self._value = value + return self + + #encoded state + def get_encoded_state(self): + return self._to_tensor(self._encoded_state) + + def set_encoded_state(self,encoded_state): + self._encoded_state = self._to_tensor(encoded_state) + return self + + #player + def set_player(self,player): + self._player = player + return self + + def get_player(self): + return self._player + + + + + + + diff --git a/planning_module/__pycache__/abstract_best_first_search.cpython-38.pyc b/planning_module/__pycache__/abstract_best_first_search.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f67cc21ea5b980617a57a1c7a8ebe89769486dca GIT binary patch literal 1887 zcmZ`)-EP}96eg)3D~{77U9hy-FbEh1)CCsrr_hIPIU=l>of=p;aDXv+;CSKxA{KTi^dm=pHUl8F-dhR8G2#!cUd<9aZPi^#b zk(Om%p7fdZ^_eo(E2@(WPuCx>-ROI;kR+jkBP=yfi(#j4cR}jaK>rYWM*r1DQRC7&Ef)7Jf~abVZhIMIMsh z?mlIUU=h}2>D6J)Rt)x?MO1fc@6=ZZ@D3ZPvxC})z2DynfGiDF zfHK~a)&7buf|?0>>Z$w2Gx#-Fk#ndkd<(WuK7idKEPVlU4R`BskL;5LkZBpNNbTPy z%MR#`fNn^!XvONzh#{?piURH{oS@3PXzw(IR2aMadd=H7ZtGc=Nv$i@#7rUk6?%|3 z$!EC+I@xWMY(zMQ-1?YSyW0rybc6+q2W@86K4`Esvta|B^jyAL70HfEEFS@CbFSXQ z`!_Br*{Dr6;aV?$C7{E_d(4 zqNX+5gMzQwEmD);31$Ey%!cX!PN7?G=VfDJTataR;dRgu7hkz}i-Rf#*^gmdcBKBb zmO3r#!HuxPmgz)tZo9l~fUxdyel|;s=ESu;bsG>|kQ^S=U(L;9$xMhHdgCTQZs%{9w)ZuxJ#a^!04-{ZGVmY)QAV^ GV*djc+Vd9x literal 0 HcmV?d00001 diff --git a/planning_module/__pycache__/abstract_depth_first_search.cpython-38.pyc b/planning_module/__pycache__/abstract_depth_first_search.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d7765d1d45a4c1bdb24badf8030a21f3480d300 GIT binary patch literal 619 zcmZ{hy-ve05Xa9+BZZGjj0_AgcBx+gArM~^<)f{TvYbqOr2*%II31|iK;jLUfsNPT zF}yPI3QX8(Kot@v`Eq~f^!NSjVY|HuHeTOP#gqMhI58hBE4J;@>o!S{OhL*b#^4a7 zC;bPIKCyd0@?@(Io#4}UL5G>}M2Rd*vdh?Q%!l6yy=Du7h)IaNM;I~5r1ua;{v)&? za%R3+-MP?JDdK96ifVk2l)CDZDC2SL%~`tIxw@rSY^vZMp4lCHv13ZGPOR7QOrR$P zl_o??vZO+^TP;m$WP&u8q~4fWK#Y~p8u1OjuJ|}liD$`lN{a81zOM3u4-#>i^W8ix zCKZ*~j}yvb;sesed;8~jG>rDoj(O40qqUwW;-cX|=|FJgNKn@(-Quh1rW9wzan~6}fKee%GH&`(#g;=`xo~6)LOD_FhpBo@#weEHN fX!Q?B_eYnmv2L65tDT*n2{&ipTgG8%HVmX-<4PH7$omt07-j0h;> zc`;C(D7^^ch|!#Ot{+Vs)Q~}G|}A=wM=Ok z2U1dc2fsgigDl}*8V!d$_V&3vE3(`>j)H#X?PW<`7F^JyFyd**y<;x9*E{$~4^RAq z5AQv!F-@biPZ38Mb1$g5p-D~rZhk3rXLT)d4IMNZ#F{>ThFkMo?oL63mX4haO{TUz zG`Wqgf)P|?Os=eojF8$f3c*lHdPvAq+UErg2T{xfPwBKGS%x8(G7}UTEV){i7-gni zTdZJK<)qjP<(fUAi>+Npe{0h+XFSoFhMCAY9J&qK42}U#7$QKk z5e1Cl%F$@Xun(VJUb!e3GO{XnWRGD?v~=|BA&i_cY{Cmb>_$-kc5%->->z&%&TPT7 z;C4`Ej2Nl)@i;S3a~;jkL~uzU&(y{oVH%-SokhOZbx1Zc(?A#Jh$Q={&^p*76@_hX_7X1LLQLe^F?y!O8rl7@si?oNP~n^ zs#kzF1F;$|}ZNwT^n~vG*b#1k^ ga20kZS3qbfo2Z$Th5;k2=BxVVye1eXb!^-E3yn2Qa{vGU literal 0 HcmV?d00001 diff --git a/planning_module/__pycache__/best_first_minimax.cpython-38.pyc b/planning_module/__pycache__/best_first_minimax.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..474cbf3498cb285d7732d889540b3be251434119 GIT binary patch literal 4372 zcma)9TaVku6`tWu)OEENd%eCS8~0|V$!^ltXyLYLv|<_^qXE&*WI|)U-FjvZLh7{)p*%k{v8uFQGd!rJ>*BGw<4-{ zS*P&~yI>uz+z;a*Nup#A_2$hm&EAh>n(Yikh>qo3Pm-V?b^YD&uYnZ)2SaqW=bd-6 zAPbAW)emJTqHY!q68~NhkIL@#-Eb(wG)yp(9xcpRS(tGrNDpWoLs{ve%#;}q_IA3o z!Ns3hHqv{jgv{fDdB&Hl#|0P0Q`a+vDXgc=vxF@i+-*@2F7A$4(ExC-A{wHJ9#^!) z67E&e7R$KTzGN2j>VRccwU05l6D3hUc-(asHbHgT2VLnk6d5~WQ=an~44E4_^Nrj@ zY33G6E4NYFf?;L8({YrYj)tLBuJ2>*%=cA2C@@Q-aj06!sP7MY{xA-ngfcDKQ5H%K zUaIP87|8CvA3h$&SjHoBX@oC``KH_Q$Z=_Rj53{X3s{w|76-#-21w ze7tTHhuaUp6u(DQ@%sf%Tf--^iPzHmQ|$^)gFE^^-Z%~F)*OK@xyAffL8Xf8JzVJ+ zg+`QRd}_?tSDQeXJv65l+IG(OfO+OTtjvu2d}4jduCtHWF8hpS?h%{Vxjn67G@J2h zEw=>Upjw^mG;QQ2Su1PedFzPZVt;N;oSaRUa+~@Od$|pJ89S0=2c5QZ3VSF!8vuP( zNe2F&41`?8W3?tvGW>89$|uUn!XzC?Ngt%lKivM4B*Lg#8fK++)iTuqHLNH6lsri- zQP8p=o7;zdWvkUo4w*)mA>LTzHzF4=HQkWNKHcbKX#EFmvYkEJI!Prl7c@M^^) zc^*TGVqFg=Q(s0cB}+0Gu*ScS@+xof1|Oe0V@ev|HE3-4vvDW0YBSzMCFfH%V{#23 zm#&g*0(>LqWKhBXVca(-#>AY!X{J_gXlG(%WU6WRF?K-bIx&AxxC0!Z5~Y3|20h&$ zB*NEUw7drtN8QZV6s0ch4f1=~h^mim%k zW4U!;$oJta9O&WMfsy_`vp=UX&Wx1^ngyZd>v#pIvZcS87^goDTpM+}VG7pHI|#=` z@>Qv$yC20uwcbs^Qj+xBQVyioj*_kn`_Os+QItkmx`?fOi(ciFfZE@Q! zV0;ZYzQM|Q(wGSxT+M}~>A z;gtPUHHe12QKB`zQ++`<`D5%((g(}8QIssx*vel}hh5-&PTvZH_}Rc823S2iO++%~ zFL9*}6a?NH9M$D3e0=o`fM=ZJ1cIbj$DX^32GD}%PRs%|_@@c>frq#g`z&JL#LleD zMi00q_&Pf(c}Mk2XhB?UaMM|1>l6r$oo8BDE2>7DDLzSQuPy2&>EzGwetDe=GOVPK z;^`?R#R|ETzd^s^UC6Ws$7>Xr6?F63nr?m%v@XHZ(4W*i`D?79S)=0?uygr&zRAZ| z&fvFb(GzZ5#}db}>mI(afKHGzwln@Y-(wRaGcsfc53M86a$;sSLH5-bI<8fwE@o6g z&6nsG>JJ-P^I>CVfu;p@TY#7yxbp8jsK^d2?w z2rJ#UOMN;J%w_Cd;N5vutrCaIys7Y?Fp)neN?37}X7VbkMRK%3wTo2fz}*yOYFk8X z&A*cYvN}pqJPgO7B%hF1s33jOq*QHWw)2!%uOUghUY*5b~$^!(y?$vEQpJYJu0BEp$q)4upMDxe1RM^nBnctX@BJu))V7-+aK z)|fEY8Qco7Zw9qlXTri|Qy)}GM?m2m8Tv`3>@|*i+{JM@C^&y|gv_cJpnwbx zsuRbIrcxY@+O;jG*KnVW={} zQ%RqU{BAHxgwlrAi>Z~xxT4b*oCEv}d0_eoXN0h0Eil0tO2w1Y;h~F__i>f6^#bov z9(OjjF0RUp3I)o`0`)veZCx9(olPjN)MaOyC!HE&Bo-n&HH zE%&0F*BkyNO+Up{F; zLyAC!d0IEFzkJ@QQ^)Z;^?~E9m(N{zC!GsutE60XO33zK#9FBZlMJ$8DXj?lnBe{j XZ^23M9SY_t9XuMK)$t!{;M)HMDgci? literal 0 HcmV?d00001 diff --git a/planning_module/__pycache__/expectimax.cpython-38.pyc b/planning_module/__pycache__/expectimax.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b47a8a4f2757b035b3924dd2a1dea1b88fbdb720 GIT binary patch literal 1147 zcmZ8gO-md>5Urk%b$8v37zGmoiPwSP90MVQfJ!cL15pVXkY<=}*NO9G_pIo$o)kRg zKkObQea!A$8t{Kxy*LGJIRj*#XIqdhlfa~b{hwz65@RKgh<>2KVZoP>{ zfM5xvEMN@YgRq2s1HzW<+72w??1Sh2!YJ^VS=x)UI1Mi%x;K|Szv!2EkPt8dfprVV zOfX^Hv;$k%!nuLK5w2+Azk3S~1TDaa*nV97eQ{@@g@mQKi*(l}tF)xzvQYvjah@A^rlqa5LTsK?iaKaU1Y=8u;4`s=t zlQcgziajum#)Td^Z97Tv!Ec|ATTyqI6c=DWz6t#+Vj9cbq zm>9Qyd`OKVHr8#^r-{G!1MrC;P-@Q)t^Dl!XG$+_wBxkm8ZtDTF4a7UIxEcboN=+fL;hc(&R2Pm)_{{hcdPR~6MrS+!XW Qzx}3G@LPsqnqqs_UuS6?TmS$7 literal 0 HcmV?d00001 diff --git a/planning_module/__pycache__/minimax.cpython-38.pyc b/planning_module/__pycache__/minimax.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40b6424b1a127cb252617d5292d5c34102cd3e6f GIT binary patch literal 5119 zcmaJ_TaVku6`tXpy6;-qi{rRkBWa>0-E}W%(yi+{iS0hL*`T(Qw5(f#)^OKaQt5Iv>q9K(=!!qL{{JSY(<-q z(|0{r(Nq&LGrcJE@$xeFBHZiJk&6$pf}Z3LC_?6O!8}czdCc~?(1iZT^>kqf z^AYn5VF??xDIDRVwnRl#QQM*>>Zl#j4qegsl8dHjJuP&W>cQ=8AOs98S}MFL#byvN+aVa&CEn;Wfn?Xu$$}& z-|INJnT&>^%pKpyL880W@ini4!{Cdb^u}S{h(~>Yu{9zaeQTi}9MR;d)(6MtP38Q`4!tG>(iY66xDpsgiMe&5^uJ3oFAW3}x348X* z&fP&j+=+YRaTx7f4U>E6V7POu7aRUB212pcYs)1a z_K!xPe3)Bl7$*ZM>H4|;%WJ=t4pEgPOiKagD@22lkfQJtvXQ(=U8T(XFoV49hYyB9 zEc}v}ek#LINiH{svKOa$6ueINht3to64<1Uvh&s^y+gRYw&SxE-#wa)o4ojkX6tJD_MO3g5+|*Ct6Cf z;(~9mV}6zW{>nr@U=vu$A*@7bn@qrB|Du84s)ZS=keM;nAva?NX>oS+Zl+Dqr|7kX zm1eOrwK7fZy)l7>VTT5GU{kx0?a}y1lErnby{9D0r(ff_PFNKtDXcpUZ;>T}kq4o& zki0%;JTWNAEx=e93vcxR0_xuFMM8$LUycPsn4ZWT87e7=yylOF0-LG{9a{-SzJ_x> ztF+yEcO$Y%VuKG7yR1jGSpisi&SNgk@Bm}pdatw)+YAE0UFzzY~o_f%n^d!E zrrk5-#jHBf&}V0k&NN1>nV3fxf#*i1@9PR9#>V%0@jg)3+>Zu9nm36_i>7P>zK0L< z+M)*KW-kr<3G9B97+o>(^c17`)x&a-Esd+_p6?9e6CDV(6 zLT6tErzE+n+*BP(B0b@4%HXp?t;=)MfjYbT;zQlc*D;rt<*UL{=8bhuBg4FPA`1_23cvvo3N0Q2h@?QrGKI(osN zx6vivqhcA>%?o>@J4!Or0!uQtsYip26WM!SkJvyrpYhM72C?s z+0$BPrFG0D|7K&Xky(dLp>GmHjD7gFiSvAYyo(vF%sgzuE3OFK26v!UbccLqCxk<;Z{ zbRI#R1qc~;PL!{7@JipdFGYiH5G9vE%-5+ac=}!t7utSaZeWvK?*e2DnkBzW2UB2` z>roiXO=^kjVQAbT%soT1uP)XjAUyGKkpb>Zu4s4Fn$s`B-30AKl} zo*7_g9avV&kONF>GnVP1wa$(WS_#Bh&B*9i_Hy=&zcJQ2TVppXRBm@`T;Ix|SID#O z^wQm}rvxM8W>)HZp^v*yNiyTLt*msK<*3rf-SL^NjQs3!r$f7rLL6Xh-3fM|arYVD z%6X+Q_PQjOn(MG~gr|cMLeRqQ3g@g~Y3T~`r)uTif@g^D76N0hqIw?U2YCY0zf z-dICMbmkil>u}^3*E@Ae>;b|d68nl@X2zgvK0rB$iV>9_%kN-672!*Yh?U?~%p$K- zlaNT#;Ydx2guNR^k?$`c76o0!B`HjyL_eV%-?7=VM$@S|O~*lLw=PH;@2q-nQU zv2O?EzdA));_uV{I)0-5*pa00-s|c=hw?MFmDlhG`1?Yon#p!y2Ix_~#(VR{vlVu$ z#FthmpY+y?R}_;AMZyvSPH*SMOHO6ySVe;KUVHH>vR}Fs`8=;ho?-sY%1aR7O3GPs lEgmX=tXR4r3#BzGn|3bYDEu9AnSu}{kQ&tY`Oo2w@qY$0Av^#8 literal 0 HcmV?d00001 diff --git a/planning_module/__pycache__/monte_carlo_tree_search.cpython-38.pyc b/planning_module/__pycache__/monte_carlo_tree_search.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bfd06f193de4f1429b46c0650da942d37ed7e83 GIT binary patch literal 4693 zcma)9TW=f36`t8UFQTZ6CCg4yC2iagNToEbdkKm(fo-HfYsIycCMBG1H=LDBX}P3k zmaa`Kfg&;*pg@p3<)IG+2^0nV7y2Xm1Ln0){tJHUcV?Fsr7A{A%^bNA&Y6!E z7OD)tfBfq=(cubX|Dn$O&q3!s6!m9Rf(f3m9&d9_Z7Z>QcH1^>J8^n$+cj+`@p^vS zH*GfwdX;v?w7sO-tF>#U?I-o#LVLlqgQU@0Y%lhj?WXBhlBM?22TW8&?J*Ozn4j3~ zWl`B>t@=|)!CG7gcakVgyXgViOE({-QLozx@5X`e5DOS?d+>YhB9k(NORGbsszOlEv(-(2e(vCWonYNSu!A^%{ zoB!EmNBt3vn6>cyl6Ao1~(m z-awVH6E@*FpTeTKm9x;wZPa$|pmuT>wJX>zd%|~Go_5u65KA3|VK?n&VW?|Sp-t5t z#k!FWdtrY+93;`BSgNAm&0=ZPs&q{mfx`IVAi-gdij#fq_Mpc^E4Iaq`-c9dU1{hwQd*(lS z`6BIZ%~a7Lpv-;&8WofqDC+O1jA~iNC)SjG^)^(^4(*A9zMJy{C_nQK{i$`2kD+$V z1i6z{$`SWtcAfo}-DRJ!ta`%6UhYk5SjCLVyh%NGCkxYAuaP@~Z=g3>B%@B6xlQY` zCG1!};Wyc*jj^Az$x80g`0-zJ53XbFNDe7mtJ*7!r`@a%-_d^B4-aG{so;Xt1(@hJ3?;Ih+{S{ zoPC>HN|QnY&AW1)si3+k%r z7_FaEwmEqVhbRIh;{kNtwAN5V$FENoFMW6uEF&V6!#agYbOedNeZJiXjM-FLhR zR(EoH-!{M>IX~#8`z=Si`$<2_bba3X+9g;jT_?lu4^u<-t;(DZ-hy=U=Ty;Gm*zCi zqrK9^ZqdDD9tGUPtRD^{xL0;we?P`4zeG`F2F6$Mufz1qe01$XJ^qhZo>8QaGcLf` z7$_o(gKh1Jl{-_8r#)r4v%yx`*v*^_h?;m)sC}KC0FATuRrJ8XF8JJAW$W-4aOYFQ zoxZM;FvVvTa4FhhsfR^1>~vxUsw_Egk=D;KnIyYQiaNznrOWw+y5Fa9-T^}!#p!yX z&+_o5fl#=Q!%M{)%beve`7X}*0Hp|R&}+cg_$D8{dZBLTjedr@NW|yi?E$`?phU_# zvQ#tUpYzPhkVPCiC-B;_mARz8uf8#1&qr|MlggA$ssaHjw+?t#JFI65hxMrgF9xMy zzJb%&Q6RrTd(pHj>X1&xsXl_*+zHtdc2ej>D4i_lZV~=COKDZHH&Fka+v{X2DDX+aHWnXG1=^`9l!;T{7xMJHzyZPQeX1!6V>@-y?#`9ipnAg7tGt*3kpw3kyiAXZGj(xQ-Qk z;!N4q!XplFi1`&6ovWA0D`xIsBDB&K8Y_bCudrC|P<0zs%RN=C+{OSUf?1SEVxRLG z`Vqzwl#;6*=%&sc2Ci;h$R}q6Df5amyh56ww8&`VZ)g{Ld=dN;Fv%Vz7 zYuxN`pDbnGGAtJMacM2?pj$I~E9`(7Lt=XC|J!A*B(#Y-js){pzh3@11)mU2jNKqT%)>Vz`-7dYw!a8N9iX;UbjXKg~U)i3ntjIV+uB Yfr4K^0^FeAMc6M)MW{dhgA;@Q0SmdoY`5JD;|gaiT6R+VxY4HLNYd4ne1Dm(K7B~G zdu9?}X-hc7%k5nsN?Dkqk&<3>WEc;4_vE3QTQp_B!5Jr;t9xrIdBeNj7xq=OBr~dc>e(FuflHt literal 0 HcmV?d00001 diff --git a/planning_module/__pycache__/ucb_best_first_minimax.cpython-38.pyc b/planning_module/__pycache__/ucb_best_first_minimax.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0b1c60c1e1d6180d597003abd710fa895fc1cd5 GIT binary patch literal 3610 zcmaJ@OK%*<5$>MH&g_%Rheb)2EhVtxtmANl1j#`#f>@8_5Qs+fu$_T}!Dzcl4mmr! z>Yi0ZVh9AGAfJNdltWItM;-DP@)vUNYfkzLI_0aLA-z->?P6-Gd-_@R)mQbn+ifxY z{{2st410|IhZ@U24~<{p(SL#vOz@12dBi#CPUei=$hEqgd1F8Ft?p%waS#Po_p@-^ zjG9(&WUX;KYFjs0CDp-j|h51+bcUE32O(Gfn7e@vA4VhX*zEuR+;i=>dY z*|F-ENj9t6{ReWYq?S2Gq?6o^Y1~ir0W`;W`l%|6n@x7NM|6(mpIbNdT_}=8T(HRb zmhEuCh4UheT;U4u1&ci4iw54l2t6V^ z$I0_iu(T`bv@tz;3!-2rY|cx*fP+e>WU*7akZ$QgdZiEPmkr25(+Od^V0JJMlD24yDpnJ1wNL2GXXjt=(h!e43e+G=G+4sfedC zPqN|=t9x2zJH{Ww6`2CSm|7vxC$UOm4dN9y%{U%qiPmxaioKD;y~$V(^YrLQX2ZKu z9~6`6@KKuVPKI|Tu#w-}@@Q|2Vq$;EN$nn&)Z08)MH|CWo`p zllTcx67LX9;&Elit?8laVrx2x?m)1h!yEPg=-LIA-#RtWh+^XOM=ogMxr0Y%5Ct5t zV1K;<2e5r-?xO9Ld>1Y${Qbtl`Gy~RU$LLF-?0bmYgPm&?AR~;d009NK5v$u;8z&5 zR_T%xi#As2obbEs54~eJZr&|@>VNum>BE7}w&K`x&@+B90sKvnPvTvb2(^K3v!Twx zIet2m>d-U_nd^yCbadnX_TKLm;m5SJENVBKRnlQycumBb!o{pmSB;D@BEtA8$#3GkCD3I+`mq&m)IvST_F2}v$udz-a*7gjq%AoW$K|2__Z zv9qawZSB~BYYnD8rY}tkSUdc@H55(!1dq0~%|qUC`Y_4S&5I22KVCU!6(7r-S>+dK z;O?*t5thBbLqI-uVJoo60T~4Se-uyv!=ydzvUPX$=|x zXf_&24dlfjiJ3G_LX=gbewvSz9MkA$sZI-xbqGGz8fr)@HY)5zy;)A^0WR&-x6q=9 z1%HaCwoeZhYLR_zTxg#q`=2un8TTxzBZOE57{L!5-7ol`dEpcw)_w1UpYUU+@aGN0 z{7W120b&`EykPT|K-8DcE-%{qoua$nS$N+#$F4PGZvo>Sgz6<$3Fa%yGm1V&5>Ehh z^VQO?7=g3AkM>$MPyd}xQP1^?^?F8QiCNrYqjcze;EuiH0O=sO!j8k@JuptQY|bxX zwN0$%Y?9+IUjpmk{LT3l1VMc;dehmo)tr`j+V0?4ofR+A* zksmI|T9CBzf%7-6Z`-rlkRzMfzAve789uKqkzF_Igq%9*jP}w@07Ittxdw|+R=TIu zL>Vtjk^a92FI73^3TanKSl;W3!#o+Mqnhw67|)YUbCzqcp*)g`LO^{$f&#*<(EiIj z*GEMl(kdld0#zXmzCnXq-@#lpsKd=a=;UV$nJ-(tKxPJTT^15n|PK zn$b*zi0!>ilNJ|pTI>y4-$jE7s>5j$C|lQPv!0`%~z9IE-JU$#mpLc_!Q*UCn%_5)U~;OhGI(& zyd}RoQ)PJmw$mPA<%f7Qfs(;FJCqi$U&wY#3%97=LySn&d#$)Gm7sTG#7vS^Uu=P- z23Xgo9m}nXP@-s8)2&Okk<79|=@+Ao9k=##gq zwXLMC>z7Nmsbh7JVAw96|7s8i_l{O-~8y^ zqitl_0*Gc_<%g^xu literal 0 HcmV?d00001 diff --git a/planning_module/__pycache__/ucb_monte_carlo_tree_search.cpython-38.pyc b/planning_module/__pycache__/ucb_monte_carlo_tree_search.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75251c42e4973cd162ea6e44c063a91f4f2588de GIT binary patch literal 4391 zcmaJ^TW=f36`tAqk`zVBvL#>Ax=x&gX(|FK(zI?GI8od@6joA8PEsPpX3bg2l$J|+ zX6f3(3Me9@c`BNxpy(6%62K4rEByiU+9&@7qYwSg?9!qX$V<%d%$fVSeCN#Fg@qc! z_rHJ3_|YMni4L`~G6Fi}tVk=VeF)oep<5dvt{P&#oHka}WvZ za>2URw`_+CF03a(*A}*Lp0KVXT;bv06}|}Y?}<><@b8PdXy88(%c3b3pYU!dT4D*D z+PBPM-8wFQPOskG_~`a9%ae^*ro(NSB;}2J{@huJrsGgi>!@;e#3sDpQ)sNP3Km&~ zjoL08)K1}|b_)-+C)gHy$~W7-cGY;4NF78`KkMgFr0b9~#1pChL86=4cn}SDqEQ+@ zPNXW^{XCIIgi6pASc!?S8yF>}(7NqcQf6wtD>} z>m{9SXtc9=>x=02-R`YBpLTH4EbC{x5%wIXNoU-<2NIHudL+syhb*J=)a#?ivWfla zT&jVJ`3>$>zk{o1s9=32sUCUD>tp1pQw!dSOr!=!5Oq69;X#;JeUJ z?(O+g>nnZ;-JmBZoII>J?icJ`_D6PyeaZ6L5j*q>Z&Jr7X7uV!8ihMqn9f?w!V&x; zT9d`XrV)7yW0#Kj$L!C|L%(2?<-#NWgC~Ut=0iA>_OoHMD`O$g zqg$`Y6W$U%7$@?v_VOfCLn%EPV*laRpL7d8CHkewm@(2tpCozcno9D@baWUC5y3ph ziFRN>1HD9@P$5i1M2fstFA}RVzX2RmyE4vp6JsuVsnS~X!u;BST%m31!+R=`4~|Ca*yuGZ$rhR$*}C@7AG}T_Qs$IBG}O;J|fZEKunoTzTT5 z$C_GXNE5%L!MC7LXLB3i=K&$127*Cy#$oB9?Lhl|6{ShMqZ{O%`kGJd%#{ZwRGG!a3w4`9tmi!Tyhx@(v!&m2ph>^ znhwW#+c7GVZ_{YMV(RdCt&>WSwCXQV6u|`J0d&x^TBsW&{q|YX|B-3*2oq-?(Nv5o zJ^)dwLa_Z%-YDq_pAr}KzAe{5!D|a^jX^DqS!)$761XI&Xxq9*0vyGVUS81#B=yoU zKawUoJ++UeAm796Q?@661>UbwlnKwcD$LEoh2J{!!atJq**7M0rdz^%C@z>=QskZ8 z#FE!ALbr~WeArk0T*>#rK7G=6!Mll~$hH|2bMV@kXPmQ*6Rts*e1Vk@0Y-GXDYA`y zOSN+T4bQC{#=qwr@go?2?oK>l==Ws$dp?koPr@mi)C5qpuy%P~-)rOxdyT2{6`Tfm z$|p^%#`XjGJ?KS4hiZ>jHY>E|SICzzlW-Wg45V~REafbtQRUn~{V+Gz$l@93V<#^yCsS9r+bT zUq+OkaY$$3N9iFaxcE`e#d22?o7MxGmMZl zg0$bZ%SiCdgoB^xb7Z$wyr1EMZhokMt&|Yll5!}ullB2LuS0`nIKNEP6{<|sSP+l1 zc+l@v;ljAWNfZvpnL?aM4iZT|DSt{88IxY5{pZkM{)W1)Qbn1LyiOGb7i4BhW+III z8S!pVMSxKTkkTcAMLR?JDs_>0%h#yi1R-;`cJ0_DwO{V4fKbL?l|Gjtl{^`plx))y zo{1t#+kNp!d*wu3Gi#RUcnZK)O zV$DmOY7WDoNjU(Aue>GmoI9gpDFgl&!$&CRS_v$4ZnPTE3nS^EjP;uEjK=; zpx^YA%;^Eq!H_*aeT-*`r8&VhYZ-awV#XovE*U}$$F`-p@g z2}LBM{Jyq<6!MPglf!W)%09D*{EXOjWaDwBkb0=(M5bVkq~_V_*z+~!q$rA#O9qG8 zp`_r|Gu1PL+{H;^(mG!uK>X!M| z4n20O>_(E@^&93bRULobWZcmZFSznSwq61lI0jwy&tE>;pq$CQX$D!Re*5J!%E3m5 zG{wMi=jBsQJtQ#8ydJK7JCoK?~`Enoyj62~df)*E+yiA%U(Wev5Km!h;Y zlxqtVP-Kk+C{U!&d6FLUH~KT?HBT=33;dFPXGpEJ+l^5Q9L|L^9DV0I-#K%)Uav4* zrGNjS^Pg45exSzk<)HB~ZuK`5f(ah80S`E*x)oakJFrdNj-7!UxTfyJ-k=neOx=zB zK{+U!x))al)u3wXrMNbz2X#~TK-Z)3%Fo`^$qKB!G-nQ4{Tuz=Q#@;;R+8=SCoX0rzh4706Zv(ny9147Y(t3 zXIV7GDxQ^Zn8Si9U|HAecdy?$g5hRA=?^+jyWY}1s7_m}E zmDwom%t7g7E=pH0EY5dYo_5uE6iMxeA=XaAP**!SY^pzrbR!uL!eK8Q#hqu7RC&9f zM$+J?bX7$i*?km7Pe(Bpb5s=fv^#*cV%fl6 z4x)C_pG=~-{aK{;)8VLnr{CEdwyzHdqcOe|-tP9Jq#L#Gz;NxI8()Su?*=z^zi4Am zhA1IEH;$wBxSNI#L71>dj0p!hR@4cRYIkV z`#NqFqeubqoPE0mh}mOnW~1$7d=H?f?qhFmJ>XO4OZFlABim*7Sz0<|Q#W&GerC=2 zteiQ5pJS+1GMmhlR)QVPjXz0Zh_0w8EzGG`sSa#mE)gp79=$?c1p@|f z1MSLAvKJW>>Skf>BLcKaa*YpKZPrg>kU!MwcKq!k>p z=0(&L(T4E~zYH6l=M7%tle4GUi6(R{dbzk%lT)&6%6Kj&Tud2b8T6yN5aDLGn^xg>lsBnf^{9Li3ka_W;}x z1*yw0Uj7Ul(be%tkTy|w%QcwEw`rhk2x}IonraTcu9(IXQmzw7nFYzsb24+E*Q_&u zYjXKCkWL!xrvW1o11tcI9~6u7$RASjnPSf-Bdqkb%O;}Zp@ zk$B&bawvmlKk3S707Va<^i@Aq%kauJeab5^bx(zJ8?StTTlpv$-vE?X_~g=QU@jTa z_+ytGYH^V|5CaYS?xT_-q#hzB2q?3~h-FX!fH`L~YtG>5R~}eXZ(0K5?5qU#SyMmr zGkf2WTX=d|>D9c0Eyy+i?%75SCeC~Pq}Ot^+lz;tRM(F1kx`;uxaUCW8j-O#PK=tj z%CAT#-^K1E{b%_;ih^qfUilm9aC4~NFpu0OVLA*)9ju<7q$64LDsELnLHKRJV|~8H zCqF+8!~gM;V<3`N9l7u>8UPF~JhgM+;HWk@2sYwL9CV0-Q#W-|7d>E{;G68QU?0^l zpans7!BB65Z4w);ofk$-OS(dv$^Vr$?wi*O>d9Z?`|?**kZmQ!j=)SQh?Z-gyoP?w zyO2GDPf*KY%SjhB4e7!j$X(!P&_Awz@*}K4igo1EK0m{^_+;xe2$$`8%%7`R;V4Mm z!=KF26Xcc6lz+wd*wjj`6xqOI=Mdza+Nn#ZefzZuV#P$?PJ_HET!`F?orKPpF{W5ZNpCLsBK*q-M@~lh2v#ED!kw^&u-m?T~=T zU(k%VXhyjsL?ORMiA6L@AGNBl6h$|VZCmM5&VE#+Nz|2IPok$Mat$Rf)9Rk_5T)(D zc$(`^CHEi`Lr*x$70kGaTagzr-sFzefMUN5&AG^%R+C?Jnou8)zsoB>I^6q_#?SwR zo+GR;!uK&hBB%*DBjw+r0(L-!4C|YMC&X@2B`YPRftU+xg9&?+!Ltwt=g^!@CLG)@ z^+B6VG!)*UWkyOxuXWVpFTa<8g^Q6xN=f*vWCBJ5I>_LoCS^SM+HkR)$JUp;{F1-( zlD8(MEraLP7nB-J>@9f}MS8VWJ|=OwS+qNneMtH(ysB+bRgwrHk zz?dussehOvHB?V1unS54p84VpKqH|9`VgLQj)OF!TqnpR}YTdBnmzqaGi=e%^_r$I&;< zspI|E&t3W^9TjM+q+fJ!NFTily3)ud8Dz;)MiCEuE=i5r+`A~k7^6h@o6wcR{4R$NPNQo_M@!&ym`mP=}8 z>Dt5+C?cbI3X-Rw=o9%8z;FE}{Q>jZC;tUM^*ggmi&872OU&%t_wbzaedo->xw$IC z^UoI_NB5T*`!{uF9|xUVDC)1M1QR@AUEboH+E!w9?UrrYcH(s1mTTHh;&uI&Z`y7W zbStfjX?sbvTWi%!+fVAK*gNfuj+A2CrCwP#G!Vt#D5mPBQn zHR~@R1#5C0+)tu3?WB8XFWz~YM%_+3d=UR7k|O-1kHN-ta65~#xENc$FJsYZXPsUe z{xM1h7(bTGm%=mkTbeXFf<%_zQ!$=_J#RA?H zu__kD(lg$wie<5ak=j?xVXZoJv!>S`-rl*_OSAZPB$M8bjN|P%lI{JrHzO6bNiZtv zI;xBvvoX*41V+uRoP}0yqqcJgwUfK3UBR~5bH3H|w5taFSn41QJ835iLtTpsg{sal z*7ND08}@d?eiA*6r7HTJES5&IO4pPTD2$)<6CCEKIN8;17phF;0%Va5surnQLiL>M zN*K12NU1P<&R!mD?)SR!X4)AJ<7D$rtPZnYe{-i3?e#Wq_qzQ7z7%e^J8{~MH+Nvt z&8@qihW8$}?mqZr6EYc(3Gul>5^oOL8Jf^m*d}#_8T1tv@@@2=%6WXBWKaelUY&d9 zH(WnQ$s1FRv?w^Bn$|j z5TPRQ>IE7q-K+}~YF9?-UTpkCH%ga!x4g2591FH z(8&3iO_+QW_LyRj9ay=ch_9`Ym9CQ?{J zOE)@cTgF|4gz#}kbu!hoPyV00fhA|<{yD~efuaohA@A_5^Q4_on}OUZX^AWdDMs4Q z(8v&Tj}bWp{)EL!;~Tk&LICD>9$6#L0NT!da)VKj2f2Od$aTEE+&|fOv<^CVa(mY{ zfFC;VcGBIZqut%47iGFW>jT;)I4WHy^Y0E)!}86_j2bo}ouvOPKR{I))wq!MN<+Iv z_kwv9Fb}g{*pJ|3*;y6-45$1WMUf>Ku(blC2YiVSo9C(WzufbbE`1#HKfpEuoXGB= zSbJ>c&V=J_Pgw3;Vk>OqW=;lZjlBspzs8OM%W3-xdZ1kwB<`)SHFynZ^M#>JUsp+( z;-3|`DcU7*fy4&wcC5gXCEYF1`WaS}WO_-_rZ}o}JipNREgI)-kh9U9t`{mT_1rL6 z3U_gMsas>5(_AKhg)=@vDMA`l8}L=W&WG2|Q}3+lPth1j`6{eEh5`svq_0CuH8TDs z&#Vkd#DQ}R-yK<*OB(#QCGaT+@e zE^*5c0|q zVw%eMDv*bfl6QW9Jr$ztqkrLC=#I2?Xno7oeIse{Eq>^J%Vk=8nZA0+utL%b;;g1! z1a0|S*rWM|UM80+Q!wKnfSI2EjY1rx%=WI7J*l0fgV69}>?l&M%hb9;l?m{3qRgg? z$ZjIQ>4;ewqzaSqFqXutlA@y|oXOu&Ma(6CNEP8vUZaW_4M}pGngl?uQ@6m2Nfb@m zSA_WzG~lJ4e3P~kAIrCBgGm+4-)dG*Xh{171{A`M$^VsJPU9zMatI|U)zzNCiPCmQ zJdwmTx?13fV%S+edJppyqD_{}5QgSK>u-aT8i@Z53pGlBOyDA5w1iB6@`3YdA5nN3 zZFf%56e)+)XDI(eGz2#k1*d>3UL%^y9U`m`K>xWV>*#^{g$1zHGy6+^RL2TFb|&mf z;S_r~#5_ex=jtVLjHySMfUR_e#)=qBq#*B6^)afZd!k)Qv6wQ!EK1}pT5*O`-p5#i zQj)d<_0+k;K-%w|$10~}DKm{z>_W<*M9C=RA7~T^F?`SvJVm-ghykwzZ zDG>OQS6=e#FL`t5uN$4-d_ig7&|a4}QDrxql~X~Xa8>R0WhX_RL@E(+Kik(fuqp3g zE0WyaAQkd+wB;ijsNsqdr3$H;ikr@fjkML5nx|(ivyVNB3gz4kPPz-robR5qhx5X3 z_)c?Y7o?d%T?dZ best_masked_value: + best_masked_value,best_child = masked_penalized_child_value,child + return best_child + + def _ucb_function(self,child,c): + parent = child.get_parent() + child_action = child.get_parent_action() + value_for_parent = parent.successor_value(child_action) + if c != 0: + exploration = sqrt(log(parent.get_visits())/(child.get_visits()+ 1)) + else: + exploration = 0 + unmasked_value = value_for_parent + c * exploration + valid_value = parent.get_action_mask()[child_action].item() + return self._value_after_mask(unmasked_value,valid_value), unmasked_value + + def _value_after_mask(self,value,valid): + return (value * valid) + (1-valid) * self.invalid_penalty + + diff --git a/planning_module/abstract_depth_first_search.py b/planning_module/abstract_depth_first_search.py new file mode 100644 index 0000000..42469d5 --- /dev/null +++ b/planning_module/abstract_depth_first_search.py @@ -0,0 +1,5 @@ +from planning_module.planning import Planning + +class AbstractDepthFirstSearch(Planning): + def __init__(self,model): + super().__init__(model) \ No newline at end of file diff --git a/planning_module/average_minimax.py b/planning_module/average_minimax.py new file mode 100644 index 0000000..88bc900 --- /dev/null +++ b/planning_module/average_minimax.py @@ -0,0 +1,31 @@ +from planning_module.minimax import Minimax + +class AverageMinimax(Minimax): + def __init__(self, + model, + action_size, + num_of_players, + max_depth, + invalid_penalty): + super().__init__( + model=model, + action_size=action_size, + num_of_players=num_of_players, + max_depth=max_depth, + invalid_penalty=invalid_penalty) + + def _update_node(self,node): + value = self._get_children_average_successor_value(node) + node.set_value(value) + + def _get_children_average_successor_value(self,node): + total_value = 0. + total_mask = 0. + mask = node.get_action_mask() + for child in node.get_children_nodes(): + total_value += node.successor_value(child.get_parent_action()) * mask[child.get_parent_action()].item() + total_mask += mask[child.get_parent_action()].item() + return total_value/total_mask if total_mask != 0. else 0. + + + diff --git a/planning_module/minimax.py b/planning_module/minimax.py new file mode 100644 index 0000000..98a9abf --- /dev/null +++ b/planning_module/minimax.py @@ -0,0 +1,143 @@ +from planning_module.abstract_depth_first_search import AbstractDepthFirstSearch +from model_module.query_operations.reward_op import RewardOp +from model_module.query_operations.next_state_op import NextStateOp +from model_module.query_operations.state_value_op import StateValueOp +from model_module.query_operations.representation_op import RepresentationOp +from model_module.query_operations.mask_op import MaskOp +from node_module.node import Node +import numpy as np +import torch + + + +class Minimax(AbstractDepthFirstSearch): + def __init__(self, + model, + action_size, + num_of_players, + max_depth, + invalid_penalty, + debug=False): + super().__init__(model) + self.max_depth = max_depth + self.action_size = action_size + self.num_of_players = num_of_players + self.model = model + self.invalid_penalty = invalid_penalty + self.debug = debug + + def plan(self,observation,player,mask): + with torch.no_grad(): + encoded_state, = self.model.representation_query(torch.tensor([observation]),RepresentationOp.KEY) + node = Node() + node.set_player(player).set_encoded_state(encoded_state[0]).set_action_mask(mask) + self._expand_minimax_tree(node) + if self.debug: print(self._transverse(node)) + return node + + ''' + Search + ''' + def _expand_minimax_tree(self,node): + layers = [[node]] + for current_depth in range(self.max_depth -1): + current_layer = [] + self._expand_nodes(layers[-1],estimate=False) + for n in layers[-1]: + current_layer.extend(n.get_children_nodes()) + layers.append(current_layer) + + self._expand_nodes(layers[-1],estimate=True) + + for layer in reversed(layers): + for n in layer: + self._update_node(n) + + def _update_node(self,node): + best_node,action,value = self._get_best_node(node) + node.set_value(value) + + ''' + Gets best successor of observation + ''' + + def _get_best_node(self,node): + best_masked_value,best_value, best_action, best_child = float("-inf"),None,None,None + mask = node.get_action_mask().numpy() + for action,child in node.get_children().items(): + valid = mask[action] + raw_child_value_to_parent = node.successor_value(action) + masked_child_value_to_parent = self._value_after_mask(raw_child_value_to_parent,valid) + if masked_child_value_to_parent > best_masked_value: + best_masked_value = masked_child_value_to_parent + best_value, best_action, best_child = raw_child_value_to_parent,action,child + return best_child,best_action,best_value + + def _value_after_mask(self,value,valid): + return (value * valid) + (1-valid) * self.invalid_penalty + + + ''' + Expansion and networks + ''' + def _expand_nodes(self,nodes:list,estimate=False): + encoded_states = torch.cat([n.get_encoded_state().unsqueeze(0) for n in nodes]) + actions = [list(range(self.action_size))]* len(nodes) + with torch.no_grad(): + rewards, next_encoded_states = self.model.dynamic_query(encoded_states,actions,RewardOp.KEY,NextStateOp.KEY) + + if estimate: + with torch.no_grad(): + values, = self.model.prediction_query(next_encoded_states,StateValueOp.KEY) + mask = torch.zeros((next_encoded_states.shape[0],self.action_size)) #shouldn't matter + else: + values = torch.ones(rewards.shape) * float("Nan") + with torch.no_grad(): + mask, = self.model.prediction_query(next_encoded_states,MaskOp.KEY) + + flat_idx = 0 + for st_idx in range(len(encoded_states)): + node = nodes[st_idx] + for a_idx in range(len(actions[st_idx])): + action = actions[st_idx][a_idx] + child_node = Node() + child_node.set_player((node.get_player() + 1)%self.num_of_players) + child_node.set_parent_info(node,action,rewards[flat_idx].item()) + child_node.set_value(values[flat_idx].item()).set_encoded_state(next_encoded_states[flat_idx]) + child_node.set_action_mask(mask[flat_idx]) + node.add_child(action,child_node) + assert (st_idx * self.action_size + a_idx) == flat_idx + flat_idx += 1 + + + ''' + Transverses tree, returns a summary string and asserts certain facts for consistency + ''' + #TODO: Fix Me + def _transverse(self,node): + def transversal(node,action,valid,current_depth): + string = "" + best_child,best_action,best_value = self._get_best_node(node) + depth_string = "(" + str(current_depth) + ") " + transition_string = "|action: " + str(action) + " |valid_bit: " + str(round(valid,2)) + if node.get_parent() is not None: + parent_string = "|parent_reward:"+ str(round(node.get_parent_reward(),4)) +\ + " |masked_parent_value:" +str(round(node.get_parent().successor_value(node.get_parent_action())*valid,4)) +\ + " |raw_parent_value:" +str(round(node.get_parent().successor_value(node.get_parent_action()),4)) + #assert node.get_parent().successor_value(node.get_parent_action()) <= node.get_parent().get_value() + else: + parent_string = "" + value_string = "|value: " +str(round(node.get_value(),4)) + " ba:"+str(node.best_action) + + string += "\t" * current_depth + depth_string + " " + transition_string + " " + parent_string + " " + value_string + "\n" + + + mask = node.get_action_mask() + for action,child in node.get_children().items(): + string += transversal(child,action,mask[action].item(),current_depth+1) + return string + + string = transversal(node,None,1,0) + return string + + diff --git a/planning_module/planning.py b/planning_module/planning.py new file mode 100644 index 0000000..38bb320 --- /dev/null +++ b/planning_module/planning.py @@ -0,0 +1,18 @@ +from node_module.node import Node + +class Planning: + def __init__(self,model): + self.model = model + self.info = {} + + def plan(self,observation,player,mask) -> Node: + ''' + creates a node with the observation and runs the search algorithm. + returns the node + ''' + raise NotImplementedError + + + + + diff --git a/planning_module/ucb_best_first_minimax.py b/planning_module/ucb_best_first_minimax.py new file mode 100644 index 0000000..b1b78b9 --- /dev/null +++ b/planning_module/ucb_best_first_minimax.py @@ -0,0 +1,95 @@ +from planning_module.abstract_best_first_search import AbstractBestFirstSearch +from node_module.best_first_node import BestFirstNode +from model_module.query_operations.reward_op import RewardOp +from model_module.query_operations.next_state_op import NextStateOp +from model_module.query_operations.state_value_op import StateValueOp +from model_module.query_operations.representation_op import RepresentationOp +from model_module.query_operations.mask_op import MaskOp + +from math import sqrt, log +import torch + + + + +class UCBBestFirstMinimax(AbstractBestFirstSearch): + def __init__(self, + model, + action_size, + num_of_players, + num_iterations, + search_expl, + invalid_penalty): + super().__init__(model) + self.action_size = action_size + self.num_of_players = num_of_players + self.num_iterations = num_iterations + self.search_expl = search_expl + self.invalid_penalty = invalid_penalty + + def plan(self,observation,player,mask): + with torch.no_grad(): + encoded_state, = self.model.representation_query(torch.tensor([observation]),RepresentationOp.KEY) + node = BestFirstNode() + node.set_player(player).set_encoded_state(encoded_state[0]).set_action_mask(mask) + for i in range(self.num_iterations): + self._search_iteration(node) + return node + + ''' + Main specific algorithm methods + ''' + def _search_iteration(self,node): + if node.is_leaf(): + self._expand_node(node) + else: + best_node = self._get_ucb_best_node(node,exploration=self.search_expl) + self._search_iteration(best_node) + self._update_node(node) + + def _update_node(self,node): + best_node = self._get_ucb_best_node(node,exploration=0) + action = best_node.get_parent_action() + unmasked_value = node.successor_value(action) + node.set_value(unmasked_value) + node.increment_visits() + + ''' + Expansion and networks + ''' + def _expand_node(self,node): + assert node.get_num_of_children() == 0 + actions = list(range(self.action_size)) + with torch.no_grad(): + rewards, next_encoded_states = self.model.dynamic_query(node.get_encoded_state().unsqueeze(0),[actions],RewardOp.KEY,NextStateOp.KEY) + if node.get_action_mask() is None: + mask, = self.model.prediction_query(node.get_encoded_state().unsqueeze(0),MaskOp.KEY) + node.set_action_mask(mask[0]) + else: + assert node.get_parent() is None + values, = self.model.prediction_query(next_encoded_states,StateValueOp.KEY) + + for idx in range(len(actions)): + action = actions[idx] + child_node = BestFirstNode() + child_node.set_player((node.get_player() + 1)%self.num_of_players) + child_node.set_parent_info(node,action,rewards[idx].item()) + child_node.set_value(values[idx].item()).set_encoded_state(next_encoded_states[idx]) + node.add_child(action,child_node) + assert child_node.get_depth() == node.get_depth() + 1 + assert node.get_num_of_children() == self.action_size #! + + + def _transverse(self,node): + pr = 0 if node.get_parent_reward() is None else node.get_parent_reward() + print(node.get_depth()*"\t"+ "("+str(node.get_parent_action())+")" + " value:" + str(round(node.get_value(),2)) + " r:"+str(round(pr,2)) + " t:"+str(round(node.get_value() + pr,2))) + for n in node.get_children_nodes(): + self._transverse(n) + + + + + + + + diff --git a/planning_module/ucb_monte_carlo_tree_search.py b/planning_module/ucb_monte_carlo_tree_search.py new file mode 100644 index 0000000..d32b073 --- /dev/null +++ b/planning_module/ucb_monte_carlo_tree_search.py @@ -0,0 +1,114 @@ +from planning_module.abstract_best_first_search import AbstractBestFirstSearch +from model_module.query_operations.reward_op import RewardOp +from model_module.query_operations.next_state_op import NextStateOp +from model_module.query_operations.state_value_op import StateValueOp +from model_module.query_operations.representation_op import RepresentationOp +from model_module.query_operations.mask_op import MaskOp +from node_module.mcts_node import MCTSNode +from math import sqrt, log +import torch + + + + +class UCBMonteCarloTreeSearch(AbstractBestFirstSearch): + def __init__(self, + model, + action_size, + num_of_players, + num_iterations, + search_expl, + invalid_penalty): + super().__init__(model) + self.action_size = action_size + self.num_of_players = num_of_players + self.num_iterations = num_iterations + self.search_expl = search_expl + self.invalid_penalty = invalid_penalty + + def plan(self,observation,player,mask): + self.player = player + with torch.no_grad(): + encoded_state, = self.model.representation_query(torch.tensor([observation]),RepresentationOp.KEY) + value, = self.model.prediction_query(encoded_state,StateValueOp.KEY) + node = MCTSNode() + node.set_player(self.player) + node.add_value(value.item()) + node.set_encoded_state(encoded_state[0]).set_action_mask(mask) + for i in range(self.num_iterations): + self._search_iteration(node) + return node + + ''' + Main specific algorithm methods + ''' + def _search_iteration(self,node): + if node.is_leaf(): + self._expand_node(node) + rollout = self._get_children_average_successor_value(node,strict=True) + else: + best_node = self._get_ucb_best_node(node,exploration=self.search_expl) + rollout = self._search_iteration(best_node) + self._update_node(node,rollout) + rollout = self._increment_rollout(node,rollout) + return rollout + + def _increment_rollout(self,node,rollout): + if node.get_parent() is not None: + if node.get_parent().get_player() != node.get_player(): + rollout = node.get_parent_reward() - rollout + else: + rollout = node.get_parent_reward() + rollout + return rollout + + def _update_node(self,node,rollout): + node.add_value(rollout) + node.increment_visits() + + ''' + Expansion and networks + ''' + def _expand_node(self,node): + assert node.get_num_of_children() == 0 + + actions = list(range(self.action_size)) + with torch.no_grad(): + rewards, next_encoded_states = self.model.dynamic_query(node.get_encoded_state().unsqueeze(0),[actions],RewardOp.KEY,NextStateOp.KEY) + if node.get_action_mask() is None: + assert node.get_parent() is not None + mask, = self.model.prediction_query(node.get_encoded_state().unsqueeze(0),MaskOp.KEY) + node.set_action_mask(mask[0]) + else: + assert node.get_parent() is None + values, = self.model.prediction_query(next_encoded_states,StateValueOp.KEY) + + for idx in range(len(actions)): + action = actions[idx] + child_node = MCTSNode() + child_node.set_player((node.get_player() + 1)%self.num_of_players) + child_node.set_parent_info(node,action,rewards[idx].item()) + child_node.add_value(values[idx].item()) + child_node.set_encoded_state(next_encoded_states[idx]) + node.add_child(action,child_node) + assert node.get_num_of_children() == self.action_size + assert node.get_player() >= 0 and node.get_player() < self.num_of_players + + + def _get_children_average_successor_value(self,node,strict=True): + # don't use penalty here + total_value = 0. + total_mask = 0. + mask = node.get_action_mask() + for child in node.get_children_nodes(): + total_value += node.successor_value(child.get_parent_action()) * mask[child.get_parent_action()].item() + total_mask += mask[child.get_parent_action()].item() + if strict: assert child.get_visits() == 0 + return total_value/total_mask if total_mask != 0. else 0. + + def _transverse(self,node): + pr = 0 if node.get_parent_reward() is None else node.get_parent_reward() + print(node.get_depth()*"\t"+ "("+str(node.get_parent_action())+")" + " value:" + str(round(node.get_value(),2)) + " r:"+str(round(pr,2)) + " t:"+str(round(node.get_value() + pr,2))) + for n in node.get_children_nodes(): + self._transverse(n) + + diff --git a/policy_module/__pycache__/compositional_adversarial_policy.cpython-38.pyc b/policy_module/__pycache__/compositional_adversarial_policy.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a56a9709fc367b9146289b4f1bf6fb10fb8cb614 GIT binary patch literal 1629 zcma)6Pmk0#6t@#6lgwo3vH}E!qN*1<95%2APE`fs-=&>a%%WDgSRwWfVe=0=S!qVY z0kINafyCHgPdJ0OCgbI>4B6>>>sc;252@_W`;k~2QgSeu^lfE3t;Ly23 za7F|oJYh*FhGKL=lHpszNHPMMqNq*0s`) zGT#E}71}5+h~5X)k~icpx&nABw(?FHB`aq|xU+IWgS@~IgC578ATx4CrxV|Jx@)8| z0q1#@x15_nV>5MoW9)-bqzkETZH(LS8`H8lG$UCZ=c=wsS+(jMpitj{GNF>9%P=P6 z3Hhv>pHUNXo)xLqoS%_@%lLU+%DBpxOIgGZq&{lvCO*j1c^zX$x>hPa%W_#|GCq)6 z#?yy?@L!)M51;%N_xN~Oi>{Dyww*wpe#{{6*X=b6H3m>2h_*zfyVQp-pvxOy&i>ve zCqNKi$+lBfMq!&xbHKp-+ZuciMqh(d@{0EPe&IT#&pyP!KRZw}jZbQCiGT0I4 zFT1qpXDn+Ko?=U{uUG8fWz_nW@WM{IN!o;31qp0d>t}mEEY8k&>GaX(8SSMTYh~jsFM(%&uxsA&J literal 0 HcmV?d00001 diff --git a/policy_module/__pycache__/epsilon_greedy_value.cpython-38.pyc b/policy_module/__pycache__/epsilon_greedy_value.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16a92133793195b09b5a86330c4429f1ddc80d7a GIT binary patch literal 2163 zcmaJ?OK%%D5GHrItJTXk9M?r2XcrBNR) zV=(y+x_$rF|`O^^%5zJZC5>3&X z6ELkUI`H2SU9krL-PeSXXbsHWFq^w0on~e6L`f+oKgHQtnwB0XiPXAOIsphT8_L*S z==vTE6*(ivbVg^+j0yVIrDW#L2zLa5yPZAou&JaNCskS&fPtsSBdJWld0M0u=cXa+ z$9tY(T4sl)EsN7sl|?R#$^@eXBQJKp@kxr*F z8$6c!q$)>)=V?4F2an2pG_IuL`$;N`L=K)yEeCtMKl7(AqTL^V7>sOecwUNeCI?GX z_|O8uPZ5d!XrkKiEHX;lsL)M1y?Y(9{)I~k1=Wj&vIE`Xg?Auc=X8PS(CJpFEk(k# zIFc%@Vz4pnISLg`6MUcWVVui%RCQrRZQ!tp!z~=PU|55$@d4sY@7-Y36nJb(V zXF*@mSFCEBHx>l+f(1G8)XS(8;v_ zf6(XUsqF7$WfEulAqe$BxY<3ClV01jQk@ps2HrTu$TSLJ%{MKZhr3FZD%u{(ieF>` zkF#M3Tsg`O!!_gm3fz$pS3s%EO&gV$Og3&F|7QH;s7w>t^DT=^dj(nJBSEB6+u()i zW^x#3yarjDHbTCf)$U`$Tkb8%iQXX}H8PVp9U4G#Y`|~ZX?=i?c#JorEhXwI z4e9LaN-0D@-2jbga1Q&Nf4*Ybog3g;s`aP}K(n`wrMZBb1EsmemxU4J4(Gp&2;#J|0&u1x_zNhjM|Sc d@C3fIV%k-_gPr;_@erA63rL4bL` zyqC434(#``m1xC$SEJRtgg3Zc>>7 zYForI8NgC|JI#lg_@&I!WZbp2^HZuS?LHX6q;~G4Nd-dA-8dKU?VRo3D!JH7@c88i z&BaIE2O>yBIf*R3`jqTaZgKml8QF$+xO?PBj&S!p?!BN!bth`@CJ&BCl>f21LKs9;aGOy-ZdD?#D!jrZ{yUwgo5npgS-X9_zppizB{Hf^c0<}hswYv zOpAd4OJmS7>{_x7g1QNQ7_&Y==7h>sn2{HdtRs0B$t588r|<$|O+GkJtFuEODo!4& z{p1u~xCxSN5!yfDr-D#cbk91p)OuxedPx7GQ*!Lg$gyhzZqMw$LFyT)JdpVoWGs}a z=y79eb9>*K(ckD3r}B^e8G)V68QFJaQ3X?Hfm!Z7x1P8~08*`~Gow%3XJqQ`THut% zV52#$f?TpkR+VcNg-e6~t*)h;eNnMwkY-#5$ktavFlY=|cF@E{NVR`6i}M{G-})Gi`xge8^xhcN zygU#aH?uN{Gj$7?dM6xh?~8G_tph-Cp$w?BRSb2b;Km<1cob(Nu_ef=cw1)r`Qoe^Lbu!V9 zGggOA=@vSC{#QHULy$7yo}(X4#GMLpPOLWdA^ifNYw%vOg13Pl_r4Y{4B~A(5RpG6 zdyt3nePAKQOpUn=e67kN&%S`U&tcSgMkgPg7m+2WBqkqW13t*(FqdKhfGi>m8rS?{ z`4!AT%9{i?DwvDUZr3ffaQ8KkeXABrvagwbrUl*tc@8D~kaBvtF4~ie*XJ=c-Js4C zz^@G*?~Sg}4aQ)dvFl5UT{+J^bGa^+=mP*Jbb|s)8+3z!%_}hQZN`2d#aX>#3a_bh z@)Ou3KSgo_2}VSIjpQ33x-;(xiO4d>mcnNWrm2-pkV6+MEUi9WaoWz>s&D(Y#Fg$l z`r26z)a_s+p7(lYKx(Q2Q>NWZ^rXJ>O83~P3puMNC6)#4RpTMFF^L6NVgr>3I_+Rd zmkngEVEWIk-&nfGRE0YCOdYH3$LMSSaV)VpDx>(4tqa8XH)|7aA=njmt~PDcbr_gG H@HYMj)Uv86l?irI^ zjW5_zPWcBMc8~rU{EfbH%3nZ2qN>-J>f3*P6+n| zJ+{LRcXr62^%}I10oA>oBr8+#ut<}5Jg~L%W1=eUKCWR>JNJ{gf=&Hzg{rX|4;|qic)ax|a zKJW29{I~fEUxok9OX85Q3*KJQH@04sMK1Cxi7LULCMv1!6?I;{reINbR2|1csY0sQ zMMusWj7MS2{4kM4o`HM%YMDlPp5!AIrK3V7)qW=XU|C|6PeDIp zaT+Pb*eUsEZ@6D%VwfkBiAaa{g?d&Mr@QIGCu4O0b(&Cv{{a24{W!gD-$Ssl}>_7@Z3DVYK*|1Hym5g z4UTvb#BoYzh#H-&2g>A7kmUOUZXAKP!N8JT*r?lp$C!REiAc6iW-4)h`SpsOi}3PM}a1MApQ>y^#vG5v!!Wa`Yw z)HMlU&+H=yb~%*?d%lD{7Va65Be$Z{R%3JfnKh%o(j%|&?vvkc9<>^G>Nn1`-Lz(O z+L@6V(D3;~*xha18EWs%$Tc|Mz;D{0kqx8kVC7)7@fw#quPo5vG_5@gfO98z11TVr zauwu2tDx8mid`k-VT9=yL^%(3fpmyOIb482H*AC7}jEi=EhvX32i^@{uT(|K=yq~08cW`JjR%v&xdljda zxTbAYXWEm1Ndf=yev!mt;G2Zf-6hnthiM|Dyn!|y#5L@fve=DwlQgLkp>#hLqbOx_ zR24X6;GJjWo3zv&yv`Wdbto`l)rTO6)uTSJrvuWVYap+~R}V&$+ZTXe)Thmbns*q5 ztTO3}_%!4I*mO#_fhMM2T!+cGoRXiw+->M{uF%P6msx!azSwXJ4r0SGc4HdJ3~V3_ zo|7FSH{nwbK>&RXS&b_Ol?61b>SAy{h5hgK=&dvFfMJ6Y? zyXaT+2l7k0_Ehu_dg=^i8#d4qAWd>cZ{EBa`p0L_E*TCr5Ayevu|Mf>I`EEtgVVpl zMKQ%2*7AaL+9yrYrbQa}X_K{+V$$YC-cE}tXYZKG)Z~e&$!|$9Q~3wBnEr(rY{7@@ zoos=A*H*zI9kTDV55wfP?;04gx4H@yd7h3VkFW3+Siu!562(>W#EKL$mHxqFysVf& z4pW$Ik}p}j&DHcND=yRw8C(qc54O?Own|RGX;1SFPET@t6&>LZ(0`H~U-iW(Cj80COYuGkOx4K%9^qX4F4TY!M{#dN@3px{ zGl;H{#^B)$nhaHKtp_26f=YYDUW(U2fUIia19UeSMQn7_P!wc|ttO;V!eFn7Uv#MB zHBxi7a~jzQZIE#rX&!^EKk?!7wGHoE>`AISsBfLKZo%CJsu;^YE?0OQE|BfEq*!Df zYc$zi?`w24RvB6kgZBSZn%&=`l7|eGEj>tVltXnWo8Cti&&UhAiMfs5$5f~?!5(EZ zLRM4}x(o6FmS>J*3yso6#V8o*X!`aJ+#$ps;_OqDecv1-2v+D^b9?8(Jw}~lMMxHe zT)XF}#3GBr4p%2$Lq56;Su6d+e{0@i1Am6!kDBM_nv-$8xTF_kur_5mT$cOyh5LC~ z{@lxEw75BfIC@;}nwAMIDK0~*O>i&iNN@dQp~|}{F68U!anATGeLjnKYy7jxs39We va7|L|LoaDRV&`GoVUIiNg)t<-M1Ye(F>dU$+1KO0!z*f>7;?_%{AKblg0_N= literal 0 HcmV?d00001 diff --git a/policy_module/__pycache__/simple_policy.cpython-38.pyc b/policy_module/__pycache__/simple_policy.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eac947e794955ed1de589c3b310b9e4ae3e4e8bb GIT binary patch literal 2889 zcma)8Pj4JG6t`zQv$LB`TGBNAhbl`*l~#(f;sU6Giqg=k2tuV4gjPkP>DZf{WM;gz zy(wAkUPvXbh%<-e*l)q-@Rbu^ffMg}HX#W;Fe`iB_}TCM_WkvD%gY@C&x5Yb{#YX9 zPy9IlF!=Z!hWP;|K?E&HMF*539h5;84nj|dC99%A}wt0ca+mB4rWOqE4`*b5!|x-iy7&nM$j zI<{5h7TAf9a-6I2VUFj|KZqM2vjvkR11iWM5G%isAr+xuZ`y-UG7|jy=!#`{C%+Iz25rE2#dUXz3bZ$nT*plETpFXad7MXHSu-Jx>O&aj8BCTO zkz+ceCt$W&Fe7Oos0c*38v-59VBQP&V{_j!!EA(INAw0cqCZpQrSXix+$spni68(a z+ICDUWi|TFv0Bw~dQ(d=$!(!(UP~zqw~D`%yr_F^7n#XeYL}#GQ5QB%-Rii^>bj_R z)2!T8y0D{auy!UORj@llQO2!2G)8Kqztmc37aFTw`)sR=gnT*Kb*!j|suwwCWI1$8 zvOXxZsw-Jrmz?U-tAJKtgXt}G?KI8H%$PKNOa9LLBUQ~Bf4Z`HW}yvTM{ z|BV|&SQdl4IKuLE@8@rF`qVsqbAFrGcO6(O89AR|D$5-)tMvmW`QHZthBzzIfTl2(YSe+#&^aY#Kd3zqs zJAUmFS~z)a*768i$7?fYuLHvMBYF~m)c=7McLVN50^7R*H0bnaZvx$3(6@k&7W8eP zJ4nA0GX~KeW;em8TpdWQ3n9536^2*Rj+EfV5TYBiF4Kn_{N=&5r(=yixl7XepVx4J8BnoqQb983Z7 zArycsxR&`Mkye@6*8v{zt3?f2mYPvEmKU~}Vk#S1^`uJGa3RXLu6hZD;b2iFVt_4c z;aZI*<06!Ae75TZwLm7-0x(2glZ)~aY~|XxA5!)JYC77;byHtkWQtaC3H*~34dud8 z)-J+D<66yM07A>xnHHc;Km@=G4Q6^v`Z{X7cKYwnA$9S>6b$h>=4CTH?Lu%j)>2R# zHwq9l%GYZkPTxM8?x6GrIi5$ZxepVGIuw6*p|IeO`Qp-{iT|zAHRJ}JV0FQ7c6I3r zovvRfK;9@QI!6bcKG;ui__~9A1@l+X{XnnHDApfJ>QDSy0Baw?8szpC`LPStC}{Yv z?$zslM+iH%1Z0~HMTDIO>0Uhjkmrz72C6atx4m&S^r4Bi*}%SHvI7YjSt)(UHI_r) zh;$dI?djdK=D^ST2!{+V>VvE-M34DC;I0a3Av4mUR#j@XTtpev^HM;lc>f#M`93Qr z(zl*o=-sN((hw}rd@s2~ZvY3m<|<4iXq# z%i4ZR^%eN;(FVt*cv#2W0>@oytB>J?@m{3U4=x~h_x=kW$8$J57VKS6(@6?fuN|Vl zfKQE#UeMsVfeV*GVgCFV)Kl2{Zcy*u4=Pam<6Z~v5H$oQ+#x#YqSXT3P19GCtZX>G z)B4h)KSNCD7LC_1{UuCpdC~cES(a&f4)IOgg1M#JFySz$HK8k=m3v7y3H|zZZ&`ng zo6s=sD#qjj$0pJ`gyq57g}cCJ^QG}p=3;v|hDxrngKPA!Zs9V+Tcod4UduUH-lRU| UlTVww#RH5FkFCkzHQHtW08wQR*#H0l literal 0 HcmV?d00001 diff --git a/policy_module/__pycache__/visit_ratio.cpython-38.pyc b/policy_module/__pycache__/visit_ratio.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19d4c164eeb0f0f15cbd31d2ab755080ae61623d GIT binary patch literal 2153 zcmZuyOK;pZ5GEx`t9{u{;wC|ygHW_U_YiN9Qx8Q^G;!NQ;|6ik7D5X`7U{Jdy+TTE zy^t@lft>OWde{Ye@n7P#r~HMUI>TMZn+FAsC~`>7hQl6#_VT=p3!jjG@nf|y z`1lsO{t<*Af@UPA5v4eGGADNcg|fC^Wz^KRtYG1?JCOnOJY==@DZAiC!)>WH4`pOYvM12Kem zSFDM3c=uirMxs7gd(HHINp)I3iR-j{L*WwB(@l~{txKg75AR+!9AkH(>$@Ooaz>75 zOIxR9g1&YsY2B7^M-VvJ*#+;qN{S|dlMC?Q(@iCn2{=!SwC3FOYnj7LTsKONJj1li z4oqJb$EhldTo$#tUS)Anq{W2C*`!oyeVD5u+^jIM7jQo3Nfv9(`33p!xA9?_%W;uT zr!pHql=^vHR^z8>JSoQy%Dif7srX)!$|8~Dr&7!D?#{3L(X(jh$9TvcnusjIMeJ#~7Ca4JG zLWKK3qcFBm!{xM|-hBk#hx)K=GF%ShuxRq?B#f1cPo_INe^q4x7)#??iqJ+dd{8z; z9qz{(W(p~VhE4Y110-m83C8dWj2KXeG_6J(rU#%Yw9Tb)ipq2fVebNl#aSbFlq!|E zIgvHLJm?D5hT)vT+!*gqz=Sk@Tvf6Vri(k0!!#47heL}^ZM?o~Q-Bjf5OH@ zn0(I)^%=~43Vp!|I=!>PiT6U3SlWOtHXLKOw;@jf^@PJ8V2`M6_*7#Mz)Z_Sd%N}g z%KdfyZUR0B{f`;(p4ta4yRgq1u>k`hoGlh-RM1 Game: + ret = super().play_game() + self.player0, self.player1 = self.player1, self.player0 + return ret + + def play_move(self,observation,player,mask) -> Tuple[Node,int]: + if player == 0: + ret = self.player0.play_move(observation,player,mask) + elif player == 1: + ret = self.player1.play_move(observation,player,mask) + return ret + + ''' class specific ''' + def get_players(self): + return [self.player0,self.player1] + + + diff --git a/policy_module/epsilon_greedy_value.py b/policy_module/epsilon_greedy_value.py new file mode 100644 index 0000000..bf10c8d --- /dev/null +++ b/policy_module/epsilon_greedy_value.py @@ -0,0 +1,39 @@ +from policy_module.simple_policy import SimplePolicy +from typing import List, Tuple, Dict +from node_module.node import Node +import numpy as np +import random + + + +class EpsilonGreedyValue(SimplePolicy): + def __init__(self,environment,planning,epsilon,reduction='successors'): + super().__init__(environment,planning,reduction=reduction) + self.epsilon = epsilon + + def play_game(self): + ''' simple inheritance ''' + return super().play_game() + + def play_move(self,observation,player,mask:np.ndarray) -> Tuple[Node,int]: + ''' returns node of observation and best action of it''' + if not isinstance(mask,np.ndarray): + raise ValueError("Mask should be a numpy array") + node = self.get_planning_algorithm().plan(observation,player,mask) + legal_actions, = np.where(mask == 1) + if random.random() >= self.get_epsilon(): + best_action = max(legal_actions,key=lambda a:node.successor_value(a)) + else: + best_action = random.choice(legal_actions) + + return node,best_action + + def get_epsilon(self): + return self.epsilon + + def set_epsilon(self,epsilon): + self.epsilon = epsilon + + def __str__(self): + return str(self.get_epsilon()) + "-epsilonValueGreedy" + diff --git a/policy_module/epsilon_greedy_visits.py b/policy_module/epsilon_greedy_visits.py new file mode 100644 index 0000000..b93aeaf --- /dev/null +++ b/policy_module/epsilon_greedy_visits.py @@ -0,0 +1,46 @@ +from planning_module.abstract_best_first_search import AbstractBestFirstSearch +from policy_module.simple_policy import SimplePolicy +from typing import List, Tuple, Dict +from game import Game +from node_module.best_first_node import BestFirstNode +import numpy as np +import random +import warnings + + + + + + +class EpsilonGreedyVisits(SimplePolicy): + def __init__(self,environment,planning_algorithm,epsilon,reduction='successors'): + super().__init__(environment,planning_algorithm,reduction=reduction) + self.epsilon = epsilon + + def play_game(self): + ''' simple inheritance ''' + return super().play_game() + + def play_move(self,observation,player,mask:np.ndarray): + ''' returns node of observation and best action of it''' + if not isinstance(mask,np.ndarray): + raise ValueError("Mask should be a numpy array") + node = self.get_planning_algorithm().plan(observation,player,mask) + if not isinstance(node,BestFirstNode): + raise ValueError("Epsilon Greedy based on visits needs a best first planning algorithm") + legal_actions, = np.where(mask == 1) + if random.random() > self.get_epsilon(): + best_action = max(legal_actions,key=lambda a:node.get_child(a).get_visits()) + else: + best_action = random.choice(legal_actions) + + return node,best_action + + def get_epsilon(self): + return self.epsilon + + def set_epsilon(self,epsilon): + self.epsilon = epsilon + + def __str__(self): + return str(self.epsilon) + "-epsilonVisitGreedy" \ No newline at end of file diff --git a/policy_module/policy.py b/policy_module/policy.py new file mode 100644 index 0000000..cdb4bff --- /dev/null +++ b/policy_module/policy.py @@ -0,0 +1,23 @@ +from game import Game +from node_module.node import Node +from typing import List, Tuple, Dict + +class Policy: + def __init__(self,environment): + self.environment = environment + + def play_game(self) -> Game: + """ override this method if necessary, but this one should be good enough + for most applications. It iterates through the environment, using the planning + the choose a decision. At each step, it fills the game and node with the appropriate information + """ + raise NotImplementedError + + + def play_move(self,observation,player,mask) -> Tuple[Node,int]: + """ calls the do_search method and uses the information of the search + to choose an action to take. + returns the node and the number of the action + it should return a node with the observation and player """ + + raise NotImplementedError \ No newline at end of file diff --git a/policy_module/simple_policy.py b/policy_module/simple_policy.py new file mode 100644 index 0000000..265a7e9 --- /dev/null +++ b/policy_module/simple_policy.py @@ -0,0 +1,76 @@ +from game import Game +from node_module.node import Node +from policy_module.policy import Policy +from typing import List, Tuple, Dict +from copy import deepcopy + + + +class SimplePolicy(Policy): + def __init__(self,environment,planning,reduction='successors',debug=False): + super().__init__(environment) + self.planning_algorithm = planning + self.reduction_operations = ["root","successors","none"] + assert reduction in self.reduction_operations, "reduction needs to be in" + str(self.reduction_operations) + self.reduction = reduction + self.debug = debug + self.info = {} + + + def play_game(self) -> Game: + """ override this method if necessary, but this one should be good enough + for most applications. It iterates through the environment, using the planning + the choose a decision. At each step, it fills the game and node with the appropriate information + """ + current_observation = self.environment.reset() + player = self.environment.get_current_player() + mask = self.environment.get_action_mask() + game = Game(self.environment.get_input_shape(),self.environment.get_action_size(),self.environment.get_num_of_players()) + game.observations.append(current_observation) + game.players.append(player) + game.masks.append(mask) + done = False + while not done: + if self.debug: + env = deepcopy(self.environment) + env.render() + input("Press Ok for next step") + env.close() + node,action = self.play_move(current_observation,player,mask) + current_observation, reward , done , info = self.environment.step(action) + player = self.environment.get_current_player() + mask = self.environment.get_action_mask() + + self._reduce_node(node) + #set info + node.set_game(game,len(game.nodes)) + game.observations.append(current_observation) + game.players.append(player) + game.masks.append(mask) + game.nodes.append(node) + game.actions.append(action) + game.rewards.append(reward) + game.dones.append(done) + game.infos.append(info) + return game + + + def _reduce_node(self,node): + """ cut some successors of the node, to save memory """ + if self.reduction == "root": #cut root successors + node.detach_from_tree() + elif self.reduction == "successors": #cut successors' successors + for succ in node.get_children_nodes(): + succ.detach_from_tree() + elif self.reduction == "none": #don't do anything + return + else: + raise ValueError("reduction operation is invalid") + + def get_planning_algorithm(self): + return self.planning_algorithm + + def set_planning_algorithm(self,planning): + self.planning_algorithm = planning + + diff --git a/policy_module/visit_ratio.py b/policy_module/visit_ratio.py new file mode 100644 index 0000000..5c48be3 --- /dev/null +++ b/policy_module/visit_ratio.py @@ -0,0 +1,41 @@ +from policy_module.simple_policy import SimplePolicy +from typing import List, Tuple, Dict +from node_module.best_first_node import BestFirstNode +import numpy as np + + + +class VisitRatio(SimplePolicy): + def __init__(self,environment,planning_algorithm,temperature=1,reduction='successors'): + super().__init__(environment,planning_algorithm,reduction=reduction) + self.temperature = temperature + + def play_game(self): + ''' simple inheritance ''' + return super().play_game() + + def play_move(self, observation,player,mask:np.ndarray): + ''' returns node of observation and best action of it''' + if not isinstance(mask,np.ndarray): + raise ValueError("Mask should be a numpy array") + node = self.get_planning_algorithm().plan(observation,player,mask) + if not isinstance(node,BestFirstNode): + raise ValueError("Exponentiated Visit Count based needs a best first planning algorithm") + probabilities = [] + legal_actions, = np.where(mask == 1) + for action in legal_actions: + probabilities.append(node.get_child(action).get_visits()) + probabilities = np.array(probabilities) ** (1/self.get_temperature()) + probabilities = probabilities/probabilities.sum() + action = np.random.choice(legal_actions,p=probabilities) + return node,action + + ''' instance specific ''' + def get_temperature(self): + return self.temperature + + def set_temperature(self,temperature): + self.temperature = temperature + + def __str__(self): + return str(self.get_temperature()) + "-exponentiatedVisitCount" \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..9a83bd7 --- /dev/null +++ b/test.py @@ -0,0 +1,188 @@ + +from environments.cart_pole import CartPole +from environments.minigrid import Minigrid +from environments.tictactoe import TicTacToe + +from planning_module.minimax import Minimax +from planning_module.average_minimax import AverageMinimax +from planning_module.ucb_best_first_minimax import UCBBestFirstMinimax +from planning_module.ucb_monte_carlo_tree_search import UCBMonteCarloTreeSearch + +from policy_module.epsilon_greedy_value import EpsilonGreedyValue +from policy_module.epsilon_greedy_visits import EpsilonGreedyVisits +from policy_module.visit_ratio import VisitRatio + +from utils.optimization.simple_optimizer import SimpleOptimizer +from utils.storage.proportional_priority_buffer import ProportionalPriorityBuffer +#from utils.storage.uniform_game_r_buffer import UniformGameReplayBuffer + +from loss_module.monte_carlo_mvr import MonteCarloMVR +from loss_module.offline_td_mvr import OfflineTDMVR +from loss_module.online_td_mvr import OnlineTDMVR + +#from loss_module.online_td_loss_constrained_state import OnlineTDLossConstrainedSate + +from utils.storage.uniform_buffer import UniformBuffer + +from model_module.disjoint_mlp import Disjoint_MLP +from torch.utils.tensorboard import SummaryWriter, writer +from datetime import datetime + + + +from math import sqrt +import os, psutil +import gc +from pympler import asizeof, tracker +import datetime + + + + + +now = datetime.datetime.now() +time_str = now.strftime("%Y-%m-%d-%H-%M-%S") + +''' environment ''' +environment = input("choose environment:\n 1 - cartpole\n 2 - minigrid\n 3 - tictactoe\n") +if environment == "1": + print("Cartpole chosen") + environment = CartPole(500) +elif environment == "2": + print("Minigrid chosen") + environment = Minigrid(max_steps=18) +elif environment == "3": + print("Tictactoe chosen") + environment = TicTacToe(self_play=False) +else: + print("couldn't understand choice. Choosing cartpole by default.") + environment = CartPole(500) + + +experiment_name = str(environment) + "_" + time_str +writer = SummaryWriter(log_dir="logs/runs/"+str(time_str)+ "_" + str(experiment_name)) + + + +model = Disjoint_MLP( + observation_shape = environment.get_input_shape(), + action_space_size = environment.get_action_size(), + encoding_shape = (8,), + fc_reward_layers = [300], + fc_value_layers = [300], + fc_representation_layers = [300], + fc_dynamics_layers = [300], + fc_mask_layers = [300], + bool_normalize_encoded_states = False +) + + +action_size = environment.get_action_size() +num_of_players = environment.get_num_of_players() + + +''' planning ''' +planning = input("choose planning:\n 1 - minimax\n 2 - averaged minimax\n 3 - BFMMS\n 4 - MCTS\n") +if planning == "1": + print("Minimax chosen") + planning = Minimax(model,action_size,num_of_players,max_depth=3,invalid_penalty=-1) +elif planning == "2": + print("Averaged Minimax chosen") + planning = AverageMinimax(model,action_size,num_of_players,max_depth=3,invalid_penalty=-1) +elif planning == "3": + print("BFMMS chosen") + planning = UCBBestFirstMinimax(model,action_size,num_of_players,num_iterations=15,search_expl=sqrt(2),invalid_penalty=-1) +elif planning == "4": + print("MCTS chosen") + planning = UCBMonteCarloTreeSearch(model,action_size,num_of_players,num_iterations=15,search_expl=sqrt(2),invalid_penalty=-1) +else: + print("couldn't understand choice. Choosing Minimax by default.") + planning = Minimax(model,action_size,num_of_players,max_depth=3,invalid_penalty=-1) + +''' policy ''' +policy = input("choose policy:\n 1 - epsilon value greedy\n 2 - epsilon visit greedy\n 3 - visit distribution\n") +if policy == "1": + print("Epsilon value greedy chosen") + policy = EpsilonGreedyValue(environment,planning,epsilon=0.05,reduction='root') +elif policy == "2": + print("Epsilon visit greedy chosen") + policy = EpsilonGreedyVisits(environment,planning,epsilon=0.05,reduction='root') +elif policy == "3": + print("Visit ration chosen") + policy = VisitRatio(environment,planning,temperature=0.05,reduction='root') +else: + print("couldn't understand choice. Choosing Epsilon value greedy by default.") + policy = EpsilonGreedyValue(environment,planning,epsilon=0.05,reduction='root') + + + +''' loss ''' +loss_module = input("choose loss:\n 1 - Monte Carlo\n 2 - Offline TD\n 3 - Online TD\n") +if loss_module == "1": + print("Monte Carlo chosen") + loss_module = MonteCarloMVR(model,unroll_steps=5,gamma_discount=0.99) +elif loss_module == "2": + print("Offline TD chosen") + loss_module = OfflineTDMVR(model,unroll_steps=5,n_steps=1,gamma_discount=0.99) +elif loss_module == "3": + print("Online TD chosen") + loss_module = OnlineTDMVR(model,unroll_steps=5,n_steps=1,gamma_discount=0.99) +else: + print("couldn't understand choice. Choosing Monte Carlo by default.") + loss_module = MonteCarloMVR(model,unroll_steps=5,gamma_discount=0.99) + +''' optimizer ''' +optimizer = SimpleOptimizer(model.parameters(),model.get_optimizers(),model.get_schedulers(),max_grad_norm=20) + +''' storage ''' +storage = input("choose storage:\n 1 - uniform buffer \n 2 - priority buffer\n") +if storage == "1": + storage = UniformBuffer(max_buffer_size=1000) +elif storage == "2": + storage = ProportionalPriorityBuffer(max_buffer_size=1000) +else: + print("Couldn't understand choice. Choosing uniform by default.") + storage = UniformBuffer(max_buffer_size=1000) + + +episodes = int(input("how many episodes?\n")) +updates_per_episode = int(input("how many updates per episode?\n")) +batch_size = int(input("What's the update batch size\n")) +scores = [] +for ep in range(episodes): + game = policy.play_game() + + + #! log score + score = sum(game.rewards) + print("episode:"+str(ep)+ " score:"+str(score)) + scores.append(score) + scores = scores[-100:] + writer.add_scalar("Score/avg_100_score",sum(scores)/len(scores),ep) + writer.add_scalar("Score/score",score,ep) + + + #! store game + if isinstance(storage,ProportionalPriorityBuffer): + new_loss, new_info = loss_module.get_loss(game.nodes) + storage.add(game.nodes,new_info["loss_per_node"]) + else: + storage.add(game.nodes) + + + #! learn + for lep in range(updates_per_episode): + nodes = storage.sample(batch_size) + loss, info = loss_module.get_loss(nodes) + optimizer.optimize(loss) + if isinstance(storage,ProportionalPriorityBuffer): + storage.updated_priorities(nodes,info["loss_per_node"]) + writer.add_scalar("Loss/loss",loss,ep) + writer.add_scalar("Loss/loss_value",info["loss_value"],ep) + writer.add_scalar("Loss/loss_reward",info["loss_reward"],ep) + writer.add_scalar("Loss/loss_mask",info["loss_mask"],ep) + +writer.close() + + + diff --git a/utils/optimization/__pycache__/Simple_Optimizer.cpython-38.pyc b/utils/optimization/__pycache__/Simple_Optimizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cde69663e4a1f77e399ab784a96323bfebceb85 GIT binary patch literal 1123 zcmZ8gPjAyO6t|tEO+&ktjRTrs;)ImTa^!*lu}z%lB(@1D5;7q+Te|$4?Fy*U3+yv+ z;*cJ>@GbZfJ|S`1SK!2R(stEe<>wdMzn|ap`}K4%@DSM7-@DOw3!%T}W8Gj4y?`Ni zz%j&dibC`XF~Y1X#H?>5#0(!J-@b-@G>VO1w-tUDRyl&xC?u@g;Bja%n>kk~v{?r@ zT)lamWM#^aN|j{El#A9Gn*!5AtL>)pKV;-Dv$%E11Tl)GmJbZH$6m7ndN{XCj1)HQiXe*(sQufL(+NCtf6Gf@HS(+RW z?%?VDTUGBZkHx0arI|Dhq^bND1Qi;iF|NrR9iVsp8L3EZ%@IbI{R-E%iIg+$jJvhF zIN6&Uj5jCP%y(9(euH|+@XWe73A>MHc4dzT^~M~HH|rrIb3Ah@t8(hQb5OuGdV*%w z<@?HF);2^7?A9WmPSvSf<{y=16fzPRvc5JooxuS6^W5KVTA>|yHJ*kWXI#<6YjPoV z4>V9T?bGguD8Bd<38o!Yi1=KSJlD=dC8^Y#ahjAjU8dSABN1gBdM2ALhu#{gEfp`d zZA76vjH?L9Y!<0V+eRbWP884DRAhY4O*e(~yP^XbYdbBZ)J{XEJyobEZIq>P8Me1J zpJ8tWZS>VT*Dws1sZha9J+cR@@&PzRhSal7YtAGoOp}b*8j-r#vzry4h>hj$8loX+ QYIC2q>VIbN(m0p=2e^_RX8-^I literal 0 HcmV?d00001 diff --git a/utils/optimization/__pycache__/simple_optimizer.cpython-38.pyc b/utils/optimization/__pycache__/simple_optimizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..926b4d3e7622c7728426db134efd671e26fa1bc2 GIT binary patch literal 1198 zcmZWo&u`N(6tvl{GR>3FDKn@fM7hmc@sZ(5&CJI<>7$y3Z{Gn zz!1Y3iqLDs2y?CxbFN5)89qWi_ZGgReqwNWK+b0W6EHY}XcUnF$RgauQGmEsWeZ=T-@fKKfIy0r`FF$8D9{a#B4sZ0y0;v(VN z%~Gj!P{kt7xq{-tTvM2;x|_$J=|IGc7Ny90o^~bA&V+9Q?xI8o+#4iT;PvY?JTG$| z7U_7*vv8lw3sqL(aT*WGaIeg(5mcZ@Ny?#gc+4da`v>po+mq9|)Q%rPQ8fO$Vw}ab^{y zaTPU6vu@P6le|?=1FI&3zhJxzp4(#uhtHD830gD~~PwgIE(B{0VlrqxC<(twX2_;#g7#Kz+5(hi%Txy^1< L@`c$i4fy0YwGJVx literal 0 HcmV?d00001 diff --git a/utils/optimization/simple_optimizer.py b/utils/optimization/simple_optimizer.py new file mode 100644 index 0000000..21a9c8d --- /dev/null +++ b/utils/optimization/simple_optimizer.py @@ -0,0 +1,29 @@ + +import torch + +class SimpleOptimizer: + def __init__(self,parameters,optimizers:list,schedulers:list=[],max_grad_norm=20): + if not isinstance(optimizers,list): optimizers = [optimizers] + if not isinstance(schedulers,list): schedulers = [schedulers] + self.parameters = parameters + self.optimizers = optimizers + self.schedulers = schedulers + self.max_grad_norm = max_grad_norm + + def optimize(self,loss): + if loss.grad_fn is None: + return 0. + for optim in self.optimizers: + optim.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.parameters, self.max_grad_norm) #clip gradients to help stabilise training + for optim in self.optimizers: + optim.step() + total_norm = 0 + for p in self.parameters: + param_norm = p.grad.detach().data.norm(2) + total_norm += param_norm.item() ** 2 + total_norm = total_norm ** 0.5 + for scheduler in self.schedulers: + scheduler.step() + return total_norm diff --git a/utils/storage/__pycache__/priority_replay_buffer.cpython-38.pyc b/utils/storage/__pycache__/priority_replay_buffer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8f04f86127e18e5471001daa10750f4af081250 GIT binary patch literal 1942 zcmZ`(OK%)S5bo}I?0CK2O&mjP84wCqV3F`jAQ6zP2$4j&z!Eu*kXE44c)E8z>v^TS zS9n*mCmV@N{)2tw2k>Y58ZJI^L!7AUeZ-EO9(8r~tE%U#uj<9xTAe^Udax7S3kZ3Q zldH{v$#>Az*C06IG$k4BQHnZCS&!M#>A80FdfrpQ9qyhH?w+!q&)IX*_AcQrX;b4q zjWW@Vadx#WeS8C}g7hdSV0Vo(?wrvcSmWLq0S`W}!RYZO5BLh~_`J?nVWq|!;H70Y zp2?(;i5~BZGL6RHkNSNfFDV#}-7Ef3XV4!(SD%3xkcx!xLm+8IL9&X0bSej=%W=iy z1YEGLL~Cu|_=ncGsc9J{xsaxnMJM5GF;vM!m}WVDOvptXjm_=lT%43eE^=+0e3Z3a z<0_H%Wep5Vhyal*ApRt>0mFa1b5vxalP8mjNIQ>(I@U$mIY^>G(Roy4OpNNBI3L;sX=cI?U@xHENZ4BW0n2bUfS3}d@yTBDLjT7-)~l(=d9s4PX!B@*1i zYaBJoOiha4BFR5!c=xlyrm=YmWbkc|%A&C5irDfz=-(B_y;0ag*0xWMFukO7nXW$$3i0C zgPZbw6v%})Xvo_z#f(w5XwU`;Xq`6LWambPycLxgqqS&AcZ`xH$^?`F{5>E7L^!37 z$*=c*VN+-70xqm_ktj%tp?e1~xJ!^gFPqvM`W4g&?!IDENLlTj*MMkrn@l|&R8+6b ztevtcg|qeHD$aq79`{~3Q(reKpckO#cjXqo;Gd({fBg0D-w#c#pB9n+>PskD9@pbi zfI`dmnpqK$vO)saEXCwHT$3N6z@#-SFO3&T8I9!~oI6>P8z&XH@nw|rA~Sw`R3s1x zH%*jo*JccCfzdJ|t>UAj&`0M<6?Enx3t`^KQnZpB#V77g`f!NkZf;sn^yZkx&bux0^i(;0!T z)5Zh?p>0+C77i*)gLkFQ;v||y0$TxQEtvzXR5@Z3ntC{|YoIs6@Z~5iWv@FYx`V>C%y2rUR&yS6D)B(u8$ zVu89uzUCkF;2iy9dg*bmJrp_i8uZW^QgS5QF$E2W+*xw?%{SlflbM+^f#d$G2f?2+ zg#3*c2t0>2O}4@;X75s%f0oyXmlVY?@VIti;Z<8CC_)J&2iQ zbGue^ik)KfyQEoWr`ZCWr`Z{{2hkcLLoQoL z+|hxg!w=6(^eZ?SPio+=9Ox9FV zTjw==f($y`v$PrUSXXaJi~{1`6Cx4X7D2qp^^_kl=Hm#0ksbeYc4-o+DuhABBWCm=<55!M{iCO7~J0GSyGMn>9n&&cRYioQkK zCZpG(NqkcQ``Z9FSbN(NxARDc85lV+M&-|{^c>UydILbk*Nqxv&hMryQ2c;X!iG`p z1Zm1+CZ;iT9u1aUn`x4Y)8Or!IXv1zth767k3$HFGZ2beClO0REAXoVLKWJ(bOJ^P zMQEWEj1KIs1K#bz#b&{>90K6d<=>3H*#}8zgFv*APo!;~Mr$vTzMUC+PG%!CPphib zc7e6P$QMQ*1VxqhrT}buj`ST>&ZwG3C~Y^GV6uacU3iW2Pn<< zsxECrNuaJ>^{BR)+DZ9+Xd;^GlHh3+w755`r#R4;3&6R6*9zl}L{F~`rRN8}Qz!@$ zNni(SBThkzgaAb{Ww)bEP@=pL+5xJvq@$hIRuZpdx)~h*-MHv+o+E<$) zdnPX6voGTK{BS&357xsdR1gAl^8WIZjhz0nAU>+t7~vggp)}%G;jTnd5NgqR;5=3p z>$*sL@0?KAV)g{{C&UDqqKE>W2k76(191P4JXDa$`*hVa#RAMJ2)csaBAP>4&le15 zV~BhOVG`d>uwOquA?KrWC-OhixhQywwvB-y7Jzc*OCTu{F{99Xmdkl%R*AIGw}b^<(a zn|~xbaL+1M7|gryu+uzfS6Zw(G{5UranlmRqU`&+?E9UB zbtCjEzW)@MU5pguoW)07L9>homlaV*^9h=~^FK%LI+`!gyx~gj;d(8QA~;{7XtrmG*-RbPF*kCv9&1lliueU|*z zA>KEQ8HwsQjV$wRNA_L9EM`9?%-**mhtUTlbY8%CG8~~?Yz}-} zhev)3f|H0c0_&_Wn>kNuWHXm}PYJByvjE;MJI7kA4Hl1`VIAo4*%Ir*J77Jw4DS{@ z3%jlA_cugQ7DA;(o@5&$Ekvqz@AEQCc5lwMwzzmfVZ%TFJ^cIkdz%|yk39Hl8tN82 zx8RXqf*6yE#BfTGw4xxb$^vOuHb|#(K)RI+(qm}#837-xp};M|Kzo~JN9`+-q&XM5 zH%*?zb!RNoJ+8auk%uwkqhwdRNmgzr+S%kP(aX~`Pp7kK+ziroKAVP~b|lZXgb$mD z07VN0rcay!aY#fLUjLMX?PAIYdAhg9v%zgHcT`ah9;V5-7~CqRKpFL!ki$KbJw;|s9ATn0g? z_tI}KJS{!z%^D*_T+>iT#F)%Ppj705Kt`Hw*Vl#Rb5*gbHGBlz4 zkei9S|!9eV@W34O`I;322*kti8*p;_5^iM_A}h!vezMH}LuulrcmFB4xmHoBVX~ zM{91+q2MPFs|(m~pF`>N$sA&I03?8rdfZjs#IGzQ!DrSS$U=Ds0f3agL*|}pRaCX> zHTSJKg|TNQ9kryo$Gr$uT6xT2?lWiZ*CTMoLT{j!s$d_f)f-$lEZq@#yNQ!Fl}}ulR=?Y5aQ^6h@nH z^?@(?b?~fx+OPQ`>L2j{+<9z~9wIL`H(qE1=9H-KG;^K3?s$G2!=O06yKtV1r=I6% z(lB&tUPs4k4Q~_(Cw7EwyuEUSY^>XHJT2HPL%kEnk7r3%TLj)|0&h1%L}DGa&ru*E z1^#7NND1D;0=Zov8W8jnyMVFUK@j-uUSJ2dv4yS(!J=KrUAZeh23=rSjq!CW>fl=K P;a3O~i4X}W{NDNx6>n&n literal 0 HcmV?d00001 diff --git a/utils/storage/__pycache__/simple_replay_buffer.cpython-38.pyc b/utils/storage/__pycache__/simple_replay_buffer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adecea2d68627a3a6fb1fd94864eee3f55ecd65b GIT binary patch literal 1388 zcmZuwOK;RL5VoCcHr;L?yaW|UTu`N48mR|*K?qT*o>-(lgcKoJY3yCn=7pV==xTdv zIU?~B=&^svS5Ey4oS1QT%eGxdo_L(e%s1c694;+&2&`XQui{OIkUyx*F1RD_VC!`d zf(V+CoJN%5nPn_uyF@s`y&%FpXOSb=J_%aaU`qmO+}${ry#%${b>s6j+)5HrK_Vuo zU>5{DSZ|RlTEYXpC6ViIN%0J#}p*Q1nCF@{#Z{@unAn_9a&$dU8^`Pq%tdc{GB(KT&E*xoMM;< z%(z-+edWPlS41tq>+ezcwajH$q@$6{!VglP)MXVOr178(x68a5)>85PB$Y)X!vm>h z*xUKcKOROqd!IsGmnt4eo|j^n$xz$uxU%U@!f#Zk#^*dO(wcJ&tbGuKdiEJTzt70V z)cPJZSOiPWCLG4(nEVD(E}^hv;*w*aW((Nk48Op37J;}QYz@hSlO^oBAxyd8z<6<0 z$wH_$D3+6eDFkQS0rY^SrxDn0#*c;I_@N7$#!Lw9(q%d~g{f_V@umym-=y#ulo}HN z;Ft6ZJELc8O!vtZ?FEkUR9uKMH(v6!OcQDG298Cimf_E+jYznx?xO5+9b=1Gqz}MS zW2%I%($UL#RL@u|Oza-#n5g9(+p#9c7I5v1Y3(BPb9xTE|HKhuJs2N)60{V?HD2So zq1wsQ!Z?{MCKN|2(2qyZS!LfM)pznxPrwCkZO5ZETAKsD32=*Ir2A%ouxvGU0zC&} z)hehE{wcf~K54@zj=^l52eP&^g9$g`=ht{@2WeaC0v{W_z`AhL;gg|&bC>gP!#Hcc zs6`C7g5nVhOj12V@gIjhA$o=F(t@o~^QrIoE#K|5?LEx>dR1m=aykjMaZ7`dgVF8B ciPQ=yu~o4@&WmLHH_Z^gMG9@xbgA$B1IjWl@&Et; literal 0 HcmV?d00001 diff --git a/utils/storage/__pycache__/storage_module_interface.cpython-38.pyc b/utils/storage/__pycache__/storage_module_interface.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6919991826170e37b0d7af39649b3cbaa9845d7d GIT binary patch literal 1217 zcmb7^&2AGh5P)pk64vl(( zJ_aw)SMrrpUx5=dy9tCsg|OtYXFVQ&^OIZ*23-O}j-IEFJwm>rvOauJ4q%$sAOsP# zAPKHsL6w(y%1?Y1B!LQ(P(?{Z@ysteiY1JaVB%C6`Y;SlYF+|V(gcC1eGQtF}S0was5Mxz^|`>=eU#vdvr<1$|?WDy@p z^U+pye4eMXDn6`KJ%=axB+F%)$@p9v8IMQr_~}J5dV3a|mX50i?~cmz74BYr>3T+5 zz66Pov}6y*2YcqNQ-ZiI%~}lFyLk3DFy%+)%eQuI!0|aw!S#i;>n9M6P>q z>Vg>*t%_Sb_QdG)C_Oi~)S?eDEEp&jQc?et3y{xMu&n}mMd%tPFy4N&Rjq^2( VhN|tK{kHvIVX6%}W^|YKy&q>V8EOCk literal 0 HcmV?d00001 diff --git a/utils/storage/__pycache__/uniform_buffer.cpython-38.pyc b/utils/storage/__pycache__/uniform_buffer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5eedaf4b4acd0f0c5f7031d9013cbdd9277edb6 GIT binary patch literal 1422 zcmZ`(OK%e~5VpOW%_dD!C=Zo~M8pAUFDy3>Al? zbCT3OP1tiHnDCZFcng+zg6@;Z{{?YmFhslFPR_XxvyvoKkc0^;*peik^o9403STsq ze$o&<(G=m5CV^;)Hmo?_0UX@{ zL+B3ez;}-J=taJeL_Se)7Y2qL(eDg(0L|DJX8bu@z!Ny8gQ#IyIkACC zOHmaz7{0Hvp>*UGCNlDBmUw}GQVn!kOmcY*)-{3>;5oav$+8X^BGFYt|D^iyT|<_E zC9+Z%VF>D+gU%5-rZtCCB86jVT*2RfK<%CuJzZ+6kTckb?_M5W43MD$YFz4arLG|O zDZpJk4c->7M%)zxuP^5Qa8?9vZC?FNFUQ}Y5qu=5D<223d z6@?19e5f7tFsbS)8q|G_r(=vIq)-p69*tW=8V0?v?T4P5k9>6*X5sZsi^*qo34AyI TuY+uOzvlD+C37i;)7pOkM6flC literal 0 HcmV?d00001 diff --git a/utils/storage/__pycache__/uniform_game_replay_buffer.cpython-38.pyc b/utils/storage/__pycache__/uniform_game_replay_buffer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86d3c486069d56bcb4ae237c9d8abf6c2891c6ff GIT binary patch literal 1599 zcmZ`(&2H2%5Vn(SHr;N25EZ2q@pFujnj@F0LX@f}7O6iV`aS5{9l6_16svnus1 z?I@89WD9m(6X`+scNl!mQxRnGcr23OolxgxUIZs`G|YqjJT0nHNOs(hMb;O=iBKZw z9eiT%Ps4+w4?$JNi3(Ji%V;Qqs`ii#!9^yGkN+|36c^fKEY9MRF^rnRy9sqo9luz> z%}%prjwFnYAt?l9LPq2Y==nh@>ehsgY)*GUcVNV=*T5;W&=THj{JzDE0^1NlA+k?g-zgQgT{c(8|1zYKwBan?M6i*GA>;lIB^EyMy9o$h)ml> zUig-5n>A#}&|ixsob)2j4Rz}7)4Z2PjbJT2;%dlgPD8zD`O(hwjt5#`G zXz8FfrBleL@0Q6XpK5yUx3nvxjOVF#Rg@NqsGYOZIFqYzTI0kTaBv2>kq-=Nb;lpM z0k)5!DU;xI{A^K>>s$IrXH!tWeFvB?C5Dbp;kjPZ9w0y<2G=z;l-{{qdqYv0#NP(P z2EoEKYmDbGAP&PG(um1%Oynl0J~cs>=-=I|AK_`h4H5&f5p_`2Gcav{>@1?<{;hMc zFxw0)q=D0hoXz@E2)7p;&)`#lKfqR{G>LgH{5E5-&e+Qtyw4XL2=CL__J9w}K7R>c zd4|uQ=?-IQ&Z`9VE@NM+D5+N@22?&kfjJ_#P~aae@yDB?6aNT_Z$n}(pkkhdM?Kf` cI<^PZtUx|IY^?6vuD#tF$sWej1i|zE0o0La^8f$< literal 0 HcmV?d00001 diff --git a/utils/storage/proportional_priority_buffer.py b/utils/storage/proportional_priority_buffer.py new file mode 100644 index 0000000..111ebcc --- /dev/null +++ b/utils/storage/proportional_priority_buffer.py @@ -0,0 +1,82 @@ +import random +import numpy as np +import torch + +''' +* This is the proportional implementation of the priority replay buffer - https://arxiv.org/pdf/1511.05952.pdf +* +* Every time a game is added a priority associated with it has to be passed. +* The higher the priority, the higher the probability it will be sampled from the buffer +* In the original article, this was the temporal difference error. +* +* As training goes by, it's important to update the priority of the games. Call update_priority for that +* +* This implementation allows to control 2 properties: +* 1) priority decay allows for updates to not completely replace the previous priority, but slowly change it using +* the formula: (decay) * new_priority + (1-decay) * previous_priority -- This is not described in the paper +* +* 2) alpha allows to regulate the importance of the priority of a game when calculating its probability +* of being sampled. The conversion of priority to probability is: +* +* probability = (priority ** alpha) / Σ (priority_k ** alpha) +* +''' +class ProportionalPriorityBuffer(): + PriorityKey = "ProportionalBufferPriority" + def __init__(self,max_buffer_size,priority_decay=1,alpha=1,beta=1,minimum_priority=0.001): + self.buffer = [] + self.trainer = None + self.max_buffer_size = max_buffer_size + self.priority_decay = priority_decay + self.alpha = alpha + self.beta = beta + self.minimum_priority = minimum_priority + self.num = 0 + + def add(self,nodes:list,priorities:list=None): + if priorities is None: priorities = [0] * len(nodes) + assert len(nodes) == len(priorities), "the number of nodes and number of priorities inserted needs to be the same" + for i in range(len(nodes)): + self._add_node(nodes[i],priorities[i]) + + def _add_node(self,node,priority=0): + self._update_priority(node,priority) + self.buffer.append(node) + self.num += 1 + if len(self.buffer) > self.max_buffer_size: + self.buffer.pop(0) + assert len(self.buffer) <= self.max_buffer_size + + def sample(self, num=1): + node_priorities = [] + for idx, node in enumerate(self.buffer): + priority = node.info[self.PriorityKey]**self.alpha + self.minimum_priority + node_priorities.append(priority) + node_priorities = np.array(node_priorities,dtype="float32") + node_probabilities = node_priorities / np.sum(node_priorities) + number_of_nodes = min(num,len(self.buffer)) + nodes = np.random.choice(self.buffer, number_of_nodes, p=node_probabilities,replace=False) + return list(nodes) + + def updated_priorities(self,nodes:list,priorities:list): + for i in range(len(nodes)): + self._update_priority(nodes[i],priorities[i]) + + def _update_priority(self,node,priority): + ''' priority is usually the loss or td error...''' + if isinstance(priority,torch.Tensor): priority = priority.item() + if self.PriorityKey not in node.info: + node.info[self.PriorityKey] = priority + else: + node.info[self.PriorityKey] = self.priority_decay * priority + (1-self.priority_decay) * node.info[self.PriorityKey] + + def get_size(self): + return len(self.buffer) + + def get_num_of_added_samples(self): + return self.num + + def __str__(self): + return "ProportionalPrioritizedBuffer"+str(self.max_buffer_size)+"_alpha"+str(self.alpha) + + diff --git a/utils/storage/uniform_buffer.py b/utils/storage/uniform_buffer.py new file mode 100644 index 0000000..774f2ca --- /dev/null +++ b/utils/storage/uniform_buffer.py @@ -0,0 +1,35 @@ +import random +import numpy as np + + +''' +* This is the simplest replay buffer. +* It appends the game to the end of a list +* and uniformly samples a game based on the size of the trajectory +* +* This is the original reasoning of the replay buffer in https://www.nature.com/articles/nature14236 +''' +class UniformBuffer(): + def __init__(self,max_buffer_size): + self.buffer = [] + self.trainer = None + self.max_buffer_size = max_buffer_size + self.num = 0 + + def add(self,nodes:list): + assert isinstance(nodes,list), "Insert a list of nodes to be added" + self.buffer.extend(nodes) + self.num += len(nodes) + while len(self.buffer) > self.max_buffer_size: + self.buffer.pop(0) + assert len(self.buffer) <= self.max_buffer_size + + def sample(self, num=1): + nodes = np.random.choice(self.buffer, num) + return list(nodes) + + def get_size(self): + return len(self.buffer) + + def get_num_of_added_samples(self): + return self.num