Skip to content

Commit

Permalink
add an example
Browse files Browse the repository at this point in the history
  • Loading branch information
rejuvyesh committed Mar 10, 2022
1 parent 9736031 commit c656da1
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 4 deletions.
5 changes: 3 additions & 2 deletions environments/environment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ function MeshCat.render(env::Environment,
return nothing
end

function seed(env::Environment; s=0)
env.rng[1] = MersenneTwister(seed)
function seed(env::Environment, s=0)
env.rng[1] = MersenneTwister(s)
return nothing
end

Expand Down Expand Up @@ -227,6 +227,7 @@ function contains(s::BoxSpace{T,N}, v::AbstractVector{T}) where {T,N}
end

# For compat with RLBase
Base.length(s::BoxSpace) = s.n
Base.in(v::AbstractVector{T}, s::BoxSpace{T,N}) where {T,N} = all(v .>= s.low) && all(v .<= s.high)
Random.rand(rng::Random.AbstractRNG, s::BoxSpace{T,N}) where {T,N} = return rand(rng, T,N) .* (s.high .- s.low) .+ s.low

Expand Down
4 changes: 2 additions & 2 deletions environments/rlenv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ RLBase.is_terminated(env::DojoRLEnv) = env.done

RLBase.reset!(env::DojoRLEnv) = reset(env.dojoenv)

RLBase.reward(env::DojoRLEnv) = error()
RLBase.reward(env::DojoRLEnv) = env.reward
RLBase.state(env::DojoRLEnv) = env.state

Random.seed!(env::DojoRLEnv, seed) = Dojo.seed(env.dojoenv, seed)
Expand All @@ -33,7 +33,7 @@ Random.seed!(env::DojoRLEnv, seed) = Dojo.seed(env.dojoenv, seed)

function (env::DojoRLEnv)(a)
s, r, d, i = step(env.dojoenv, a)
env.state = s
env.state .= s
env.reward = r
env.done = d
env.info = i
Expand Down
5 changes: 5 additions & 0 deletions examples/deeprl/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[deps]
Dojo = "ac60b53e-8d92-4c83-b960-e78698fa1916"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReinforcementLearning = "158674fc-8238-5cab-b5ba-03dfc80d1318"
67 changes: 67 additions & 0 deletions examples/deeprl/ant_ppo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using ReinforcementLearning
using Flux
using Flux.Losses

using Random
using Dojo

function RL.Experiment(
::Val{:JuliaRL},
::Val{:PPO},
::Val{:DojoAnt},
::Nothing,
save_dir = nothing,
seed = 42
)
rng = MersenneTwister(seed)
N_ENV = 6
UPDATE_FREQ = 32
env_vec = [Dojo.DojoRLEnv("ant") for i in 1:N_ENV]
for i in 1:N_ENV
Random.seed!(env_vec[i], hash(seed+i))
end
env = MultiThreadEnv(env_vec; is_force=true)

ns, na = length(state(env[1])), length(action_space(env[1]))
RLBase.reset!(env)

agent = Agent(
policy = PPOPolicy(
approximator = ActorCritic(
actor = Chain(
Dense(ns, 256, relu; init = glorot_uniform(rng)),
Dense(256, na; init = glorot_uniform(rng)),
),
critic = Chain(
Dense(ns, 256, relu; init = glorot_uniform(rng)),
Dense(256, 1; init = glorot_uniform(rng)),
),
optimizer = ADAM(1e-3),
),
γ = 0.99f0,
λ = 0.95f0,
clip_range = 0.1f0,
max_grad_norm = 0.5f0,
n_epochs = 4,
n_microbatches = 4,
actor_loss_weight = 1.0f0,
critic_loss_weight = 0.5f0,
entropy_loss_weight = 0.001f0,
update_freq = UPDATE_FREQ,
),
trajectory = PPOTrajectory(;
capacity = UPDATE_FREQ,
state = Matrix{Float32} => (ns, N_ENV),
action = Vector{Int} => (N_ENV,),
action_log_prob = Vector{Float32} => (N_ENV,),
reward = Vector{Float32} => (N_ENV,),
terminal = Vector{Bool} => (N_ENV,),
),
)
stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
hook = TotalBatchRewardPerEpisode(N_ENV)
Experiment(agent, env, stop_condition, hook, "# PPO with Dojo Ant")
end

ex = E`JuliaRL_PPO_DojoAnt`
run(ex)

0 comments on commit c656da1

Please sign in to comment.