Skip to content

Commit

Permalink
add simple cartpole example; issue is with RL.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
rejuvyesh committed Mar 11, 2022
1 parent 65ef927 commit 5b21803
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 1 deletion.
3 changes: 2 additions & 1 deletion environments/rlenv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ function (env::DojoRLEnv)(a)
env.done = d
env.info = i
return nothing
end
end
(env::DojoRLEnv)(a::AbstractFloat) = env([a])
88 changes: 88 additions & 0 deletions examples/deeprl/cartpole_ddpg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
using ReinforcementLearning
using Flux
using Flux.Losses

using Random
using Dojo

function RL.Experiment(
::Val{:JuliaRL},
::Val{:DDPG},
::Val{:DojoCartpole},
::Nothing,
save_dir = nothing,
seed = 42
)

rng = MersenneTwister(seed)
inner_env = Dojo.DojoRLEnv("cartpole")
Random.seed!(inner_env, seed)
# TODO
low = -5.0
high = 5.0
ns, na = length(state(inner_env)), length(action_space(inner_env))
@show na
A = Dojo.BoxSpace(na)
env = ActionTransformedEnv(
inner_env;
action_mapping = x -> low .+ (x .+ 1) .* 0.5 .* (high .- low),
action_space_mapping = _ -> A
)

init = glorot_uniform(rng)

create_actor() = Chain(
Dense(ns, 30, relu; init = init),
Dense(30, 30, relu; init = init),
Dense(30, na, tanh; init = init),
)
create_critic() = Chain(
Dense(ns + na, 30, relu; init = init),
Dense(30, 30, relu; init = init),
Dense(30, 1; init = init),
)

agent = Agent(
policy = DDPGPolicy(
behavior_actor = NeuralNetworkApproximator(
model = create_actor(),
optimizer = ADAM(),
),
behavior_critic = NeuralNetworkApproximator(
model = create_critic(),
optimizer = ADAM(),
),
target_actor = NeuralNetworkApproximator(
model = create_actor(),
optimizer = ADAM(),
),
target_critic = NeuralNetworkApproximator(
model = create_critic(),
optimizer = ADAM(),
),
γ = 0.99f0,
ρ = 0.995f0,
na = na,
batch_size = 64,
start_steps = 1000,
start_policy = RandomPolicy(A; rng = rng),
update_after = 1000,
update_freq = 1,
act_limit = 1.0,
act_noise = 0.1,
rng = rng,
),
trajectory = CircularArraySARTTrajectory(
capacity = 10000,
state = Vector{Float32} => (ns,),
action = Float32 => (na, ),
),
)

stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "# Dojo Cartpole with DDPG")
end

ex = E`JuliaRL_DDPG_DojoCartpole`
run(ex)
67 changes: 67 additions & 0 deletions examples/deeprl/cartpole_ppo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using ReinforcementLearning
using Flux
using Flux.Losses

using Random
using Dojo

function RL.Experiment(
::Val{:JuliaRL},
::Val{:PPO},
::Val{:DojoCartpole},
::Nothing,
save_dir = nothing,
seed = 42
)
rng = MersenneTwister(seed)
N_ENV = 6
UPDATE_FREQ = 32
env_vec = [Dojo.DojoRLEnv("cartpole") for i in 1:N_ENV]
for i in 1:N_ENV
Random.seed!(env_vec[i], hash(seed+i))
end
env = MultiThreadEnv(env_vec)

ns, na = length(state(env[1])), length(action_space(env[1]))
RLBase.reset!(env; is_force=true)

agent = Agent(
policy = PPOPolicy(
approximator = ActorCritic(
actor = Chain(
Dense(ns, 256, relu; init = glorot_uniform(rng)),
Dense(256, na; init = glorot_uniform(rng)),
),
critic = Chain(
Dense(ns, 256, relu; init = glorot_uniform(rng)),
Dense(256, 1; init = glorot_uniform(rng)),
),
optimizer = ADAM(1e-3),
),
γ = 0.99f0,
λ = 0.95f0,
clip_range = 0.1f0,
max_grad_norm = 0.5f0,
n_epochs = 4,
n_microbatches = 4,
actor_loss_weight = 1.0f0,
critic_loss_weight = 0.5f0,
entropy_loss_weight = 0.001f0,
update_freq = UPDATE_FREQ,
),
trajectory = PPOTrajectory(;
capacity = UPDATE_FREQ,
state = Matrix{Float32} => (ns, N_ENV),
action = Vector{Int} => (N_ENV,),
action_log_prob = Vector{Float32} => (N_ENV,),
reward = Vector{Float32} => (N_ENV,),
terminal = Vector{Bool} => (N_ENV,),
),
)
stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
hook = TotalBatchRewardPerEpisode(N_ENV)
Experiment(agent, env, stop_condition, hook, "# PPO with Dojo Cartpole")
end

ex = E`JuliaRL_PPO_DojoCartpole`
run(ex)

0 comments on commit 5b21803

Please sign in to comment.