From 08188a602843ad95d978d515a7d4183de17bf6d4 Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Sun, 13 Mar 2022 02:47:01 +0000 Subject: [PATCH] fix ppo policy --- examples/deeprl/ant_ppo.jl | 14 +++++++++----- examples/deeprl/cartpole_ppo.jl | 12 ++++++++---- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/examples/deeprl/ant_ppo.jl b/examples/deeprl/ant_ppo.jl index 5359ef5f4..e2a13f463 100644 --- a/examples/deeprl/ant_ppo.jl +++ b/examples/deeprl/ant_ppo.jl @@ -28,13 +28,17 @@ function RL.Experiment( agent = Agent( policy = PPOPolicy( approximator = ActorCritic( - actor = Chain( - Dense(ns, 256, relu; init = glorot_uniform(rng)), - Dense(256, na; init = glorot_uniform(rng)), - ), + actor = GaussianNetwork( + pre = Chain( + Dense(ns, 64, relu; init = glorot_uniform(rng)), + Dense(64, 64, relu; init = glorot_uniform(rng)), + ), + μ = Chain(Dense(64, na, tanh; init = glorot_uniform(rng)), vec), + logσ = Chain(Dense(64, na; init = glorot_uniform(rng)), vec), + ), critic = Chain( Dense(ns, 256, relu; init = glorot_uniform(rng)), - Dense(256, 1; init = glorot_uniform(rng)), + Dense(256, na; init = glorot_uniform(rng)), ), optimizer = ADAM(1e-3), ), diff --git a/examples/deeprl/cartpole_ppo.jl b/examples/deeprl/cartpole_ppo.jl index ebfd86516..7ba8bb5c0 100644 --- a/examples/deeprl/cartpole_ppo.jl +++ b/examples/deeprl/cartpole_ppo.jl @@ -28,10 +28,14 @@ function RL.Experiment( agent = Agent( policy = PPOPolicy( approximator = ActorCritic( - actor = Chain( - Dense(ns, 256, relu; init = glorot_uniform(rng)), - Dense(256, na; init = glorot_uniform(rng)), - ), + actor = GaussianNetwork( + pre = Chain( + Dense(ns, 64, relu; init = glorot_uniform(rng)), + Dense(64, 64, relu; init = glorot_uniform(rng)), + ), + μ = Chain(Dense(64, na, tanh; init = glorot_uniform(rng)), vec), + logσ = Chain(Dense(64, na; init = glorot_uniform(rng)), vec), + ), critic = Chain( Dense(ns, 256, relu; init = glorot_uniform(rng)), Dense(256, 1; init = glorot_uniform(rng)),