diff --git a/examples/deeprl/ant_ppo.jl b/examples/deeprl/ant_ppo.jl index 5359ef5f4..e2a13f463 100644 --- a/examples/deeprl/ant_ppo.jl +++ b/examples/deeprl/ant_ppo.jl @@ -28,13 +28,17 @@ function RL.Experiment( agent = Agent( policy = PPOPolicy( approximator = ActorCritic( - actor = Chain( - Dense(ns, 256, relu; init = glorot_uniform(rng)), - Dense(256, na; init = glorot_uniform(rng)), - ), + actor = GaussianNetwork( + pre = Chain( + Dense(ns, 64, relu; init = glorot_uniform(rng)), + Dense(64, 64, relu; init = glorot_uniform(rng)), + ), + μ = Chain(Dense(64, na, tanh; init = glorot_uniform(rng)), vec), + logσ = Chain(Dense(64, na; init = glorot_uniform(rng)), vec), + ), critic = Chain( Dense(ns, 256, relu; init = glorot_uniform(rng)), - Dense(256, 1; init = glorot_uniform(rng)), + Dense(256, na; init = glorot_uniform(rng)), ), optimizer = ADAM(1e-3), ), diff --git a/examples/deeprl/cartpole_ppo.jl b/examples/deeprl/cartpole_ppo.jl index ebfd86516..7ba8bb5c0 100644 --- a/examples/deeprl/cartpole_ppo.jl +++ b/examples/deeprl/cartpole_ppo.jl @@ -28,10 +28,14 @@ function RL.Experiment( agent = Agent( policy = PPOPolicy( approximator = ActorCritic( - actor = Chain( - Dense(ns, 256, relu; init = glorot_uniform(rng)), - Dense(256, na; init = glorot_uniform(rng)), - ), + actor = GaussianNetwork( + pre = Chain( + Dense(ns, 64, relu; init = glorot_uniform(rng)), + Dense(64, 64, relu; init = glorot_uniform(rng)), + ), + μ = Chain(Dense(64, na, tanh; init = glorot_uniform(rng)), vec), + logσ = Chain(Dense(64, na; init = glorot_uniform(rng)), vec), + ), critic = Chain( Dense(ns, 256, relu; init = glorot_uniform(rng)), Dense(256, 1; init = glorot_uniform(rng)),