Skip to content

Commit

Permalink
fix ppo policy
Browse files Browse the repository at this point in the history
  • Loading branch information
rejuvyesh committed Mar 13, 2022
1 parent b7b11e0 commit 08188a6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
14 changes: 9 additions & 5 deletions examples/deeprl/ant_ppo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@ function RL.Experiment(
agent = Agent(
policy = PPOPolicy(
approximator = ActorCritic(
actor = Chain(
Dense(ns, 256, relu; init = glorot_uniform(rng)),
Dense(256, na; init = glorot_uniform(rng)),
),
actor = GaussianNetwork(
pre = Chain(
Dense(ns, 64, relu; init = glorot_uniform(rng)),
Dense(64, 64, relu; init = glorot_uniform(rng)),
),
μ = Chain(Dense(64, na, tanh; init = glorot_uniform(rng)), vec),
logσ = Chain(Dense(64, na; init = glorot_uniform(rng)), vec),
),
critic = Chain(
Dense(ns, 256, relu; init = glorot_uniform(rng)),
Dense(256, 1; init = glorot_uniform(rng)),
Dense(256, na; init = glorot_uniform(rng)),
),
optimizer = ADAM(1e-3),
),
Expand Down
12 changes: 8 additions & 4 deletions examples/deeprl/cartpole_ppo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@ function RL.Experiment(
agent = Agent(
policy = PPOPolicy(
approximator = ActorCritic(
actor = Chain(
Dense(ns, 256, relu; init = glorot_uniform(rng)),
Dense(256, na; init = glorot_uniform(rng)),
),
actor = GaussianNetwork(
pre = Chain(
Dense(ns, 64, relu; init = glorot_uniform(rng)),
Dense(64, 64, relu; init = glorot_uniform(rng)),
),
μ = Chain(Dense(64, na, tanh; init = glorot_uniform(rng)), vec),
logσ = Chain(Dense(64, na; init = glorot_uniform(rng)), vec),
),
critic = Chain(
Dense(ns, 256, relu; init = glorot_uniform(rng)),
Dense(256, 1; init = glorot_uniform(rng)),
Expand Down

0 comments on commit 08188a6

Please sign in to comment.