Skip to content

Commit

Permalink
fix space related definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
findmyway committed Mar 12, 2022
1 parent b606aa1 commit f79e9e3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 32 deletions.
52 changes: 23 additions & 29 deletions environments/environment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ mutable struct Environment{X,T,M,A,O,I}
dynamics_jacobian_state::Matrix{T}
dynamics_jacobian_input::Matrix{T}
input_previous::Vector{T}
control_map::Matrix{T}
control_map::Matrix{T}
num_states::Int
num_inputs::Int
num_observations::Int
Expand Down Expand Up @@ -66,33 +66,33 @@ end
attitude_decompress: flag for pre- and post-concatenating Jacobians with attitude Jacobians
"""
function Base.step(env::Environment, x, u;
gradients=false,
attitude_decompress=false)
gradients = false,
attitude_decompress = false)

mechanism = env.mechanism
timestep= mechanism.timestep
timestep = mechanism.timestep

x0 = x
# u = clip(env.input_space, u) # control limits
env.input_previous .= u # for rendering in Gym
u_scaled = env.control_map * u
u_scaled = env.control_map * u

z0 = env.representation == :minimal ? minimal_to_maximal(mechanism, x0) : x0
z1 = step!(mechanism, z0, u_scaled; opts=env.opts_step)
z1 = step!(mechanism, z0, u_scaled; opts = env.opts_step)
env.state .= env.representation == :minimal ? maximal_to_minimal(mechanism, z1) : z1

# Compute cost
costs = cost(env, x, u)

# Check termination
done = is_done(env, x)
# Check termination
done = is_done(env, x)

# Gradients
if gradients
if env.representation == :minimal
fx, fu = get_minimal_gradients!(env.mechanism, z0, u_scaled, opts=env.opts_grad)
fx, fu = get_minimal_gradients!(env.mechanism, z0, u_scaled, opts = env.opts_grad)
elseif env.representation == :maximal
fx, fu = get_maximal_gradients!(env.mechanism, z0, u_scaled, opts=env.opts_grad)
fx, fu = get_maximal_gradients!(env.mechanism, z0, u_scaled, opts = env.opts_grad)
if attitude_decompress
A0 = attitude_jacobian(z0, length(env.mechanism.bodies))
A1 = attitude_jacobian(z1, length(env.mechanism.bodies))
Expand All @@ -109,11 +109,11 @@ function Base.step(env::Environment, x, u;
end

function Base.step(env::Environment, u;
gradients=false,
attitude_decompress=false)
step(env, env.state, u;
gradients=gradients,
attitude_decompress=attitude_decompress)
gradients = false,
attitude_decompress = false)
step(env, env.state, u;
gradients = gradients,
attitude_decompress = attitude_decompress)
end

"""
Expand Down Expand Up @@ -156,7 +156,7 @@ is_done(env::Environment, x) = false
x: state
"""
function Base.reset(env::Environment{X};
x=nothing) where X
x = nothing) where {X}

initialize!(env.mechanism, type2symbol(X))
if x != nothing
Expand All @@ -172,14 +172,14 @@ function Base.reset(env::Environment{X};
return get_observation(env)
end

function MeshCat.render(env::Environment,
mode="human")
function MeshCat.render(env::Environment,
mode = "human")
z = env.representation == :minimal ? minimal_to_maximal(env.mechanism, env.state) : env.state
set_robot(env.vis, env.mechanism, z, name=:robot)
set_robot(env.vis, env.mechanism, z, name = :robot)
return nothing
end

function seed(env::Environment, s=0)
function seed(env::Environment, s = 0)
env.rng[1] = MersenneTwister(s)
return nothing
end
Expand Down Expand Up @@ -214,26 +214,20 @@ mutable struct BoxSpace{T,N} <: Space{T,N}
dtype::DataType # this is always T, it's needed to interface with Stable-Baselines
end

function BoxSpace(n::Int; low::AbstractVector{T} = -ones(n), high::AbstractVector{T} = ones(n)) where T
function BoxSpace(n::Int; low::AbstractVector{T} = -ones(n), high::AbstractVector{T} = ones(n)) where {T}
return BoxSpace{T,n}(n, low, high, (n,), T)
end

function sample(s::BoxSpace{T,N}) where {T,N}
return rand(T,N) .* (s.high .- s.low) .+ s.low
return rand(T, N) .* (s.high .- s.low) .+ s.low
end

function contains(s::BoxSpace{T,N}, v::AbstractVector{T}) where {T,N}
all(v .>= s.low) && all(v .<= s.high)
end

# For compat with RLBase
Base.length(s::BoxSpace) = s.n
Base.in(v::AbstractVector{T}, s::BoxSpace{T,N}) where {T,N} = all(v .>= s.low) && all(v .<= s.high)
Random.rand(rng::Random.AbstractRNG, s::BoxSpace{T,N}) where {T,N} = return rand(rng, T,N) .* (s.high .- s.low) .+ s.low

function clip(s::BoxSpace, u)
clamp.(u, s.low, s.high)
end



Random.rand(rng::Random.AbstractRNG, s::BoxSpace{T,N}) where {T,N} = return rand(rng, T, N) .* (s.high .- s.low) .+ s.low
10 changes: 7 additions & 3 deletions environments/rlenv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ function DojoRLEnv(name::String; kwargs...)
DojoRLEnv(Dojo.get_environment(name; kwargs...))
end

RLBase.action_space(env::DojoRLEnv) = env.dojoenv.input_space
RLBase.state_space(env::DojoRLEnv) = env.dojoenv.observation_space
function Base.convert(::Type{RLBase.Space}, s::BoxSpace)
RLBase.Space([BoxSpace(1; low = s.low[i:i], high = s.high[i:i]) for i in 1:s.n])
end

RLBase.action_space(env::DojoRLEnv) = convert(RLBase.Space, env.dojoenv.input_space)
RLBase.state_space(env::DojoRLEnv) = convert(RLBase.Space, env.dojoenv.observation_space)
RLBase.is_terminated(env::DojoRLEnv) = env.done

RLBase.reset!(env::DojoRLEnv) = reset(env.dojoenv)
Expand All @@ -39,4 +43,4 @@ function (env::DojoRLEnv)(a)
env.info = i
return nothing
end
(env::DojoRLEnv)(a::AbstractFloat) = env([a])
(env::DojoRLEnv)(a::Number) = env([a])

0 comments on commit f79e9e3

Please sign in to comment.