From b7758f6257af628aae9cf2343a7a96a0dbd3890e Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Tue, 31 May 2022 01:12:26 +0800 Subject: [PATCH 01/62] feat(pix-obs dmc): add init for visualization --- envpool/mujoco/dmc/mujoco_env.cc | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index b4d4faa0..37953391 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -52,6 +52,15 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, // create model and data model_ = mj_loadXML(model_filename.c_str(), vfs.get(), error_.begin(), 1000); data_ = mj_makeData(model_); + // create visualization + mjv_defaultCamera(&cam); + mjv_defaultOption(&opt); + mjv_defaultScene(&scn); + mjr_defaultContext(&con); + + // create scene and context + mjv_makeScene(m, &scn, 2000); + mjr_makeContext(m, &con, 200); #ifdef ENVPOOL_TEST qpos0_.reset(new mjtNum[model_->nq]); #endif @@ -60,6 +69,8 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, MujocoEnv::~MujocoEnv() { mj_deleteModel(model_); mj_deleteData(data_); + mjr_freeContext(&con); + mjv_freeScene(&scn); } // rl control Environment @@ -159,6 +170,26 @@ void MujocoEnv::PhysicsStep(int nstep, const mjtNum* action) { mj_step1(model_, data_); } +// https://github.com/deepmind/dm_control/blob/1.0.2/dm_control/mujoco/engine.py#L165 +void MujocoEnv::PhysicsRender(height = 240, width = 320, camera_id = -1, + overlays = (), depth = False, + segmentation = False, scene_option = None, + render_flag_overrides = None, ) { + if (keyframe_id < 0) { + mj_resetData(model_, data_); + } else { + // actually no one steps to this line + assert(keyframe_id < model_->nkey); + mj_resetDataKeyframe(model_, data_, keyframe_id); + } + + // PhysicsAfterReset may be overwritten? + int old_flags = model_->opt.disableflags; + model_->opt.disableflags |= mjDSBL_ACTUATION; + PhysicsForward(); + model_->opt.disableflags = old_flags; +} + // randomizer // https://github.com/deepmind/dm_control/blob/1.0.2/dm_control/suite/utils/randomizers.py#L35 void MujocoEnv::RandomizeLimitedAndRotationalJoints(std::mt19937* gen) { From 15250245e547780d7ad05039000e990e2d4c4429 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Tue, 31 May 2022 21:11:59 +0800 Subject: [PATCH 02/62] feat(pix-obs dmc): init Render method --- envpool/mujoco/dmc/mujoco_env.cc | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index 37953391..e797000e 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -52,7 +52,7 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, // create model and data model_ = mj_loadXML(model_filename.c_str(), vfs.get(), error_.begin(), 1000); data_ = mj_makeData(model_); - // create visualization + // init visualization mjv_defaultCamera(&cam); mjv_defaultOption(&opt); mjv_defaultScene(&scn); @@ -60,7 +60,14 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, // create scene and context mjv_makeScene(m, &scn, 2000); - mjr_makeContext(m, &con, 200); + mjr_makeContext(m, &con, 200); + + // center and scale view + cam.lookat[0] = m->stat.center[0]; + cam.lookat[1] = m->stat.center[1]; + cam.lookat[2] = m->stat.center[2]; + cam.distance = 1.5 * m->stat.extent; + #ifdef ENVPOOL_TEST qpos0_.reset(new mjtNum[model_->nq]); #endif From fecae8e8eebb294c0caee765f0218f4ab1ce1dee Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Fri, 3 Jun 2022 23:37:48 +0800 Subject: [PATCH 03/62] feat(pix-obs dmc): physicsrender method finished without segmentation --- envpool/mujoco/dmc/mujoco_env.cc | 206 ++++++++++++++++++++++++++----- envpool/mujoco/dmc/mujoco_env.h | 11 ++ 2 files changed, 185 insertions(+), 32 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index e797000e..b8857960 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -23,11 +23,21 @@ namespace mujoco_dmc { MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, - int n_sub_steps, int max_episode_steps) + int n_sub_steps, int max_episode_steps, + int height, int width, + const std::string& camera_id, + bool depth, bool segmentation, + ) : n_sub_steps_(n_sub_steps), max_episode_steps_(max_episode_steps), elapsed_step_(max_episode_steps + 1), - done_(true) { + done_(true), + height_(240), + width_(320), + camera_id_("-1"), + depth_(false), + segmentation_(false) { + initOpenGL(); // initialize vfs from common assets and raw xml // https://github.com/deepmind/dm_control/blob/1.0.2/dm_control/mujoco/wrapper/core.py#L158 // https://github.com/deepmind/mujoco/blob/main/python/mujoco/structs.cc @@ -52,21 +62,32 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, // create model and data model_ = mj_loadXML(model_filename.c_str(), vfs.get(), error_.begin(), 1000); data_ = mj_makeData(model_); + // MuJoCo visualization + mjvScene scene_; + mjvCamera camera_; + mjvOption option_; + mjrContext context_; // init visualization - mjv_defaultCamera(&cam); - mjv_defaultOption(&opt); - mjv_defaultScene(&scn); - mjr_defaultContext(&con); + mjv_defaultCamera(&camera_); + mjv_defaultOption(&option_); + mjv_defaultScene(&scene_); + mjr_defaultContext(&context_); // create scene and context - mjv_makeScene(m, &scn, 2000); - mjr_makeContext(m, &con, 200); - - // center and scale view - cam.lookat[0] = m->stat.center[0]; - cam.lookat[1] = m->stat.center[1]; - cam.lookat[2] = m->stat.center[2]; - cam.distance = 1.5 * m->stat.extent; + // void mjv_makeScene(const mjModel* m, mjvScene* scn, int maxgeom); + mjv_makeScene(model_, &scene_, 2000); + // void mjr_makeContext(const mjModel* m, mjrContext* con, int fontscale); + mjr_makeContext(model_, &context_, 200); + // set rendering to offscreen buffer + mjr_setBuffer(mjFB_OFFSCREEN, &context_); + // allocate rgb and depth buffers + unsigned char* rgb_array_ = (unsigned char*)std::malloc(3*width_*height_); + float* depth_array_ = (float*)std::malloc(sizeof(float)*width_*height_); + // camera configuration + // cam.lookat[0] = m->stat.center[0]; + // cam.lookat[1] = m->stat.center[1]; + // cam.lookat[2] = m->stat.center[2]; + // cam.distance = 1.5 * m->stat.extent; #ifdef ENVPOOL_TEST qpos0_.reset(new mjtNum[model_->nq]); @@ -76,8 +97,10 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, MujocoEnv::~MujocoEnv() { mj_deleteModel(model_); mj_deleteData(data_); - mjr_freeContext(&con); - mjv_freeScene(&scn); + mjr_freeContext(&context_); + mjv_freeScene(&scene_); + closeOpenGL(); + } // rl control Environment @@ -178,23 +201,20 @@ void MujocoEnv::PhysicsStep(int nstep, const mjtNum* action) { } // https://github.com/deepmind/dm_control/blob/1.0.2/dm_control/mujoco/engine.py#L165 -void MujocoEnv::PhysicsRender(height = 240, width = 320, camera_id = -1, - overlays = (), depth = False, - segmentation = False, scene_option = None, - render_flag_overrides = None, ) { - if (keyframe_id < 0) { - mj_resetData(model_, data_); - } else { - // actually no one steps to this line - assert(keyframe_id < model_->nkey); - mj_resetDataKeyframe(model_, data_, keyframe_id); - } +void MujocoEnv::PhysicsRender(int height, int width, + const std::string& camera_id, + bool depth, bool segmentation) { + // update abstract scene + mjv_updateScene(model_, data_, &option_, NULL, &camera_, mjCAT_ALL, &scene_); + mjrRect viewport = {0, 0, width_, height_}; + // render scene in offscreen buffer + mjr_render(viewport, &scene_, &context_); - // PhysicsAfterReset may be overwritten? - int old_flags = model_->opt.disableflags; - model_->opt.disableflags |= mjDSBL_ACTUATION; - PhysicsForward(); - model_->opt.disableflags = old_flags; + // read rgb and depth buffers + mjr_readPixels(rgb_array_, depth_array_, viewport, &context_); + + // segmentation results not implemented + return {rgb_array_, depth_array_, segmentation_array_} } // randomizer @@ -241,5 +261,127 @@ void MujocoEnv::RandomizeLimitedAndRotationalJoints(std::mt19937* gen) { } } } +// create OpenGL context/window +void initOpenGL(void) { + //------------------------ EGL +#if defined(MJ_EGL) + // desired config + const EGLint configAttribs[] = { + EGL_RED_SIZE, 8, + EGL_GREEN_SIZE, 8, + EGL_BLUE_SIZE, 8, + EGL_ALPHA_SIZE, 8, + EGL_DEPTH_SIZE, 24, + EGL_STENCIL_SIZE, 8, + EGL_COLOR_BUFFER_TYPE, EGL_RGB_BUFFER, + EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, + EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT, + EGL_NONE + }; + + // get default display + EGLDisplay eglDpy = eglGetDisplay(EGL_DEFAULT_DISPLAY); + if (eglDpy==EGL_NO_DISPLAY) { + mju_error_i("Could not get EGL display, error 0x%x\n", eglGetError()); + } + + // initialize + EGLint major, minor; + if (eglInitialize(eglDpy, &major, &minor)!=EGL_TRUE) { + mju_error_i("Could not initialize EGL, error 0x%x\n", eglGetError()); + } + + // choose config + EGLint numConfigs; + EGLConfig eglCfg; + if (eglChooseConfig(eglDpy, configAttribs, &eglCfg, 1, &numConfigs)!=EGL_TRUE) { + mju_error_i("Could not choose EGL config, error 0x%x\n", eglGetError()); + } + + // bind OpenGL API + if (eglBindAPI(EGL_OPENGL_API)!=EGL_TRUE) { + mju_error_i("Could not bind EGL OpenGL API, error 0x%x\n", eglGetError()); + } + + // create context + EGLContext eglCtx = eglCreateContext(eglDpy, eglCfg, EGL_NO_CONTEXT, NULL); + if (eglCtx==EGL_NO_CONTEXT) { + mju_error_i("Could not create EGL context, error 0x%x\n", eglGetError()); + } + + // make context current, no surface (let OpenGL handle FBO) + if (eglMakeCurrent(eglDpy, EGL_NO_SURFACE, EGL_NO_SURFACE, eglCtx)!=EGL_TRUE) { + mju_error_i("Could not make EGL context current, error 0x%x\n", eglGetError()); + } + + //------------------------ OSMESA +#elif defined(MJ_OSMESA) + // create context + ctx = OSMesaCreateContextExt(GL_RGBA, 24, 8, 8, 0); + if (!ctx) { + mju_error("OSMesa context creation failed"); + } + + // make current + if (!OSMesaMakeCurrent(ctx, buffer, GL_UNSIGNED_BYTE, 800, 800)) { + mju_error("OSMesa make current failed"); + } + + //------------------------ GLFW +#else + // init GLFW + if (!glfwInit()) { + mju_error("Could not initialize GLFW"); + } + + // create invisible window, single-buffered + glfwWindowHint(GLFW_VISIBLE, 0); + glfwWindowHint(GLFW_DOUBLEBUFFER, GLFW_FALSE); + GLFWwindow* window = glfwCreateWindow(800, 800, "Invisible window", NULL, NULL); + if (!window) { + mju_error("Could not create GLFW window"); + } + + // make context current + glfwMakeContextCurrent(window); +#endif +} + + +// close OpenGL context/window +void closeOpenGL(void) { + //------------------------ EGL +#if defined(MJ_EGL) + // get current display + EGLDisplay eglDpy = eglGetCurrentDisplay(); + if (eglDpy==EGL_NO_DISPLAY) { + return; + } + + // get current context + EGLContext eglCtx = eglGetCurrentContext(); + + // release context + eglMakeCurrent(eglDpy, EGL_NO_SURFACE, EGL_NO_SURFACE, EGL_NO_CONTEXT); + // destroy context if valid + if (eglCtx!=EGL_NO_CONTEXT) { + eglDestroyContext(eglDpy, eglCtx); + } + + // terminate display + eglTerminate(eglDpy); + + //------------------------ OSMESA +#elif defined(MJ_OSMESA) + OSMesaDestroyContext(ctx); + + //------------------------ GLFW +#else + // terminate GLFW (crashes with Linux NVidia drivers) + #if defined(__APPLE__) || defined(_WIN32) + glfwTerminate(); + #endif +#endif +} } // namespace mujoco_dmc diff --git a/envpool/mujoco/dmc/mujoco_env.h b/envpool/mujoco/dmc/mujoco_env.h index 4e4434a1..591e4a66 100644 --- a/envpool/mujoco/dmc/mujoco_env.h +++ b/envpool/mujoco/dmc/mujoco_env.h @@ -19,7 +19,18 @@ #include #include +// select EGL, OSMESA or GLFW +#if defined(MJ_EGL) + #include +#elif defined(MJ_OSMESA) + #include + OSMesaContext ctx; + unsigned char buffer[10000000]; +#else + #include +#endif +#include "array_safety.h" #include #include #include From 781ca9790023a5c5813acace6dcf01b32a324d07 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Sat, 4 Jun 2022 23:17:16 +0800 Subject: [PATCH 04/62] feat(pix-obs dmc): physicsrender method finished without segmentation --- envpool/mujoco/dmc/mujoco_env.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index b8857960..c81fc987 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -214,6 +214,7 @@ void MujocoEnv::PhysicsRender(int height, int width, mjr_readPixels(rgb_array_, depth_array_, viewport, &context_); // segmentation results not implemented + return {rgb_array_, depth_array_, segmentation_array_} } From adb239262add158f805e24e95963279a751bdc1a Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Sun, 5 Jun 2022 02:47:07 +0800 Subject: [PATCH 05/62] feat(pix-obs dmc): init bazel BUILD --- third_party/mujoco/mujoco.BUILD | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/mujoco/mujoco.BUILD b/third_party/mujoco/mujoco.BUILD index 36f1d112..248ce99c 100644 --- a/third_party/mujoco/mujoco.BUILD +++ b/third_party/mujoco/mujoco.BUILD @@ -3,7 +3,8 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "mujoco_lib", srcs = glob(["lib/*"]), - hdrs = glob(["include/mujoco/*.h"]), + hdrs = glob(["include/mujoco/*.h", + "sample/*.h"]), includes = [ "include", "include/mujoco", From b0b3e51f05015887b30d5b795d7d55967397f810 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Sun, 12 Jun 2022 11:00:54 +0800 Subject: [PATCH 06/62] fix(pix-obs dmc): merge upstream --- envpool/mujoco/dmc/mujoco_env.cc | 88 +++++++++++++++++--------------- envpool/mujoco/dmc/mujoco_env.h | 12 ++--- third_party/mujoco/mujoco.BUILD | 6 ++- 3 files changed, 58 insertions(+), 48 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index c81fc987..e23aebbe 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -23,11 +23,9 @@ namespace mujoco_dmc { MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, - int n_sub_steps, int max_episode_steps, - int height, int width, - const std::string& camera_id, - bool depth, bool segmentation, - ) + int n_sub_steps, int max_episode_steps, int height, + int width, const std::string& camera_id, bool depth, + bool segmentation, ) : n_sub_steps_(n_sub_steps), max_episode_steps_(max_episode_steps), elapsed_step_(max_episode_steps + 1), @@ -81,14 +79,15 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, // set rendering to offscreen buffer mjr_setBuffer(mjFB_OFFSCREEN, &context_); // allocate rgb and depth buffers - unsigned char* rgb_array_ = (unsigned char*)std::malloc(3*width_*height_); - float* depth_array_ = (float*)std::malloc(sizeof(float)*width_*height_); + unsigned char* rgb_array_ = (unsigned char*)std::malloc(3 * width_ * height_); + auto* depth_array_ = + (reinterpret_cast)std::malloc(sizeof(float) * width_ * height_); // camera configuration // cam.lookat[0] = m->stat.center[0]; // cam.lookat[1] = m->stat.center[1]; // cam.lookat[2] = m->stat.center[2]; // cam.distance = 1.5 * m->stat.extent; - + #ifdef ENVPOOL_TEST qpos0_.reset(new mjtNum[model_->nq]); #endif @@ -100,7 +99,6 @@ MujocoEnv::~MujocoEnv() { mjr_freeContext(&context_); mjv_freeScene(&scene_); closeOpenGL(); - } // rl control Environment @@ -202,8 +200,8 @@ void MujocoEnv::PhysicsStep(int nstep, const mjtNum* action) { // https://github.com/deepmind/dm_control/blob/1.0.2/dm_control/mujoco/engine.py#L165 void MujocoEnv::PhysicsRender(int height, int width, - const std::string& camera_id, - bool depth, bool segmentation) { + const std::string& camera_id, bool depth, + bool segmentation) { // update abstract scene mjv_updateScene(model_, data_, &option_, NULL, &camera_, mjCAT_ALL, &scene_); mjrRect viewport = {0, 0, width_, height_}; @@ -214,8 +212,8 @@ void MujocoEnv::PhysicsRender(int height, int width, mjr_readPixels(rgb_array_, depth_array_, viewport, &context_); // segmentation results not implemented - - return {rgb_array_, depth_array_, segmentation_array_} + + return { rgb_array_, depth_array_, segmentation_array_ } } // randomizer @@ -267,52 +265,62 @@ void initOpenGL(void) { //------------------------ EGL #if defined(MJ_EGL) // desired config - const EGLint configAttribs[] = { - EGL_RED_SIZE, 8, - EGL_GREEN_SIZE, 8, - EGL_BLUE_SIZE, 8, - EGL_ALPHA_SIZE, 8, - EGL_DEPTH_SIZE, 24, - EGL_STENCIL_SIZE, 8, - EGL_COLOR_BUFFER_TYPE, EGL_RGB_BUFFER, - EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, - EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT, - EGL_NONE - }; + const EGLint configAttribs[] = {EGL_RED_SIZE, + 8, + EGL_GREEN_SIZE, + 8, + EGL_BLUE_SIZE, + 8, + EGL_ALPHA_SIZE, + 8, + EGL_DEPTH_SIZE, + 24, + EGL_STENCIL_SIZE, + 8, + EGL_COLOR_BUFFER_TYPE, + EGL_RGB_BUFFER, + EGL_SURFACE_TYPE, + EGL_PBUFFER_BIT, + EGL_RENDERABLE_TYPE, + EGL_OPENGL_BIT, + EGL_NONE}; // get default display EGLDisplay eglDpy = eglGetDisplay(EGL_DEFAULT_DISPLAY); - if (eglDpy==EGL_NO_DISPLAY) { + if (eglDpy == EGL_NO_DISPLAY) { mju_error_i("Could not get EGL display, error 0x%x\n", eglGetError()); } // initialize EGLint major, minor; - if (eglInitialize(eglDpy, &major, &minor)!=EGL_TRUE) { + if (eglInitialize(eglDpy, &major, &minor) != EGL_TRUE) { mju_error_i("Could not initialize EGL, error 0x%x\n", eglGetError()); } // choose config EGLint numConfigs; EGLConfig eglCfg; - if (eglChooseConfig(eglDpy, configAttribs, &eglCfg, 1, &numConfigs)!=EGL_TRUE) { + if (eglChooseConfig(eglDpy, configAttribs, &eglCfg, 1, &numConfigs) != + EGL_TRUE) { mju_error_i("Could not choose EGL config, error 0x%x\n", eglGetError()); } // bind OpenGL API - if (eglBindAPI(EGL_OPENGL_API)!=EGL_TRUE) { + if (eglBindAPI(EGL_OPENGL_API) != EGL_TRUE) { mju_error_i("Could not bind EGL OpenGL API, error 0x%x\n", eglGetError()); } // create context EGLContext eglCtx = eglCreateContext(eglDpy, eglCfg, EGL_NO_CONTEXT, NULL); - if (eglCtx==EGL_NO_CONTEXT) { + if (eglCtx == EGL_NO_CONTEXT) { mju_error_i("Could not create EGL context, error 0x%x\n", eglGetError()); } // make context current, no surface (let OpenGL handle FBO) - if (eglMakeCurrent(eglDpy, EGL_NO_SURFACE, EGL_NO_SURFACE, eglCtx)!=EGL_TRUE) { - mju_error_i("Could not make EGL context current, error 0x%x\n", eglGetError()); + if (eglMakeCurrent(eglDpy, EGL_NO_SURFACE, EGL_NO_SURFACE, eglCtx) != + EGL_TRUE) { + mju_error_i("Could not make EGL context current, error 0x%x\n", + eglGetError()); } //------------------------ OSMESA @@ -338,7 +346,8 @@ void initOpenGL(void) { // create invisible window, single-buffered glfwWindowHint(GLFW_VISIBLE, 0); glfwWindowHint(GLFW_DOUBLEBUFFER, GLFW_FALSE); - GLFWwindow* window = glfwCreateWindow(800, 800, "Invisible window", NULL, NULL); + GLFWwindow* window = + glfwCreateWindow(800, 800, "Invisible window", NULL, NULL); if (!window) { mju_error("Could not create GLFW window"); } @@ -348,14 +357,13 @@ void initOpenGL(void) { #endif } - // close OpenGL context/window void closeOpenGL(void) { //------------------------ EGL #if defined(MJ_EGL) // get current display EGLDisplay eglDpy = eglGetCurrentDisplay(); - if (eglDpy==EGL_NO_DISPLAY) { + if (eglDpy == EGL_NO_DISPLAY) { return; } @@ -366,7 +374,7 @@ void closeOpenGL(void) { eglMakeCurrent(eglDpy, EGL_NO_SURFACE, EGL_NO_SURFACE, EGL_NO_CONTEXT); // destroy context if valid - if (eglCtx!=EGL_NO_CONTEXT) { + if (eglCtx != EGL_NO_CONTEXT) { eglDestroyContext(eglDpy, eglCtx); } @@ -379,10 +387,10 @@ void closeOpenGL(void) { //------------------------ GLFW #else - // terminate GLFW (crashes with Linux NVidia drivers) - #if defined(__APPLE__) || defined(_WIN32) - glfwTerminate(); - #endif +// terminate GLFW (crashes with Linux NVidia drivers) +#if defined(__APPLE__) || defined(_WIN32) + glfwTerminate(); +#endif #endif } } // namespace mujoco_dmc diff --git a/envpool/mujoco/dmc/mujoco_env.h b/envpool/mujoco/dmc/mujoco_env.h index 591e4a66..0d290762 100644 --- a/envpool/mujoco/dmc/mujoco_env.h +++ b/envpool/mujoco/dmc/mujoco_env.h @@ -21,20 +21,20 @@ #include // select EGL, OSMESA or GLFW #if defined(MJ_EGL) - #include +#include #elif defined(MJ_OSMESA) - #include - OSMesaContext ctx; - unsigned char buffer[10000000]; +#include +OSMesaContext ctx; +unsigned char buffer[10000000]; #else - #include +#include #endif -#include "array_safety.h" #include #include #include +#include "array_safety.h" #include "envpool/mujoco/dmc/utils.h" namespace mujoco_dmc { diff --git a/third_party/mujoco/mujoco.BUILD b/third_party/mujoco/mujoco.BUILD index 248ce99c..596d5bfb 100644 --- a/third_party/mujoco/mujoco.BUILD +++ b/third_party/mujoco/mujoco.BUILD @@ -3,8 +3,10 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "mujoco_lib", srcs = glob(["lib/*"]), - hdrs = glob(["include/mujoco/*.h", - "sample/*.h"]), + hdrs = glob([ + "include/mujoco/*.h", + "sample/*.h", + ]), includes = [ "include", "include/mujoco", From 4a1c3d7231b74fe4acb1a97e2fdbe4e5b2247235 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Thu, 16 Jun 2022 20:46:58 +0800 Subject: [PATCH 07/62] fix(pix-obs dmc): fix glfw --- envpool/mujoco/BUILD | 2 ++ envpool/workspace0.bzl | 9 +++++++++ third_party/glfw/BUILD | 0 third_party/glfw/glfw.BUILD | 16 ++++++++++++++++ third_party/mujoco/mujoco.BUILD | 1 + 5 files changed, 28 insertions(+) create mode 100644 third_party/glfw/BUILD create mode 100644 third_party/glfw/glfw.BUILD diff --git a/envpool/mujoco/BUILD b/envpool/mujoco/BUILD index bd8fb982..118329dc 100644 --- a/envpool/mujoco/BUILD +++ b/envpool/mujoco/BUILD @@ -90,6 +90,8 @@ cc_library( "//envpool/core:async_envpool", "@mujoco//:mujoco_lib", "@pugixml", + "@glfw//:glfw", + ], ) diff --git a/envpool/workspace0.bzl b/envpool/workspace0.bzl index 344224a4..9c47e886 100644 --- a/envpool/workspace0.bzl +++ b/envpool/workspace0.bzl @@ -16,6 +16,8 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") +load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository") + def workspace(): """Load requested packages.""" @@ -314,6 +316,13 @@ def workspace(): ], build_file = "//third_party/vizdoom_extra_maps:vizdoom_extra_maps.BUILD", ) + + new_git_repository( + name = "glfw", + remote = "https://github.com/glfw/glfw.git", + commit = "8d7e5cdb49a1a5247df612157ecffdd8e68923d2", + build_file = "@//third_party/glfw:glfw.BUILD", + ) maybe( http_archive, diff --git a/third_party/glfw/BUILD b/third_party/glfw/BUILD new file mode 100644 index 00000000..e69de29b diff --git a/third_party/glfw/glfw.BUILD b/third_party/glfw/glfw.BUILD new file mode 100644 index 00000000..457364dd --- /dev/null +++ b/third_party/glfw/glfw.BUILD @@ -0,0 +1,16 @@ +package(default_visibility = ["//visibility:public"]) + +LINUX_LINKOPTS = [] + +cc_library( + name = "glfw", + hdrs = [ + "include/GLFW/glfw3.h", + "include/GLFW/glfw3native.h", + ], + linkopts = select({ + "@bazel_tools//src/conditions:linux_x86_64": LINUX_LINKOPTS, + }), + deps = [], + strip_include_prefix = "include", +) \ No newline at end of file diff --git a/third_party/mujoco/mujoco.BUILD b/third_party/mujoco/mujoco.BUILD index 596d5bfb..33122319 100644 --- a/third_party/mujoco/mujoco.BUILD +++ b/third_party/mujoco/mujoco.BUILD @@ -10,6 +10,7 @@ cc_library( includes = [ "include", "include/mujoco", + "sample", ], linkopts = ["-Wl,-rpath,'$$ORIGIN'"], linkstatic = 0, From f8e6036952e7d0a4d31a6b58c1b6e2c2f4eb9764 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Thu, 16 Jun 2022 20:51:07 +0800 Subject: [PATCH 08/62] fix(pix-obs dmc): fix glfw --- envpool/mujoco/BUILD | 3 +-- envpool/workspace0.bzl | 3 +-- third_party/glfw/glfw.BUILD | 4 ++-- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/envpool/mujoco/BUILD b/envpool/mujoco/BUILD index fa8403b0..789925ae 100644 --- a/envpool/mujoco/BUILD +++ b/envpool/mujoco/BUILD @@ -88,10 +88,9 @@ cc_library( data = [":gen_mujoco_dmc_xml"], deps = [ "//envpool/core:async_envpool", + "@glfw", "@mujoco//:mujoco_lib", "@pugixml", - "@glfw//:glfw", - ], ) diff --git a/envpool/workspace0.bzl b/envpool/workspace0.bzl index 9c47e886..fcf22a11 100644 --- a/envpool/workspace0.bzl +++ b/envpool/workspace0.bzl @@ -18,7 +18,6 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository") - def workspace(): """Load requested packages.""" maybe( @@ -316,7 +315,7 @@ def workspace(): ], build_file = "//third_party/vizdoom_extra_maps:vizdoom_extra_maps.BUILD", ) - + new_git_repository( name = "glfw", remote = "https://github.com/glfw/glfw.git", diff --git a/third_party/glfw/glfw.BUILD b/third_party/glfw/glfw.BUILD index 457364dd..208bfeff 100644 --- a/third_party/glfw/glfw.BUILD +++ b/third_party/glfw/glfw.BUILD @@ -11,6 +11,6 @@ cc_library( linkopts = select({ "@bazel_tools//src/conditions:linux_x86_64": LINUX_LINKOPTS, }), - deps = [], strip_include_prefix = "include", -) \ No newline at end of file + deps = [], +) From a2d754b6754784685d34a84d6c8ba1350629720f Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Thu, 16 Jun 2022 20:55:29 +0800 Subject: [PATCH 09/62] fix(pix-obs dmc): fix mujoco_env --- envpool/mujoco/dmc/mujoco_env.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index e23aebbe..68a6fe60 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -25,7 +25,7 @@ namespace mujoco_dmc { MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, int n_sub_steps, int max_episode_steps, int height, int width, const std::string& camera_id, bool depth, - bool segmentation, ) + bool segmentation) : n_sub_steps_(n_sub_steps), max_episode_steps_(max_episode_steps), elapsed_step_(max_episode_steps + 1), From c4eaca1710766b2ff555439a5ed1d76bf4c414d6 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Thu, 16 Jun 2022 21:09:06 +0800 Subject: [PATCH 10/62] fix(pix-obs dmc): fix mujoco_env --- envpool/mujoco/dmc/mujoco_env.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index 68a6fe60..383203be 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -94,8 +94,8 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, } MujocoEnv::~MujocoEnv() { - mj_deleteModel(model_); mj_deleteData(data_); + mj_deleteModel(model_); mjr_freeContext(&context_); mjv_freeScene(&scene_); closeOpenGL(); @@ -261,7 +261,7 @@ void MujocoEnv::RandomizeLimitedAndRotationalJoints(std::mt19937* gen) { } } // create OpenGL context/window -void initOpenGL(void) { +void MujocoEnv::initOpenGL(void) { //------------------------ EGL #if defined(MJ_EGL) // desired config @@ -358,7 +358,7 @@ void initOpenGL(void) { } // close OpenGL context/window -void closeOpenGL(void) { +void MujocoEnv::closeOpenGL(void) { //------------------------ EGL #if defined(MJ_EGL) // get current display From 242b2f00cf5d242b49bb1c575bac4b3f079b9a5d Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Thu, 16 Jun 2022 21:13:30 +0800 Subject: [PATCH 11/62] fix(pix-obs dmc): fix mujoco_env --- envpool/mujoco/dmc/mujoco_env.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/envpool/mujoco/dmc/mujoco_env.h b/envpool/mujoco/dmc/mujoco_env.h index 0d290762..f364050d 100644 --- a/envpool/mujoco/dmc/mujoco_env.h +++ b/envpool/mujoco/dmc/mujoco_env.h @@ -55,6 +55,10 @@ class MujocoEnv { protected: mjModel* model_; mjData* data_; + mjvScene scene_; + mjvCamera camera_; + mjvOption option_; + mjrContext context_; int n_sub_steps_, max_episode_steps_, elapsed_step_; float reward_, discount_; bool done_; @@ -105,6 +109,10 @@ class MujocoEnv { // randomizer // https://github.com/deepmind/dm_control/blob/1.0.2/dm_control/suite/utils/randomizers.py#L35 void RandomizeLimitedAndRotationalJoints(std::mt19937* gen); + // create OpenGL context/window + void initOpenGL(void); + // close OpenGL context/window + void closeOpenGL(void); }; } // namespace mujoco_dmc From 5b6a77c61a3edae931036d81e44d2a585e87bcc5 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Thu, 16 Jun 2022 21:17:22 +0800 Subject: [PATCH 12/62] fix(pix-obs dmc): fix mujoco_env --- envpool/mujoco/dmc/mujoco_env.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/envpool/mujoco/dmc/mujoco_env.h b/envpool/mujoco/dmc/mujoco_env.h index f364050d..090b30f5 100644 --- a/envpool/mujoco/dmc/mujoco_env.h +++ b/envpool/mujoco/dmc/mujoco_env.h @@ -106,6 +106,10 @@ class MujocoEnv { // https://github.com/deepmind/dm_control/blob/1.0.2/dm_control/mujoco/engine.py#L146 void PhysicsStep(int nstep, const mjtNum* action); + // https://github.com/deepmind/dm_control/blob/1.0.2/dm_control/mujoco/engine.py#L165 + void MujocoEnv::PhysicsRender(int height, int width, + const std::string& camera_id, bool depth, + bool segmentation); // randomizer // https://github.com/deepmind/dm_control/blob/1.0.2/dm_control/suite/utils/randomizers.py#L35 void RandomizeLimitedAndRotationalJoints(std::mt19937* gen); From e346072a38088bfa5932f1c7420f0c1f64ebf376 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Thu, 16 Jun 2022 21:21:10 +0800 Subject: [PATCH 13/62] fix(pix-obs dmc): fix mujoco_env --- envpool/mujoco/dmc/mujoco_env.cc | 2 +- envpool/mujoco/dmc/mujoco_env.h | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index 383203be..c8148732 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -213,7 +213,7 @@ void MujocoEnv::PhysicsRender(int height, int width, // segmentation results not implemented - return { rgb_array_, depth_array_, segmentation_array_ } + return {rgb_array_, depth_array_, segmentation_array_}; } // randomizer diff --git a/envpool/mujoco/dmc/mujoco_env.h b/envpool/mujoco/dmc/mujoco_env.h index 090b30f5..250e3b78 100644 --- a/envpool/mujoco/dmc/mujoco_env.h +++ b/envpool/mujoco/dmc/mujoco_env.h @@ -62,6 +62,8 @@ class MujocoEnv { int n_sub_steps_, max_episode_steps_, elapsed_step_; float reward_, discount_; bool done_; + unsigned char* rgb_array_; + auto* depth_array_; #ifdef ENVPOOL_TEST std::unique_ptr qpos0_; #endif From 950af9934bbea946e6801b5312e79a8ee72084a4 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Thu, 16 Jun 2022 21:24:06 +0800 Subject: [PATCH 14/62] fix(pix-obs dmc): fix mujoco_env --- envpool/mujoco/dmc/mujoco_env.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index c8148732..55328585 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -213,7 +213,7 @@ void MujocoEnv::PhysicsRender(int height, int width, // segmentation results not implemented - return {rgb_array_, depth_array_, segmentation_array_}; + // return {rgb_array_, depth_array_, segmentation_array_}; } // randomizer From dd643f4f821f79ce1bd9a87af66d8c1fcb54fcee Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Thu, 16 Jun 2022 21:28:37 +0800 Subject: [PATCH 15/62] fix(pix-obs dmc): fix mujoco_env --- envpool/mujoco/dmc/mujoco_env.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/envpool/mujoco/dmc/mujoco_env.h b/envpool/mujoco/dmc/mujoco_env.h index 250e3b78..29851d0d 100644 --- a/envpool/mujoco/dmc/mujoco_env.h +++ b/envpool/mujoco/dmc/mujoco_env.h @@ -62,6 +62,9 @@ class MujocoEnv { int n_sub_steps_, max_episode_steps_, elapsed_step_; float reward_, discount_; bool done_; + int height_, width_; + bool depth_, segmentation_; + const std::string& camera_id_; unsigned char* rgb_array_; auto* depth_array_; #ifdef ENVPOOL_TEST From fca61e0613e5ee406e09adc21f774a5e501c3be7 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Thu, 16 Jun 2022 21:30:56 +0800 Subject: [PATCH 16/62] fix(pix-obs dmc): fix mujoco_env --- envpool/mujoco/dmc/mujoco_env.h | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.h b/envpool/mujoco/dmc/mujoco_env.h index 29851d0d..3496ac83 100644 --- a/envpool/mujoco/dmc/mujoco_env.h +++ b/envpool/mujoco/dmc/mujoco_env.h @@ -112,9 +112,8 @@ class MujocoEnv { void PhysicsStep(int nstep, const mjtNum* action); // https://github.com/deepmind/dm_control/blob/1.0.2/dm_control/mujoco/engine.py#L165 - void MujocoEnv::PhysicsRender(int height, int width, - const std::string& camera_id, bool depth, - bool segmentation); + void PhysicsRender(int height, int width, const std::string& camera_id, + bool depth, bool segmentation); // randomizer // https://github.com/deepmind/dm_control/blob/1.0.2/dm_control/suite/utils/randomizers.py#L35 void RandomizeLimitedAndRotationalJoints(std::mt19937* gen); From 93afd2f5e02fcbb577fb94c0bee7d0cf6487368e Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Thu, 16 Jun 2022 21:33:05 +0800 Subject: [PATCH 17/62] fix(pix-obs dmc): fix mujoco_env --- envpool/mujoco/dmc/mujoco_env.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/envpool/mujoco/dmc/mujoco_env.h b/envpool/mujoco/dmc/mujoco_env.h index 3496ac83..12a46520 100644 --- a/envpool/mujoco/dmc/mujoco_env.h +++ b/envpool/mujoco/dmc/mujoco_env.h @@ -66,7 +66,7 @@ class MujocoEnv { bool depth_, segmentation_; const std::string& camera_id_; unsigned char* rgb_array_; - auto* depth_array_; + float* depth_array_; #ifdef ENVPOOL_TEST std::unique_ptr qpos0_; #endif From 3621eec49e9fc57523bfb93f47303c236d1b29a1 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Thu, 16 Jun 2022 21:35:28 +0800 Subject: [PATCH 18/62] fix(pix-obs dmc): fix mujoco_env --- envpool/mujoco/dmc/mujoco_env.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/envpool/mujoco/dmc/mujoco_env.h b/envpool/mujoco/dmc/mujoco_env.h index 12a46520..b1187402 100644 --- a/envpool/mujoco/dmc/mujoco_env.h +++ b/envpool/mujoco/dmc/mujoco_env.h @@ -73,7 +73,8 @@ class MujocoEnv { public: MujocoEnv(const std::string& base_path, const std::string& raw_xml, - int n_sub_steps, int max_episode_steps); + int n_sub_steps, int max_episode_steps, int height, int width, + const std::string& camera_id, bool depth, bool segmentation); ~MujocoEnv(); // rl control Environment From 15d56515ad2571645d984ccdcdcdcfd97d1322f2 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Thu, 16 Jun 2022 21:39:14 +0800 Subject: [PATCH 19/62] fix(pix-obs dmc): fix mujoco_env --- envpool/mujoco/dmc/mujoco_env.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index 55328585..05f0c9a8 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -80,8 +80,7 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, mjr_setBuffer(mjFB_OFFSCREEN, &context_); // allocate rgb and depth buffers unsigned char* rgb_array_ = (unsigned char*)std::malloc(3 * width_ * height_); - auto* depth_array_ = - (reinterpret_cast)std::malloc(sizeof(float) * width_ * height_); + auto* depth_array_ = (float*)std::malloc(sizeof(float) * width_ * height_); // camera configuration // cam.lookat[0] = m->stat.center[0]; // cam.lookat[1] = m->stat.center[1]; From f67e9888a6d0f64e04a621613e54a34d85e2e629 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Thu, 16 Jun 2022 21:46:47 +0800 Subject: [PATCH 20/62] fix(pix-obs dmc): fix lint --- envpool/mujoco/dmc/mujoco_env.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index 05f0c9a8..da2fef30 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -80,7 +80,8 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, mjr_setBuffer(mjFB_OFFSCREEN, &context_); // allocate rgb and depth buffers unsigned char* rgb_array_ = (unsigned char*)std::malloc(3 * width_ * height_); - auto* depth_array_ = (float*)std::malloc(sizeof(float) * width_ * height_); + auto* depth_array_ = + reinterpret_cast std::malloc(sizeof(float) * width_ * height_); // camera configuration // cam.lookat[0] = m->stat.center[0]; // cam.lookat[1] = m->stat.center[1]; From f89b4251fc68526b3f28abf0a96d6bf5dd522fe2 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Sat, 16 Jul 2022 21:25:54 +0800 Subject: [PATCH 21/62] fix: pass lint --- .../mujoco_dmc_suite_ext_render_align_test.py | 103 +++++++++++++++++ envpool/mujoco/dmc/mujoco_env.cc | 104 +++++++++--------- envpool/workspace0.bzl | 1 - 3 files changed, 156 insertions(+), 52 deletions(-) create mode 100644 envpool/mujoco/dmc/mujoco_dmc_suite_ext_render_align_test.py diff --git a/envpool/mujoco/dmc/mujoco_dmc_suite_ext_render_align_test.py b/envpool/mujoco/dmc/mujoco_dmc_suite_ext_render_align_test.py new file mode 100644 index 00000000..ed078fdf --- /dev/null +++ b/envpool/mujoco/dmc/mujoco_dmc_suite_ext_render_align_test.py @@ -0,0 +1,103 @@ +# Copyright 2022 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for Mujoco dm_control suite align check.""" + +from typing import Any, List, no_type_check + +import dm_env +import numpy as np +from absl import logging +from absl.testing import absltest +from dm_control import suite + +from envpool.mujoco.dmc import DmcHumanoidCMUDMEnvPool, DmcHumanoidCMUEnvSpec + + +class _MujocoDmcSuiteExtAlignTest(absltest.TestCase): + + @no_type_check + def run_space_check(self, env0: dm_env.Environment, env1: Any) -> None: + """Check observation_spec() and action_spec().""" + obs0, obs1 = env0.observation_spec(), env1.observation_spec() + for k in obs0: + self.assertTrue(hasattr(obs1, k)) + np.testing.assert_allclose(obs0[k].shape, getattr(obs1, k).shape) + act0, act1 = env0.action_spec(), env1.action_spec() + np.testing.assert_allclose(act0.shape, act1.shape) + np.testing.assert_allclose(act0.minimum, act1.minimum) + np.testing.assert_allclose(act0.maximum, act1.maximum) + + @no_type_check + def reset_state( + self, env: dm_env.Environment, ts: dm_env.TimeStep, domain: str, task: str + ) -> None: + # manually reset, mimic initialize_episode + with env.physics.reset_context(): + env.physics.data.qpos = ts.observation.qpos0[0] + if domain in ["humanoid_CMU"]: + env.physics.after_reset() + + def sample_action(self, action_spec: dm_env.specs.Array) -> np.ndarray: + return np.random.uniform( + low=action_spec.minimum, + high=action_spec.maximum, + size=action_spec.shape, + ) + + def run_align_check( + self, env0: dm_env.Environment, env1: Any, domain: str, task: str + ) -> None: + logging.info(f"align check for {domain} {task}") + obs_spec, action_spec = env0.observation_spec(), env0.action_spec() + for i in range(3): + np.random.seed(i) + env0.reset() + a = self.sample_action(action_spec) + ts = env1.reset(np.array([0])) + self.reset_state(env0, ts, domain, task) + logging.info(f"reset qpos {ts.observation.qpos0[0]}") + cnt = 0 + done = False + while not done: + cnt += 1 + a = self.sample_action(action_spec) + # logging.info(f"{cnt} {a}") + ts0 = env0.step(a) + ts1 = env1.step(np.array([a]), np.array([0])) + done = ts0.step_type == dm_env.StepType.LAST + o0, o1 = ts0.observation, ts1.observation + for k in obs_spec: + np.testing.assert_allclose(o0[k], getattr(o1, k)[0]) + np.testing.assert_allclose(ts0.step_type, ts1.step_type[0]) + np.testing.assert_allclose(ts0.reward, ts1.reward[0], atol=1e-8) + np.testing.assert_allclose(ts0.discount, ts1.discount[0]) + + def run_align_check_entry( + self, domain: str, tasks: List[str], spec_cls: Any, envpool_cls: Any + ) -> None: + for task in tasks: + env0 = suite.load(domain, task) + env1 = envpool_cls(spec_cls(spec_cls.gen_config(task_name=task))) + self.run_space_check(env0, env1) + self.run_align_check(env0, env1, domain, task) + + def test_humanoid_CMU(self) -> None: + self.run_align_check_entry( + "humanoid_CMU", ["stand", "run"], DmcHumanoidCMUEnvSpec, + DmcHumanoidCMUDMEnvPool + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index daac9eda..08b36dce 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -14,10 +14,12 @@ #include "envpool/mujoco/dmc/mujoco_env.h" +#include #include #include #include #include +#include #include namespace mujoco_dmc { @@ -139,9 +141,11 @@ void MujocoEnv::ControlStep(const mjtNum* action) { void MujocoEnv::TaskBeforeStep(const mjtNum* action) { PhysicsSetControl(action); } + float MujocoEnv::TaskGetReward() { throw std::runtime_error("GetReward not implemented"); } + float MujocoEnv::TaskGetDiscount() { return 1.0; } bool MujocoEnv::TaskShouldTerminateEpisode() { return false; } @@ -261,64 +265,62 @@ void MujocoEnv::RandomizeLimitedAndRotationalJoints(std::mt19937* gen) { } } - /** - * FrameStack env wrapper implementation. - * - * The original gray scale image are saved inside maxpool_buf_. - * The stacked result is in stack_buf_ where len(stack_buf_) == stack_num_. - * - * At reset time, we need to clear all data in stack_buf_ with push_all = - * true and maxpool = false (there is only one observation); at step time, - * we push max(maxpool_buf_[0], maxpool_buf_[1]) at the end of - * stack_buf_, and pop the first item in stack_buf_, with push_all = false - * and maxpool = true. - * - * @param push_all whether to use the most recent observation to write all - * of the data in stack_buf_. - * @param maxpool whether to perform maxpool operation on the last two - * observation. Maybe there is only one? - */ - void PushStack(bool push_all, bool maxpool) { - auto* ptr = static_cast(maxpool_buf_[0].Data()); - if (maxpool) { - auto* ptr1 = static_cast(maxpool_buf_[1].Data()); - for (std::size_t i = 0; i < maxpool_buf_[0].size; ++i) { - ptr[i] = std::max(ptr[i], ptr1[i]); - } +/** + * FrameStack env wrapper implementation. + * + * The original gray scale image are saved inside maxpool_buf_. + * The stacked result is in stack_buf_ where len(stack_buf_) == stack_num_. + * + * At reset time, we need to clear all data in stack_buf_ with push_all = + * true and maxpool = false (there is only one observation); at step time, + * we push max(maxpool_buf_[0], maxpool_buf_[1]) at the end of + * stack_buf_, and pop the first item in stack_buf_, with push_all = false + * and maxpool = true. + * + * @param push_all whether to use the most recent observation to write all + * of the data in stack_buf_. + * @param maxpool whether to perform maxpool operation on the last two + * observation. Maybe there is only one? + */ +void MujocoEnv::PushStack(bool push_all, bool maxpool) { + auto* ptr = static_cast(maxpool_buf_[0].Data()); + if (maxpool) { + auto* ptr1 = static_cast(maxpool_buf_[1].Data()); + for (std::size_t i = 0; i < maxpool_buf_[0].size; ++i) { + ptr[i] = std::max(ptr[i], ptr1[i]); } - Resize(maxpool_buf_[0], &resize_img_, use_inter_area_resize_); - Array tgt = std::move(*stack_buf_.begin()); - ptr = static_cast(tgt.Data()); - stack_buf_.pop_front(); - if (gray_scale_) { - tgt.Assign(resize_img_); - } else { - auto* ptr1 = static_cast(resize_img_.Data()); - // tgt = resize_img_.transpose(1, 2, 0) - // tgt[i, j, k] = resize_img_[j, k, i] - std::size_t h = resize_img_.Shape(0); - std::size_t w = resize_img_.Shape(1); - for (std::size_t j = 0; j < h; ++j) { - for (std::size_t k = 0; k < w; ++k) { - for (std::size_t i = 0; i < 3; ++i) { - ptr[i * h * w + j * w + k] = ptr1[j * w * 3 + k * 3 + i]; - } + } + Resize(maxpool_buf_[0], &resize_img_, use_inter_area_resize_); + Array tgt = std::move(*stack_buf_.begin()); + ptr = static_cast(tgt.Data()); + stack_buf_.pop_front(); + if (gray_scale_) { + tgt.Assign(resize_img_); + } else { + auto* ptr1 = static_cast(resize_img_.Data()); + // tgt = resize_img_.transpose(1, 2, 0) + // tgt[i, j, k] = resize_img_[j, k, i] + std::size_t h = resize_img_.Shape(0); + std::size_t w = resize_img_.Shape(1); + for (std::size_t j = 0; j < h; ++j) { + for (std::size_t k = 0; k < w; ++k) { + for (std::size_t i = 0; i < 3; ++i) { + ptr[i * h * w + j * w + k] = ptr1[j * w * 3 + k * 3 + i]; } } } - std::size_t size = tgt.size; - stack_buf_.push_back(std::move(tgt)); - if (push_all) { - for (auto& s : stack_buf_) { - auto* ptr_s = static_cast(s.Data()); - if (ptr != ptr_s) { - std::memcpy(ptr_s, ptr, size); - } + } + std::size_t size = tgt.size; + stack_buf_.push_back(std::move(tgt)); + if (push_all) { + for (auto& s : stack_buf_) { + auto* ptr_s = static_cast(s.Data()); + if (ptr != ptr_s) { + std::memcpy(ptr_s, ptr, size); } } } -}; - +} // create OpenGL context/window void MujocoEnv::initOpenGL(void) { diff --git a/envpool/workspace0.bzl b/envpool/workspace0.bzl index ee4786de..608a17ea 100644 --- a/envpool/workspace0.bzl +++ b/envpool/workspace0.bzl @@ -16,7 +16,6 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") - load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository") load("//third_party/cuda:cuda.bzl", "cuda_configure") From a1ed423eca46feb93efe5d66f8b32d8130d074ef Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 15 Aug 2022 14:41:49 +0800 Subject: [PATCH 22/62] fix: update mujoco BUILD file --- envpool/mujoco/BUILD | 5 +- ...dmc_suite_ext_render_deterministic_test.py | 81 +++++++++++++++++++ third_party/egl/BUILD | 0 third_party/egl/egl.BUILD | 16 ++++ 4 files changed, 101 insertions(+), 1 deletion(-) create mode 100644 envpool/mujoco/dmc/mujoco_dmc_suite_ext_render_deterministic_test.py create mode 100644 third_party/egl/BUILD create mode 100644 third_party/egl/egl.BUILD diff --git a/envpool/mujoco/BUILD b/envpool/mujoco/BUILD index 789925ae..4637b528 100644 --- a/envpool/mujoco/BUILD +++ b/envpool/mujoco/BUILD @@ -86,9 +86,12 @@ cc_library( "dmc/walker.h", ], data = [":gen_mujoco_dmc_xml"], + linkopts = [ + "-lEGL", + "-lGL", + ], deps = [ "//envpool/core:async_envpool", - "@glfw", "@mujoco//:mujoco_lib", "@pugixml", ], diff --git a/envpool/mujoco/dmc/mujoco_dmc_suite_ext_render_deterministic_test.py b/envpool/mujoco/dmc/mujoco_dmc_suite_ext_render_deterministic_test.py new file mode 100644 index 00000000..9337b93e --- /dev/null +++ b/envpool/mujoco/dmc/mujoco_dmc_suite_ext_render_deterministic_test.py @@ -0,0 +1,81 @@ +# Copyright 2022 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for Mujoco dm_control deterministic check.""" + +from typing import Any, List, Optional + +import dm_env +import numpy as np +from absl.testing import absltest + +from envpool.mujoco.dmc import DmcHumanoidCMUDMEnvPool, DmcHumanoidCMUEnvSpec + + +class _MujocoDmcSuiteExtDeterministicTest(absltest.TestCase): + + def check( + self, + spec_cls: Any, + envpool_cls: Any, + task: str, + obs_keys: List[str], + blacklist: Optional[List[str]] = None, + num_envs: int = 4, + ) -> None: + np.random.seed(0) + env0 = envpool_cls( + spec_cls(spec_cls.gen_config(num_envs=num_envs, seed=0, task_name=task)) + ) + env1 = envpool_cls( + spec_cls(spec_cls.gen_config(num_envs=num_envs, seed=0, task_name=task)) + ) + env2 = envpool_cls( + spec_cls(spec_cls.gen_config(num_envs=num_envs, seed=1, task_name=task)) + ) + act_spec = env0.action_spec() + for t in range(3000): + action = np.array( + [ + np.random.uniform( + low=act_spec.minimum, high=act_spec.maximum, size=act_spec.shape + ) for _ in range(num_envs) + ] + ) + ts0 = env0.step(action) + obs0 = ts0.observation + obs1 = env1.step(action).observation + obs2 = env2.step(action).observation + for k in obs_keys: + o0 = getattr(obs0, k) + o1 = getattr(obs1, k) + o2 = getattr(obs2, k) + np.testing.assert_allclose(o0, o1) + if blacklist and k in blacklist: + continue + if np.abs(o0).sum() > 0 and ts0.step_type[0] != dm_env.StepType.FIRST: + self.assertFalse(np.allclose(o0, o2), (t, k, o0, o2)) + + def test_humanoid_CMU(self) -> None: + obs_keys = [ + "joint_angles", "head_height", "extremities", "torso_vertical", + "com_velocity", "velocity" + ] + for task in ["stand", "run"]: + self.check( + DmcHumanoidCMUEnvSpec, DmcHumanoidCMUDMEnvPool, task, obs_keys + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/third_party/egl/BUILD b/third_party/egl/BUILD new file mode 100644 index 00000000..e69de29b diff --git a/third_party/egl/egl.BUILD b/third_party/egl/egl.BUILD new file mode 100644 index 00000000..208bfeff --- /dev/null +++ b/third_party/egl/egl.BUILD @@ -0,0 +1,16 @@ +package(default_visibility = ["//visibility:public"]) + +LINUX_LINKOPTS = [] + +cc_library( + name = "glfw", + hdrs = [ + "include/GLFW/glfw3.h", + "include/GLFW/glfw3native.h", + ], + linkopts = select({ + "@bazel_tools//src/conditions:linux_x86_64": LINUX_LINKOPTS, + }), + strip_include_prefix = "include", + deps = [], +) From a5f10606d0bb6250251ba072df39f9d4e0d98ede Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 15 Aug 2022 15:02:36 +0800 Subject: [PATCH 23/62] fix: update mujoco BUILD file --- envpool/mujoco/dmc/mujoco_env.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.h b/envpool/mujoco/dmc/mujoco_env.h index b1187402..dd531d7b 100644 --- a/envpool/mujoco/dmc/mujoco_env.h +++ b/envpool/mujoco/dmc/mujoco_env.h @@ -21,13 +21,13 @@ #include // select EGL, OSMESA or GLFW #if defined(MJ_EGL) -#include +#include #elif defined(MJ_OSMESA) #include OSMesaContext ctx; unsigned char buffer[10000000]; #else -#include +#include #endif #include From f8e0e4b3eb1075ad490cc50e254ab3a7c7b2a7e6 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 15 Aug 2022 15:42:57 +0800 Subject: [PATCH 24/62] fix: use EGL header --- envpool/mujoco/dmc/mujoco_env.cc | 91 +++++++++++++++++--------------- envpool/mujoco/dmc/mujoco_env.h | 2 +- 2 files changed, 49 insertions(+), 44 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index 08b36dce..82a7adab 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -324,8 +324,41 @@ void MujocoEnv::PushStack(bool push_all, bool maxpool) { // create OpenGL context/window void MujocoEnv::initOpenGL(void) { + //------------------------ GLFW + +#if defined(MJ_GLFW) + // init GLFW + if (!glfwInit()) { + mju_error("Could not initialize GLFW"); + } + + // create invisible window, single-buffered + glfwWindowHint(GLFW_VISIBLE, 0); + glfwWindowHint(GLFW_DOUBLEBUFFER, GLFW_FALSE); + GLFWwindow* window = + glfwCreateWindow(800, 800, "Invisible window", NULL, NULL); + if (!window) { + mju_error("Could not create GLFW window"); + } + + // make context current + glfwMakeContextCurrent(window); + //------------------------ OSMESA +#elif defined(MJ_OSMESA) + // create context + ctx = OSMesaCreateContextExt(GL_RGBA, 24, 8, 8, 0); + if (!ctx) { + mju_error("OSMesa context creation failed"); + } + + // make current + if (!OSMesaMakeCurrent(ctx, buffer, GL_UNSIGNED_BYTE, 800, 800)) { + mju_error("OSMesa make current failed"); + } + //------------------------ EGL -#if defined(MJ_EGL) + +#else // desired config const EGLint configAttribs[] = {EGL_RED_SIZE, 8, @@ -385,44 +418,26 @@ void MujocoEnv::initOpenGL(void) { eglGetError()); } - //------------------------ OSMESA -#elif defined(MJ_OSMESA) - // create context - ctx = OSMesaCreateContextExt(GL_RGBA, 24, 8, 8, 0); - if (!ctx) { - mju_error("OSMesa context creation failed"); - } - - // make current - if (!OSMesaMakeCurrent(ctx, buffer, GL_UNSIGNED_BYTE, 800, 800)) { - mju_error("OSMesa make current failed"); - } - - //------------------------ GLFW -#else - // init GLFW - if (!glfwInit()) { - mju_error("Could not initialize GLFW"); - } - - // create invisible window, single-buffered - glfwWindowHint(GLFW_VISIBLE, 0); - glfwWindowHint(GLFW_DOUBLEBUFFER, GLFW_FALSE); - GLFWwindow* window = - glfwCreateWindow(800, 800, "Invisible window", NULL, NULL); - if (!window) { - mju_error("Could not create GLFW window"); - } - - // make context current - glfwMakeContextCurrent(window); #endif } // close OpenGL context/window void MujocoEnv::closeOpenGL(void) { + //------------------------ GLFW + +#if defined(MJ_GLFW) +// terminate GLFW (crashes with Linux NVidia drivers) +#if defined(__APPLE__) || defined(_WIN32) + glfwTerminate(); +#endif + + //------------------------ OSMESA +#elif defined(MJ_OSMESA) + OSMesaDestroyContext(ctx); + //------------------------ EGL -#if defined(MJ_EGL) + +#else // get current display EGLDisplay eglDpy = eglGetCurrentDisplay(); if (eglDpy == EGL_NO_DISPLAY) { @@ -443,16 +458,6 @@ void MujocoEnv::closeOpenGL(void) { // terminate display eglTerminate(eglDpy); - //------------------------ OSMESA -#elif defined(MJ_OSMESA) - OSMesaDestroyContext(ctx); - - //------------------------ GLFW -#else -// terminate GLFW (crashes with Linux NVidia drivers) -#if defined(__APPLE__) || defined(_WIN32) - glfwTerminate(); -#endif #endif } } // namespace mujoco_dmc diff --git a/envpool/mujoco/dmc/mujoco_env.h b/envpool/mujoco/dmc/mujoco_env.h index dd531d7b..c71ac1b1 100644 --- a/envpool/mujoco/dmc/mujoco_env.h +++ b/envpool/mujoco/dmc/mujoco_env.h @@ -20,7 +20,7 @@ #include #include // select EGL, OSMESA or GLFW -#if defined(MJ_EGL) +#if defined(MJ_GLFW) #include #elif defined(MJ_OSMESA) #include From ec8aeb35cb36fa0fed9e962ea45eb528d73e7b69 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 15 Aug 2022 15:57:37 +0800 Subject: [PATCH 25/62] fix: update types --- envpool/mujoco/dmc/mujoco_env.cc | 5 ++--- envpool/mujoco/dmc/mujoco_env.h | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index 82a7adab..b1ec7a85 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -81,9 +81,8 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, // set rendering to offscreen buffer mjr_setBuffer(mjFB_OFFSCREEN, &context_); // allocate rgb and depth buffers - unsigned char* rgb_array_ = (unsigned char*)std::malloc(3 * width_ * height_); - auto* depth_array_ = - reinterpret_cast std::malloc(sizeof(float) * width_ * height_); + std::array rgb_array_; + std::array depth_array_; // camera configuration // cam.lookat[0] = m->stat.center[0]; // cam.lookat[1] = m->stat.center[1]; diff --git a/envpool/mujoco/dmc/mujoco_env.h b/envpool/mujoco/dmc/mujoco_env.h index c71ac1b1..df94a598 100644 --- a/envpool/mujoco/dmc/mujoco_env.h +++ b/envpool/mujoco/dmc/mujoco_env.h @@ -65,7 +65,7 @@ class MujocoEnv { int height_, width_; bool depth_, segmentation_; const std::string& camera_id_; - unsigned char* rgb_array_; + float* rgb_array_; float* depth_array_; #ifdef ENVPOOL_TEST std::unique_ptr qpos0_; From c40b5330c578f8bf1dbf016ae9c5c3fc2c8e91b6 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 15 Aug 2022 16:09:44 +0800 Subject: [PATCH 26/62] fix: update types --- envpool/mujoco/dmc/mujoco_env.cc | 8 ++++---- envpool/mujoco/dmc/mujoco_env.h | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index b1ec7a85..90aaa387 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -25,8 +25,8 @@ namespace mujoco_dmc { MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, - int n_sub_steps, int max_episode_steps, int height, - int width, const std::string& camera_id, bool depth, + int n_sub_steps, int max_episode_steps, const int height, + const int width, const std::string& camera_id, bool depth, bool segmentation) : n_sub_steps_(n_sub_steps), max_episode_steps_(max_episode_steps), @@ -81,8 +81,8 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, // set rendering to offscreen buffer mjr_setBuffer(mjFB_OFFSCREEN, &context_); // allocate rgb and depth buffers - std::array rgb_array_; - std::array depth_array_; + std::vector rgb_array_({3, width_, height_}); + std::vector depth_array_({width_, height_}); // camera configuration // cam.lookat[0] = m->stat.center[0]; // cam.lookat[1] = m->stat.center[1]; diff --git a/envpool/mujoco/dmc/mujoco_env.h b/envpool/mujoco/dmc/mujoco_env.h index df94a598..c71ac1b1 100644 --- a/envpool/mujoco/dmc/mujoco_env.h +++ b/envpool/mujoco/dmc/mujoco_env.h @@ -65,7 +65,7 @@ class MujocoEnv { int height_, width_; bool depth_, segmentation_; const std::string& camera_id_; - float* rgb_array_; + unsigned char* rgb_array_; float* depth_array_; #ifdef ENVPOOL_TEST std::unique_ptr qpos0_; From 2a706bd3d0c2c09b424e10a2382e8f88059cbddb Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 15 Aug 2022 16:21:20 +0800 Subject: [PATCH 27/62] fix: update types --- envpool/mujoco/dmc/mujoco_env.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index 90aaa387..77fdbee9 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -81,8 +81,8 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, // set rendering to offscreen buffer mjr_setBuffer(mjFB_OFFSCREEN, &context_); // allocate rgb and depth buffers - std::vector rgb_array_({3, width_, height_}); - std::vector depth_array_({width_, height_}); + std::vector rgb_array_(3 * width_ * height_); + std::vector depth_array_(width_ * height_); // camera configuration // cam.lookat[0] = m->stat.center[0]; // cam.lookat[1] = m->stat.center[1]; From 881734b9f999bf072f3299ad98fa1f43535d7be5 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 15 Aug 2022 16:25:48 +0800 Subject: [PATCH 28/62] fix: update stack wrapper --- envpool/mujoco/dmc/mujoco_env.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/envpool/mujoco/dmc/mujoco_env.h b/envpool/mujoco/dmc/mujoco_env.h index c71ac1b1..1acfbc9d 100644 --- a/envpool/mujoco/dmc/mujoco_env.h +++ b/envpool/mujoco/dmc/mujoco_env.h @@ -59,12 +59,12 @@ class MujocoEnv { mjvCamera camera_; mjvOption option_; mjrContext context_; + const std::string& camera_id_; int n_sub_steps_, max_episode_steps_, elapsed_step_; float reward_, discount_; bool done_; int height_, width_; bool depth_, segmentation_; - const std::string& camera_id_; unsigned char* rgb_array_; float* depth_array_; #ifdef ENVPOOL_TEST @@ -122,6 +122,8 @@ class MujocoEnv { void initOpenGL(void); // close OpenGL context/window void closeOpenGL(void); + + void PushStack(bool push_all, bool maxpool); }; } // namespace mujoco_dmc From a7e81383bdde6129507f01ae4cc97999ebb08ea0 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 15 Aug 2022 16:27:04 +0800 Subject: [PATCH 29/62] fix: update stack wrapper --- envpool/mujoco/dmc/mujoco_env.cc | 78 ++++++++++++++++---------------- envpool/mujoco/dmc/mujoco_env.h | 2 +- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index 77fdbee9..a56cb7f0 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -281,45 +281,45 @@ void MujocoEnv::RandomizeLimitedAndRotationalJoints(std::mt19937* gen) { * @param maxpool whether to perform maxpool operation on the last two * observation. Maybe there is only one? */ -void MujocoEnv::PushStack(bool push_all, bool maxpool) { - auto* ptr = static_cast(maxpool_buf_[0].Data()); - if (maxpool) { - auto* ptr1 = static_cast(maxpool_buf_[1].Data()); - for (std::size_t i = 0; i < maxpool_buf_[0].size; ++i) { - ptr[i] = std::max(ptr[i], ptr1[i]); - } - } - Resize(maxpool_buf_[0], &resize_img_, use_inter_area_resize_); - Array tgt = std::move(*stack_buf_.begin()); - ptr = static_cast(tgt.Data()); - stack_buf_.pop_front(); - if (gray_scale_) { - tgt.Assign(resize_img_); - } else { - auto* ptr1 = static_cast(resize_img_.Data()); - // tgt = resize_img_.transpose(1, 2, 0) - // tgt[i, j, k] = resize_img_[j, k, i] - std::size_t h = resize_img_.Shape(0); - std::size_t w = resize_img_.Shape(1); - for (std::size_t j = 0; j < h; ++j) { - for (std::size_t k = 0; k < w; ++k) { - for (std::size_t i = 0; i < 3; ++i) { - ptr[i * h * w + j * w + k] = ptr1[j * w * 3 + k * 3 + i]; - } - } - } - } - std::size_t size = tgt.size; - stack_buf_.push_back(std::move(tgt)); - if (push_all) { - for (auto& s : stack_buf_) { - auto* ptr_s = static_cast(s.Data()); - if (ptr != ptr_s) { - std::memcpy(ptr_s, ptr, size); - } - } - } -} +// void MujocoEnv::PushStack(bool push_all, bool maxpool) { +// auto* ptr = static_cast(maxpool_buf_[0].Data()); +// if (maxpool) { +// auto* ptr1 = static_cast(maxpool_buf_[1].Data()); +// for (std::size_t i = 0; i < maxpool_buf_[0].size; ++i) { +// ptr[i] = std::max(ptr[i], ptr1[i]); +// } +// } +// Resize(maxpool_buf_[0], &resize_img_, use_inter_area_resize_); +// Array tgt = std::move(*stack_buf_.begin()); +// ptr = static_cast(tgt.Data()); +// stack_buf_.pop_front(); +// if (gray_scale_) { +// tgt.Assign(resize_img_); +// } else { +// auto* ptr1 = static_cast(resize_img_.Data()); +// // tgt = resize_img_.transpose(1, 2, 0) +// // tgt[i, j, k] = resize_img_[j, k, i] +// std::size_t h = resize_img_.Shape(0); +// std::size_t w = resize_img_.Shape(1); +// for (std::size_t j = 0; j < h; ++j) { +// for (std::size_t k = 0; k < w; ++k) { +// for (std::size_t i = 0; i < 3; ++i) { +// ptr[i * h * w + j * w + k] = ptr1[j * w * 3 + k * 3 + i]; +// } +// } +// } +// } +// std::size_t size = tgt.size; +// stack_buf_.push_back(std::move(tgt)); +// if (push_all) { +// for (auto& s : stack_buf_) { +// auto* ptr_s = static_cast(s.Data()); +// if (ptr != ptr_s) { +// std::memcpy(ptr_s, ptr, size); +// } +// } +// } +// } // create OpenGL context/window void MujocoEnv::initOpenGL(void) { diff --git a/envpool/mujoco/dmc/mujoco_env.h b/envpool/mujoco/dmc/mujoco_env.h index 1acfbc9d..851f006c 100644 --- a/envpool/mujoco/dmc/mujoco_env.h +++ b/envpool/mujoco/dmc/mujoco_env.h @@ -123,7 +123,7 @@ class MujocoEnv { // close OpenGL context/window void closeOpenGL(void); - void PushStack(bool push_all, bool maxpool); + // void PushStack(bool push_all, bool maxpool); }; } // namespace mujoco_dmc From 5ab45df0604b7be166eaf0828587c9ba8e90642b Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 15 Aug 2022 16:45:59 +0800 Subject: [PATCH 30/62] fix: update mujoco env argument --- envpool/mujoco/dmc/mujoco_env.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index a56cb7f0..6da5a97a 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -25,9 +25,10 @@ namespace mujoco_dmc { MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, - int n_sub_steps, int max_episode_steps, const int height, - const int width, const std::string& camera_id, bool depth, - bool segmentation) + int n_sub_steps, int max_episode_steps, + const int height = 240, const int width = 320, + const std::string& camera_id = std::string("-1"), + bool depth = false, bool segmentation = false) : n_sub_steps_(n_sub_steps), max_episode_steps_(max_episode_steps), elapsed_step_(max_episode_steps + 1), From ce366122affe1a24594f66bec0be2eb770ffb953 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 15 Aug 2022 20:36:30 +0800 Subject: [PATCH 31/62] fix: update mujoco env arguments --- envpool/mujoco/dmc/mujoco_env.cc | 3 --- envpool/mujoco/dmc/mujoco_env.h | 6 ++++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index 6da5a97a..fcb1de3b 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -417,7 +417,6 @@ void MujocoEnv::initOpenGL(void) { mju_error_i("Could not make EGL context current, error 0x%x\n", eglGetError()); } - #endif } @@ -436,7 +435,6 @@ void MujocoEnv::closeOpenGL(void) { OSMesaDestroyContext(ctx); //------------------------ EGL - #else // get current display EGLDisplay eglDpy = eglGetCurrentDisplay(); @@ -457,7 +455,6 @@ void MujocoEnv::closeOpenGL(void) { // terminate display eglTerminate(eglDpy); - #endif } } // namespace mujoco_dmc diff --git a/envpool/mujoco/dmc/mujoco_env.h b/envpool/mujoco/dmc/mujoco_env.h index 851f006c..4a670168 100644 --- a/envpool/mujoco/dmc/mujoco_env.h +++ b/envpool/mujoco/dmc/mujoco_env.h @@ -73,8 +73,10 @@ class MujocoEnv { public: MujocoEnv(const std::string& base_path, const std::string& raw_xml, - int n_sub_steps, int max_episode_steps, int height, int width, - const std::string& camera_id, bool depth, bool segmentation); + int n_sub_steps, int max_episode_steps, const int height = 240, + const int width = 320, + const std::string& camera_id = std::string("-1"), + bool depth = false, bool segmentation = false); ~MujocoEnv(); // rl control Environment From 66b0c8c6efb37e64da364f19af73a40d80bfcbad Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 15 Aug 2022 20:39:05 +0800 Subject: [PATCH 32/62] fix: update mujoco env arguments --- envpool/mujoco/dmc/mujoco_env.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index fcb1de3b..69372733 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -26,9 +26,9 @@ namespace mujoco_dmc { MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, int n_sub_steps, int max_episode_steps, - const int height = 240, const int width = 320, - const std::string& camera_id = std::string("-1"), - bool depth = false, bool segmentation = false) + const int height, const int, + const std::string& camera_id, + bool depth, bool segmentation) : n_sub_steps_(n_sub_steps), max_episode_steps_(max_episode_steps), elapsed_step_(max_episode_steps + 1), From 3e4169c809bd06c4a7fbcade9df6e61d89aed112 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 15 Aug 2022 20:40:33 +0800 Subject: [PATCH 33/62] fix: update mujoco env arguments --- envpool/mujoco/dmc/mujoco_env.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index 69372733..6be2ca6c 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -25,10 +25,9 @@ namespace mujoco_dmc { MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, - int n_sub_steps, int max_episode_steps, - const int height, const int, - const std::string& camera_id, - bool depth, bool segmentation) + int n_sub_steps, int max_episode_steps, const int height, + const int, const std::string& camera_id, bool depth, + bool segmentation) : n_sub_steps_(n_sub_steps), max_episode_steps_(max_episode_steps), elapsed_step_(max_episode_steps + 1), From 138c3827e37f2aea3e13603704795d1b54c4fe91 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 15 Aug 2022 22:23:49 +0800 Subject: [PATCH 34/62] fix: update egl bazel build file --- envpool/mujoco/BUILD | 1 + envpool/workspace0.bzl | 1 + third_party/egl/egl.BUILD | 49 ++++++++++++++++++++++++++++++--------- 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/envpool/mujoco/BUILD b/envpool/mujoco/BUILD index 559b0539..6ff804ce 100644 --- a/envpool/mujoco/BUILD +++ b/envpool/mujoco/BUILD @@ -95,6 +95,7 @@ cc_library( "@mujoco//:mujoco_lib", "@pugixml", ], + alwayslink = 1, ) pybind_extension( diff --git a/envpool/workspace0.bzl b/envpool/workspace0.bzl index 66ee6dae..3a45b9d3 100644 --- a/envpool/workspace0.bzl +++ b/envpool/workspace0.bzl @@ -18,6 +18,7 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository") load("//third_party/cuda:cuda.bzl", "cuda_configure") +load("//third_party/cuda:cuda.bzl", "cuda_configure") def workspace(): """Load requested packages.""" diff --git a/third_party/egl/egl.BUILD b/third_party/egl/egl.BUILD index 208bfeff..c5ba93ba 100644 --- a/third_party/egl/egl.BUILD +++ b/third_party/egl/egl.BUILD @@ -1,16 +1,43 @@ -package(default_visibility = ["//visibility:public"]) +# Copyright 2021 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This is loaded in workspace0.bzl to provide EGL library.""" -LINUX_LINKOPTS = [] +_EGL_DIR = "/usr/include" + +def _impl(rctx): + cuda_dir = rctx.os.environ.get(_EGL_DIR, default = "/usr/include") + rctx.file("WORKSPACE") + rctx.file("BUILD", content = """ +package(default_visibility = ["//visibility:public"]) cc_library( - name = "glfw", - hdrs = [ - "include/GLFW/glfw3.h", - "include/GLFW/glfw3native.h", + name = "EGL_headers", + srcs = [ + "EGL/egl.h", + "EGL/eglext.h", + "EGL/eglplatform.h", + "KHR/khrplatform.h", + ], + defines = ["USE_OZONE"], + includes = ["."], +) +""") + +egl_headers = repository_rule( + implementation = _impl, + environ = [ + _CUDA_DIR, ], - linkopts = select({ - "@bazel_tools//src/conditions:linux_x86_64": LINUX_LINKOPTS, - }), - strip_include_prefix = "include", - deps = [], ) From 0dee3d1534e3561e20237550bf2eb455f1d26ae0 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 15 Aug 2022 22:30:34 +0800 Subject: [PATCH 35/62] fix: update egl bazel build file --- envpool/workspace0.bzl | 2 +- third_party/egl/{egl.BUILD => egl.bzl} | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) rename third_party/egl/{egl.BUILD => egl.bzl} (78%) diff --git a/envpool/workspace0.bzl b/envpool/workspace0.bzl index 3a45b9d3..19e1951a 100644 --- a/envpool/workspace0.bzl +++ b/envpool/workspace0.bzl @@ -18,7 +18,7 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository") load("//third_party/cuda:cuda.bzl", "cuda_configure") -load("//third_party/cuda:cuda.bzl", "cuda_configure") +load("//third_party/egl:egl.bzl", "egl_headers") def workspace(): """Load requested packages.""" diff --git a/third_party/egl/egl.BUILD b/third_party/egl/egl.bzl similarity index 78% rename from third_party/egl/egl.BUILD rename to third_party/egl/egl.bzl index c5ba93ba..eae99075 100644 --- a/third_party/egl/egl.BUILD +++ b/third_party/egl/egl.bzl @@ -17,7 +17,8 @@ _EGL_DIR = "/usr/include" def _impl(rctx): - cuda_dir = rctx.os.environ.get(_EGL_DIR, default = "/usr/include") + egl_dir = rctx.os.environ.get(_EGL_DIR, default = "/usr/include") + rctx.symlink("{}/include".format(egl_dir), "include") rctx.file("WORKSPACE") rctx.file("BUILD", content = """ package(default_visibility = ["//visibility:public"]) @@ -25,10 +26,10 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "EGL_headers", srcs = [ - "EGL/egl.h", - "EGL/eglext.h", - "EGL/eglplatform.h", - "KHR/khrplatform.h", + "include/EGL/egl.h", + "include/EGL/eglext.h", + "include/EGL/eglplatform.h", + "include/KHR/khrplatform.h", ], defines = ["USE_OZONE"], includes = ["."], @@ -38,6 +39,6 @@ cc_library( egl_headers = repository_rule( implementation = _impl, environ = [ - _CUDA_DIR, + _EGL_DIR, ], ) From d4e3a4168aa5564a37e7de8e14f3fb5762f83f46 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 15 Aug 2022 22:37:54 +0800 Subject: [PATCH 36/62] fix: update egl bazel build file --- envpool/mujoco/BUILD | 1 + envpool/workspace0.bzl | 5 +++++ third_party/egl/egl.bzl | 5 ++--- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/envpool/mujoco/BUILD b/envpool/mujoco/BUILD index 6ff804ce..f1ed51af 100644 --- a/envpool/mujoco/BUILD +++ b/envpool/mujoco/BUILD @@ -94,6 +94,7 @@ cc_library( "//envpool/core:async_envpool", "@mujoco//:mujoco_lib", "@pugixml", + "@egl//:egl_headers", ], alwayslink = 1, ) diff --git a/envpool/workspace0.bzl b/envpool/workspace0.bzl index 19e1951a..d747ea07 100644 --- a/envpool/workspace0.bzl +++ b/envpool/workspace0.bzl @@ -402,4 +402,9 @@ def workspace(): name = "cuda", ) + maybe( + egl_headers, + name = "egl", + ) + workspace0 = workspace diff --git a/third_party/egl/egl.bzl b/third_party/egl/egl.bzl index eae99075..fcb055a3 100644 --- a/third_party/egl/egl.bzl +++ b/third_party/egl/egl.bzl @@ -24,15 +24,14 @@ def _impl(rctx): package(default_visibility = ["//visibility:public"]) cc_library( - name = "EGL_headers", - srcs = [ + name = "egl_headers", + hdrs = [ "include/EGL/egl.h", "include/EGL/eglext.h", "include/EGL/eglplatform.h", "include/KHR/khrplatform.h", ], defines = ["USE_OZONE"], - includes = ["."], ) """) From 10f36ff2b242e7db728bce8496d7202bdf50d08d Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 15 Aug 2022 22:38:18 +0800 Subject: [PATCH 37/62] fix: update egl bazel build file --- envpool/mujoco/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/envpool/mujoco/BUILD b/envpool/mujoco/BUILD index f1ed51af..4046cc8c 100644 --- a/envpool/mujoco/BUILD +++ b/envpool/mujoco/BUILD @@ -92,9 +92,9 @@ cc_library( ], deps = [ "//envpool/core:async_envpool", + "@egl//:egl_headers", "@mujoco//:mujoco_lib", "@pugixml", - "@egl//:egl_headers", ], alwayslink = 1, ) From fd0a3ad69b2f58bd198bd3947284ad4bbe448485 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 15 Aug 2022 22:52:12 +0800 Subject: [PATCH 38/62] fix: update egl bazel build file --- third_party/egl/egl.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/egl/egl.bzl b/third_party/egl/egl.bzl index fcb055a3..36693c31 100644 --- a/third_party/egl/egl.bzl +++ b/third_party/egl/egl.bzl @@ -17,7 +17,7 @@ _EGL_DIR = "/usr/include" def _impl(rctx): - egl_dir = rctx.os.environ.get(_EGL_DIR, default = "/usr/include") + egl_dir = rctx.os.environ.get(_EGL_DIR, default = "/usr") rctx.symlink("{}/include".format(egl_dir), "include") rctx.file("WORKSPACE") rctx.file("BUILD", content = """ From e0a2e85616e3668c244e673c3155d6cd8f49297d Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 15 Aug 2022 22:56:23 +0800 Subject: [PATCH 39/62] fix: update egl bazel build file --- envpool/mujoco/BUILD | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/envpool/mujoco/BUILD b/envpool/mujoco/BUILD index 4046cc8c..63891627 100644 --- a/envpool/mujoco/BUILD +++ b/envpool/mujoco/BUILD @@ -87,12 +87,11 @@ cc_library( ], data = [":gen_mujoco_dmc_xml"], linkopts = [ - "-lEGL", "-lGL", + "-lEGL", ], deps = [ "//envpool/core:async_envpool", - "@egl//:egl_headers", "@mujoco//:mujoco_lib", "@pugixml", ], From 4b168eec92083e25a725bd1657e8e3e47e914bf8 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 15 Aug 2022 23:01:21 +0800 Subject: [PATCH 40/62] fix: update egl bazel build file --- .bazelrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.bazelrc b/.bazelrc index 4b603b72..96a8c017 100644 --- a/.bazelrc +++ b/.bazelrc @@ -3,7 +3,7 @@ build --action_env=BAZEL_LINKOPTS=-static-libgcc build --action_env=CUDA_DIR=/usr/local/cuda build --incompatible_strict_action_env --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 --client_env=BAZEL_CXXOPTS=-std=c++17 build:debug --cxxopt=-DENVPOOL_TEST --compilation_mode=dbg -s -build:test --cxxopt=-DENVPOOL_TEST --copt=-g0 --copt=-O3 --copt=-DNDEBUG --copt=-msse --copt=-msse2 --copt=-mmmx +build:test --cxxopt=-DENVPOOL_TEST --copt=-g0 --copt=-O3 --copt=-DNDEBUG --copt=-msse --copt=-msse2 --copt=-mmmx --copt -DMESA_EGL_NO_X11_HEADERS --copt -DEGL_NO_X11 build:release --copt=-g0 --copt=-O3 --copt=-DNDEBUG --copt=-msse --copt=-msse2 --copt=-mmmx build:clang-tidy --aspects @bazel_clang_tidy//clang_tidy:clang_tidy.bzl%clang_tidy_aspect From bf4a08bd68a941f994a97415ba7032f5286a8da4 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Thu, 18 Aug 2022 23:08:00 +0800 Subject: [PATCH 41/62] fix: egl initialization error --- envpool/mujoco/dmc/mujoco_env.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index 6be2ca6c..f8078bcb 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -387,9 +387,11 @@ void MujocoEnv::initOpenGL(void) { // initialize EGLint major, minor; - if (eglInitialize(eglDpy, &major, &minor) != EGL_TRUE) { - mju_error_i("Could not initialize EGL, error 0x%x\n", eglGetError()); - } + eglInitialize(eglDpy, &major, &minor); + + // if (eglInitialize(eglDpy, &major, &minor) != EGL_TRUE) { + // mju_error_i("Could not initialize EGL, error 0x%x\n", eglGetError()); + // } // choose config EGLint numConfigs; From ba0a868997a638d8fcf3d8e35b5f8d39f13985a3 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 22 Aug 2022 14:48:31 +0800 Subject: [PATCH 42/62] fix: switch to osmesa gl backend --- .bazelrc | 2 +- envpool/mujoco/BUILD | 4 ++-- envpool/mujoco/dmc/mujoco_env.cc | 7 +++++++ envpool/mujoco/dmc/mujoco_env.h | 4 ++-- 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/.bazelrc b/.bazelrc index 96a8c017..423926f2 100644 --- a/.bazelrc +++ b/.bazelrc @@ -3,7 +3,7 @@ build --action_env=BAZEL_LINKOPTS=-static-libgcc build --action_env=CUDA_DIR=/usr/local/cuda build --incompatible_strict_action_env --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 --client_env=BAZEL_CXXOPTS=-std=c++17 build:debug --cxxopt=-DENVPOOL_TEST --compilation_mode=dbg -s -build:test --cxxopt=-DENVPOOL_TEST --copt=-g0 --copt=-O3 --copt=-DNDEBUG --copt=-msse --copt=-msse2 --copt=-mmmx --copt -DMESA_EGL_NO_X11_HEADERS --copt -DEGL_NO_X11 +build:test --cxxopt=-DENVPOOL_TEST --copt=-g0 --copt=-O3 --copt=-DNDEBUG --copt=-msse --copt=-msse2 --copt=-mmmx --copt=-DMESA_EGL_NO_X11_HEADERS --copt=-DEGL_NO_X11 --copt=-DMJ_OSMESA build:release --copt=-g0 --copt=-O3 --copt=-DNDEBUG --copt=-msse --copt=-msse2 --copt=-mmmx build:clang-tidy --aspects @bazel_clang_tidy//clang_tidy:clang_tidy.bzl%clang_tidy_aspect diff --git a/envpool/mujoco/BUILD b/envpool/mujoco/BUILD index 63891627..ac6da150 100644 --- a/envpool/mujoco/BUILD +++ b/envpool/mujoco/BUILD @@ -87,8 +87,8 @@ cc_library( ], data = [":gen_mujoco_dmc_xml"], linkopts = [ - "-lGL", - "-lEGL", + "-lOSMesa", + "-lrt", ], deps = [ "//envpool/core:async_envpool", diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index f8078bcb..4cef3100 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -21,6 +21,10 @@ #include #include #include +#if defined(MJ_OSMESA) +OSMesaContext ctx; +unsigned char buffer[10000000]; +#endif namespace mujoco_dmc { @@ -78,6 +82,9 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, mjv_makeScene(model_, &scene_, 2000); // void mjr_makeContext(const mjModel* m, mjrContext* con, int fontscale); mjr_makeContext(model_, &context_, 200); + + // default free camera + mjv_defaultFreeCamera(model_, &camera_); // set rendering to offscreen buffer mjr_setBuffer(mjFB_OFFSCREEN, &context_); // allocate rgb and depth buffers diff --git a/envpool/mujoco/dmc/mujoco_env.h b/envpool/mujoco/dmc/mujoco_env.h index 4a670168..1517ffd2 100644 --- a/envpool/mujoco/dmc/mujoco_env.h +++ b/envpool/mujoco/dmc/mujoco_env.h @@ -24,8 +24,8 @@ #include #elif defined(MJ_OSMESA) #include -OSMesaContext ctx; -unsigned char buffer[10000000]; +// OSMesaContext ctx; +// unsigned char buffer[10000000]; #else #include #endif From 71c53e9a469d66ff6782154170ffc24577464f40 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 22 Aug 2022 15:02:11 +0800 Subject: [PATCH 43/62] fix: switch to osmesa gl backend --- envpool/mujoco/dmc/mujoco_env.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index 4cef3100..9be05564 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -84,7 +84,7 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, mjr_makeContext(model_, &context_, 200); // default free camera - mjv_defaultFreeCamera(model_, &camera_); + // mjv_defaultFreeCamera(model_, &camera_); // set rendering to offscreen buffer mjr_setBuffer(mjFB_OFFSCREEN, &context_); // allocate rgb and depth buffers From 8a0bc15d00e9c16c998528b721b8144548a36e20 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 22 Aug 2022 15:13:30 +0800 Subject: [PATCH 44/62] fix: switch to osmesa gl backend --- envpool/mujoco/dmc/mujoco_env.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index 9be05564..7af538ce 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -106,7 +106,7 @@ MujocoEnv::~MujocoEnv() { mj_deleteModel(model_); mjr_freeContext(&context_); mjv_freeScene(&scene_); - closeOpenGL(); + // closeOpenGL(); } // rl control Environment From 129057a93a641293ede9d1d4ab421bc39a3f6989 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 22 Aug 2022 16:41:29 +0800 Subject: [PATCH 45/62] fix: switch to osmesa gl backend --- envpool/mujoco/BUILD | 1 - envpool/mujoco/dmc/mujoco_env.cc | 20 +++++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/envpool/mujoco/BUILD b/envpool/mujoco/BUILD index ac6da150..024b6941 100644 --- a/envpool/mujoco/BUILD +++ b/envpool/mujoco/BUILD @@ -88,7 +88,6 @@ cc_library( data = [":gen_mujoco_dmc_xml"], linkopts = [ "-lOSMesa", - "-lrt", ], deps = [ "//envpool/core:async_envpool", diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index 7af538ce..d20a7458 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -21,10 +21,6 @@ #include #include #include -#if defined(MJ_OSMESA) -OSMesaContext ctx; -unsigned char buffer[10000000]; -#endif namespace mujoco_dmc { @@ -99,14 +95,21 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, #ifdef ENVPOOL_TEST qpos0_.reset(new mjtNum[model_->nq]); #endif +#if defined(MJ_OSMESA) + OSMesaContext ctx_; + unsigned char buffer_[10000000]; +#endif } MujocoEnv::~MujocoEnv() { + std::free(rgb_array_); + std::free(depth_array_); + std::free(buffer_); mj_deleteData(data_); mj_deleteModel(model_); mjr_freeContext(&context_); mjv_freeScene(&scene_); - // closeOpenGL(); + closeOpenGL(); } // rl control Environment @@ -352,13 +355,13 @@ void MujocoEnv::initOpenGL(void) { //------------------------ OSMESA #elif defined(MJ_OSMESA) // create context - ctx = OSMesaCreateContextExt(GL_RGBA, 24, 8, 8, 0); - if (!ctx) { + ctx_ = OSMesaCreateContextExt(GL_RGBA, 24, 8, 8, 0); + if (!ctx_) { mju_error("OSMesa context creation failed"); } // make current - if (!OSMesaMakeCurrent(ctx, buffer, GL_UNSIGNED_BYTE, 800, 800)) { + if (!OSMesaMakeCurrent(ctx_, buffer_, GL_UNSIGNED_BYTE, 800, 800)) { mju_error("OSMesa make current failed"); } @@ -441,7 +444,6 @@ void MujocoEnv::closeOpenGL(void) { //------------------------ OSMESA #elif defined(MJ_OSMESA) OSMesaDestroyContext(ctx); - //------------------------ EGL #else // get current display From e760f245edf2048899ff8fdf303fcc191b83f3ee Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 22 Aug 2022 16:42:15 +0800 Subject: [PATCH 46/62] fix: switch to osmesa gl backend --- envpool/mujoco/dmc/mujoco_env.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index d20a7458..7488211e 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -443,7 +443,7 @@ void MujocoEnv::closeOpenGL(void) { //------------------------ OSMESA #elif defined(MJ_OSMESA) - OSMesaDestroyContext(ctx); + OSMesaDestroyContext(ctx_); //------------------------ EGL #else // get current display From ff02df9393527313bf670b0547186d2f9f950dda Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 22 Aug 2022 16:46:45 +0800 Subject: [PATCH 47/62] fix: switch to osmesa gl backend --- envpool/mujoco/dmc/mujoco_env.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index 7488211e..52ed03e6 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -37,6 +37,10 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, camera_id_("-1"), depth_(false), segmentation_(false) { +#if defined(MJ_OSMESA) + OSMesaContext ctx_; + unsigned char buffer_[10000000]; +#endif initOpenGL(); // initialize vfs from common assets and raw xml // https://github.com/deepmind/dm_control/blob/1.0.2/dm_control/mujoco/wrapper/core.py#L158 @@ -95,10 +99,6 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, #ifdef ENVPOOL_TEST qpos0_.reset(new mjtNum[model_->nq]); #endif -#if defined(MJ_OSMESA) - OSMesaContext ctx_; - unsigned char buffer_[10000000]; -#endif } MujocoEnv::~MujocoEnv() { From 0832add847950507070b34ccf5ce40bd52f66dff Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 22 Aug 2022 16:53:01 +0800 Subject: [PATCH 48/62] fix: switch to osmesa gl backend --- envpool/mujoco/dmc/mujoco_env.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/envpool/mujoco/dmc/mujoco_env.h b/envpool/mujoco/dmc/mujoco_env.h index 1517ffd2..5644a25b 100644 --- a/envpool/mujoco/dmc/mujoco_env.h +++ b/envpool/mujoco/dmc/mujoco_env.h @@ -70,6 +70,10 @@ class MujocoEnv { #ifdef ENVPOOL_TEST std::unique_ptr qpos0_; #endif +#if defined(MJ_OSMESA) + OSMesaContext ctx_; + unsigned char buffer_[10000000]; +#endif public: MujocoEnv(const std::string& base_path, const std::string& raw_xml, From e4a4f289e9836ce4ee649db4cfebc5cc37bfe925 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 22 Aug 2022 16:56:45 +0800 Subject: [PATCH 49/62] fix: switch to osmesa gl backend --- envpool/mujoco/dmc/mujoco_env.cc | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index 52ed03e6..fff3e8b3 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -37,10 +37,6 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, camera_id_("-1"), depth_(false), segmentation_(false) { -#if defined(MJ_OSMESA) - OSMesaContext ctx_; - unsigned char buffer_[10000000]; -#endif initOpenGL(); // initialize vfs from common assets and raw xml // https://github.com/deepmind/dm_control/blob/1.0.2/dm_control/mujoco/wrapper/core.py#L158 @@ -102,9 +98,9 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, } MujocoEnv::~MujocoEnv() { - std::free(rgb_array_); - std::free(depth_array_); - std::free(buffer_); + // std::free(rgb_array_); + // std::free(depth_array_); + // std::free(buffer_); mj_deleteData(data_); mj_deleteModel(model_); mjr_freeContext(&context_); From 314db6ef98e67fa2d00b4a6fa58176cdd41098d6 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 22 Aug 2022 16:58:05 +0800 Subject: [PATCH 50/62] fix: switch to osmesa gl backend --- envpool/mujoco/dmc/mujoco_env.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index fff3e8b3..4aa57209 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -105,7 +105,7 @@ MujocoEnv::~MujocoEnv() { mj_deleteModel(model_); mjr_freeContext(&context_); mjv_freeScene(&scene_); - closeOpenGL(); + // closeOpenGL(); } // rl control Environment From 55bebfba67ab1475eaf0ac7d24b860669de17dc7 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 22 Aug 2022 17:00:21 +0800 Subject: [PATCH 51/62] fix: switch to osmesa gl backend --- envpool/mujoco/dmc/mujoco_env.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index 4aa57209..7bfed468 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -84,8 +84,8 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, // set rendering to offscreen buffer mjr_setBuffer(mjFB_OFFSCREEN, &context_); // allocate rgb and depth buffers - std::vector rgb_array_(3 * width_ * height_); - std::vector depth_array_(width_ * height_); + // std::vector rgb_array_(3 * width_ * height_); + // std::vector depth_array_(width_ * height_); // camera configuration // cam.lookat[0] = m->stat.center[0]; // cam.lookat[1] = m->stat.center[1]; From 699397c970c6fbaf3c7d2d8f94e55f6173912918 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 22 Aug 2022 17:02:02 +0800 Subject: [PATCH 52/62] fix: switch to osmesa gl backend --- envpool/mujoco/dmc/mujoco_env.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/envpool/mujoco/dmc/mujoco_env.h b/envpool/mujoco/dmc/mujoco_env.h index 5644a25b..76de1a45 100644 --- a/envpool/mujoco/dmc/mujoco_env.h +++ b/envpool/mujoco/dmc/mujoco_env.h @@ -72,7 +72,7 @@ class MujocoEnv { #endif #if defined(MJ_OSMESA) OSMesaContext ctx_; - unsigned char buffer_[10000000]; + unsigned char buffer_[100]; #endif public: From 6368c9d0a273a3b3cdeb6afbabc5b7fdc320bfff Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 22 Aug 2022 17:07:07 +0800 Subject: [PATCH 53/62] fix: switch to osmesa gl backend --- envpool/mujoco/dmc/mujoco_env.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index 7bfed468..80254868 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -100,12 +100,12 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, MujocoEnv::~MujocoEnv() { // std::free(rgb_array_); // std::free(depth_array_); - // std::free(buffer_); + std::free(buffer_); mj_deleteData(data_); mj_deleteModel(model_); mjr_freeContext(&context_); mjv_freeScene(&scene_); - // closeOpenGL(); + closeOpenGL(); } // rl control Environment From b61fc5879a4e70acb7a7958d59295dfa16bc1bc2 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 22 Aug 2022 17:09:32 +0800 Subject: [PATCH 54/62] fix: switch to osmesa gl backend --- envpool/mujoco/dmc/mujoco_env.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index 80254868..f5f60a2c 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -100,7 +100,7 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, MujocoEnv::~MujocoEnv() { // std::free(rgb_array_); // std::free(depth_array_); - std::free(buffer_); + mj_deleteData(data_); mj_deleteModel(model_); mjr_freeContext(&context_); @@ -440,6 +440,7 @@ void MujocoEnv::closeOpenGL(void) { //------------------------ OSMESA #elif defined(MJ_OSMESA) OSMesaDestroyContext(ctx_); + std::free(buffer_); //------------------------ EGL #else // get current display From 15f0a3a72471b55f20a54b6487cfefbc2808a2c2 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 22 Aug 2022 17:10:30 +0800 Subject: [PATCH 55/62] fix: switch to osmesa gl backend --- envpool/mujoco/dmc/mujoco_env.cc | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/envpool/mujoco/dmc/mujoco_env.cc b/envpool/mujoco/dmc/mujoco_env.cc index f5f60a2c..7db0d17f 100644 --- a/envpool/mujoco/dmc/mujoco_env.cc +++ b/envpool/mujoco/dmc/mujoco_env.cc @@ -100,12 +100,11 @@ MujocoEnv::MujocoEnv(const std::string& base_path, const std::string& raw_xml, MujocoEnv::~MujocoEnv() { // std::free(rgb_array_); // std::free(depth_array_); - mj_deleteData(data_); mj_deleteModel(model_); - mjr_freeContext(&context_); - mjv_freeScene(&scene_); - closeOpenGL(); + // mjr_freeContext(&context_); + // mjv_freeScene(&scene_); + // closeOpenGL(); } // rl control Environment @@ -440,7 +439,7 @@ void MujocoEnv::closeOpenGL(void) { //------------------------ OSMESA #elif defined(MJ_OSMESA) OSMesaDestroyContext(ctx_); - std::free(buffer_); + // std::free(buffer_); //------------------------ EGL #else // get current display From 42dd8199e2cd5e4370821e4edd7f2403d1b016ed Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 22 Aug 2022 17:33:43 +0800 Subject: [PATCH 56/62] fix: switch to osmesa gl backend --- .bazelrc | 1 + 1 file changed, 1 insertion(+) diff --git a/.bazelrc b/.bazelrc index 423926f2..45bf7bcb 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1,3 +1,4 @@ +build --action_env=DISPLAY='' build --action_env=BAZEL_LINKLIBS=-l%:libstdc++.a:-lm build --action_env=BAZEL_LINKOPTS=-static-libgcc build --action_env=CUDA_DIR=/usr/local/cuda From 9f0bbd5572f522792996861427fb2dcf4c3c81a2 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 22 Aug 2022 17:42:36 +0800 Subject: [PATCH 57/62] fix: switch to osmesa gl backend --- third_party/mujoco/mujoco.BUILD | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/third_party/mujoco/mujoco.BUILD b/third_party/mujoco/mujoco.BUILD index 20d8f951..7f3c17f2 100644 --- a/third_party/mujoco/mujoco.BUILD +++ b/third_party/mujoco/mujoco.BUILD @@ -12,7 +12,10 @@ cc_library( "include/mujoco", "sample", ], - linkopts = ["-Wl,-rpath,'$$ORIGIN'"], + linkopts = [ + "-Wl,-rpath,'$$ORIGIN'", + "-lOSMesa", + ], linkstatic = 0, ) From 778db9bea1052b6a1b8301edaf80a8330cc33f9e Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 22 Aug 2022 17:47:48 +0800 Subject: [PATCH 58/62] chore: comment on the unfixed code --- envpool/make_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/envpool/make_test.py b/envpool/make_test.py index 9225bb9d..0c98b62c 100644 --- a/envpool/make_test.py +++ b/envpool/make_test.py @@ -126,7 +126,7 @@ def test_make_mujoco_gym(self) -> None: "Walker2d-v4", ] ) - + # make test with GL loading error def test_make_mujoco_dmc(self) -> None: self.check_step( [ From b94f4f1f6fc58d9f72ac19405bfd11ad9714431f Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 22 Aug 2022 17:50:48 +0800 Subject: [PATCH 59/62] fix: pass lint --- envpool/make_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/envpool/make_test.py b/envpool/make_test.py index 0c98b62c..a70f0450 100644 --- a/envpool/make_test.py +++ b/envpool/make_test.py @@ -126,6 +126,7 @@ def test_make_mujoco_gym(self) -> None: "Walker2d-v4", ] ) + # make test with GL loading error def test_make_mujoco_dmc(self) -> None: self.check_step( From 9055bc80eb13cf296fde9ae861d1a40be7537bc3 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Sun, 26 Mar 2023 07:19:39 +0800 Subject: [PATCH 60/62] merge: sync with upstream main --- .clang-tidy | 7 +- .github/workflows/lint.yml | 84 +++--- .github/workflows/release.yml | 71 +++--- .github/workflows/test.yml | 32 +-- CODE_OF_CONDUCT.md | 128 ++++++++++ Makefile | 74 ++++-- README.md | 10 +- WORKSPACE | 14 + docker/dev-cn.dockerfile | 34 ++- docker/dev.dockerfile | 25 +- docker/release.dockerfile | 79 ++++-- docs/content/build.rst | 12 +- docs/content/new_env.rst | 26 +- docs/content/python_interface.rst | 12 +- docs/content/xla_interface.rst | 2 +- docs/env/atari.rst | 6 +- docs/env/dm_control.rst | 2 +- docs/env/minigrid.rst | 18 ++ docs/env/procgen.rst | 100 ++++++++ docs/index.rst | 2 + docs/spelling_wordlist.txt | 2 + envpool/BUILD | 2 + envpool/__init__.py | 4 +- envpool/atari/__init__.py | 3 +- envpool/atari/atari_env.h | 22 +- envpool/atari/atari_envpool_test.py | 13 +- envpool/atari/registration.py | 8 +- envpool/box2d/BUILD | 2 +- envpool/box2d/__init__.py | 19 +- envpool/box2d/bipedal_walker_env.cc | 4 +- envpool/box2d/bipedal_walker_env.h | 4 +- envpool/box2d/car_dynamics.cc | 6 +- envpool/box2d/car_dynamics.h | 4 +- envpool/box2d/car_racing_env.cc | 5 +- envpool/box2d/car_racing_env.h | 2 +- envpool/box2d/lunar_lander_env.cc | 5 +- envpool/box2d/lunar_lander_env.h | 4 +- envpool/box2d/registration.py | 5 + envpool/box2d/utils.cc | 4 +- envpool/classic_control/__init__.py | 42 ++- envpool/classic_control/acrobot.h | 11 +- envpool/classic_control/cartpole.h | 5 +- envpool/classic_control/mountain_car.h | 5 +- .../classic_control/mountain_car_continuous.h | 5 +- envpool/classic_control/pendulum.h | 5 +- envpool/classic_control/registration.py | 7 + envpool/core/array.h | 4 +- envpool/core/async_envpool.h | 2 +- envpool/core/dict.h | 9 +- envpool/core/env.h | 5 +- envpool/core/env_spec.h | 4 +- envpool/core/envpool.h | 1 + envpool/core/py_envpool.h | 6 +- envpool/core/spec.h | 2 +- envpool/core/state_buffer_queue_test.cc | 4 + envpool/dummy/__init__.py | 3 +- envpool/dummy/dummy_envpool.h | 5 +- envpool/entry.py | 46 +++- envpool/minigrid/BUILD | 86 +++++++ envpool/minigrid/__init__.py | 28 ++ envpool/minigrid/empty.h | 92 +++++++ envpool/minigrid/impl/minigrid_empty_env.cc | 64 +++++ envpool/minigrid/impl/minigrid_empty_env.h | 35 +++ envpool/minigrid/impl/minigrid_env.cc | 241 ++++++++++++++++++ envpool/minigrid/impl/minigrid_env.h | 58 +++++ envpool/minigrid/impl/utils.h | 154 +++++++++++ envpool/minigrid/minigrid.cc | 21 ++ envpool/minigrid/minigrid_align_test.py | 105 ++++++++ .../minigrid/minigrid_deterministic_test.py | 61 +++++ envpool/minigrid/registration.py | 86 +++++++ envpool/mujoco/dmc/__init__.py | 106 +++++--- envpool/mujoco/dmc/manipulator.h | 2 +- .../mujoco/dmc/mujoco_dmc_suite_align_test.py | 2 +- envpool/mujoco/dmc/registration.py | 8 +- envpool/mujoco/gym/__init__.py | 62 +++-- envpool/mujoco/gym/mujoco_env.h | 5 +- envpool/mujoco/gym/registration.py | 8 +- envpool/pip.bzl | 2 + envpool/procgen/BUILD | 85 ++++++ envpool/procgen/__init__.py | 31 +++ envpool/procgen/procgen_env.h | 213 ++++++++++++++++ envpool/procgen/procgen_env_test.cc | 66 +++++ envpool/procgen/procgen_envpool.cc | 23 ++ envpool/procgen/procgen_test.py | 94 +++++++ envpool/procgen/registration.py | 58 +++++ envpool/python/BUILD | 26 +- envpool/python/api.py | 6 +- envpool/python/data.py | 45 +++- envpool/python/dm_envpool.py | 12 +- envpool/python/env_spec.py | 57 +++++ envpool/python/envpool.py | 10 +- envpool/python/gym_envpool.py | 30 +-- envpool/python/gymnasium_envpool.py | 97 +++++++ envpool/registration.py | 17 +- envpool/toy_text/__init__.py | 40 +-- envpool/toy_text/blackjack.h | 5 +- envpool/toy_text/catch.h | 5 +- envpool/toy_text/cliffwalking.h | 4 +- envpool/toy_text/frozen_lake.h | 5 +- envpool/toy_text/nchain.h | 5 +- envpool/toy_text/registration.py | 7 + envpool/toy_text/taxi.h | 3 +- envpool/vizdoom/__init__.py | 6 +- envpool/vizdoom/registration.py | 5 +- envpool/vizdoom/vizdoom_env.h | 13 +- envpool/workspace0.bzl | 155 ++++++----- envpool/workspace1.bzl | 3 + setup.cfg | 20 +- third_party/gym3_libenv/BUILD | 13 + third_party/gym3_libenv/gym3_libenv.BUILD | 5 + third_party/pip_requirements/.gitignore | 1 + .../pip_requirements/requirements-dev.txt | 21 ++ .../pip_requirements/requirements-release.txt | 9 + third_party/pip_requirements/requirements.txt | 50 ---- third_party/procgen/BUILD | 13 + third_party/procgen/procgen.BUILD | 31 +++ 116 files changed, 2962 insertions(+), 580 deletions(-) create mode 100644 CODE_OF_CONDUCT.md create mode 100644 docs/env/minigrid.rst create mode 100644 docs/env/procgen.rst create mode 100644 envpool/minigrid/BUILD create mode 100644 envpool/minigrid/__init__.py create mode 100644 envpool/minigrid/empty.h create mode 100644 envpool/minigrid/impl/minigrid_empty_env.cc create mode 100644 envpool/minigrid/impl/minigrid_empty_env.h create mode 100644 envpool/minigrid/impl/minigrid_env.cc create mode 100644 envpool/minigrid/impl/minigrid_env.h create mode 100644 envpool/minigrid/impl/utils.h create mode 100644 envpool/minigrid/minigrid.cc create mode 100644 envpool/minigrid/minigrid_align_test.py create mode 100644 envpool/minigrid/minigrid_deterministic_test.py create mode 100644 envpool/minigrid/registration.py create mode 100644 envpool/procgen/BUILD create mode 100644 envpool/procgen/__init__.py create mode 100644 envpool/procgen/procgen_env.h create mode 100644 envpool/procgen/procgen_env_test.cc create mode 100644 envpool/procgen/procgen_envpool.cc create mode 100644 envpool/procgen/procgen_test.py create mode 100644 envpool/procgen/registration.py create mode 100644 envpool/python/gymnasium_envpool.py mode change 100644 => 100755 envpool/workspace0.bzl create mode 100644 third_party/gym3_libenv/BUILD create mode 100644 third_party/gym3_libenv/gym3_libenv.BUILD create mode 100644 third_party/pip_requirements/.gitignore create mode 100644 third_party/pip_requirements/requirements-dev.txt create mode 100644 third_party/pip_requirements/requirements-release.txt delete mode 100644 third_party/pip_requirements/requirements.txt create mode 100644 third_party/procgen/BUILD create mode 100644 third_party/procgen/procgen.BUILD diff --git a/.clang-tidy b/.clang-tidy index e835e13f..d62bd5b9 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -20,9 +20,12 @@ Checks: ' performance-*, portability-*, readability-*, - -clang-diagnostic-delete-non-abstract-non-virtual-dtor, - -google-runtime-references, + -bugprone-easily-swappable-parameters, + -bugprone-implicit-widening-of-multiplication-result, + -bugprone-narrowing-conversions, -modernize-use-trailing-return-type, + -readability-function-cognitive-complexity, + -readability-identifier-length, -readability-magic-numbers, -readability-static-accessed-through-instance, -readability-uppercase-literal-suffix, diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 8c4f79b3..04dc618e 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -6,45 +6,45 @@ jobs: lint: runs-on: ubuntu-latest steps: - - name: Cancel previous run - uses: styfle/cancel-workflow-action@0.9.1 - with: - access_token: ${{ github.token }} - - uses: actions/checkout@v2 - - uses: actions/setup-go@v2 - with: - go-version: '^1.17.3' - - name: Set up Python 3.8 - uses: actions/setup-python@v2 - with: - python-version: 3.8 - - name: Upgrade pip - run: | - python -m pip install --upgrade pip - - name: flake8 - run: | - make flake8 - - name: isort and yapf - run: | - make py-format - - name: cpplint - run: | - make cpplint - - name: clang-format - run: | - make clang-format - - name: buildifier - run: | - make buildifier - - name: addlicense - run: | - make addlicense - - name: mypy - run: | - make mypy - - name: docstyle - run: | - make docstyle - - name: spelling - run: | - make spelling + - name: Cancel previous run + uses: styfle/cancel-workflow-action@0.11.0 + with: + access_token: ${{ github.token }} + - uses: actions/checkout@v3 + - uses: actions/setup-go@v3 + with: + go-version: ">=1.16.0" + - name: Set up Python 3.10 + uses: actions/setup-python@v4 + with: + python-version: "3.10" + - name: Upgrade pip + run: | + python -m pip install --upgrade pip + - name: flake8 + run: | + make flake8 + - name: isort and yapf + run: | + make py-format + - name: cpplint + run: | + make cpplint + - name: clang-format + run: | + make clang-format + - name: buildifier + run: | + make buildifier + - name: addlicense + run: | + make addlicense + - name: mypy + run: | + make mypy + - name: docstyle + run: | + make docstyle + - name: spelling + run: | + make spelling diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e7e04f15..ff2c6e2b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -5,40 +5,49 @@ on: push: branches: - main + tags: + - v* jobs: release: - runs-on: ${{ matrix.python-version }} + runs-on: ubuntu-latest + container: trinkle23897/envpool-release:2023-01-02-5f1a5fd strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10'] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] steps: - - name: Cancel previous run - uses: styfle/cancel-workflow-action@0.9.1 - with: - access_token: ${{ github.token }} - - uses: actions/checkout@v2 - - uses: actions/setup-go@v2 - with: - go-version: '^1.17.3' - - name: Set up Python ${{ matrix.python-version }} - run: | - pyenv global ${{ matrix.python-version }}-dev - # uses: actions/setup-python@v2 - # with: - # python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip setuptools wheel - - name: Build - run: | - make pypi-wheel BAZELOPT="--remote_cache=http://bazel-cache.sail:8080" - pip3 install wheelhouse/*.whl --force-reinstall - - name: Test - run: | - make release-test - - name: Upload artifact - uses: actions/upload-artifact@main - with: - name: wheel - path: wheelhouse/ + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + run: | + ln -sf /root/.cache $HOME/.cache + ln -sf /root/.pyenv $HOME/.pyenv + pyenv global ${{ matrix.python-version }}-dev + - name: Build + run: | + make pypi-wheel + pip3 install dist/*.whl --force-reinstall + - name: Test + run: | + make release-test + - name: Upload artifact + uses: actions/upload-artifact@main + with: + name: wheel + path: wheelhouse/ + + publish: + runs-on: ubuntu-latest + needs: [release] + steps: + - uses: actions/download-artifact@v3 + with: + path: artifact + - name: Move files so the next action can find them + run: | + mkdir dist && mv artifact/wheel/* dist/ + ls dist/ + - name: Publish distribution to PyPI + if: startsWith(github.ref, 'refs/tags') + uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7613f464..5bd98c8a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,24 +6,14 @@ jobs: test: runs-on: [self-hosted, Linux, X64, Test] steps: - - name: Cancel previous run - uses: styfle/cancel-workflow-action@0.9.1 - with: - access_token: ${{ github.token }} - - uses: actions/checkout@v2 - - uses: actions/setup-go@v2 - with: - go-version: '^1.17.3' - - name: Set up Python 3.8 - uses: actions/setup-python@v2 - with: - python-version: 3.8 - - name: Install dependencies - run: | - python -m pip install --upgrade pip setuptools wheel - - name: Test - run: | - make bazel-test BAZELOPT="--remote_cache=http://bazel-cache.sail:8080" - - name: Run clang-tidy - run: | - make clang-tidy BAZELOPT="--remote_cache=http://bazel-cache.sail:8080" + - name: Cancel previous run + uses: styfle/cancel-workflow-action@0.11.0 + with: + access_token: ${{ github.token }} + - uses: actions/checkout@v3 + - name: Test + run: | + make bazel-test + - name: Run clang-tidy + run: | + make clang-tidy diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 00000000..18c91471 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/Makefile b/Makefile index 01f34a8d..0a024181 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,9 @@ BAZEL_FILES = $(shell find . -type f -name "*BUILD" -o -name "*.bzl") COMMIT_HASH = $(shell git log -1 --format=%h) COPYRIGHT = "Garena Online Private Limited" BAZELOPT = +DATE = $(shell date "+%Y-%m-%d") +DOCKER_TAG = $(DATE)-$(COMMIT_HASH) +DOCKER_USER = trinkle23897 PATH := $(HOME)/go/bin:$(PATH) # installation @@ -29,14 +32,14 @@ cpplint-install: $(call check_install, cpplint) clang-format-install: - command -v clang-format-11 || sudo apt-get install -y clang-format-11 + command -v clang-format || sudo apt-get install -y clang-format clang-tidy-install: command -v clang-tidy || sudo apt-get install -y clang-tidy go-install: # requires go >= 1.16 - command -v go || (sudo apt-get install -y golang-1.16 && sudo ln -sf /usr/lib/go-1.16/bin/go /usr/bin/go) + command -v go || (sudo apt-get install -y golang-1.18 && sudo ln -sf /usr/lib/go-1.18/bin/go /usr/bin/go) bazel-install: go-install command -v bazel || (go install github.com/bazelbuild/bazelisk@latest && ln -sf $(HOME)/go/bin/bazelisk $(HOME)/go/bin/bazel) @@ -55,7 +58,7 @@ doc-install: $(call check_install_extra, sphinxcontrib.spelling, sphinxcontrib.spelling pyenchant) auditwheel-install: - $(call check_install_extra, auditwheel, auditwheel typed-ast) + $(call check_install_extra, auditwheel, auditwheel typed-ast patchelf) # python linter @@ -74,7 +77,7 @@ cpplint: cpplint-install cpplint $(CPP_FILES) clang-format: clang-format-install - clang-format-11 --style=file -i $(CPP_FILES) -n --Werror + clang-format --style=file -i $(CPP_FILES) -n --Werror # bazel file linter @@ -83,28 +86,31 @@ buildifier: buildifier-install # bazel build/test -clang-tidy: clang-tidy-install +bazel-pip-requirement-dev: + cd third_party/pip_requirements && (cmp requirements.txt requirements-dev.txt || ln -sf requirements-dev.txt requirements.txt) + +bazel-pip-requirement-release: + cd third_party/pip_requirements && (cmp requirements.txt requirements-release.txt || ln -sf requirements-release.txt requirements.txt) + +clang-tidy: clang-tidy-install bazel-pip-requirement-dev bazel build $(BAZELOPT) //... --config=clang-tidy --config=test -bazel-debug: bazel-install - bazel build $(BAZELOPT) //... --config=debug +bazel-debug: bazel-install bazel-pip-requirement-dev bazel run $(BAZELOPT) //:setup --config=debug -- bdist_wheel mkdir -p dist cp bazel-bin/setup.runfiles/$(PROJECT_NAME)/dist/*.whl ./dist -bazel-build: bazel-install - bazel build $(BAZELOPT) //... --config=test +bazel-build: bazel-install bazel-pip-requirement-dev bazel run $(BAZELOPT) //:setup --config=test -- bdist_wheel mkdir -p dist cp bazel-bin/setup.runfiles/$(PROJECT_NAME)/dist/*.whl ./dist -bazel-release: bazel-install - bazel build $(BAZELOPT) //... --config=release +bazel-release: bazel-install bazel-pip-requirement-release bazel run $(BAZELOPT) //:setup --config=release -- bdist_wheel mkdir -p dist cp bazel-bin/setup.runfiles/$(PROJECT_NAME)/dist/*.whl ./dist -bazel-test: bazel-install +bazel-test: bazel-install bazel-pip-requirement-dev bazel test --test_output=all $(BAZELOPT) //... --config=test --spawn_strategy=local --color=yes bazel-clean: bazel-install @@ -113,7 +119,7 @@ bazel-clean: bazel-install # documentation addlicense: addlicense-install - addlicense -c $(COPYRIGHT) -l apache -y 2022 -check $(PROJECT_FOLDER) + addlicense -c $(COPYRIGHT) -l apache -y 2023 -check $(PROJECT_FOLDER) docstyle: doc-install pydocstyle $(PROJECT_NAME) && doc8 docs && cd docs && make html SPHINXOPTS="-W" @@ -136,31 +142,45 @@ lint: buildifier flake8 py-format clang-format cpplint clang-tidy mypy docstyle format: py-format-install clang-format-install buildifier-install addlicense-install isort $(PYTHON_FILES) yapf -ir $(PYTHON_FILES) - clang-format-11 -style=file -i $(CPP_FILES) + clang-format -style=file -i $(CPP_FILES) buildifier -r -lint=fix $(BAZEL_FILES) - addlicense -c $(COPYRIGHT) -l apache -y 2022 $(PROJECT_FOLDER) + addlicense -c $(COPYRIGHT) -l apache -y 2023 $(PROJECT_FOLDER) # Build docker images -docker-dev: - docker build --network=host -t $(PROJECT_NAME):$(COMMIT_HASH) -f docker/dev.dockerfile . - docker run --network=host -v /:/host -it $(PROJECT_NAME):$(COMMIT_HASH) bash - echo successfully build docker image with tag $(PROJECT_NAME):$(COMMIT_HASH) +docker-ci: + docker build --network=host -t $(PROJECT_NAME):$(DOCKER_TAG) -f docker/dev.dockerfile . + echo successfully build docker image with tag $(PROJECT_NAME):$(DOCKER_TAG) + +docker-ci-push: docker-ci + docker tag $(PROJECT_NAME):$(DOCKER_TAG) $(DOCKER_USER)/$(PROJECT_NAME):$(DOCKER_TAG) + docker push $(DOCKER_USER)/$(PROJECT_NAME):$(DOCKER_TAG) + +docker-ci-launch: docker-ci + docker run --network=host -v /home/ubuntu:/home/github-action --shm-size=4gb -it $(PROJECT_NAME):$(DOCKER_TAG) bash + +docker-dev: docker-ci + docker run --network=host -v /:/host -v $(shell pwd):/app -v $(HOME)/.cache:/root/.cache --shm-size=4gb -it $(PROJECT_NAME):$(DOCKER_TAG) zsh # for mainland China docker-dev-cn: - docker build --network=host -t $(PROJECT_NAME):$(COMMIT_HASH) -f docker/dev-cn.dockerfile . - docker run --network=host -v /:/host -it $(PROJECT_NAME):$(COMMIT_HASH) bash - echo successfully build docker image with tag $(PROJECT_NAME):$(COMMIT_HASH) + docker build --network=host -t $(PROJECT_NAME):$(DOCKER_TAG) -f docker/dev-cn.dockerfile . + echo successfully build docker image with tag $(PROJECT_NAME):$(DOCKER_TAG) + docker run --network=host -v /:/host -v $(shell pwd):/app -v $(HOME)/.cache:/root/.cache --shm-size=4gb -it $(PROJECT_NAME):$(DOCKER_TAG) zsh docker-release: - docker build --network=host -t $(PROJECT_NAME)-release:$(COMMIT_HASH) -f docker/release.dockerfile . - mkdir -p wheelhouse - docker run --network=host -v `pwd`/wheelhouse:/whl -it $(PROJECT_NAME)-release:$(COMMIT_HASH) bash -c "cp wheelhouse/* /whl" - echo successfully build docker image with tag $(PROJECT_NAME)-release:$(COMMIT_HASH) + docker build --network=host -t $(PROJECT_NAME)-release:$(DOCKER_TAG) -f docker/release.dockerfile . + echo successfully build docker image with tag $(PROJECT_NAME)-release:$(DOCKER_TAG) + +docker-release-push: docker-release + docker tag $(PROJECT_NAME)-release:$(DOCKER_TAG) $(DOCKER_USER)/$(PROJECT_NAME)-release:$(DOCKER_TAG) + docker push $(DOCKER_USER)/$(PROJECT_NAME)-release:$(DOCKER_TAG) + +docker-release-launch: docker-release + docker run --network=host -v /:/host -v $(shell pwd):/app -v $(HOME)/.cache:/root/.cache --shm-size=4gb -it $(PROJECT_NAME)-release:$(DOCKER_TAG) zsh pypi-wheel: auditwheel-install bazel-release - ls dist/*.whl -Art | tail -n 1 | xargs auditwheel repair --plat manylinux_2_17_x86_64 + ls dist/*.whl -Art | tail -n 1 | xargs auditwheel repair --plat manylinux_2_24_x86_64 release-test1: cd envpool && python3 make_test.py diff --git a/README.md b/README.md index ad43fb4b..e2a4b2fb 100644 --- a/README.md +++ b/README.md @@ -16,12 +16,12 @@ - [x] [ViZDoom single player](https://envpool.readthedocs.io/en/latest/env/vizdoom.html) - [x] [DeepMind Control Suite](https://envpool.readthedocs.io/en/latest/env/dm_control.html) - [x] [Box2D](https://envpool.readthedocs.io/en/latest/env/box2d.html) -- [ ] Procgen -- [ ] Minigrid +- [x] [Procgen](https://envpool.readthedocs.io/en/latest/env/procgen.html) +- [x] [Minigrid](https://envpool.readthedocs.io/en/latest/env/minigrid.html) Here are EnvPool's several highlights: -- Compatible with OpenAI `gym` APIs and DeepMind `dm_env` APIs; +- Compatible with OpenAI `gym` APIs, DeepMind `dm_env` APIs, and [`gymnasium`](https://github.com/Farama-Foundation/Gymnasium) APIs; - Manage a pool of envs, interact with the envs in batched APIs by default; - Support both synchronous execution and asynchronous execution; - Support both single player and multi-player environment; @@ -115,7 +115,7 @@ env = envpool.make("Pong-v5", env_type="gym", num_envs=100) # or use envpool.make_gym(...) obs = env.reset() # should be (100, 4, 84, 84) act = np.zeros(100, dtype=int) -obs, rew, done, info = env.step(act) +obs, rew, term, trunc, info = env.step(act) ``` Under the synchronous mode, `envpool` closely resembles `openai-gym`/`dm-env`. It has the `reset` and `step` functions with the same meaning. However, there is one exception in `envpool`: batch interaction is the default. Therefore, during the creation of the envpool, there is a `num_envs` argument that denotes how many envs you like to run in parallel. @@ -145,7 +145,7 @@ env = envpool.make("Pong-v5", env_type="gym", num_envs=num_envs, batch_size=batc action_num = env.action_space.n env.async_reset() # send the initial reset signal to all envs while True: - obs, rew, done, info = env.recv() + obs, rew, term, trunc, info = env.recv() env_id = info["env_id"] action = np.random.randint(action_num, size=batch_size) env.send(action, env_id) diff --git a/WORKSPACE b/WORKSPACE index 24d317e6..bb0ab8e3 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -8,6 +8,20 @@ load("//envpool:workspace1.bzl", workspace1 = "workspace") workspace1() +# QT special, cannot move to workspace2.bzl, not sure why + +load("@local_config_qt//:local_qt.bzl", "local_qt_path") + +new_local_repository( + name = "qt", + build_file = "@com_justbuchanan_rules_qt//:qt.BUILD", + path = local_qt_path(), +) + +load("@com_justbuchanan_rules_qt//tools:qt_toolchain.bzl", "register_qt_toolchains") + +register_qt_toolchains() + load("//envpool:pip.bzl", pip_workspace = "workspace") pip_workspace() diff --git a/docker/dev-cn.dockerfile b/docker/dev-cn.dockerfile index a457b98a..0880d279 100644 --- a/docker/dev-cn.dockerfile +++ b/docker/dev-cn.dockerfile @@ -1,22 +1,40 @@ -FROM ubuntu:20.04 +# Need docker >= 20.10.9, see https://stackoverflow.com/questions/71941032/why-i-cannot-run-apt-update-inside-a-fresh-ubuntu22-04 + +FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04 ARG DEBIAN_FRONTEND=noninteractive ARG HOME=/root ARG PATH=$PATH:$HOME/go/bin RUN apt-get update \ - && apt-get install -y python3-pip python3-dev golang-1.16 clang-format-11 git wget swig \ + && apt-get install -y python3-pip python3-dev golang-1.18 git wget curl zsh tmux vim \ && rm -rf /var/lib/apt/lists/* RUN ln -s /usr/bin/python3 /usr/bin/python -RUN ln -sf /usr/lib/go-1.16/bin/go /usr/bin/go +RUN ln -sf /usr/lib/go-1.18/bin/go /usr/bin/go +RUN sh -c "$(curl -fsSL https://raw.githubusercontent.com/ohmyzsh/ohmyzsh/master/tools/install.sh)" +WORKDIR $HOME +RUN git clone https://github.com/gpakosz/.tmux.git +RUN ln -s -f .tmux/.tmux.conf +RUN cp .tmux/.tmux.conf.local . +RUN echo "set-option -g default-shell /bin/zsh" >> .tmux.conf.local +RUN echo "set-option -g history-limit 10000" >> .tmux.conf.local RUN go env -w GOPROXY=https://goproxy.cn -RUN wget https://mirrors.huaweicloud.com/bazel/5.1.1/bazel-5.1.1-linux-x86_64 -RUN chmod +x bazel-5.1.1-linux-x86_64 +RUN wget https://mirrors.huaweicloud.com/bazel/6.0.0/bazel-6.0.0-linux-x86_64 +RUN chmod +x bazel-6.0.0-linux-x86_64 RUN mkdir -p $HOME/go/bin -RUN mv bazel-5.1.1-linux-x86_64 $HOME/go/bin/bazel +RUN mv bazel-6.0.0-linux-x86_64 $HOME/go/bin/bazel RUN go install github.com/bazelbuild/buildtools/buildifier@latest -RUN pip3 install --upgrade pip isort yapf cpplint flake8 flake8_bugbear mypy && rm -rf ~/.pip/cache +RUN $HOME/go/bin/bazel version + +RUN useradd -ms /bin/zsh github-action + +RUN apt-get update \ + && apt-get install -y clang-format clang-tidy swig qtdeclarative5-dev \ + && rm -rf /var/lib/apt/lists/* + +RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple +RUN sed -i "s@http://.*archive.ubuntu.com@https://mirrors.tuna.tsinghua.edu.cn@g" /etc/apt/sources.list +RUN sed -i "s@http://.*security.ubuntu.com@https://mirrors.tuna.tsinghua.edu.cn@g" /etc/apt/sources.list WORKDIR /app -COPY . . diff --git a/docker/dev.dockerfile b/docker/dev.dockerfile index c9814ec1..3360235c 100644 --- a/docker/dev.dockerfile +++ b/docker/dev.dockerfile @@ -1,18 +1,33 @@ -FROM ubuntu:20.04 +# Need docker >= 20.10.9, see https://stackoverflow.com/questions/71941032/why-i-cannot-run-apt-update-inside-a-fresh-ubuntu22-04 + +FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04 ARG DEBIAN_FRONTEND=noninteractive ARG HOME=/root ARG PATH=$PATH:$HOME/go/bin RUN apt-get update \ - && apt-get install -y python3-pip python3-dev golang-1.16 clang-format-11 git wget swig \ + && apt-get install -y python3-pip python3-dev golang-1.18 git wget curl zsh tmux vim \ && rm -rf /var/lib/apt/lists/* RUN ln -s /usr/bin/python3 /usr/bin/python -RUN ln -sf /usr/lib/go-1.16/bin/go /usr/bin/go +RUN ln -sf /usr/lib/go-1.18/bin/go /usr/bin/go +RUN sh -c "$(curl -fsSL https://raw.githubusercontent.com/ohmyzsh/ohmyzsh/master/tools/install.sh)" +WORKDIR $HOME +RUN git clone https://github.com/gpakosz/.tmux.git +RUN ln -s -f .tmux/.tmux.conf +RUN cp .tmux/.tmux.conf.local . +RUN echo "set-option -g default-shell /bin/zsh" >> .tmux.conf.local +RUN echo "set-option -g history-limit 10000" >> .tmux.conf.local +RUN echo "export PATH=$PATH:$HOME/go/bin" >> .zshrc RUN go install github.com/bazelbuild/bazelisk@latest && ln -sf $HOME/go/bin/bazelisk $HOME/go/bin/bazel RUN go install github.com/bazelbuild/buildtools/buildifier@latest -RUN pip3 install --upgrade pip isort yapf cpplint flake8 flake8_bugbear mypy && rm -rf ~/.pip/cache +RUN $HOME/go/bin/bazel version + +RUN useradd -ms /bin/zsh github-action + +RUN apt-get update \ + && apt-get install -y clang-format clang-tidy swig qtdeclarative5-dev \ + && rm -rf /var/lib/apt/lists/* WORKDIR /app -COPY . . diff --git a/docker/release.dockerfile b/docker/release.dockerfile index f0427c91..d3c88e85 100644 --- a/docker/release.dockerfile +++ b/docker/release.dockerfile @@ -1,48 +1,79 @@ -FROM ubuntu:16.04 +FROM nvidia/cuda:11.3.1-cudnn8-devel-ubuntu16.04 ARG DEBIAN_FRONTEND=noninteractive ARG HOME=/root -ENV PATH=$HOME/go/bin:$PATH +ENV PATH=$HOME/go/bin:$HOME/.pyenv/shims:$HOME/.pyenv/bin:$PATH +ENV LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH WORKDIR $HOME -# install base dependencies +# setup env + +RUN apt-get update && apt-get install -y software-properties-common && add-apt-repository ppa:ubuntu-toolchain-r/test -RUN apt-get update && apt-get install -y software-properties-common && add-apt-repository ppa:ubuntu-toolchain-r/test && add-apt-repository ppa:deadsnakes/ppa RUN apt-get update \ - && apt-get install -y git curl wget gcc-9 g++-9 build-essential patchelf make libssl-dev zlib1g-dev \ - libbz2-dev libreadline-dev libsqlite3-dev llvm \ - libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev swig \ - python3.7 python3.8 python3.9 python3.10 \ - python3.7-dev python3.8-dev python3.9-dev python3.10-dev \ - python3.8-distutils python3.9-distutils python3.10-distutils -RUN ln -sf /usr/bin/python3 /usr/bin/python -RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-9 60 --slave /usr/bin/g++ g++ /usr/bin/g++-9 + && apt-get install -y git curl wget zsh gcc-9 g++-9 build-essential make tmux \ + && rm -rf /var/lib/apt/lists/* -# install pip +RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-9 60 --slave /usr/bin/g++ g++ /usr/bin/g++-9 +RUN sh -c "$(curl -fsSL https://raw.githubusercontent.com/ohmyzsh/ohmyzsh/master/tools/install.sh)" +WORKDIR $HOME +RUN git clone https://github.com/gpakosz/.tmux.git +RUN ln -s -f .tmux/.tmux.conf +RUN cp .tmux/.tmux.conf.local . +RUN echo "set-option -g default-shell /bin/zsh" >> .tmux.conf.local +RUN echo "set-option -g history-limit 10000" >> .tmux.conf.local -RUN wget https://bootstrap.pypa.io/get-pip.py -RUN for i in 7 8 9 10; do ln -sf /usr/bin/python3.$i /usr/bin/python3; python3 get-pip.py; done +RUN curl https://pyenv.run | sh # install go from source -RUN wget https://golang.org/dl/go1.17.3.linux-amd64.tar.gz -RUN rm -rf /usr/local/go && tar -C /usr/local -xzf go1.17.3.linux-amd64.tar.gz +RUN wget https://golang.org/dl/go1.19.4.linux-amd64.tar.gz +RUN rm -rf /usr/local/go && tar -C /usr/local -xzf go1.19.4.linux-amd64.tar.gz RUN ln -sf /usr/local/go/bin/go /usr/bin/go # install bazel RUN go install github.com/bazelbuild/bazelisk@latest && ln -sf $HOME/go/bin/bazelisk $HOME/go/bin/bazel -# install big wheels - -RUN for i in 7 8 9 10; do ln -sf /usr/bin/python3.$i /usr/bin/python3; pip3 install torch opencv-python-headless; done - RUN bazel version -WORKDIR /app +# install base dependencies + +RUN apt-get update \ + && apt-get install -y swig zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev libncursesw5-dev libffi-dev liblzma-dev \ + llvm xz-utils tk-dev libxml2-dev libxmlsec1-dev qtdeclarative5-dev \ + && rm -rf /var/lib/apt/lists/* +# use self-compiled openssl instead of system provided (1.0.2) +RUN apt-get remove -y libssl-dev + +# install newest openssl (for py3.10 and py3.11) + +RUN wget https://www.openssl.org/source/openssl-1.1.1s.tar.gz +RUN tar xf openssl-1.1.1s.tar.gz +WORKDIR $HOME/openssl-1.1.1s +RUN ./config +RUN make -j +RUN make install + +# install python + +RUN echo 'export PYENV_ROOT="$HOME/.pyenv"' >> /etc/profile +RUN echo 'command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH"' >> /etc/profile +RUN echo 'eval "$(pyenv init -)"' >> /etc/profile + +RUN LDFLAGS="-Wl,-rpath,/root/openssl-1.1.1s/lib" CONFIGURE_OPTS="-with-openssl=/root/openssl-1.1.1s" pyenv install -v 3.11-dev +RUN LDFLAGS="-Wl,-rpath,/root/openssl-1.1.1s/lib" CONFIGURE_OPTS="-with-openssl=/root/openssl-1.1.1s" pyenv install -v 3.10-dev +RUN LDFLAGS="-Wl,-rpath,/root/openssl-1.1.1s/lib" CONFIGURE_OPTS="-with-openssl=/root/openssl-1.1.1s" pyenv install -v 3.9-dev +RUN LDFLAGS="-Wl,-rpath,/root/openssl-1.1.1s/lib" CONFIGURE_OPTS="-with-openssl=/root/openssl-1.1.1s" pyenv install -v 3.8-dev +RUN LDFLAGS="-Wl,-rpath,/root/openssl-1.1.1s/lib" CONFIGURE_OPTS="-with-openssl=/root/openssl-1.1.1s" pyenv install -v 3.7-dev + +WORKDIR /__w/envpool/envpool COPY . . -# compile and test release wheels +# cache bazel build (cpp only) + +RUN bazel build //envpool/utils:image_process_test --config=release +RUN bazel build //envpool/vizdoom/bin:vizdoom_bin --config=release -RUN for i in 7 8 9 10; do ln -sf /usr/bin/python3.$i /usr/bin/python3; make pypi-wheel BAZELOPT="--remote_cache=http://bazel-cache.sail:8080"; pip3 install wheelhouse/*cp3$i*.whl; rm dist/*.whl; make release-test; done +WORKDIR /app diff --git a/docs/content/build.rst b/docs/content/build.rst index 022a6d56..ac5242cf 100644 --- a/docs/content/build.rst +++ b/docs/content/build.rst @@ -1,7 +1,7 @@ Build From Source ================= -We recommend building EnvPool on Ubuntu 20.04 environment. +We recommend building EnvPool on Ubuntu 22.04 environment. We use `bazel `_ to build EnvPool. Comparing with `pip `_, using Bazel to build python package with C++ .so @@ -50,10 +50,10 @@ or `golang `_ with version >= 1.16: # then follow the instructions on golang official website go env -w GOPROXY=https://goproxy.cn - wget https://mirrors.huaweicloud.com/bazel/5.1.1/bazel-5.1.1-linux-x86_64 - chmod +x bazel-5.1.1-linux-x86_64 + wget https://mirrors.huaweicloud.com/bazel/6.0.0/bazel-6.0.0-linux-x86_64 + chmod +x bazel-6.0.0-linux-x86_64 mkdir -p $HOME/go/bin - mv bazel-5.1.1-linux-x86_64 $HOME/go/bin/bazel + mv bazel-6.0.0-linux-x86_64 $HOME/go/bin/bazel export PATH=$PATH:$HOME/go/bin # or write to .bashrc / .zshrc @@ -99,6 +99,7 @@ To build a release version, type: .. code-block:: bash + cp third_party/pip_requirements/requirements-release.txt third_party/pip_requirements/requirements.txt bazel run --config=release //:setup -- bdist_wheel This creates a wheel under ``bazel-bin/setup.runfiles/envpool/dist``. @@ -142,6 +143,9 @@ We provide several shortcuts to make things easier. # This will automatically run the tests make bazel-test + # This will build a wheel for release + make bazel-release + Use Docker to Create Develop Environment ---------------------------------------- diff --git a/docs/content/new_env.rst b/docs/content/new_env.rst index e6234759..706e1317 100644 --- a/docs/content/new_env.rst +++ b/docs/content/new_env.rst @@ -452,8 +452,8 @@ After that, you can import ``_CartPoleEnvSpec`` and ``_CartPoleEnvPool`` from The next step is to apply python-side wrapper (gym/dm_env APIs) to raw classes. In ``envpool/classic_control/__init__.py``, use ``py_env`` function to -instantiate ``CartPoleEnvSpec``, ``CartPoleDMEnvPool``, and -``CartPoleGymEnvPool``. +instantiate ``CartPoleEnvSpec``, ``CartPoleDMEnvPool``, +``CartPoleGymEnvPool`` and ``CartPoleGymnasiumEnvPool``. :: @@ -461,14 +461,18 @@ instantiate ``CartPoleEnvSpec``, ``CartPoleDMEnvPool``, and from .classic_control_envpool import _CartPoleEnvPool, _CartPoleEnvSpec - CartPoleEnvSpec, CartPoleDMEnvPool, CartPoleGymEnvPool = py_env( - _CartPoleEnvSpec, _CartPoleEnvPool - ) + ( + CartPoleEnvSpec, + CartPoleDMEnvPool, + CartPoleGymEnvPool, + CartPoleGymnasiumEnvPool, + ) = py_env(_CartPoleEnvSpec, _CartPoleEnvPool) __all__ = [ "CartPoleEnvSpec", "CartPoleDMEnvPool", "CartPoleGymEnvPool", + "CartPoleGymnasiumEnvPool", ] @@ -487,6 +491,7 @@ To register a task in EnvPool, you need to call ``register`` function in spec_cls="CartPoleEnvSpec", dm_cls="CartPoleDMEnvPool", gym_cls="CartPoleGymEnvPool", + gymnasium_cls="CartPoleGymnasiumEnvPool", max_episode_steps=200, reward_threshold=195.0, ) @@ -497,15 +502,16 @@ To register a task in EnvPool, you need to call ``register`` function in spec_cls="CartPoleEnvSpec", dm_cls="CartPoleDMEnvPool", gym_cls="CartPoleGymEnvPool", + gymnasium_cls="CartPoleGymnasiumEnvPool", max_episode_steps=500, reward_threshold=475.0, ) -``task_id``, ``import_path``, ``spec_cls``, ``dm_cls``, and ``gym_cls`` are -required arguments. Other arguments such as ``max_episode_steps`` and -``reward_threshold`` are env-specific. For example, if someone use -``envpool.make("CartPole-v1")``, the ``reward_threshold`` will be set to 475.0 -at ``CartPoleEnvPool`` initialization. +``task_id``, ``import_path``, ``spec_cls``, ``dm_cls``, ``gym_cls`` and +``gymnasium_cls`` are required arguments. Other arguments such as +``max_episode_steps`` and ``reward_threshold`` are env-specific. For example, +if someone use ``envpool.make("CartPole-v1")``, the ``reward_threshold`` will +be set to 475.0 at ``CartPoleEnvPool`` initialization. Finally, it is crucial to let the top-level module import this file. In ``envpool/entry.py``, add the following line: diff --git a/docs/content/python_interface.rst b/docs/content/python_interface.rst index 8c71ae5e..8504c5ae 100644 --- a/docs/content/python_interface.rst +++ b/docs/content/python_interface.rst @@ -11,7 +11,7 @@ batched environments: * ``task_id (str)``: task id, use ``envpool.list_all_envs()`` to see all support tasks; * ``env_type (str)``: generate with ``gym.Env`` or ``dm_env.Environment`` - interface, available options are ``dm`` and ``gym``; + interface, available options are ``dm``, ``gym``, and ``gymnasium``; * ``num_envs (int)``: how many envs are in the envpool, default to ``1``; * ``batch_size (int)``: async configuration, see the last section, default to ``num_envs``; @@ -20,7 +20,8 @@ batched environments: * ``seed (int)``: set seed over all environments. The i-th environment seed will be set with i+seed, default to ``42``; * ``max_episode_steps (int)``: set the max steps in one episode. This value is - env-specific (108000 in Atari for example); + env-specific (27000 steps or 27000 * 4 = 108000 frames in Atari for + example); * ``max_num_players (int)``: the maximum number of player in one env, useful in multi-agent env. In single agent environment, it is always ``1``; * ``thread_affinity_offset (int)``: the start id of binding thread. ``-1`` @@ -43,8 +44,9 @@ The observation space and action space of resulted environment are observation/action's first dimension is always equal to ``num_envs`` (sync mode) or equal to ``batch_size`` (async mode). -``envpool.make_gym`` and ``envpool.make_dm`` are shortcuts for -``envpool.make(..., env_type="gym" | "dm")``, respectively. +``envpool.make_gym``, ``envpool.make_dm``, and ``envpool.make_gymnasium`` are +shortcuts for ``envpool.make(..., env_type="gym" | "dm" | "gymnasium")``, +respectively. envpool.make_spec ----------------- @@ -117,7 +119,7 @@ Data Output Format ------------------ +----------+----------------------------------------------------------------------+------------------------------------------------------------------+ -| function | gym | dm | +| function | gym & gymnasium | dm | | | | | +==========+======================================================================+==================================================================+ | reset | | env_id -> obs array (single observation) | env_id -> TimeStep(FIRST, obs|info|env_id, rew=0, discount or 1) | diff --git a/docs/content/xla_interface.rst b/docs/content/xla_interface.rst index 38f37778..89371688 100644 --- a/docs/content/xla_interface.rst +++ b/docs/content/xla_interface.rst @@ -27,7 +27,7 @@ These functions can be obtained from the envpool instance which we created from the Python API. :: - env = envpool.make(..., env_type="gym" | "dm") + env = envpool.make(..., env_type="gym" | "dm" | "gymnasium") handle, recv, send, step = env.xla() diff --git a/docs/env/atari.rst b/docs/env/atari.rst index 88754a77..0f23a622 100644 --- a/docs/env/atari.rst +++ b/docs/env/atari.rst @@ -25,7 +25,9 @@ Options ``env.step``, default to ``batch_size``; * ``seed (int)``: the environment seed, default to ``42``; * ``max_episode_steps (int)``: the maximum number of steps for one episode, - default to ``108000``; + default to ``27000``, which corresponds to 108000 frames or roughly 30 + minutes of game-play (`Hessel et al. 2018, Table 3 + `_) because of the 4 skipped frames; * ``img_height (int)``: the desired observation image height, default to ``84``; * ``img_width (int)``: the desired observation image width, default to ``84``; @@ -52,6 +54,8 @@ Options image resize, default to ``True``. * ``use_fire_reset (bool)``: whether to use ``fire-reset`` wrapper, default to ``True``. +* ``full_action_space (bool)``: whether to use full action space of ALE of 18 + actions, default to ``False``. Observation Space diff --git a/docs/env/dm_control.rst b/docs/env/dm_control.rst index 4832496e..e96728e6 100644 --- a/docs/env/dm_control.rst +++ b/docs/env/dm_control.rst @@ -46,7 +46,7 @@ BallInCupCatch-v1 - ``max_episode_steps``: 1000; -CartpoleBalance-v1, CartpoleBalanceSparse-v1, CarpoletSwingup-v1, CartpoleSwingupSparse-v1, CartpoleTwoPoles-v1, CartpoleThreePoles-v1 +CartpoleBalance-v1, CartpoleBalanceSparse-v1, CartpoleSwingup-v1, CartpoleSwingupSparse-v1, CartpoleTwoPoles-v1, CartpoleThreePoles-v1 -------------------------------------------------------------------------------------------------------------------------------------- `dm_control suite cartpole source code diff --git a/docs/env/minigrid.rst b/docs/env/minigrid.rst new file mode 100644 index 00000000..e35c740f --- /dev/null +++ b/docs/env/minigrid.rst @@ -0,0 +1,18 @@ +Minigrid +======== + +We use ``minigrid==2.1.0`` as the codebase. +See https://github.com/Farama-Foundation/Minigrid/tree/v2.1.0 + + +Empty +----- + +Registered Configurations + +- `MiniGrid-Empty-5x5-v0` +- `MiniGrid-Empty-Random-5x5-v0` +- `MiniGrid-Empty-6x6-v0` +- `MiniGrid-Empty-Random-6x6-v0` +- `MiniGrid-Empty-8x8-v0` +- `MiniGrid-Empty-16x16-v0` diff --git a/docs/env/procgen.rst b/docs/env/procgen.rst new file mode 100644 index 00000000..a2cd8605 --- /dev/null +++ b/docs/env/procgen.rst @@ -0,0 +1,100 @@ +Procgen +======= + +We use ``procgen==0.10.7`` as the codebase. +See https://github.com/openai/procgen/tree/0.10.7 + + +Options +------- + +* ``task_id (str)``: see available tasks below; +* ``num_envs (int)``: how many environments you would like to create; +* ``batch_size (int)``: the expected batch size for return result, default to + ``num_envs``; +* ``num_threads (int)``: the maximum thread number for executing the actual + ``env.step``, default to ``batch_size``; +* ``seed (int)``: the environment seed, default to ``42``; +* ``max_episode_steps (int)``: the maximum number of steps for one episode, + each procgen game has different timeout value; +* ``channel_first (bool)``: whether to transpose the observation image to + ``(3, 64, 64)``, default to ``True``; +* ``env_name (str)``: one of 16 procgen env name; +* ``num_levels (int)``: default to ``0``; +* ``start_level (int)``: default to ``0``; +* ``use_sequential_levels (bool)``: default to ``False``; +* ``center_agent (bool)``: default to ``True``; +* ``use_backgrounds (bool)``: default to ``True``; +* ``use_monochrome_assets (bool)``: default to ``False``; +* ``restrict_themes (bool)``: default to ``False``; +* ``use_generated_assets (bool)``: default to ``False``; +* ``paint_vel_info (bool)``: default to ``False``; +* ``use_easy_jump (bool)``: default to ``False``; +* ``distribution_mode (int)``: one of ``(0, 1, 2, 10)``; ``0`` stands for easy + mode, ``1`` stands for hard mode, ``2`` stands for extreme mode, ``10`` + stands for memory mode. The default value is determined by ``task_id``. + +Note: arguments after ``env_name`` are provided by procgen environment itself. +We keep the default value as-is. We haven't tested the setting of +``use_sequential_levels == True``, and have no promise it is aligned with the +original version of procgen (PRs for fixing this issue are highly welcome). + + +Observation Space +----------------- + +The observation image shape is ``(3, 64, 64)`` when ``channel_first`` is +``True`` (default), ``(64, 64, 3)`` when ``channel_first`` is ``False``. + + +Action Space +------------ + +15 action buttons in total, ranging from 0 to 14. + + +Available Tasks +--------------- + +* ``BigfishEasy-v0`` +* ``BigfishHard-v0`` +* ``BossfightEasy-v0`` +* ``BossfightHard-v0`` +* ``CaveflyerEasy-v0`` +* ``CaveflyerHard-v0`` +* ``CaveflyerMemory-v0`` +* ``ChaserEasy-v0`` +* ``ChaserHard-v0`` +* ``ChaserExtreme-v0`` +* ``ClimberEasy-v0`` +* ``ClimberHard-v0`` +* ``CoinrunEasy-v0`` +* ``CoinrunHard-v0`` +* ``DodgeballEasy-v0`` +* ``DodgeballHard-v0`` +* ``DodgeballExtreme-v0`` +* ``DodgeballMemory-v0`` +* ``FruitbotEasy-v0`` +* ``FruitbotHard-v0`` +* ``HeistEasy-v0`` +* ``HeistHard-v0`` +* ``HeistMemory-v0`` +* ``JumperEasy-v0`` +* ``JumperHard-v0`` +* ``JumperMemory-v0`` +* ``LeaperEasy-v0`` +* ``LeaperHard-v0`` +* ``LeaperExtreme-v0`` +* ``MazeEasy-v0`` +* ``MazeHard-v0`` +* ``MazeMemory-v0`` +* ``MinerEasy-v0`` +* ``MinerHard-v0`` +* ``MinerMemory-v0`` +* ``NinjaEasy-v0`` +* ``NinjaHard-v0`` +* ``PlunderEasy-v0`` +* ``PlunderHard-v0`` +* ``StarpilotEasy-v0`` +* ``StarpilotHard-v0`` +* ``StarpilotExtreme-v0`` diff --git a/docs/index.rst b/docs/index.rst index cf6c0830..403f76f1 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -90,7 +90,9 @@ stable version through `envpool.readthedocs.io/en/stable/ env/box2d env/classic_control env/dm_control + env/minigrid env/mujoco_gym + env/procgen env/toy_text env/vizdoom diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 29ac4a00..2ef42414 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -3,6 +3,7 @@ async dm env pybind +cpplint envs subprocess codebase @@ -68,3 +69,4 @@ api jit mins lidar +procgen diff --git a/envpool/BUILD b/envpool/BUILD index 4dd760cb..93a7fc6d 100644 --- a/envpool/BUILD +++ b/envpool/BUILD @@ -35,6 +35,7 @@ py_library( "//envpool/classic_control:classic_control_registration", "//envpool/mujoco:mujoco_dmc_registration", "//envpool/mujoco:mujoco_gym_registration", + "//envpool/procgen:procgen_registration", "//envpool/toy_text:toy_text_registration", "//envpool/vizdoom:vizdoom_registration", ], @@ -51,6 +52,7 @@ py_library( "//envpool/classic_control", "//envpool/mujoco:mujoco_dmc", "//envpool/mujoco:mujoco_gym", + "//envpool/procgen", "//envpool/python", "//envpool/toy_text", "//envpool/vizdoom", diff --git a/envpool/__init__.py b/envpool/__init__.py index 303bcf27..c9f8e5da 100644 --- a/envpool/__init__.py +++ b/envpool/__init__.py @@ -19,16 +19,18 @@ make, make_dm, make_gym, + make_gymnasium, make_spec, register, ) -__version__ = "0.6.6" +__version__ = "0.8.2" __all__ = [ "register", "make", "make_dm", "make_gym", + "make_gymnasium", "make_spec", "list_all_envs", ] diff --git a/envpool/atari/__init__.py b/envpool/atari/__init__.py index 41bba129..dc4011be 100644 --- a/envpool/atari/__init__.py +++ b/envpool/atari/__init__.py @@ -18,7 +18,7 @@ from .atari_envpool import _AtariEnvPool, _AtariEnvSpec -AtariEnvSpec, AtariDMEnvPool, AtariGymEnvPool = py_env( +AtariEnvSpec, AtariDMEnvPool, AtariGymEnvPool, AtariGymnasiumEnvPool = py_env( _AtariEnvSpec, _AtariEnvPool ) @@ -26,4 +26,5 @@ "AtariEnvSpec", "AtariDMEnvPool", "AtariGymEnvPool", + "AtariGymnasiumEnvPool", ] diff --git a/envpool/atari/atari_env.h b/envpool/atari/atari_env.h index bca85d13..48cb83ce 100644 --- a/envpool/atari/atari_env.h +++ b/envpool/atari/atari_env.h @@ -32,14 +32,14 @@ namespace atari { -bool TurnOffVerbosity() { +auto TurnOffVerbosity() { ale::Logger::setMode(ale::Logger::Error); return true; } static bool verbosity_off = TurnOffVerbosity(); -std::string GetRomPath(const std::string& base_path, const std::string& task) { +auto GetRomPath(const std::string& base_path, const std::string& task) { std::stringstream ss; // hardcode path here :( ss << base_path << "/atari/roms/" << task << ".bin"; @@ -54,7 +54,7 @@ class AtariEnvFns { "zero_discount_on_life_loss"_.Bind(false), "episodic_life"_.Bind(false), "reward_clip"_.Bind(false), "use_fire_reset"_.Bind(true), "img_height"_.Bind(84), "img_width"_.Bind(84), - "task"_.Bind(std::string("pong")), + "task"_.Bind(std::string("pong")), "full_action_space"_.Bind(false), "repeat_action_probability"_.Bind(0.0f), "use_inter_area_resize"_.Bind(true), "gray_scale"_.Bind(true)); } @@ -72,7 +72,9 @@ class AtariEnvFns { static decltype(auto) ActionSpec(const Config& conf) { ale::ALEInterface env; env.loadROM(GetRomPath(conf["base_path"_], conf["task"_])); - int action_size = env.getMinimalActionSet().size(); + int action_size = conf["full_action_space"_] + ? env.getLegalActionSet().size() + : env.getMinimalActionSet().size(); return MakeDict("action"_.Bind(Spec({-1}, {0, action_size - 1}))); } }; @@ -88,9 +90,9 @@ class AtariEnv : public Env { std::unique_ptr env_; ale::ActionVect action_set_; int max_episode_steps_, elapsed_step_, stack_num_, frame_skip_; - bool fire_reset_, reward_clip_, zero_discount_on_life_loss_; + bool fire_reset_{false}, reward_clip_, zero_discount_on_life_loss_; bool gray_scale_, episodic_life_, use_inter_area_resize_; - bool done_; + bool done_{true}; int lives_; FrameSpec raw_spec_, resize_spec_, transpose_spec_; std::deque stack_buf_; @@ -107,13 +109,11 @@ class AtariEnv : public Env { elapsed_step_(max_episode_steps_ + 1), stack_num_(spec.config["stack_num"_]), frame_skip_(spec.config["frame_skip"_]), - fire_reset_(false), reward_clip_(spec.config["reward_clip"_]), zero_discount_on_life_loss_(spec.config["zero_discount_on_life_loss"_]), gray_scale_(spec.config["gray_scale"_]), episodic_life_(spec.config["episodic_life"_]), use_inter_area_resize_(spec.config["use_inter_area_resize"_]), - done_(true), raw_spec_({kRawHeight, kRawWidth, gray_scale_ ? 1 : 3}), resize_spec_({spec.config["img_height"_], spec.config["img_width"_], gray_scale_ ? 1 : 3}), @@ -126,7 +126,11 @@ class AtariEnv : public Env { spec.config["repeat_action_probability"_]); env_->setInt("random_seed", seed_); env_->loadROM(rom_path_); - action_set_ = env_->getMinimalActionSet(); + if (spec.config["full_action_space"_]) { + action_set_ = env_->getLegalActionSet(); + } else { + action_set_ = env_->getMinimalActionSet(); + } if (spec.config["use_fire_reset"_]) { // https://github.com/sail-sg/envpool/issues/221 for (auto a : action_set_) { diff --git a/envpool/atari/atari_envpool_test.py b/envpool/atari/atari_envpool_test.py index a6bcb5fe..a687060e 100644 --- a/envpool/atari/atari_envpool_test.py +++ b/envpool/atari/atari_envpool_test.py @@ -25,7 +25,7 @@ import envpool.atari.registration # noqa: F401 from envpool.atari.atari_envpool import _AtariEnvPool, _AtariEnvSpec -from envpool.registration import make_dm, make_gym +from envpool.registration import make_dm, make_gym, make_gymnasium class _AtariEnvPoolTest(absltest.TestCase): @@ -61,19 +61,30 @@ def test_raw_envpool(self) -> None: fps = total * batch / duration * 4 logging.info(f"Raw envpool FPS = {fps:.6f}") + def test_full_action_space(self) -> None: + env = make_gym("Pong-v5", full_action_space=True) + self.assertEqual(env.action_space.n, 18) + env = make_gym("Breakout-v5", full_action_space=True) + self.assertEqual(env.action_space.n, 18) + def test_align(self) -> None: """Make sure gym's envpool and dm_env's envpool generate the same data.""" num_envs = 4 env0 = make_gym("SpaceInvaders-v5", num_envs=num_envs) env1 = make_dm("SpaceInvaders-v5", num_envs=num_envs) + env2 = make_gymnasium("SpaceInvaders-v5", num_envs=num_envs) obs0, _ = env0.reset() obs1 = env1.reset().observation.obs + obs2, _ = env2.reset() np.testing.assert_allclose(obs0, obs1) + np.testing.assert_allclose(obs1, obs2) for _ in range(1000): action = np.random.randint(6, size=num_envs) obs0 = env0.step(action)[0] obs1 = env1.step(action).observation.obs + obs2 = env2.step(action)[0] np.testing.assert_allclose(obs0, obs1) + np.testing.assert_allclose(obs1, obs2) # cv2.imwrite(f"/tmp/log/align{i}.png", obs0[0, 1:].transpose(1, 2, 0)) def test_reset_life(self) -> None: diff --git a/envpool/atari/registration.py b/envpool/atari/registration.py index 2e44d6b8..66467e6e 100644 --- a/envpool/atari/registration.py +++ b/envpool/atari/registration.py @@ -15,9 +15,7 @@ import os -from envpool.registration import register - -base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +from envpool.registration import base_path, register atari_rom_path = os.path.join(base_path, "atari", "roms") atari_game_list = sorted( @@ -32,7 +30,7 @@ spec_cls="AtariEnvSpec", dm_cls="AtariDMEnvPool", gym_cls="AtariGymEnvPool", + gymnasium_cls="AtariGymnasiumEnvPool", task=game, - base_path=base_path, - max_episode_steps=108000, + max_episode_steps=27000, ) diff --git a/envpool/box2d/BUILD b/envpool/box2d/BUILD index 71a0564f..a987e59e 100644 --- a/envpool/box2d/BUILD +++ b/envpool/box2d/BUILD @@ -89,7 +89,7 @@ py_test( ":box2d_registration", requirement("absl-py"), requirement("gym"), - requirement("box2d"), + requirement("box2d-py"), requirement("pygame"), requirement("opencv-python-headless"), requirement("numpy"), diff --git a/envpool/box2d/__init__.py b/envpool/box2d/__init__.py index d4b3785e..80c80665 100644 --- a/envpool/box2d/__init__.py +++ b/envpool/box2d/__init__.py @@ -26,24 +26,28 @@ _LunarLanderDiscreteEnvSpec, ) -CarRacingEnvSpec, CarRacingDMEnvPool, CarRacingGymEnvPool = py_env( - _CarRacingEnvSpec, _CarRacingEnvPool -) +( + BipedalWalkerEnvSpec, BipedalWalkerDMEnvPool, BipedalWalkerGymEnvPool, + BipedalWalkerGymnasiumEnvPool +) = py_env(_BipedalWalkerEnvSpec, _BipedalWalkerEnvPool) -BipedalWalkerEnvSpec, BipedalWalkerDMEnvPool, BipedalWalkerGymEnvPool = py_env( - _BipedalWalkerEnvSpec, _BipedalWalkerEnvPool -) +( + CarRacingEnvSpec, CarRacingDMEnvPool, CarRacingGymEnvPool, + CarRacingGymnasiumEnvPool +) = py_env(_CarRacingEnvSpec, _CarRacingEnvPool) ( LunarLanderContinuousEnvSpec, LunarLanderContinuousDMEnvPool, LunarLanderContinuousGymEnvPool, + LunarLanderContinuousGymnasiumEnvPool, ) = py_env(_LunarLanderContinuousEnvSpec, _LunarLanderContinuousEnvPool) ( LunarLanderDiscreteEnvSpec, LunarLanderDiscreteDMEnvPool, LunarLanderDiscreteGymEnvPool, + LunarLanderDiscreteGymnasiumEnvPool, ) = py_env(_LunarLanderDiscreteEnvSpec, _LunarLanderDiscreteEnvPool) __all__ = [ @@ -53,10 +57,13 @@ "BipedalWalkerEnvSpec", "BipedalWalkerDMEnvPool", "BipedalWalkerGymEnvPool", + "BipedalWalkerGymnasiumEnvPool", "LunarLanderContinuousEnvSpec", "LunarLanderContinuousDMEnvPool", "LunarLanderContinuousGymEnvPool", + "LunarLanderContinuousGymnasiumEnvPool", "LunarLanderDiscreteEnvSpec", "LunarLanderDiscreteDMEnvPool", "LunarLanderDiscreteGymEnvPool", + "LunarLanderDiscreteGymnasiumEnvPool", ] diff --git a/envpool/box2d/bipedal_walker_env.cc b/envpool/box2d/bipedal_walker_env.cc index 9e68cc05..d97a9cee 100644 --- a/envpool/box2d/bipedal_walker_env.cc +++ b/envpool/box2d/bipedal_walker_env.cc @@ -65,9 +65,7 @@ BipedalWalkerBox2dEnv::BipedalWalkerBox2dEnv(bool hardcore, : max_episode_steps_(max_episode_steps), elapsed_step_(max_episode_steps + 1), hardcore_(hardcore), - done_(true), - world_(new b2World(b2Vec2(0.0, -10.0))), - hull_(nullptr) { + world_(new b2World(b2Vec2(0.0, -10.0))) { for (const auto* p : kHullPoly) { hull_poly_.emplace_back(Vec2(p[0] / kScaleDouble, p[1] / kScaleDouble)); } diff --git a/envpool/box2d/bipedal_walker_env.h b/envpool/box2d/bipedal_walker_env.h index 640476e7..9d0179dd 100644 --- a/envpool/box2d/bipedal_walker_env.h +++ b/envpool/box2d/bipedal_walker_env.h @@ -75,7 +75,7 @@ class BipedalWalkerBox2dEnv { protected: int max_episode_steps_, elapsed_step_; float reward_, prev_shaping_; - bool hardcore_, done_; + bool hardcore_, done_{true}; std::array obs_; // info float scroll_; @@ -83,7 +83,7 @@ class BipedalWalkerBox2dEnv { // box2d related std::unique_ptr world_; - b2Body* hull_; + b2Body* hull_{nullptr}; std::vector hull_poly_; std::vector terrain_; std::array legs_; diff --git a/envpool/box2d/car_dynamics.cc b/envpool/box2d/car_dynamics.cc index 573c0b49..c2081cc0 100644 --- a/envpool/box2d/car_dynamics.cc +++ b/envpool/box2d/car_dynamics.cc @@ -36,7 +36,7 @@ b2PolygonShape GeneratePolygon(const float (*poly)[2], int size) { // NOLINT Car::Car(std::shared_ptr world, float init_angle, float init_x, float init_y) - : world_(std::move(world)), hull_(nullptr), fuel_spent_(0) { + : world_(std::move(world)) { // Create hull b2BodyDef bd; bd.position.Set(init_x, init_y); @@ -186,7 +186,7 @@ void Car::Step(float dt) { } else if (w->skid_start == nullptr) { w->skid_start = std::make_unique(w->body->GetPosition()); } else { - w->skid_particle = CreateParticle(*(w->skid_start.get()), + w->skid_particle = CreateParticle(*(w->skid_start.get()), // NOLINT w->body->GetPosition(), grass); w->skid_start = nullptr; } @@ -264,7 +264,7 @@ void Car::Draw(const cv::Mat& surf, float zoom, cv::fillPoly(surf, poly, color); auto* user_data = - reinterpret_cast(body->GetUserData().pointer); + reinterpret_cast(body->GetUserData().pointer); // NOLINT if (user_data == nullptr || user_data->type != WHEEL_TYPE) { continue; } diff --git a/envpool/box2d/car_dynamics.h b/envpool/box2d/car_dynamics.h index 41ff9828..21a2cfd9 100644 --- a/envpool/box2d/car_dynamics.h +++ b/envpool/box2d/car_dynamics.h @@ -138,9 +138,9 @@ class Car { std::deque> particles_; std::vector drawlist_; std::shared_ptr world_; - b2Body* hull_; + b2Body* hull_{nullptr}; std::vector wheels_; - float fuel_spent_; + float fuel_spent_{0}; std::shared_ptr CreateParticle(b2Vec2 point1, b2Vec2 point2, bool grass); diff --git a/envpool/box2d/car_racing_env.cc b/envpool/box2d/car_racing_env.cc index c0e10fdf..f2c6f154 100644 --- a/envpool/box2d/car_racing_env.cc +++ b/envpool/box2d/car_racing_env.cc @@ -34,9 +34,9 @@ void CarRacingFrictionDetector::Contact(b2Contact* contact, bool begin) { Tile* tile = nullptr; Wheel* obj = nullptr; - auto* u1 = reinterpret_cast( + auto* u1 = reinterpret_cast( // NOLINT contact->GetFixtureA()->GetBody()->GetUserData().pointer); - auto* u2 = reinterpret_cast( + auto* u2 = reinterpret_cast( // NOLINT contact->GetFixtureB()->GetBody()->GetUserData().pointer); if (u1 == nullptr || u2 == nullptr) { @@ -84,7 +84,6 @@ CarRacingBox2dEnv::CarRacingBox2dEnv(int max_episode_steps, : lap_complete_percent_(lap_complete_percent), max_episode_steps_(max_episode_steps), elapsed_step_(max_episode_steps + 1), - done_(true), world_(new b2World(b2Vec2(0.0, 0.0))) { b2PolygonShape shape; std::array vertices = {b2Vec2(0, 0), b2Vec2(1, 0), b2Vec2(1, -1), diff --git a/envpool/box2d/car_racing_env.h b/envpool/box2d/car_racing_env.h index 97fa2b94..aa1c23ad 100644 --- a/envpool/box2d/car_racing_env.h +++ b/envpool/box2d/car_racing_env.h @@ -80,7 +80,7 @@ class CarRacingBox2dEnv { float reward_{0}; float prev_reward_{0}; float step_reward_{0}; - bool done_; + bool done_{true}; cv::Mat surf_; cv::Mat img_array_; diff --git a/envpool/box2d/lunar_lander_env.cc b/envpool/box2d/lunar_lander_env.cc index d8dd22aa..d72d0a2b 100644 --- a/envpool/box2d/lunar_lander_env.cc +++ b/envpool/box2d/lunar_lander_env.cc @@ -52,10 +52,7 @@ LunarLanderBox2dEnv::LunarLanderBox2dEnv(bool continuous, int max_episode_steps) : max_episode_steps_(max_episode_steps), elapsed_step_(max_episode_steps + 1), continuous_(continuous), - done_(true), - world_(new b2World(b2Vec2(0.0, -10.0))), - moon_(nullptr), - lander_(nullptr) { + world_(new b2World(b2Vec2(0.0, -10.0))) { for (const auto* p : kLanderPoly) { lander_poly_.emplace_back(Vec2(p[0] / kScale, p[1] / kScale)); } diff --git a/envpool/box2d/lunar_lander_env.h b/envpool/box2d/lunar_lander_env.h index 81ba8644..f5299786 100644 --- a/envpool/box2d/lunar_lander_env.h +++ b/envpool/box2d/lunar_lander_env.h @@ -53,12 +53,12 @@ class LunarLanderBox2dEnv { protected: int max_episode_steps_, elapsed_step_; float reward_, prev_shaping_; - bool continuous_, done_; + bool continuous_, done_{true}; std::array obs_; // box2d related std::unique_ptr world_; - b2Body *moon_, *lander_; + b2Body *moon_{nullptr}, *lander_{nullptr}; std::vector particles_; std::vector lander_poly_; std::array legs_; diff --git a/envpool/box2d/registration.py b/envpool/box2d/registration.py index 1252f21a..7d446480 100644 --- a/envpool/box2d/registration.py +++ b/envpool/box2d/registration.py @@ -21,6 +21,7 @@ spec_cls="CarRacingEnvSpec", dm_cls="CarRacingDMEnvPool", gym_cls="CarRacingGymEnvPool", + gymnasium_cls="CarRacingGymnasiumEnvPool", max_episode_steps=1000, ) @@ -30,6 +31,7 @@ spec_cls="BipedalWalkerEnvSpec", dm_cls="BipedalWalkerDMEnvPool", gym_cls="BipedalWalkerGymEnvPool", + gymnasium_cls="BipedalWalkerGymnasiumEnvPool", hardcore=False, max_episode_steps=1600, ) @@ -40,6 +42,7 @@ spec_cls="BipedalWalkerEnvSpec", dm_cls="BipedalWalkerDMEnvPool", gym_cls="BipedalWalkerGymEnvPool", + gymnasium_cls="BipedalWalkerGymnasiumEnvPool", hardcore=True, max_episode_steps=2000, ) @@ -50,6 +53,7 @@ spec_cls="LunarLanderDiscreteEnvSpec", dm_cls="LunarLanderDiscreteDMEnvPool", gym_cls="LunarLanderDiscreteGymEnvPool", + gymnasium_cls="LunarLanderDiscreteGymnasiumEnvPool", max_episode_steps=1000, ) @@ -59,5 +63,6 @@ spec_cls="LunarLanderContinuousEnvSpec", dm_cls="LunarLanderContinuousDMEnvPool", gym_cls="LunarLanderContinuousGymEnvPool", + gymnasium_cls="LunarLanderContinuousGymnasiumEnvPool", max_episode_steps=1000, ) diff --git a/envpool/box2d/utils.cc b/envpool/box2d/utils.cc index 8298365d..a9a6145b 100644 --- a/envpool/box2d/utils.cc +++ b/envpool/box2d/utils.cc @@ -17,7 +17,7 @@ namespace box2d { b2Vec2 Vec2(double x, double y) { - return b2Vec2(static_cast(x), static_cast(y)); + return {static_cast(x), static_cast(y)}; } float Sign(double val, double eps) { @@ -43,7 +43,7 @@ b2Vec2 RotateRad(const b2Vec2& v, float angle) { b2Vec2 Multiply(const b2Transform& trans, const b2Vec2& v) { float x = (trans.q.c * v.x - trans.q.s * v.y) + trans.p.x; float y = (trans.q.s * v.x + trans.q.c * v.y) + trans.p.y; - return b2Vec2(x, y); + return {x, y}; } } // namespace box2d diff --git a/envpool/classic_control/__init__.py b/envpool/classic_control/__init__.py index faa8bbb9..511c8259 100644 --- a/envpool/classic_control/__init__.py +++ b/envpool/classic_control/__init__.py @@ -28,40 +28,58 @@ _PendulumEnvSpec, ) -CartPoleEnvSpec, CartPoleDMEnvPool, CartPoleGymEnvPool = py_env( - _CartPoleEnvSpec, _CartPoleEnvPool -) +( + CartPoleEnvSpec, + CartPoleDMEnvPool, + CartPoleGymEnvPool, + CartPoleGymnasiumEnvPool, +) = py_env(_CartPoleEnvSpec, _CartPoleEnvPool) -PendulumEnvSpec, PendulumDMEnvPool, PendulumGymEnvPool = py_env( - _PendulumEnvSpec, _PendulumEnvPool -) +( + PendulumEnvSpec, + PendulumDMEnvPool, + PendulumGymEnvPool, + PendulumGymnasiumEnvPool, +) = py_env(_PendulumEnvSpec, _PendulumEnvPool) -(MountainCarEnvSpec, MountainCarDMEnvPool, - MountainCarGymEnvPool) = py_env(_MountainCarEnvSpec, _MountainCarEnvPool) +( + MountainCarEnvSpec, + MountainCarDMEnvPool, + MountainCarGymEnvPool, + MountainCarGymnasiumEnvPool, +) = py_env(_MountainCarEnvSpec, _MountainCarEnvPool) ( MountainCarContinuousEnvSpec, MountainCarContinuousDMEnvPool, - MountainCarContinuousGymEnvPool + MountainCarContinuousGymEnvPool, MountainCarContinuousGymnasiumEnvPool ) = py_env(_MountainCarContinuousEnvSpec, _MountainCarContinuousEnvPool) -AcrobotEnvSpec, AcrobotDMEnvPool, AcrobotGymEnvPool = py_env( - _AcrobotEnvSpec, _AcrobotEnvPool -) +( + AcrobotEnvSpec, + AcrobotDMEnvPool, + AcrobotGymEnvPool, + AcrobotGymnasiumEnvPool, +) = py_env(_AcrobotEnvSpec, _AcrobotEnvPool) __all__ = [ "CartPoleEnvSpec", "CartPoleDMEnvPool", "CartPoleGymEnvPool", + "CartPoleGymnasiumEnvPool", "PendulumEnvSpec", "PendulumDMEnvPool", "PendulumGymEnvPool", + "PendulumGymnasiumEnvPool", "MountainCarEnvSpec", "MountainCarDMEnvPool", "MountainCarGymEnvPool", + "MountainCarGymnasiumEnvPool", "MountainCarContinuousEnvSpec", "MountainCarContinuousDMEnvPool", "MountainCarContinuousGymEnvPool", + "MountainCarContinuousGymnasiumEnvPool", "AcrobotEnvSpec", "AcrobotDMEnvPool", "AcrobotGymEnvPool", + "AcrobotGymnasiumEnvPool", ] diff --git a/envpool/classic_control/acrobot.h b/envpool/classic_control/acrobot.h index bc466d51..afbeae59 100644 --- a/envpool/classic_control/acrobot.h +++ b/envpool/classic_control/acrobot.h @@ -53,10 +53,10 @@ class AcrobotEnv : public Env { V5(double s0, double s1, double s2, double s3, double s4) : s0(s0), s1(s1), s2(s2), s3(s3), s4(s4) {} V5 operator+(const V5& v) const { - return V5(s0 + v.s0, s1 + v.s1, s2 + v.s2, s3 + v.s3, s4 + v.s4); + return {s0 + v.s0, s1 + v.s1, s2 + v.s2, s3 + v.s3, s4 + v.s4}; } V5 operator*(double v) const { - return V5(s0 * v, s1 * v, s2 * v, s3 * v, s4 * v); + return {s0 * v, s1 * v, s2 * v, s3 * v, s4 * v}; } }; @@ -74,15 +74,14 @@ class AcrobotEnv : public Env { int max_episode_steps_, elapsed_step_; V5 s_; std::uniform_real_distribution<> dist_; - bool done_; + bool done_{true}; public: AcrobotEnv(const Spec& spec, int env_id) : Env(spec, env_id), max_episode_steps_(spec.config["max_episode_steps"_]), elapsed_step_(max_episode_steps_ + 1), - dist_(-kInitRange, kInitRange), - done_(true) {} + dist_(-kInitRange, kInitRange) {} bool IsDone() override { return done_; } @@ -164,7 +163,7 @@ class AcrobotEnv : public Env { kM * kL * kLC * dtheta1 * dtheta1 * std::sin(theta2) - phi2) / (kM * kLC * kLC + kI - d2 * d2 / d1); double ddtheta1 = -(d2 * ddtheta2 + phi1) / d1; - return V5(dtheta1, dtheta2, ddtheta1, ddtheta2, 0); + return {dtheta1, dtheta2, ddtheta1, ddtheta2, 0}; } void WriteState(float reward) { diff --git a/envpool/classic_control/cartpole.h b/envpool/classic_control/cartpole.h index 98fb8099..57dca894 100644 --- a/envpool/classic_control/cartpole.h +++ b/envpool/classic_control/cartpole.h @@ -63,15 +63,14 @@ class CartPoleEnv : public Env { int max_episode_steps_, elapsed_step_; double x_, x_dot_, theta_, theta_dot_; std::uniform_real_distribution<> dist_; - bool done_; + bool done_{true}; public: CartPoleEnv(const Spec& spec, int env_id) : Env(spec, env_id), max_episode_steps_(spec.config["max_episode_steps"_]), elapsed_step_(max_episode_steps_ + 1), - dist_(-kInitRange, kInitRange), - done_(true) {} + dist_(-kInitRange, kInitRange) {} bool IsDone() override { return done_; } diff --git a/envpool/classic_control/mountain_car.h b/envpool/classic_control/mountain_car.h index 69793e04..f1535091 100644 --- a/envpool/classic_control/mountain_car.h +++ b/envpool/classic_control/mountain_car.h @@ -56,15 +56,14 @@ class MountainCarEnv : public Env { int max_episode_steps_, elapsed_step_; double pos_, vel_; std::uniform_real_distribution<> dist_; - bool done_; + bool done_{true}; public: MountainCarEnv(const Spec& spec, int env_id) : Env(spec, env_id), max_episode_steps_(spec.config["max_episode_steps"_]), elapsed_step_(max_episode_steps_ + 1), - dist_(-0.6, -0.4), - done_(true) {} + dist_(-0.6, -0.4) {} bool IsDone() override { return done_; } diff --git a/envpool/classic_control/mountain_car_continuous.h b/envpool/classic_control/mountain_car_continuous.h index 21a1f231..a86739cb 100644 --- a/envpool/classic_control/mountain_car_continuous.h +++ b/envpool/classic_control/mountain_car_continuous.h @@ -56,15 +56,14 @@ class MountainCarContinuousEnv : public Env { int max_episode_steps_, elapsed_step_; double pos_, vel_; std::uniform_real_distribution<> dist_; - bool done_; + bool done_{true}; public: MountainCarContinuousEnv(const Spec& spec, int env_id) : Env(spec, env_id), max_episode_steps_(spec.config["max_episode_steps"_]), elapsed_step_(max_episode_steps_ + 1), - dist_(-0.6, -0.4), - done_(true) {} + dist_(-0.6, -0.4) {} bool IsDone() override { return done_; } diff --git a/envpool/classic_control/pendulum.h b/envpool/classic_control/pendulum.h index d3cc891a..f2a594ad 100644 --- a/envpool/classic_control/pendulum.h +++ b/envpool/classic_control/pendulum.h @@ -53,7 +53,7 @@ class PendulumEnv : public Env { int version_; double theta_, theta_dot_; std::uniform_real_distribution<> dist_, dist_dot_; - bool done_; + bool done_{true}; public: PendulumEnv(const Spec& spec, int env_id) @@ -62,8 +62,7 @@ class PendulumEnv : public Env { elapsed_step_(max_episode_steps_ + 1), version_(spec.config["version"_]), dist_(-M_PI, M_PI), - dist_dot_(-1, 1), - done_(true) {} + dist_dot_(-1, 1) {} bool IsDone() override { return done_; } diff --git a/envpool/classic_control/registration.py b/envpool/classic_control/registration.py index b7b2bb5b..fdbd90ce 100644 --- a/envpool/classic_control/registration.py +++ b/envpool/classic_control/registration.py @@ -21,6 +21,7 @@ spec_cls="CartPoleEnvSpec", dm_cls="CartPoleDMEnvPool", gym_cls="CartPoleGymEnvPool", + gymnasium_cls="CartPoleGymnasiumEnvPool", max_episode_steps=200, reward_threshold=195.0, ) @@ -31,6 +32,7 @@ spec_cls="CartPoleEnvSpec", dm_cls="CartPoleDMEnvPool", gym_cls="CartPoleGymEnvPool", + gymnasium_cls="CartPoleGymnasiumEnvPool", max_episode_steps=500, reward_threshold=475.0, ) @@ -41,6 +43,7 @@ spec_cls="PendulumEnvSpec", dm_cls="PendulumDMEnvPool", gym_cls="PendulumGymEnvPool", + gymnasium_cls="PendulumGymnasiumEnvPool", version=0, max_episode_steps=200, ) @@ -51,6 +54,7 @@ spec_cls="PendulumEnvSpec", dm_cls="PendulumDMEnvPool", gym_cls="PendulumGymEnvPool", + gymnasium_cls="PendulumGymnasiumEnvPool", version=1, max_episode_steps=200, ) @@ -61,6 +65,7 @@ spec_cls="MountainCarEnvSpec", dm_cls="MountainCarDMEnvPool", gym_cls="MountainCarGymEnvPool", + gymnasium_cls="MountainCarGymnasiumEnvPool", max_episode_steps=200, ) @@ -70,6 +75,7 @@ spec_cls="MountainCarContinuousEnvSpec", dm_cls="MountainCarContinuousDMEnvPool", gym_cls="MountainCarContinuousGymEnvPool", + gymnasium_cls="MountainCarContinuousGymnasiumEnvPool", max_episode_steps=999, ) @@ -79,5 +85,6 @@ spec_cls="AcrobotEnvSpec", dm_cls="AcrobotDMEnvPool", gym_cls="AcrobotGymEnvPool", + gymnasium_cls="AcrobotGymnasiumEnvPool", max_episode_steps=500, ) diff --git a/envpool/core/array.h b/envpool/core/array.h index e99cd51b..84aabf35 100644 --- a/envpool/core/array.h +++ b/envpool/core/array.h @@ -116,8 +116,8 @@ class Array { if (shape_[0] > 0) { offset = start * size / shape_[0]; } - return Array(ptr_.get() + offset * element_size, std::move(new_shape), - element_size, [](char* p) {}); + return {ptr_.get() + offset * element_size, std::move(new_shape), + element_size, [](char* p) {}}; } /** diff --git a/envpool/core/async_envpool.h b/envpool/core/async_envpool.h index 089f998f..4ce04f40 100644 --- a/envpool/core/async_envpool.h +++ b/envpool/core/async_envpool.h @@ -117,7 +117,7 @@ class AsyncEnvPool : public EnvPool { } } - ~AsyncEnvPool() { + ~AsyncEnvPool() override { stop_ = 1; // LOG(INFO) << "envpool send: " << dur_send_.count(); // LOG(INFO) << "envpool recv: " << dur_recv_.count(); diff --git a/envpool/core/dict.h b/envpool/core/dict.h index 064b1eb1..263ce461 100644 --- a/envpool/core/dict.h +++ b/envpool/core/dict.h @@ -205,8 +205,11 @@ class Dict : public std::decay_t { std::enable_if_t::value, bool> = true> [[nodiscard]] std::vector AllValues() const { std::vector rets; - std::apply([&](auto&&... value) { (rets.push_back(Type(value)), ...); }, - *static_cast(this)); + std::apply( + [&](auto&&... value) { + (rets.push_back(static_cast(value)), ...); + }, + *static_cast(this)); return rets; } @@ -297,7 +300,7 @@ std::vector MakeArray(const std::tuple& specs) { * Takes a vector of `ShapeSpec`. */ std::vector MakeArray(const std::vector& specs) { - return std::vector(specs.begin(), specs.end()); + return {specs.begin(), specs.end()}; } #endif // ENVPOOL_CORE_DICT_H_ diff --git a/envpool/core/env.h b/envpool/core/env.h index d8f3381f..4d866072 100644 --- a/envpool/core/env.h +++ b/envpool/core/env.h @@ -58,7 +58,7 @@ class Env { private: StateBufferQueue* sbq_; - int order_, current_step_; + int order_, current_step_{-1}; bool is_single_player_; StateBuffer::WritableSlice slice_; // for parsing single env action from input action batch @@ -79,7 +79,6 @@ class Env { env_id_(env_id), seed_(spec.config["seed"_] + env_id), gen_(seed_), - current_step_(-1), is_single_player_(max_num_players_ == 1), action_specs_(spec.action_spec.template AllValues()), is_player_action_(Transform(action_specs_, [](const ShapeSpec& s) { @@ -88,6 +87,8 @@ class Env { slice_.done_write = [] { LOG(INFO) << "Use `Allocate` to write state."; }; } + virtual ~Env() = default; + void SetAction(std::shared_ptr> action_batch, int env_index) { action_batch_ = std::move(action_batch); diff --git a/envpool/core/env_spec.h b/envpool/core/env_spec.h index 32c4e866..123dae92 100644 --- a/envpool/core/env_spec.h +++ b/envpool/core/env_spec.h @@ -51,8 +51,8 @@ class EnvSpec { using Config = decltype(ConcatDict(common_config, EnvFns::DefaultConfig())); using ConfigKeys = typename Config::Keys; using ConfigValues = typename Config::Values; - using StateSpec = decltype( - ConcatDict(common_state_spec, EnvFns::StateSpec(std::declval()))); + using StateSpec = decltype(ConcatDict( + common_state_spec, EnvFns::StateSpec(std::declval()))); using ActionSpec = decltype(ConcatDict( common_action_spec, EnvFns::ActionSpec(std::declval()))); using StateKeys = typename StateSpec::Keys; diff --git a/envpool/core/envpool.h b/envpool/core/envpool.h index 75e945f5..0798581a 100644 --- a/envpool/core/envpool.h +++ b/envpool/core/envpool.h @@ -33,6 +33,7 @@ class EnvPool { using State = NamedVector>; using Action = NamedVector>; explicit EnvPool(EnvSpec spec) : spec(std::move(spec)) {} + virtual ~EnvPool() = default; protected: virtual void Send(const std::vector& action) { diff --git a/envpool/core/py_envpool.h b/envpool/core/py_envpool.h index f41f48fb..35a33e45 100644 --- a/envpool/core/py_envpool.h +++ b/envpool/core/py_envpool.h @@ -72,8 +72,8 @@ struct ArrayToNumpyHelper> { reinterpret_cast((*inner_ptr)->Data()), capsule); } } - return py::array(py::dtype("object"), a.Shape(), - reinterpret_cast(ptr->get()), capsule); + return {py::dtype("object"), a.Shape(), + reinterpret_cast(ptr->get()), capsule}; } }; @@ -83,7 +83,7 @@ Array NumpyToArray(const py::array& arr) { ArrayT arr_t(arr); ShapeSpec spec(arr_t.itemsize(), std::vector(arr_t.shape(), arr_t.shape() + arr_t.ndim())); - return Array(spec, reinterpret_cast(arr_t.mutable_data())); + return {spec, reinterpret_cast(arr_t.mutable_data())}; } template diff --git a/envpool/core/spec.h b/envpool/core/spec.h index fceb5310..909977f9 100644 --- a/envpool/core/spec.h +++ b/envpool/core/spec.h @@ -44,7 +44,7 @@ class ShapeSpec { [[nodiscard]] ShapeSpec Batch(int batch_size) const { std::vector new_shape = {batch_size}; new_shape.insert(new_shape.end(), shape.begin(), shape.end()); - return ShapeSpec(element_size, std::move(new_shape)); + return {element_size, std::move(new_shape)}; } [[nodiscard]] std::vector Shape() const { auto s = std::vector(shape.size()); diff --git a/envpool/core/state_buffer_queue_test.cc b/envpool/core/state_buffer_queue_test.cc index caac0394..2c8890fb 100644 --- a/envpool/core/state_buffer_queue_test.cc +++ b/envpool/core/state_buffer_queue_test.cc @@ -31,14 +31,18 @@ TEST(StateBufferQueueTest, Basic) { std::srand(std::time(nullptr)); std::size_t size = 0; for (std::size_t i = 0; i < batch; ++i) { + LOG(INFO) << i << " start"; std::size_t num_players = 1; auto slice = queue.Allocate(num_players); + LOG(INFO) << i << " allocate"; slice.done_write(); + LOG(INFO) << i << " done_write"; EXPECT_EQ(slice.arr[0].Shape(0), 10); EXPECT_EQ(slice.arr[1].Shape(0), 1); size += num_players; } std::vector out = queue.Wait(); + LOG(INFO) << "finish wait"; EXPECT_EQ(out[0].Shape(0), size); EXPECT_EQ(out[1].Shape(0), size); EXPECT_EQ(batch, size); diff --git a/envpool/dummy/__init__.py b/envpool/dummy/__init__.py index b9577efe..ad4692d2 100644 --- a/envpool/dummy/__init__.py +++ b/envpool/dummy/__init__.py @@ -18,7 +18,7 @@ from .dummy_envpool import _DummyEnvPool, _DummyEnvSpec -DummyEnvSpec, DummyDMEnvPool, DummyGymEnvPool = py_env( +DummyEnvSpec, DummyDMEnvPool, DummyGymEnvPool, DummyGymnasiumEnvPool = py_env( _DummyEnvSpec, _DummyEnvPool ) @@ -26,4 +26,5 @@ "DummyEnvSpec", "DummyDMEnvPool", "DummyGymEnvPool", + "DummyGymnasiumEnvPool", ] diff --git a/envpool/dummy/dummy_envpool.h b/envpool/dummy/dummy_envpool.h index 9e2b82a2..a0ff212d 100644 --- a/envpool/dummy/dummy_envpool.h +++ b/envpool/dummy/dummy_envpool.h @@ -115,15 +115,14 @@ using DummyEnvSpec = EnvSpec; */ class DummyEnv : public Env { protected: - int state_; + int state_{0}; public: /** * Initilize the env, in this function we perform tasks like loading the game * rom etc. */ - DummyEnv(const Spec& spec, int env_id) - : Env(spec, env_id), state_(0) { + DummyEnv(const Spec& spec, int env_id) : Env(spec, env_id) { if (seed_ < 1) { seed_ = 1; } diff --git a/envpool/entry.py b/envpool/entry.py index f21e93ca..eed70a29 100644 --- a/envpool/entry.py +++ b/envpool/entry.py @@ -13,10 +13,42 @@ # limitations under the License. """Entry point for all envs' registration.""" -import envpool.atari.registration # noqa: F401 -import envpool.box2d.registration # noqa: F401 -import envpool.classic_control.registration # noqa: F401 -import envpool.mujoco.dmc.registration # noqa: F401 -import envpool.mujoco.gym.registration # noqa: F401 -import envpool.toy_text.registration # noqa: F401 -import envpool.vizdoom.registration # noqa: F401 +try: + import envpool.atari.registration # noqa: F401 +except ImportError: + pass + +try: + import envpool.box2d.registration # noqa: F401 +except ImportError: + pass + +try: + import envpool.classic_control.registration # noqa: F401 +except ImportError: + pass + +try: + import envpool.mujoco.dmc.registration # noqa: F401 +except ImportError: + pass + +try: + import envpool.mujoco.gym.registration # noqa: F401 +except ImportError: + pass + +try: + import envpool.procgen.registration # noqa: F401 +except ImportError: + pass + +try: + import envpool.toy_text.registration # noqa: F401 +except ImportError: + pass + +try: + import envpool.vizdoom.registration # noqa: F401 +except ImportError: + pass diff --git a/envpool/minigrid/BUILD b/envpool/minigrid/BUILD new file mode 100644 index 00000000..1ae60e13 --- /dev/null +++ b/envpool/minigrid/BUILD @@ -0,0 +1,86 @@ +# Copyright 2023 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@pip_requirements//:requirements.bzl", "requirement") +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "minigrid_env", + srcs = [ + "impl/minigrid_empty_env.cc", + "impl/minigrid_env.cc", + ], + hdrs = [ + "empty.h", + "impl/minigrid_empty_env.h", + "impl/minigrid_env.h", + "impl/utils.h", + ], + deps = [ + "//envpool/core:async_envpool", + ], +) + +pybind_extension( + name = "minigrid_envpool", + srcs = ["minigrid.cc"], + deps = [ + ":minigrid_env", + "//envpool/core:py_envpool", + ], +) + +py_library( + name = "minigrid", + srcs = ["__init__.py"], + data = [":minigrid_envpool.so"], + deps = ["//envpool/python:api"], +) + +py_library( + name = "minigrid_registration", + srcs = ["registration.py"], + deps = [ + "//envpool:registration", + ], +) + +py_test( + name = "minigrid_align_test", + size = "enormous", + srcs = ["minigrid_align_test.py"], + deps = [ + ":minigrid", + ":minigrid_registration", + requirement("absl-py"), + requirement("gym"), + requirement("numpy"), + requirement("minigrid"), + ], +) + +py_test( + name = "minigrid_deterministic_test", + size = "enormous", + srcs = ["minigrid_deterministic_test.py"], + deps = [ + ":minigrid", + ":minigrid_registration", + requirement("absl-py"), + requirement("gym"), + requirement("numpy"), + ], +) diff --git a/envpool/minigrid/__init__.py b/envpool/minigrid/__init__.py new file mode 100644 index 00000000..13f727b6 --- /dev/null +++ b/envpool/minigrid/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2023 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Minigrid env in EnvPool.""" + +from envpool.python.api import py_env + +from .minigrid_envpool import _EmptyEnvPool, _EmptyEnvSpec + +(EmptyEnvSpec, EmptyDMEnvPool, EmptyGymEnvPool, + EmptyGymnasiumEnvPool) = py_env(_EmptyEnvSpec, _EmptyEnvPool) + +__all__ = [ + "EmptyEmvSpec", + "EmptyDMEnvPool", + "EmptyGymEnvPool", + "EmptyGymnasiumEnvPool", +] diff --git a/envpool/minigrid/empty.h b/envpool/minigrid/empty.h new file mode 100644 index 00000000..9910d5e2 --- /dev/null +++ b/envpool/minigrid/empty.h @@ -0,0 +1,92 @@ +/* + * Copyright 2023 Garena Online Private Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ENVPOOL_MINIGRID_EMPTY_H_ +#define ENVPOOL_MINIGRID_EMPTY_H_ + +#include + +#include "envpool/core/async_envpool.h" +#include "envpool/core/env.h" +#include "envpool/minigrid/impl/minigrid_empty_env.h" +#include "envpool/minigrid/impl/minigrid_env.h" + +namespace minigrid { + +class EmptyEnvFns { + public: + static decltype(auto) DefaultConfig() { + return MakeDict("size"_.Bind(8), + "agent_start_pos"_.Bind(std::pair(1, 1)), + "agent_start_dir"_.Bind(0), "agent_view_size"_.Bind(7)); + } + template + static decltype(auto) StateSpec(const Config& conf) { + int agent_view_size = conf["agent_view_size"_]; + int size = conf["size"_]; + return MakeDict("obs:direction"_.Bind(Spec({-1}, {0, 3})), + "obs:image"_.Bind(Spec( + {agent_view_size, agent_view_size, 3}, {0, 255})), + "info:agent_pos"_.Bind(Spec({2}, {0, size}))); + } + template + static decltype(auto) ActionSpec(const Config& conf) { + return MakeDict("action"_.Bind(Spec({-1}, {0, 6}))); + } +}; + +using EmptyEnvSpec = EnvSpec; +using FrameSpec = Spec; + +class EmptyEnv : public Env, public MiniGridEmptyEnv { + public: + EmptyEnv(const Spec& spec, int env_id) + : Env(spec, env_id), + MiniGridEmptyEnv(spec.config["size"_], spec.config["agent_start_pos"_], + spec.config["agent_start_dir"_], + spec.config["max_episode_steps"_], + spec.config["agent_view_size"_]) { + gen_ref_ = &gen_; + } + + bool IsDone() override { return done_; } + + void Reset() override { + MiniGridReset(); + WriteState(0.0); + } + + void Step(const Action& action) override { + int act = action["action"_]; + WriteState(MiniGridStep(static_cast(act))); + } + + private: + void WriteState(float reward) { + State state = Allocate(); + GenImage(state["obs:image"_]); + state["obs:direction"_] = agent_dir_; + state["reward"_] = reward; + state["info:agent_pos"_](0) = agent_pos_.first; + state["info:agent_pos"_](1) = agent_pos_.second; + } +}; + +using EmptyEnvPool = AsyncEnvPool; + +} // namespace minigrid + +#endif // ENVPOOL_MINIGRID_EMPTY_H_ diff --git a/envpool/minigrid/impl/minigrid_empty_env.cc b/envpool/minigrid/impl/minigrid_empty_env.cc new file mode 100644 index 00000000..8da7e5f1 --- /dev/null +++ b/envpool/minigrid/impl/minigrid_empty_env.cc @@ -0,0 +1,64 @@ +// Copyright 2023 Garena Online Private Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "envpool/minigrid/impl/minigrid_empty_env.h" + +#include +#include + +namespace minigrid { + +MiniGridEmptyEnv::MiniGridEmptyEnv(int size, + std::pair agent_start_pos, + int agent_start_dir, int max_steps, + int agent_view_size) { + width_ = size; + height_ = size; + agent_start_pos_ = agent_start_pos; + agent_start_dir_ = agent_start_dir; + see_through_walls_ = true; + max_steps_ = max_steps; + agent_view_size_ = agent_view_size; +} + +void MiniGridEmptyEnv::GenGrid() { + grid_.clear(); + for (int i = 0; i < height_; ++i) { + std::vector temp_vec(width_); + for (int j = 0; j < width_; ++j) { + temp_vec[j] = WorldObj(kEmpty); + } + grid_.emplace_back(temp_vec); + } + // generate the surrounding walls + for (int i = 0; i < width_; ++i) { + grid_[0][i] = WorldObj(kWall, kGrey); + grid_[height_ - 1][i] = WorldObj(kWall, kGrey); + } + for (int i = 0; i < height_; ++i) { + grid_[i][0] = WorldObj(kWall, kGrey); + grid_[i][width_ - 1] = WorldObj(kWall, kGrey); + } + // place a goal square in the bottom-right corner + grid_[height_ - 2][width_ - 2] = WorldObj(kGoal, kGreen); + // place the agent + if (agent_start_pos_.first == -1) { + PlaceAgent(1, 1, width_ - 2, height_ - 2); + } else { + agent_pos_ = agent_start_pos_; + agent_dir_ = agent_start_dir_; + } +} + +} // namespace minigrid diff --git a/envpool/minigrid/impl/minigrid_empty_env.h b/envpool/minigrid/impl/minigrid_empty_env.h new file mode 100644 index 00000000..6a2b8bc1 --- /dev/null +++ b/envpool/minigrid/impl/minigrid_empty_env.h @@ -0,0 +1,35 @@ +/* + * Copyright 2023 Garena Online Private Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ENVPOOL_MINIGRID_IMPL_MINIGRID_EMPTY_ENV_H_ +#define ENVPOOL_MINIGRID_IMPL_MINIGRID_EMPTY_ENV_H_ + +#include + +#include "envpool/minigrid/impl/minigrid_env.h" + +namespace minigrid { + +class MiniGridEmptyEnv : public MiniGridEnv { + public: + MiniGridEmptyEnv(int size, std::pair agent_start_pos, + int agent_start_dir, int max_steps, int agent_view_size); + void GenGrid() override; +}; + +} // namespace minigrid + +#endif // ENVPOOL_MINIGRID_IMPL_MINIGRID_EMPTY_ENV_H_ diff --git a/envpool/minigrid/impl/minigrid_env.cc b/envpool/minigrid/impl/minigrid_env.cc new file mode 100644 index 00000000..c2ac26e6 --- /dev/null +++ b/envpool/minigrid/impl/minigrid_env.cc @@ -0,0 +1,241 @@ +// Copyright 2023 Garena Online Private Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * Note: + * The grid layout for this implementation is: + * + * 0 -------------> x (width_) + * | + * | grid[y][x] -> (x, y) + * | + * v + * y (height_) + */ + +#include "envpool/minigrid/impl/minigrid_env.h" + +#include + +namespace minigrid { + +void MiniGridEnv::MiniGridReset() { + GenGrid(); + step_count_ = 0; + done_ = false; + CHECK_GE(agent_pos_.first, 0); + CHECK_GE(agent_pos_.second, 0); + CHECK_GE(agent_dir_, 0); + CHECK(grid_[agent_pos_.second][agent_pos_.first].CanOverlap()); + carrying_ = WorldObj(kEmpty); +} + +float MiniGridEnv::MiniGridStep(Act act) { + step_count_ += 1; + float reward = 0.0; + // Get the position in front of the agent + std::pair fwd_pos = agent_pos_; + switch (agent_dir_) { + case 0: + fwd_pos.first += 1; + break; + case 1: + fwd_pos.second += 1; + break; + case 2: + fwd_pos.first -= 1; + break; + case 3: + fwd_pos.second -= 1; + break; + default: + CHECK(false); + break; + } + CHECK_GE(fwd_pos.first, 0); + CHECK(fwd_pos.first < width_); + CHECK_GE(fwd_pos.second, 0); + CHECK(fwd_pos.second < height_); + // Get the forward cell object + if (act == kLeft) { + agent_dir_ -= 1; + if (agent_dir_ < 0) { + agent_dir_ += 4; + } + } else if (act == kRight) { + agent_dir_ = (agent_dir_ + 1) % 4; + } else if (act == kForward) { + if (grid_[fwd_pos.second][fwd_pos.first].CanOverlap()) { + agent_pos_ = fwd_pos; + } + if (grid_[fwd_pos.second][fwd_pos.first].GetType() == kGoal) { + done_ = true; + reward = 1 - 0.9 * (static_cast(step_count_) / max_steps_); + } else if (grid_[fwd_pos.second][fwd_pos.first].GetType() == kLava) { + done_ = true; + } + } else if (act == kPickup) { + if (carrying_.GetType() == kEmpty && + grid_[fwd_pos.second][fwd_pos.first].CanPickup()) { + carrying_ = grid_[fwd_pos.second][fwd_pos.first]; + grid_[fwd_pos.second][fwd_pos.first] = WorldObj(kEmpty); + } + } else if (act == kDrop) { + if (carrying_.GetType() != kEmpty && + grid_[fwd_pos.second][fwd_pos.first].GetType() == kEmpty) { + grid_[fwd_pos.second][fwd_pos.first] = carrying_; + carrying_ = WorldObj(kEmpty); + } + } else if (act == kToggle) { + WorldObj obj = grid_[fwd_pos.second][fwd_pos.first]; + if (obj.GetType() == kDoor) { + if (obj.GetDoorLocked()) { + // If the agent has the right key to open the door + if (carrying_.GetType() == kKey && + carrying_.GetColor() == obj.GetColor()) { + grid_[fwd_pos.second][fwd_pos.first].SetDoorOpen(true); + } + } else { + grid_[fwd_pos.second][fwd_pos.first].SetDoorOpen(!obj.GetDoorOpen()); + } + } else if (obj.GetType() == kBox) { + // WARNING: this is MESSY!!! + auto* contains = grid_[fwd_pos.second][fwd_pos.first].GetContains(); + if (contains != nullptr) { + grid_[fwd_pos.second][fwd_pos.first] = *contains; + grid_[fwd_pos.second][fwd_pos.first].SetContains( + contains->GetContains()); + contains->SetContains(nullptr); + delete contains; + } else { + grid_[fwd_pos.second][fwd_pos.first] = WorldObj(kEmpty); + } + } + } else if (act != kDone) { + CHECK(false); + } + if (step_count_ >= max_steps_) { + done_ = true; + } + return reward; +} + +void MiniGridEnv::PlaceAgent(int start_x, int start_y, int end_x, int end_y) { + // Place an object at an empty position in the grid + end_x = (end_x == -1) ? width_ - 1 : end_x; + end_y = (end_y == -1) ? height_ - 1 : end_y; + CHECK(start_x <= end_x && start_y <= end_y); + std::uniform_int_distribution<> x_dist(start_x, end_x); + std::uniform_int_distribution<> y_dist(start_y, end_y); + while (true) { + int x = x_dist(*gen_ref_); + int y = y_dist(*gen_ref_); + if (grid_[y][x].GetType() != kEmpty) { + continue; + } + agent_pos_.first = x; + agent_pos_.second = y; + break; + } + // Randomly select a direction + if (agent_start_dir_ == -1) { + std::uniform_int_distribution<> dir_dist(0, 3); + agent_dir_ = dir_dist(*gen_ref_); + } +} + +void MiniGridEnv::GenImage(const Array& obs) { + // Get the extents of the square set of tiles visible to the agent + // Note: the bottom extent indices are not include in the set + int top_x; + int top_y; + if (agent_dir_ == 0) { + top_x = agent_pos_.first; + top_y = agent_pos_.second - (agent_view_size_ / 2); + } else if (agent_dir_ == 1) { + top_x = agent_pos_.first - (agent_view_size_ / 2); + top_y = agent_pos_.second; + } else if (agent_dir_ == 2) { + top_x = agent_pos_.first - agent_view_size_ + 1; + top_y = agent_pos_.second - (agent_view_size_ / 2); + } else if (agent_dir_ == 3) { + top_x = agent_pos_.first - (agent_view_size_ / 2); + top_y = agent_pos_.second - agent_view_size_ + 1; + } else { + CHECK(false); + } + + // Generate the sub-grid observed by the agent + std::vector> agent_view_grid; + for (int i = 0; i < agent_view_size_; ++i) { + std::vector temp_vec; + for (int j = 0; j < agent_view_size_; ++j) { + int x = top_x + j; + int y = top_y + i; + if (x >= 0 && x < width_ && y >= 0 && y < height_) { + temp_vec.emplace_back(WorldObj(grid_[y][x].GetType())); + } else { + temp_vec.emplace_back(WorldObj(kWall)); + } + } + agent_view_grid.emplace_back(temp_vec); + } + // Rotate the agent view grid to relatively facing up + for (int i = 0; i < agent_dir_ + 1; ++i) { + // Rotate counter-clockwise + std::vector> copy_grid = + agent_view_grid; // This is a deep copy + for (int y = 0; y < agent_view_size_; ++y) { + for (int x = 0; x < agent_view_size_; ++x) { + copy_grid[agent_view_size_ - 1 - x][y] = agent_view_grid[y][x]; + } + } + agent_view_grid = copy_grid; + } + // Process occluders and visibility + // Note that this incurs some performance cost + int agent_pos_x = agent_view_size_ / 2; + int agent_pos_y = agent_view_size_ - 1; + std::vector> vis_mask(agent_view_size_, + std::vector(agent_view_size_)); + for (auto& row : vis_mask) { + std::fill(row.begin(), row.end(), 0); + } + if (!see_through_walls_) { + // TODO(siping): Process_vis + vis_mask[agent_pos_y][agent_pos_x] = true; + } else { + for (auto& row : vis_mask) { + std::fill(row.begin(), row.end(), 1); + } + } + // Let the agent see what it's carrying + if (carrying_.GetType() != kEmpty) { + agent_view_grid[agent_pos_y][agent_pos_x] = carrying_; + } else { + agent_view_grid[agent_pos_y][agent_pos_x] = WorldObj(kEmpty); + } + for (int y = 0; y < agent_view_size_; ++y) { + for (int x = 0; x < agent_view_size_; ++x) { + if (vis_mask[y][x]) { + // Transpose to align with the python library + obs(x, y, 0) = static_cast(agent_view_grid[y][x].GetType()); + obs(x, y, 1) = static_cast(agent_view_grid[y][x].GetColor()); + obs(x, y, 2) = static_cast(agent_view_grid[y][x].GetState()); + } + } + } +} + +} // namespace minigrid diff --git a/envpool/minigrid/impl/minigrid_env.h b/envpool/minigrid/impl/minigrid_env.h new file mode 100644 index 00000000..fadf8de0 --- /dev/null +++ b/envpool/minigrid/impl/minigrid_env.h @@ -0,0 +1,58 @@ +/* + * Copyright 2023 Garena Online Private Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ENVPOOL_MINIGRID_IMPL_MINIGRID_ENV_H_ +#define ENVPOOL_MINIGRID_IMPL_MINIGRID_ENV_H_ + +#include +#include +#include + +#include "envpool/core/array.h" +#include "envpool/minigrid/impl/utils.h" + +namespace minigrid { + +class MiniGridEnv { + protected: + int width_; + int height_; + int max_steps_{100}; + int step_count_{0}; + int agent_view_size_{7}; + bool see_through_walls_{false}; + bool done_{true}; + std::pair agent_start_pos_; + std::pair agent_pos_; + int agent_start_dir_; + int agent_dir_; + std::mt19937* gen_ref_; + std::vector> grid_; + WorldObj carrying_; + + public: + MiniGridEnv() { carrying_ = WorldObj(kEmpty); } + void MiniGridReset(); + float MiniGridStep(Act act); + void PlaceAgent(int start_x = 0, int start_y = 0, int end_x = -1, + int end_y = -1); + void GenImage(const Array& obs); + virtual void GenGrid() {} +}; + +} // namespace minigrid + +#endif // ENVPOOL_MINIGRID_IMPL_MINIGRID_ENV_H_ diff --git a/envpool/minigrid/impl/utils.h b/envpool/minigrid/impl/utils.h new file mode 100644 index 00000000..328146fa --- /dev/null +++ b/envpool/minigrid/impl/utils.h @@ -0,0 +1,154 @@ +/* + * Copyright 2023 Garena Online Private Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ENVPOOL_MINIGRID_IMPL_UTILS_H_ +#define ENVPOOL_MINIGRID_IMPL_UTILS_H_ + +#include + +namespace minigrid { + +enum Act { + // Turn left, turn right, move forward + kLeft = 0, + kRight = 1, + kForward = 2, + // Pick up an object + kPickup = 3, + // Drop an object + kDrop = 4, + // Toggle/activate an object + kToggle = 5, + // Done completing task + kDone = 6 +}; + +enum Color { + kRed = 0, + kGreen = 1, + kBlue = 2, + kPurple = 3, + kYellow = 4, + kGrey = 5, + kUnassigned = 6 +}; + +enum Type { + kUnseen = 0, + kEmpty = 1, + kWall = 2, + kFloor = 3, + kDoor = 4, + kKey = 5, + kBall = 6, + kBox = 7, + kGoal = 8, + kLava = 9, + kAgent = 10 +}; + +// constants +static const std::unordered_map kCanSeeBehind{ + {kEmpty, true}, {kWall, false}, {kGoal, true}, + {kFloor, true}, {kLava, true}, {kKey, true}, + {kBall, true}, {kDoor, true}, {kBox, true}}; +static const std::unordered_map kCanOverlap{ + {kEmpty, true}, {kWall, false}, {kGoal, true}, + {kFloor, true}, {kLava, true}, {kKey, false}, + {kBall, false}, {kDoor, true}, {kBox, false}}; +static const std::unordered_map kCanPickup{ + {kEmpty, false}, {kWall, false}, {kGoal, false}, + {kFloor, false}, {kLava, false}, {kKey, true}, + {kBall, true}, {kDoor, false}, {kBox, true}}; + +// object class + +class WorldObj { + private: + Type type_; + Color color_; + WorldObj* contains_; + bool door_open_{true}; // this variable only makes sence when type_ == kDoor + bool door_locked_{ + false}; // this variable only makes sence when type_ == kDoor + + public: + explicit WorldObj(Type type = kEmpty, Color color = kUnassigned, + WorldObj* contains = nullptr) + : type_(type), contains_(contains) { + if (color == kUnassigned) { + switch (type) { + case kEmpty: + case kLava: + color_ = kRed; + break; + case kWall: + color_ = kGrey; + break; + case kGoal: + color_ = kGreen; + break; + case kKey: + case kBall: + case kFloor: + color_ = kBlue; + break; + default: + CHECK(false); + break; + } + } else { + color_ = color; + } + } + ~WorldObj() { delete contains_; } + [[nodiscard]] bool CanSeeBehind() const { + return door_open_ && kCanSeeBehind.at(type_); + } + [[nodiscard]] bool CanOverlap() const { + return door_open_ && kCanOverlap.at(type_); + } + [[nodiscard]] bool CanPickup() const { return kCanPickup.at(type_); } + [[nodiscard]] bool GetDoorOpen() const { return door_open_; } + void SetDoorOpen(bool flag) { door_open_ = flag; } + [[nodiscard]] bool GetDoorLocked() const { return door_locked_; } + void SetDoorLocker(bool flag) { door_locked_ = flag; } + Type GetType() { return type_; } + Color GetColor() { return color_; } + int GetState() { + if (type_ != kDoor) { + return 0; + } + if (door_locked_) { + return 2; + } + if (door_open_) { + return 0; + } + return 1; + } + WorldObj* GetContains() { return contains_; } + void SetContains(WorldObj* contains) { + if (contains != nullptr && type_ != kBox) { + CHECK(false); + } + contains_ = contains; + } +}; + +} // namespace minigrid + +#endif // ENVPOOL_MINIGRID_IMPL_UTILS_H_ diff --git a/envpool/minigrid/minigrid.cc b/envpool/minigrid/minigrid.cc new file mode 100644 index 00000000..d0af3d07 --- /dev/null +++ b/envpool/minigrid/minigrid.cc @@ -0,0 +1,21 @@ +// Copyright 2021 Garena Online Private Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "envpool/core/py_envpool.h" +#include "envpool/minigrid/empty.h" + +using EmptyEnvSpec = PyEnvSpec; +using EmptyEnvPool = PyEnvPool; + +PYBIND11_MODULE(minigrid_envpool, m) { REGISTER(m, EmptyEnvSpec, EmptyEnvPool) } diff --git a/envpool/minigrid/minigrid_align_test.py b/envpool/minigrid/minigrid_align_test.py new file mode 100644 index 00000000..6790fa7b --- /dev/null +++ b/envpool/minigrid/minigrid_align_test.py @@ -0,0 +1,105 @@ +# Copyright 2023 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for minigrid environments check.""" + +import time +from typing import Any + +import gymnasium as gym +import minigrid # noqa: F401 +import numpy as np +from absl import logging +from absl.testing import absltest + +import envpool.minigrid.registration # noqa: F401 +from envpool.registration import make_gym + + +class _MiniGridEnvPoolAlignTest(absltest.TestCase): + + def check_spec( + self, spec0: gym.spaces.Space, spec1: gym.spaces.Space + ) -> None: + self.assertEqual(spec0.dtype, spec1.dtype) + if isinstance(spec0, gym.spaces.Discrete): + self.assertEqual(spec0.n, spec1.n) + elif isinstance(spec0, gym.spaces.Box): + np.testing.assert_allclose(spec0.low, spec1.low) + np.testing.assert_allclose(spec0.high, spec1.high) + + def run_align_check( + self, + task_id: str, + num_envs: int = 1, + total: int = 10000, + **kwargs: Any, + ) -> None: + env0 = gym.make(task_id) + env1 = make_gym(task_id, num_envs=num_envs, seed=0, **kwargs) + self.check_spec( + env0.observation_space["direction"], env1.observation_space["direction"] + ) + self.check_spec( + env0.observation_space["image"], env1.observation_space["image"] + ) + self.check_spec(env0.action_space, env1.action_space) + done0 = True + acts = [] + total_time_envpool = 0.0 + total_time_gym = 0.0 + for _ in range(total): + act = env0.action_space.sample() + acts.append(act) + start = time.time() + obs1, rew1, term1, trunc1, info1 = env1.step(np.array([act])) + end = time.time() + total_time_envpool += end - start + start = time.time() + if done0: + obs0, info0 = env0.reset() + auto_reset = True + term0 = trunc0 = False + env0.unwrapped.agent_pos = info1["agent_pos"][0] + env0.unwrapped.agent_dir = obs1["direction"][0] + else: + obs0, rew0, term0, trunc0, info0 = env0.step(act) + auto_reset = False + end = time.time() + total_time_gym += end - start + self.assertEqual(obs0["image"].shape, (7, 7, 3)) + self.assertEqual(obs1["image"].shape, (num_envs, 7, 7, 3)) + done0 = term0 | trunc0 + done1 = term1 | trunc1 + if not auto_reset: + np.testing.assert_allclose(obs0["direction"], obs1["direction"][0]) + np.testing.assert_allclose(obs0["image"], obs1["image"][0]) + np.testing.assert_allclose(rew0, rew1[0], rtol=1e-6) + np.testing.assert_allclose(done0, done1[0]) + np.testing.assert_allclose( + env0.unwrapped.agent_pos, info1["agent_pos"][0] + ) + logging.info(f"{total_time_envpool=}") + logging.info(f"{total_time_gym=}") + + def test_empty(self) -> None: + self.run_align_check("MiniGrid-Empty-5x5-v0") + self.run_align_check("MiniGrid-Empty-6x6-v0") + self.run_align_check("MiniGrid-Empty-8x8-v0") + self.run_align_check("MiniGrid-Empty-16x16-v0") + self.run_align_check("MiniGrid-Empty-Random-5x5-v0") + self.run_align_check("MiniGrid-Empty-Random-6x6-v0") + + +if __name__ == "__main__": + absltest.main() diff --git a/envpool/minigrid/minigrid_deterministic_test.py b/envpool/minigrid/minigrid_deterministic_test.py new file mode 100644 index 00000000..2b2928dd --- /dev/null +++ b/envpool/minigrid/minigrid_deterministic_test.py @@ -0,0 +1,61 @@ +# Copyright 2023 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for minigrid environments check.""" + +from typing import Any + +import numpy as np +from absl.testing import absltest + +import envpool.minigrid.registration # noqa: F401 +from envpool.registration import make_gym + + +class _MiniGridEnvPoolDeterministicTest(absltest.TestCase): + + def run_deterministic_check( + self, + task_id: str, + num_envs: int = 4, + total: int = 5000, + seed: int = 1, + **kwargs: Any, + ) -> None: + env0 = make_gym(task_id, num_envs=num_envs, seed=0, **kwargs) + env1 = make_gym(task_id, num_envs=num_envs, seed=0, **kwargs) + env2 = make_gym(task_id, num_envs=num_envs, seed=1, **kwargs) + act_space = env0.action_space + act_space.seed(seed) + same_count = 0 + for _ in range(total): + action = np.array([act_space.sample() for _ in range(num_envs)]) + obs0, rew0, terminated, truncated, info0 = env0.step(action) + obs1, rew1, terminated, truncated, info1 = env1.step(action) + obs2, rew2, terminated, truncated, info2 = env2.step(action) + np.testing.assert_allclose(obs0["image"], obs1["image"]) + np.testing.assert_allclose(obs0["direction"], obs1["direction"]) + # TODO: this may fail because the available state in minigrid env + # is limited + same_count += np.allclose(obs0["image"], obs2["image"]) and np.allclose( + obs0["direction"], obs2["direction"] + ) + assert same_count == 0, f"{same_count=}" + + def test_empty(self) -> None: + self.run_deterministic_check("MiniGrid-Empty-Random-5x5-v0") + self.run_deterministic_check("MiniGrid-Empty-Random-6x6-v0") + + +if __name__ == "__main__": + absltest.main() diff --git a/envpool/minigrid/registration.py b/envpool/minigrid/registration.py new file mode 100644 index 00000000..03d586f8 --- /dev/null +++ b/envpool/minigrid/registration.py @@ -0,0 +1,86 @@ +# Copyright 2023 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Minigrid env registration.""" + +from envpool.registration import register + +register( + task_id="MiniGrid-Empty-5x5-v0", + import_path="envpool.minigrid", + spec_cls="EmptyEnvSpec", + dm_cls="EmptyDMEnvPool", + gym_cls="EmptyGymEnvPool", + gymnasium_cls="EmptyGymnasiumEnvPool", + max_episode_steps=100, + size=5, +) + +register( + task_id="MiniGrid-Empty-Random-5x5-v0", + import_path="envpool.minigrid", + spec_cls="EmptyEnvSpec", + dm_cls="EmptyDMEnvPool", + gym_cls="EmptyGymEnvPool", + gymnasium_cls="EmptyGymnasiumEnvPool", + max_episode_steps=100, + size=5, + agent_start_pos=(-1, -1), + agent_start_dir=-1, +) + +register( + task_id="MiniGrid-Empty-6x6-v0", + import_path="envpool.minigrid", + spec_cls="EmptyEnvSpec", + dm_cls="EmptyDMEnvPool", + gym_cls="EmptyGymEnvPool", + gymnasium_cls="EmptyGymnasiumEnvPool", + max_episode_steps=144, + size=6, +) + +register( + task_id="MiniGrid-Empty-Random-6x6-v0", + import_path="envpool.minigrid", + spec_cls="EmptyEnvSpec", + dm_cls="EmptyDMEnvPool", + gym_cls="EmptyGymEnvPool", + gymnasium_cls="EmptyGymnasiumEnvPool", + max_episode_steps=144, + size=6, + agent_start_pos=(-1, -1), + agent_start_dir=-1, +) + +register( + task_id="MiniGrid-Empty-8x8-v0", + import_path="envpool.minigrid", + spec_cls="EmptyEnvSpec", + dm_cls="EmptyDMEnvPool", + gym_cls="EmptyGymEnvPool", + gymnasium_cls="EmptyGymnasiumEnvPool", + max_episode_steps=256, + size=8, +) + +register( + task_id="MiniGrid-Empty-16x16-v0", + import_path="envpool.minigrid", + spec_cls="EmptyEnvSpec", + dm_cls="EmptyDMEnvPool", + gym_cls="EmptyGymEnvPool", + gymnasium_cls="EmptyGymnasiumEnvPool", + max_episode_steps=1024, + size=16, +) diff --git a/envpool/mujoco/dmc/__init__.py b/envpool/mujoco/dmc/__init__.py index 1754d110..0000b9a2 100644 --- a/envpool/mujoco/dmc/__init__.py +++ b/envpool/mujoco/dmc/__init__.py @@ -47,100 +47,128 @@ ) from envpool.python.api import py_env -DmcAcrobotEnvSpec, DmcAcrobotDMEnvPool, DmcAcrobotGymEnvPool = py_env( - _DmcAcrobotEnvSpec, _DmcAcrobotEnvPool -) -DmcBallInCupEnvSpec, DmcBallInCupDMEnvPool, DmcBallInCupGymEnvPool = py_env( - _DmcBallInCupEnvSpec, _DmcBallInCupEnvPool -) -DmcCartpoleEnvSpec, DmcCartpoleDMEnvPool, DmcCartpoleGymEnvPool = py_env( - _DmcCartpoleEnvSpec, _DmcCartpoleEnvPool -) -DmcCheetahEnvSpec, DmcCheetahDMEnvPool, DmcCheetahGymEnvPool = py_env( - _DmcCheetahEnvSpec, _DmcCheetahEnvPool -) -DmcFingerEnvSpec, DmcFingerDMEnvPool, DmcFingerGymEnvPool = py_env( - _DmcFingerEnvSpec, _DmcFingerEnvPool -) -DmcFishEnvSpec, DmcFishDMEnvPool, DmcFishGymEnvPool = py_env( - _DmcFishEnvSpec, _DmcFishEnvPool -) -DmcHopperEnvSpec, DmcHopperDMEnvPool, DmcHopperGymEnvPool = py_env( - _DmcHopperEnvSpec, _DmcHopperEnvPool -) -DmcHumanoidEnvSpec, DmcHumanoidDMEnvPool, DmcHumanoidGymEnvPool = py_env( - _DmcHumanoidEnvSpec, _DmcHumanoidEnvPool -) +( + DmcAcrobotEnvSpec, DmcAcrobotDMEnvPool, DmcAcrobotGymEnvPool, + DmcAcrobotGymnasiumEnvPool +) = py_env(_DmcAcrobotEnvSpec, _DmcAcrobotEnvPool) +( + DmcBallInCupEnvSpec, DmcBallInCupDMEnvPool, DmcBallInCupGymEnvPool, + DmcBallInCupGymnasiumEnvPool +) = py_env(_DmcBallInCupEnvSpec, _DmcBallInCupEnvPool) +( + DmcCartpoleEnvSpec, DmcCartpoleDMEnvPool, DmcCartpoleGymEnvPool, + DmcCartpoleGymnasiumEnvPool +) = py_env(_DmcCartpoleEnvSpec, _DmcCartpoleEnvPool) +( + DmcCheetahEnvSpec, DmcCheetahDMEnvPool, DmcCheetahGymEnvPool, + DmcCheetahGymnasiumEnvPool +) = py_env(_DmcCheetahEnvSpec, _DmcCheetahEnvPool) +( + DmcFingerEnvSpec, DmcFingerDMEnvPool, DmcFingerGymEnvPool, + DmcFingerGymnasiumEnvPool +) = py_env(_DmcFingerEnvSpec, _DmcFingerEnvPool) +(DmcFishEnvSpec, DmcFishDMEnvPool, DmcFishGymEnvPool, + DmcFishGymnasiumEnvPool) = py_env(_DmcFishEnvSpec, _DmcFishEnvPool) +( + DmcHopperEnvSpec, DmcHopperDMEnvPool, DmcHopperGymEnvPool, + DmcHopperGymnasiumEnvPool +) = py_env(_DmcHopperEnvSpec, _DmcHopperEnvPool) +( + DmcHumanoidEnvSpec, DmcHumanoidDMEnvPool, DmcHumanoidGymEnvPool, + DmcHumanoidGymnasiumEnvPool +) = py_env(_DmcHumanoidEnvSpec, _DmcHumanoidEnvPool) ( DmcHumanoidCMUEnvSpec, DmcHumanoidCMUDMEnvPool, DmcHumanoidCMUGymEnvPool, + DmcHumanoidCMUGymnasiumEnvPool, ) = py_env(_DmcHumanoidCMUEnvSpec, _DmcHumanoidCMUEnvPool) ( DmcManipulatorEnvSpec, DmcManipulatorDMEnvPool, DmcManipulatorGymEnvPool, + DmcManipulatorGymnasiumEnvPool, ) = py_env(_DmcManipulatorEnvSpec, _DmcManipulatorEnvPool) -DmcPendulumEnvSpec, DmcPendulumDMEnvPool, DmcPendulumGymEnvPool = py_env( - _DmcPendulumEnvSpec, _DmcPendulumEnvPool -) -DmcPointMassEnvSpec, DmcPointMassDMEnvPool, DmcPointMassGymEnvPool = py_env( - _DmcPointMassEnvSpec, _DmcPointMassEnvPool -) -DmcReacherEnvSpec, DmcReacherDMEnvPool, DmcReacherGymEnvPool = py_env( - _DmcReacherEnvSpec, _DmcReacherEnvPool -) -DmcSwimmerEnvSpec, DmcSwimmerDMEnvPool, DmcSwimmerGymEnvPool = py_env( - _DmcSwimmerEnvSpec, _DmcSwimmerEnvPool -) -DmcWalkerEnvSpec, DmcWalkerDMEnvPool, DmcWalkerGymEnvPool = py_env( - _DmcWalkerEnvSpec, _DmcWalkerEnvPool -) +( + DmcPendulumEnvSpec, DmcPendulumDMEnvPool, DmcPendulumGymEnvPool, + DmcPendulumGymnasiumEnvPool +) = py_env(_DmcPendulumEnvSpec, _DmcPendulumEnvPool) +( + DmcPointMassEnvSpec, DmcPointMassDMEnvPool, DmcPointMassGymEnvPool, + DmcPointMassGymnasiumEnvPool +) = py_env(_DmcPointMassEnvSpec, _DmcPointMassEnvPool) +( + DmcReacherEnvSpec, DmcReacherDMEnvPool, DmcReacherGymEnvPool, + DmcReacherGymnasiumEnvPool +) = py_env(_DmcReacherEnvSpec, _DmcReacherEnvPool) +( + DmcSwimmerEnvSpec, DmcSwimmerDMEnvPool, DmcSwimmerGymEnvPool, + DmcSwimmerGymnasiumEnvPool +) = py_env(_DmcSwimmerEnvSpec, _DmcSwimmerEnvPool) +( + DmcWalkerEnvSpec, DmcWalkerDMEnvPool, DmcWalkerGymEnvPool, + DmcWalkerGymnasiumEnvPool +) = py_env(_DmcWalkerEnvSpec, _DmcWalkerEnvPool) __all__ = [ "DmcAcrobotEnvSpec", "DmcAcrobotDMEnvPool", "DmcAcrobotGymEnvPool", + "DmcAcrobotGymnasiumEnvPool", "DmcBallInCupEnvSpec", "DmcBallInCupDMEnvPool", "DmcBallInCupGymEnvPool", + "DmcBallInCupGymnasiumEnvPool", "DmcCartpoleEnvSpec", "DmcCartpoleDMEnvPool", "DmcCartpoleGymEnvPool", + "DmcCartpoleGymnasiumEnvPool", "DmcCheetahEnvSpec", "DmcCheetahDMEnvPool", "DmcCheetahGymEnvPool", + "DmcCheetahGymnasiumEnvPool", "DmcFingerEnvSpec", "DmcFingerDMEnvPool", "DmcFingerGymEnvPool", + "DmcFingerGymnasiumEnvPool", "DmcFishEnvSpec", "DmcFishDMEnvPool", "DmcFishGymEnvPool", + "DmcFishGymnasiumEnvPool", "DmcHopperEnvSpec", "DmcHopperDMEnvPool", "DmcHopperGymEnvPool", + "DmcHopperGymnasiumEnvPool", "DmcHumanoidEnvSpec", "DmcHumanoidDMEnvPool", "DmcHumanoidGymEnvPool", + "DmcHumanoidGymnasiumEnvPool", "DmcHumanoidCMUEnvSpec", "DmcHumanoidCMUDMEnvPool", "DmcHumanoidCMUGymEnvPool", + "DmcHumanoidCMUGymnasiumEnvPool", "DmcManipulatorEnvSpec", "DmcManipulatorDMEnvPool", "DmcManipulatorGymEnvPool", + "DmcManipulatorGymnasiumEnvPool", "DmcPendulumEnvSpec", "DmcPendulumDMEnvPool", "DmcPendulumGymEnvPool", + "DmcPendulumGymnasiumEnvPool", "DmcPointMassEnvSpec", "DmcPointMassDMEnvPool", "DmcPointMassGymEnvPool", + "DmcPointMassGymnasiumEnvPool", "DmcReacherEnvSpec", "DmcReacherDMEnvPool", "DmcReacherGymEnvPool", + "DmcReacherGymnasiumEnvPool", "DmcSwimmerEnvSpec", "DmcSwimmerDMEnvPool", "DmcSwimmerGymEnvPool", + "DmcSwimmerGymnasiumEnvPool", "DmcWalkerEnvSpec", "DmcWalkerDMEnvPool", "DmcWalkerGymEnvPool", + "DmcWalkerGymnasiumEnvPool", ] diff --git a/envpool/mujoco/dmc/manipulator.h b/envpool/mujoco/dmc/manipulator.h index 39db29c5..d508cbfc 100644 --- a/envpool/mujoco/dmc/manipulator.h +++ b/envpool/mujoco/dmc/manipulator.h @@ -173,7 +173,7 @@ class ManipulatorEnv : public Env, public MujocoEnv { while (penetrating) { for (std::size_t i = 0; i < kArmJoints.size(); ++i) { int id_joint = id_arm_joints_[i]; - bool is_limited = model_->jnt_limited[id_joint] == 1 ? true : false; + bool is_limited = model_->jnt_limited[id_joint] == 1; mjtNum lower = is_limited ? model_->jnt_range[id_joint * 2 + 0] : -M_PI; mjtNum upper = is_limited ? model_->jnt_range[id_joint * 2 + 1] : M_PI; data_->qpos[id_arm_qpos_[i]] = RandUniform(lower, upper)(gen_); diff --git a/envpool/mujoco/dmc/mujoco_dmc_suite_align_test.py b/envpool/mujoco/dmc/mujoco_dmc_suite_align_test.py index 5089ad5b..0f399739 100644 --- a/envpool/mujoco/dmc/mujoco_dmc_suite_align_test.py +++ b/envpool/mujoco/dmc/mujoco_dmc_suite_align_test.py @@ -150,7 +150,7 @@ def run_align_check( done = ts0.step_type == dm_env.StepType.LAST o0, o1 = ts0.observation, ts1.observation for k in obs_spec: - np.testing.assert_allclose(o0[k], getattr(o1, k)[0]) + np.testing.assert_allclose(o0[k], getattr(o1, k)[0], atol=1e-6) np.testing.assert_allclose(ts0.step_type, ts1.step_type[0]) np.testing.assert_allclose(ts0.reward, ts1.reward[0], atol=1e-8) np.testing.assert_allclose(ts0.discount, ts1.discount[0]) diff --git a/envpool/mujoco/dmc/registration.py b/envpool/mujoco/dmc/registration.py index 7d8c9535..93b91b81 100644 --- a/envpool/mujoco/dmc/registration.py +++ b/envpool/mujoco/dmc/registration.py @@ -13,14 +13,8 @@ # limitations under the License. """Mujoco dm_control suite env registration.""" -import os - from envpool.registration import register -base_path = os.path.abspath( - os.path.join(os.path.dirname(__file__), "..", "..") -) - # from suite.BENCHMARKING dmc_mujoco_envs = [ ("acrobot", "swingup", 1000), @@ -71,7 +65,7 @@ spec_cls=f"Dmc{domain_name}EnvSpec", dm_cls=f"Dmc{domain_name}DMEnvPool", gym_cls=f"Dmc{domain_name}GymEnvPool", - base_path=base_path, + gymnasium_cls=f"Dmc{domain_name}GymnasiumEnvPool", task_name=task, max_episode_steps=max_episode_steps, ) diff --git a/envpool/mujoco/gym/__init__.py b/envpool/mujoco/gym/__init__.py index ac1fdc45..a60cd2ee 100644 --- a/envpool/mujoco/gym/__init__.py +++ b/envpool/mujoco/gym/__init__.py @@ -39,29 +39,33 @@ ) from envpool.python.api import py_env -GymAntEnvSpec, GymAntDMEnvPool, GymAntGymEnvPool = py_env( - _GymAntEnvSpec, _GymAntEnvPool -) +(GymAntEnvSpec, GymAntDMEnvPool, GymAntGymEnvPool, + GymAntGymnasiumEnvPool) = py_env(_GymAntEnvSpec, _GymAntEnvPool) ( GymHalfCheetahEnvSpec, GymHalfCheetahDMEnvPool, GymHalfCheetahGymEnvPool, + GymHalfCheetahGymnasiumEnvPool, ) = py_env(_GymHalfCheetahEnvSpec, _GymHalfCheetahEnvPool) -GymHopperEnvSpec, GymHopperDMEnvPool, GymHopperGymEnvPool = py_env( - _GymHopperEnvSpec, _GymHopperEnvPool -) -GymHumanoidEnvSpec, GymHumanoidDMEnvPool, GymHumanoidGymEnvPool = py_env( - _GymHumanoidEnvSpec, _GymHumanoidEnvPool -) +( + GymHopperEnvSpec, GymHopperDMEnvPool, GymHopperGymEnvPool, + GymHopperGymnasiumEnvPool +) = py_env(_GymHopperEnvSpec, _GymHopperEnvPool) +( + GymHumanoidEnvSpec, GymHumanoidDMEnvPool, GymHumanoidGymEnvPool, + GymHumanoidGymnasiumEnvPool +) = py_env(_GymHumanoidEnvSpec, _GymHumanoidEnvPool) ( GymHumanoidStandupEnvSpec, GymHumanoidStandupDMEnvPool, GymHumanoidStandupGymEnvPool, + GymHumanoidStandupGymnasiumEnvPool, ) = py_env(_GymHumanoidStandupEnvSpec, _GymHumanoidStandupEnvPool) ( GymInvertedDoublePendulumEnvSpec, GymInvertedDoublePendulumDMEnvPool, GymInvertedDoublePendulumGymEnvPool, + GymInvertedDoublePendulumGymnasiumEnvPool, ) = py_env( _GymInvertedDoublePendulumEnvSpec, _GymInvertedDoublePendulumEnvPool ) @@ -69,52 +73,68 @@ GymInvertedPendulumEnvSpec, GymInvertedPendulumDMEnvPool, GymInvertedPendulumGymEnvPool, + GymInvertedPendulumGymnasiumEnvPool, ) = py_env(_GymInvertedPendulumEnvSpec, _GymInvertedPendulumEnvPool) -GymPusherEnvSpec, GymPusherDMEnvPool, GymPusherGymEnvPool = py_env( - _GymPusherEnvSpec, _GymPusherEnvPool -) -GymReacherEnvSpec, GymReacherDMEnvPool, GymReacherGymEnvPool = py_env( - _GymReacherEnvSpec, _GymReacherEnvPool -) -GymSwimmerEnvSpec, GymSwimmerDMEnvPool, GymSwimmerGymEnvPool = py_env( - _GymSwimmerEnvSpec, _GymSwimmerEnvPool -) -GymWalker2dEnvSpec, GymWalker2dDMEnvPool, GymWalker2dGymEnvPool = py_env( - _GymWalker2dEnvSpec, _GymWalker2dEnvPool -) +( + GymPusherEnvSpec, GymPusherDMEnvPool, GymPusherGymEnvPool, + GymPusherGymnasiumEnvPool +) = py_env(_GymPusherEnvSpec, _GymPusherEnvPool) +( + GymReacherEnvSpec, GymReacherDMEnvPool, GymReacherGymEnvPool, + GymReacherGymnasiumEnvPool +) = py_env(_GymReacherEnvSpec, _GymReacherEnvPool) +( + GymSwimmerEnvSpec, GymSwimmerDMEnvPool, GymSwimmerGymEnvPool, + GymSwimmerGymnasiumEnvPool +) = py_env(_GymSwimmerEnvSpec, _GymSwimmerEnvPool) +( + GymWalker2dEnvSpec, GymWalker2dDMEnvPool, GymWalker2dGymEnvPool, + GymWalker2dGymnasiumEnvPool +) = py_env(_GymWalker2dEnvSpec, _GymWalker2dEnvPool) __all__ = [ "GymAntEnvSpec", "GymAntDMEnvPool", "GymAntGymEnvPool", + "GymnasiumAntGymEnvPool", "GymHalfCheetahEnvSpec", "GymHalfCheetahDMEnvPool", "GymHalfCheetahGymEnvPool", + "GymHalfCheetahGymnasiumEnvPool", "GymHopperEnvSpec", "GymHopperDMEnvPool", "GymHopperGymEnvPool", + "GymHopperGymnasiumEnvPool", "GymHumanoidEnvSpec", "GymHumanoidDMEnvPool", "GymHumanoidGymEnvPool", + "GymHumanoidGymnasiumEnvPool", "GymHumanoidStandupEnvSpec", "GymHumanoidStandupDMEnvPool", "GymHumanoidStandupGymEnvPool", + "GymHumanoidStandupGymnasiumEnvPool", "GymInvertedDoublePendulumEnvSpec", "GymInvertedDoublePendulumDMEnvPool", "GymInvertedDoublePendulumGymEnvPool", + "GymInvertedDoublePendulumGymnasiumEnvPool", "GymInvertedPendulumEnvSpec", "GymInvertedPendulumDMEnvPool", "GymInvertedPendulumGymEnvPool", + "GymInvertedPendulumGymnasiumEnvPool", "GymPusherEnvSpec", "GymPusherDMEnvPool", "GymPusherGymEnvPool", + "GymPusherGymnasiumEnvPool", "GymReacherEnvSpec", "GymReacherDMEnvPool", "GymReacherGymEnvPool", + "GymReacherGymnasiumEnvPool", "GymSwimmerEnvSpec", "GymSwimmerDMEnvPool", "GymSwimmerGymEnvPool", + "GymSwimmerGymnasiumEnvPool", "GymWalker2dEnvSpec", "GymWalker2dDMEnvPool", "GymWalker2dGymEnvPool", + "GymWalker2dGymnasiumEnvPool", ] diff --git a/envpool/mujoco/gym/mujoco_env.h b/envpool/mujoco/gym/mujoco_env.h index cf70a8c8..03db2b6f 100644 --- a/envpool/mujoco/gym/mujoco_env.h +++ b/envpool/mujoco/gym/mujoco_env.h @@ -38,7 +38,7 @@ class MujocoEnv { int frame_skip_; bool post_constraint_; int max_episode_steps_, elapsed_step_; - bool done_; + bool done_{true}; public: MujocoEnv(const std::string& xml, int frame_skip, bool post_constraint, @@ -54,8 +54,7 @@ class MujocoEnv { frame_skip_(frame_skip), post_constraint_(post_constraint), max_episode_steps_(max_episode_steps), - elapsed_step_(max_episode_steps + 1), - done_(true) { + elapsed_step_(max_episode_steps + 1) { std::memcpy(init_qpos_, data_->qpos, sizeof(mjtNum) * model_->nq); std::memcpy(init_qvel_, data_->qvel, sizeof(mjtNum) * model_->nv); } diff --git a/envpool/mujoco/gym/registration.py b/envpool/mujoco/gym/registration.py index bde22411..25364c56 100644 --- a/envpool/mujoco/gym/registration.py +++ b/envpool/mujoco/gym/registration.py @@ -13,14 +13,8 @@ # limitations under the License. """Mujoco gym env registration.""" -import os - from envpool.registration import register -base_path = os.path.abspath( - os.path.join(os.path.dirname(__file__), "..", "..") -) - gym_mujoco_envs = [ ("Ant", "v3", False, 1000), ("Ant", "v4", True, 1000), @@ -56,7 +50,7 @@ spec_cls=f"Gym{task}EnvSpec", dm_cls=f"Gym{task}DMEnvPool", gym_cls=f"Gym{task}GymEnvPool", - base_path=base_path, + gymnasium_cls=f"Gym{task}GymnasiumEnvPool", post_constraint=post_constraint, max_episode_steps=max_episode_steps, **extra_args, diff --git a/envpool/pip.bzl b/envpool/pip.bzl index 8549b8af..b6323108 100644 --- a/envpool/pip.bzl +++ b/envpool/pip.bzl @@ -23,6 +23,8 @@ def workspace(): pip_install( name = "pip_requirements", python_interpreter = "python3", + # default timeout value is 600, change it if you failed. + # timeout = 3600, quiet = False, requirements = "@envpool//third_party/pip_requirements:requirements.txt", # extra_pip_args = ["--extra-index-url", "https://mirrors.aliyun.com/pypi/simple"], diff --git a/envpool/procgen/BUILD b/envpool/procgen/BUILD new file mode 100644 index 00000000..a2ff24c8 --- /dev/null +++ b/envpool/procgen/BUILD @@ -0,0 +1,85 @@ +# Copyright 2023 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@pip_requirements//:requirements.bzl", "requirement") +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package(default_visibility = ["//visibility:public"]) + +genrule( + name = "gen_procgen_assets", + srcs = ["@procgen//:procgen_assets"], + outs = ["assets"], + cmd = "mkdir -p $(OUTS) && cp -r $(SRCS) $(OUTS)", +) + +cc_library( + name = "procgen_env", + hdrs = ["procgen_env.h"], + data = [ + ":gen_procgen_assets", + ], + deps = [ + "//envpool/core:async_envpool", + "@procgen", + ], +) + +cc_test( + name = "procgen_env_test", + srcs = ["procgen_env_test.cc"], + deps = [ + ":procgen_env", + "@com_google_googletest//:gtest_main", + ], +) + +pybind_extension( + name = "procgen_envpool", + srcs = ["procgen_envpool.cc"], + linkopts = [ + "-ldl", + ], + deps = [ + ":procgen_env", + "//envpool/core:py_envpool", + ], +) + +py_library( + name = "procgen", + srcs = ["__init__.py"], + data = [":procgen_envpool.so"], + deps = ["//envpool/python:api"], +) + +py_library( + name = "procgen_registration", + srcs = ["registration.py"], + deps = [ + "//envpool:registration", + ], +) + +py_test( + name = "procgen_test", + srcs = ["procgen_test.py"], + deps = [ + ":procgen", + ":procgen_registration", + requirement("numpy"), + requirement("absl-py"), + requirement("gym"), + ], +) diff --git a/envpool/procgen/__init__.py b/envpool/procgen/__init__.py new file mode 100644 index 00000000..23238277 --- /dev/null +++ b/envpool/procgen/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2023 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Procgen env Init.""" +from envpool.python.api import py_env + +from .procgen_envpool import _ProcgenEnvPool, _ProcgenEnvSpec + +( + ProcgenEnvSpec, + ProcgenDMEnvPool, + ProcgenGymEnvPool, + ProcgenGymnasiumEnvPool, +) = py_env(_ProcgenEnvSpec, _ProcgenEnvPool) + +__all__ = [ + "ProcgenEnvSpec", + "ProcgenDMEnvPool", + "ProcgenGymEnvPool", + "ProcgenGymnasiumEnvPool", +] diff --git a/envpool/procgen/procgen_env.h b/envpool/procgen/procgen_env.h new file mode 100644 index 00000000..e2d43df9 --- /dev/null +++ b/envpool/procgen/procgen_env.h @@ -0,0 +1,213 @@ +/* + * Copyright 2023 Garena Online Private Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ENVPOOL_PROCGEN_PROCGEN_ENV_H_ +#define ENVPOOL_PROCGEN_PROCGEN_ENV_H_ + +#include +#include +#include +#include +#include + +#include "envpool/core/async_envpool.h" +#include "envpool/core/env.h" +#include "game.h" + +namespace procgen { + +/* + All the procgen's games have the same observation buffer size, 64 x 64 pixels + x 3 colors (RGB) there are 15 possible action buttoms and observation is RGB + 32 or RGB 888, + QT build needs: sudo apt update && sudo apt install qtdeclarative5-dev + */ +static const int kRes = 64; +static std::once_flag procgen_global_init_flag; + +void ProcgenGlobalInit(std::string path) { + if (global_resource_root.empty()) { + global_resource_root = std::move(path); + images_load(); + } +} + +// https://github.com/openai/procgen/blob/0.10.7/procgen/src/vecgame.cpp#L156 +std::size_t HashStrUint32(const std::string& str) { + std::size_t hash = 0x811c9dc5; + std::size_t prime = 0x1000193; + for (uint8_t value : str) { + hash ^= value; + hash *= prime; + } + return hash; +} + +class ProcgenEnvFns { + public: + static decltype(auto) DefaultConfig() { + return MakeDict( + "env_name"_.Bind(std::string("bigfish")), "channel_first"_.Bind(true), + "num_levels"_.Bind(0), "start_level"_.Bind(0), + "use_sequential_levels"_.Bind(false), "center_agent"_.Bind(true), + "use_backgrounds"_.Bind(true), "use_monochrome_assets"_.Bind(false), + "restrict_themes"_.Bind(false), "use_generated_assets"_.Bind(false), + "paint_vel_info"_.Bind(false), "use_easy_jump"_.Bind(false), + "distribution_mode"_.Bind(1)); + } + + template + static decltype(auto) StateSpec(const Config& conf) { + // The observation is RGB 64 x 64 x 3 + return MakeDict( + "obs"_.Bind(Spec(conf["channel_first"_] + ? std::vector{3, kRes, kRes} + : std::vector{kRes, kRes, 3}, + {0, 255})), + "info:prev_level_seed"_.Bind(Spec({-1})), + "info:prev_level_complete"_.Bind(Spec({-1})), + "info:level_seed"_.Bind(Spec({-1}))); + } + + template + static decltype(auto) ActionSpec(const Config& conf) { + // 15 action buttons in total, ranging from 0 to 14 + return MakeDict("action"_.Bind(Spec({-1}, {0, 14}))); + } +}; + +using ProcgenEnvSpec = EnvSpec; +using FrameSpec = Spec; + +class ProcgenEnv : public Env { + protected: + std::shared_ptr game_; + std::string env_name_; + bool channel_first_; + // buffer used by game + FrameSpec obs_spec_; + Array obs_; + float reward_; + int level_seed_, prev_level_seed_; + uint8_t done_{1}, prev_level_complete_; + + public: + ProcgenEnv(const Spec& spec, int env_id) + : Env(spec, env_id), + env_name_(spec.config["env_name"_]), + channel_first_(spec.config["channel_first"_]), + obs_spec_({kRes, kRes, 3}), + obs_(obs_spec_) { + /* Initialize the single game we are holding in this EnvPool environment + * It depends on some default setting along with the config map passed in + * We mostly follow how it's done in the vector environment at Procgen and + * translate it into single one. + * https://github.com/openai/procgen/blob/0.10.7/procgen/src/vecgame.cpp#L312 + */ + std::call_once(procgen_global_init_flag, ProcgenGlobalInit, + spec.config["base_path"_] + "/procgen/assets/"); + // CHECK_NE(globalGameRegistry, nullptr); + // game_ = globalGameRegistry->at(env_name_)(); + game_ = make_game(spec.config["env_name"_]); + CHECK_EQ(game_->game_name, env_name_); + game_->level_seed_rand_gen.seed(seed_); + int num_levels = spec.config["num_levels"_]; + int start_level = spec.config["start_level"_]; + if (num_levels <= 0) { + game_->level_seed_low = 0; + game_->level_seed_high = std::numeric_limits::max(); + } else { + game_->level_seed_low = start_level; + game_->level_seed_high = start_level + num_levels; + } + game_->game_n = env_id; + if (game_->fixed_asset_seed == 0) { + game_->fixed_asset_seed = static_cast(HashStrUint32(env_name_)); + } + + // buffers for the game to outwrite observations each step + game_->reward_ptr = &reward_; + game_->first_ptr = &done_; + game_->obs_bufs.emplace_back(static_cast(obs_.Data())); + game_->info_bufs.emplace_back(static_cast(&level_seed_)); + game_->info_bufs.emplace_back(static_cast(&prev_level_seed_)); + game_->info_bufs.emplace_back(static_cast(&prev_level_complete_)); + game_->info_name_to_offset["level_seed"] = 0; + game_->info_name_to_offset["prev_level_seed"] = 1; + game_->info_name_to_offset["prev_level_complete"] = 2; + // game options + game_->options.use_easy_jump = spec.config["use_easy_jump"_]; + game_->options.paint_vel_info = spec.config["paint_vel_info"_]; + game_->options.use_generated_assets = spec.config["use_generated_assets"_]; + game_->options.use_monochrome_assets = + spec.config["use_monochrome_assets"_]; + game_->options.restrict_themes = spec.config["restrict_themes"_]; + game_->options.use_backgrounds = spec.config["use_backgrounds"_]; + game_->options.center_agent = spec.config["center_agent"_]; + game_->options.use_sequential_levels = + spec.config["use_sequential_levels"_]; + game_->options.distribution_mode = + static_cast(spec.config["distribution_mode"_]); + game_->game_init(); + } + + void Reset() override { + game_->step_data.done = false; + game_->step_data.reward = 0.0; + game_->step_data.level_complete = false; + game_->reset(); + game_->observe(); + WriteObs(); + } + + void Step(const Action& action) override { + game_->action = action["action"_]; + game_->step(); + WriteObs(); + } + + bool IsDone() override { return done_ != 0; } + + private: + void WriteObs() { + State state = Allocate(); + if (channel_first_) { + // convert from HWC to CHW + auto* data = static_cast(state["obs"_].Data()); + auto* buffer = static_cast(obs_.Data()); + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < kRes; ++j) { + for (int k = 0; k < kRes; ++k) { + data[i * kRes * kRes + j * kRes + k] = + buffer[j * kRes * 3 + k * 3 + i]; + } + } + } + } else { + state["obs"_].Assign(obs_); + } + state["reward"_] = reward_; + state["info:prev_level_seed"_] = prev_level_seed_; + state["info:prev_level_complete"_] = prev_level_complete_; + state["info:level_seed"_] = level_seed_; + } +}; + +using ProcgenEnvPool = AsyncEnvPool; + +} // namespace procgen + +#endif // ENVPOOL_PROCGEN_PROCGEN_ENV_H_ diff --git a/envpool/procgen/procgen_env_test.cc b/envpool/procgen/procgen_env_test.cc new file mode 100644 index 00000000..744bda8e --- /dev/null +++ b/envpool/procgen/procgen_env_test.cc @@ -0,0 +1,66 @@ +// Copyright 2023 Garena Online Private Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "envpool/procgen/procgen_env.h" + +#include +#include + +#include + +using ProcgenState = procgen::ProcgenEnv::State; +using ProcgenAction = procgen::ProcgenEnv::Action; + +TEST(PRocgenEnvTest, BasicStep) { + std::srand(std::time(nullptr)); + auto config = procgen::ProcgenEnvSpec::kDefaultConfig; + std::size_t batch = 4; + config["num_envs"_] = batch; + config["batch_size"_] = batch; + config["seed"_] = 0; + config["env_name"_] = "coinrun"; + int total_iter = 10000; + procgen::ProcgenEnvSpec spec(config); + procgen::ProcgenEnvPool envpool(spec); + Array all_env_ids(Spec({static_cast(batch)})); + for (std::size_t i = 0; i < batch; ++i) { + all_env_ids[i] = i; + } + envpool.Reset(all_env_ids); + std::vector raw_action(3); + ProcgenAction action(&raw_action); + for (int i = 0; i < total_iter; ++i) { + auto state_vec = envpool.Recv(); + ProcgenState state(&state_vec); + EXPECT_EQ(state["obs"_].Shape(), + std::vector({batch, 3, 64, 64})); + uint8_t* data = static_cast(state["obs"_].Data()); + int index = 0; + for (std::size_t j = 0; j < batch; ++j) { + // ensure there's no black screen in each frame + int sum = 0; + for (int k = 0; k < 64 * 64 * 3; ++k) { + sum += data[index++]; + } + EXPECT_NE(sum, 0) << i << " " << j; + } + action["env_id"_] = state["info:env_id"_]; + action["players.env_id"_] = state["info:env_id"_]; + action["action"_] = Array(Spec({static_cast(batch)})); + for (std::size_t j = 0; j < batch; ++j) { + action["action"_][j] = std::rand() % 15; + } + envpool.Send(action); + } +} diff --git a/envpool/procgen/procgen_envpool.cc b/envpool/procgen/procgen_envpool.cc new file mode 100644 index 00000000..2cdc6e9b --- /dev/null +++ b/envpool/procgen/procgen_envpool.cc @@ -0,0 +1,23 @@ +// Copyright 2023 Garena Online Private Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "envpool/core/py_envpool.h" +#include "envpool/procgen/procgen_env.h" + +using ProcgenEnvSpec = PyEnvSpec; +using ProcgenEnvPool = PyEnvPool; + +PYBIND11_MODULE(procgen_envpool, m) { + REGISTER(m, ProcgenEnvSpec, ProcgenEnvPool) +} diff --git a/envpool/procgen/procgen_test.py b/envpool/procgen/procgen_test.py new file mode 100644 index 00000000..0ed1fc85 --- /dev/null +++ b/envpool/procgen/procgen_test.py @@ -0,0 +1,94 @@ +# Copyright 2023 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for Procgen environments.""" + +# import cv2 +import numpy as np +from absl import logging +from absl.testing import absltest + +from envpool.procgen.registration import distribution, procgen_game_config +from envpool.registration import make_gym + + +class _ProcgenEnvPoolTest(absltest.TestCase): + + def deterministic_check( + self, task_id: str, num_envs: int = 4, total: int = 200 + ) -> None: + logging.info(f"deterministic check for {task_id}") + env0 = make_gym(task_id, num_envs=num_envs, seed=0) + env1 = make_gym(task_id, num_envs=num_envs, seed=0) + env2 = make_gym(task_id, num_envs=num_envs, seed=1) + act_space = env0.action_space + for _ in range(total): + action = np.array([act_space.sample() for _ in range(num_envs)]) + obs0 = env0.step(action)[0] + obs1 = env1.step(action)[0] + obs2 = env2.step(action)[0] + np.testing.assert_allclose(obs0, obs1) + self.assertFalse(np.allclose(obs0, obs2)) + + def test_deterministic(self) -> None: + for env_name, _, dist_mode in procgen_game_config: + for dist_value in dist_mode: + task_id = f"{env_name.capitalize()}{distribution[dist_value]}-v0" + self.deterministic_check(task_id) + + def test_align(self) -> None: + task_id = "CoinrunHard-v0" + seed = 0 + env = make_gym(task_id, seed=seed, channel_first=False) + env.action_space.seed(seed) + done = [False] + cnt = sum_reward = sum_obs = 0 + while not done[0]: + cnt += 1 + act = env.action_space.sample() + obs, rew, term, trunc, info = env.step(np.array([act])) + sum_obs = obs[0].astype(int) + sum_obs + done = term | trunc + sum_reward += rew[0] + # cv2.imwrite(f"/tmp/envpool/{cnt}.png", obs[0]) + # print(f"{cnt=} {obs.sum()=} {done=} {rew=} {info=}") + self.assertEqual(sum_reward, 10) + self.assertEqual(rew[0], 10) + self.assertEqual(cnt, 645) + self.assertEqual(info["level_seed"][0], 209652397) + self.assertEqual(info["prev_level_complete"][0], 1) + pixel_mean_ref = [196.86093636, 144.85448235, 95.27605529] + pixel_mean = (sum_obs / cnt).mean(axis=0).mean(axis=0) # type: ignore + np.testing.assert_allclose(pixel_mean, pixel_mean_ref) + + def test_channel_first( + self, + task_id: str = "CoinrunHard-v0", + seed: int = 0, + total: int = 1000, + ) -> None: + env1 = make_gym(task_id, seed=seed, channel_first=True) + env2 = make_gym(task_id, seed=seed, channel_first=False) + self.assertEqual(env1.observation_space.shape, (3, 64, 64)) + self.assertEqual(env2.observation_space.shape, (64, 64, 3)) + for _ in range(total): + act = env1.action_space.sample() + obs1 = env1.step(np.array([act]))[0][0] + obs2 = env2.step(np.array([act]))[0][0] + self.assertEqual(obs1.shape, (3, 64, 64)) + self.assertEqual(obs2.shape, (64, 64, 3)) + np.testing.assert_allclose(obs1, obs2.transpose(2, 0, 1)) + + +if __name__ == "__main__": + absltest.main() diff --git a/envpool/procgen/registration.py b/envpool/procgen/registration.py new file mode 100644 index 00000000..23639b41 --- /dev/null +++ b/envpool/procgen/registration.py @@ -0,0 +1,58 @@ +# Copyright 2023 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Procgen env registration.""" +from envpool.registration import register + +# 16 games in Procgen +# https://github.com/openai/procgen/blob/0.10.7/procgen/src/game.cpp#L56-L66 +procgen_game_config = [ + ("bigfish", 6000, [0, 1]), + ("bossfight", 4000, [0, 1]), + ("caveflyer", 1000, [0, 1, 10]), + ("chaser", 1000, [0, 1, 2]), + ("climber", 1000, [0, 1]), + ("coinrun", 1000, [0, 1]), + ("dodgeball", 1000, [0, 1, 2, 10]), + ("fruitbot", 1000, [0, 1]), + ("heist", 1000, [0, 1, 10]), + ("jumper", 1000, [0, 1, 10]), + ("leaper", 500, [0, 1, 2]), + ("maze", 500, [0, 1, 10]), + ("miner", 1000, [0, 1, 10]), + ("ninja", 1000, [0, 1]), + ("plunder", 4000, [0, 1]), + ("starpilot", 1000, [0, 1, 2]), +] + +distribution = { + 0: "Easy", + 1: "Hard", + 2: "Extreme", + 10: "Memory", +} + +for env_name, timeout, dist_mode in procgen_game_config: + for dist_value in dist_mode: + dist_name = distribution[dist_value] + register( + task_id=f"{env_name.capitalize()}{dist_name}-v0", + import_path="envpool.procgen", + spec_cls="ProcgenEnvSpec", + dm_cls="ProcgenDMEnvPool", + gym_cls="ProcgenGymEnvPool", + gymnasium_cls="ProcgenGymnasiumEnvPool", + env_name=env_name, + distribution_mode=dist_value, + max_episode_steps=timeout, + ) diff --git a/envpool/python/BUILD b/envpool/python/BUILD index ef0ec376..f95b1573 100644 --- a/envpool/python/BUILD +++ b/envpool/python/BUILD @@ -38,9 +38,10 @@ py_library( name = "data", srcs = ["data.py"], deps = [ - requirement("treevalue"), + requirement("optree"), requirement("dm-env"), requirement("gym"), + requirement("gymnasium"), requirement("numpy"), ":protocol", ], @@ -52,6 +53,7 @@ py_library( deps = [ requirement("dm-env"), requirement("gym"), + requirement("gymnasium"), ":data", ":protocol", ":utils", @@ -62,7 +64,7 @@ py_library( name = "envpool", srcs = ["envpool.py"], deps = [ - requirement("treevalue"), + requirement("optree"), requirement("dm-env"), requirement("numpy"), requirement("packaging"), @@ -95,7 +97,7 @@ py_library( name = "dm_envpool", srcs = ["dm_envpool.py"], deps = [ - requirement("treevalue"), + requirement("optree"), requirement("dm-env"), requirement("numpy"), ":data", @@ -109,7 +111,7 @@ py_library( name = "gym_envpool", srcs = ["gym_envpool.py"], deps = [ - requirement("treevalue"), + requirement("optree"), requirement("dm-env"), requirement("gym"), requirement("numpy"), @@ -120,6 +122,21 @@ py_library( ], ) +py_library( + name = "gymnasium_envpool", + srcs = ["gymnasium_envpool.py"], + deps = [ + requirement("optree"), + requirement("dm-env"), + requirement("gymnasium"), + requirement("numpy"), + ":data", + ":envpool", + ":lax", + ":utils", + ], +) + py_library( name = "api", srcs = ["api.py"], @@ -127,6 +144,7 @@ py_library( ":dm_envpool", ":env_spec", ":gym_envpool", + ":gymnasium_envpool", ":protocol", ], ) diff --git a/envpool/python/api.py b/envpool/python/api.py index 1646e753..1431b6ac 100644 --- a/envpool/python/api.py +++ b/envpool/python/api.py @@ -18,12 +18,13 @@ from .dm_envpool import DMEnvPoolMeta from .env_spec import EnvSpecMeta from .gym_envpool import GymEnvPoolMeta +from .gymnasium_envpool import GymnasiumEnvPoolMeta from .protocol import EnvPool, EnvSpec def py_env( envspec: Type[EnvSpec], envpool: Type[EnvPool] -) -> Tuple[Type[EnvSpec], Type[EnvPool], Type[EnvPool]]: +) -> Tuple[Type[EnvSpec], Type[EnvPool], Type[EnvPool], Type[EnvPool]]: """Initialize EnvPool for users.""" # remove the _ prefix added when registering cpp class via pybind spec_name = envspec.__name__[1:] @@ -32,4 +33,7 @@ def py_env( EnvSpecMeta(spec_name, (envspec,), {}), # type: ignore[return-value] DMEnvPoolMeta(pool_name.replace("EnvPool", "DMEnvPool"), (envpool,), {}), GymEnvPoolMeta(pool_name.replace("EnvPool", "GymEnvPool"), (envpool,), {}), + GymnasiumEnvPoolMeta( + pool_name.replace("EnvPool", "GymnasiumEnvPool"), (envpool,), {} + ), ) diff --git a/envpool/python/data.py b/envpool/python/data.py index 73a32f1b..cae23440 100644 --- a/envpool/python/data.py +++ b/envpool/python/data.py @@ -18,8 +18,10 @@ import dm_env import gym +import gymnasium import numpy as np -import treevalue +import optree +from optree import PyTreeSpec from .protocol import ArraySpec @@ -106,8 +108,27 @@ def gym_spec_transform( ) -def dm_structure(root_name: str, - keys: List[str]) -> List[Tuple[List[str], int]]: +def gymnasium_spec_transform( + name: str, spec: ArraySpec, spec_type: str +) -> gymnasium.Space: + """Transform ArraySpec to gymnasium.Env compatible spaces.""" + if np.prod(np.abs(spec.shape)) == 1 and \ + np.isclose(spec.minimum, 0) and spec.maximum < ACTION_THRESHOLD: + # special treatment for discrete action space + discrete_range = int(spec.maximum - spec.minimum + 1) + return gymnasium.spaces.Discrete(n=discrete_range, start=int(spec.minimum)) + return gymnasium.spaces.Box( + shape=[s for s in spec.shape if s != -1], + dtype=spec.dtype, + low=spec.minimum, + high=spec.maximum, + ) + + +def dm_structure( + root_name: str, + keys: List[str], +) -> Tuple[List[Tuple[int, ...]], List[int], PyTreeSpec]: """Convert flat keys into tree structure for namedtuple construction.""" new_keys = [] for key in keys: @@ -117,13 +138,19 @@ def dm_structure(root_name: str, key = key.replace("obs:", f"{root_name}:") # compatible with to_namedtuple new_keys.append(key.replace(":", ".")) dict_tree = to_nested_dict(dict(zip(new_keys, list(range(len(new_keys)))))) - tree_pairs = treevalue.flatten(treevalue.TreeValue(dict_tree)) - return tree_pairs + structure = to_namedtuple(root_name, dict_tree) + paths, indices, treespec = optree.tree_flatten_with_path(structure) + return paths, indices, treespec -def gym_structure(keys: List[str]) -> List[Tuple[List[str], int]]: +def gym_structure( + keys: List[str] +) -> Tuple[List[Tuple[str, ...]], List[int], PyTreeSpec]: """Convert flat keys into tree structure for dict construction.""" keys = [k.replace(":", ".") for k in keys] - structure = to_nested_dict(dict(zip(keys, list(range(len(keys)))))) - tree_pairs = treevalue.flatten(treevalue.TreeValue(structure)) - return tree_pairs + dict_tree = to_nested_dict(dict(zip(keys, list(range(len(keys)))))) + paths, indices, treespec = optree.tree_flatten_with_path(dict_tree) + return paths, indices, treespec + + +gymnasium_structure = gym_structure diff --git a/envpool/python/dm_envpool.py b/envpool/python/dm_envpool.py index fbefd510..18fa3216 100644 --- a/envpool/python/dm_envpool.py +++ b/envpool/python/dm_envpool.py @@ -18,7 +18,7 @@ import dm_env import numpy as np -import treevalue +import optree from dm_env import TimeStep from .data import dm_structure @@ -50,6 +50,7 @@ def __new__(cls: Any, name: str, parents: Tuple, attrs: Dict) -> Any: base = parents[0] try: from .lax import XlaMixin + parents = ( base, DMEnvPoolMixin, EnvPoolMixin, XlaMixin, dm_env.Environment ) @@ -68,8 +69,7 @@ def _xla(self: Any) -> None: check_key_duplication(name, "state", state_keys) check_key_duplication(name, "action", action_keys) - tree_pairs = dm_structure("State", state_keys) - state_idx = list(zip(*tree_pairs))[-1] + state_paths, state_idx, treepsec = dm_structure("State", state_keys) def _to_dm( self: Any, @@ -77,10 +77,8 @@ def _to_dm( reset: bool, return_info: bool, ) -> TimeStep: - values = map(lambda i: state_values[i], state_idx) - state = treevalue.unflatten( - [(path, vi) for (path, _), vi in zip(tree_pairs, values)] - ) + values = (state_values[i] for i in state_idx) + state = optree.tree_unflatten(treepsec, values) timestep = TimeStep( step_type=state.step_type, observation=state.State, diff --git a/envpool/python/env_spec.py b/envpool/python/env_spec.py index 18ee269b..e7f47719 100644 --- a/envpool/python/env_spec.py +++ b/envpool/python/env_spec.py @@ -20,10 +20,12 @@ import dm_env import gym +import gymnasium from .data import ( dm_spec_transform, gym_spec_transform, + gymnasium_spec_transform, to_namedtuple, to_nested_dict, ) @@ -165,6 +167,61 @@ def action_space(self: EnvSpec) -> Union[gym.Space, Dict[str, Any]]: } return to_nested_dict(spec, gym.spaces.Dict) + @property + def gymnasium_observation_space( + self: EnvSpec + ) -> Union[gymnasium.Space, Dict[str, Any]]: + """Convert internal state_spec to gymnasium.Env compatible format. + + Returns: + observation_space: A dict (maybe nested) that contains all keys + that start with ``obs`` with their corresponding specs. + + Note: + If only one key starts with ``obs``, it returns that space instead of + all for simplicity. + """ + spec = self.state_array_spec + spec = { + k.replace("obs:", ""): + gymnasium_spec_transform(k.replace(":", ".").split(".")[-1], v, "obs") + for k, v in spec.items() + if k.startswith("obs") + } + if len(spec) == 1: + return list(spec.values())[0] + return to_nested_dict(spec, gymnasium.spaces.Dict) + + @property + def gymnasium_action_space( + self: EnvSpec + ) -> Union[gymnasium.Space, Dict[str, Any]]: + """Convert internal action_spec to gymnasium.Env compatible format. + + Returns: + action_space: A dict (maybe nested) that contains key-value paired + corresponding specs. + + Note: + If the original action_spec has a length of 3 ("env_id", + "players.env_id", *), it returns the last space instead of all for + simplicity. + """ + spec = self.action_array_spec + if len(spec) == 3: + # only env_id, players.env_id, action + spec.pop("env_id") + spec.pop("players.env_id") + return gymnasium_spec_transform( + list(spec.keys())[0], + list(spec.values())[0], "act" + ) + spec = { + k: gymnasium_spec_transform(k.split(".")[-1], v, "act") + for k, v in spec.items() + } + return to_nested_dict(spec, gymnasium.spaces.Dict) + def __repr__(self: EnvSpec) -> str: """Prettify debug info.""" config_info = pprint.pformat(self.config)[6:] diff --git a/envpool/python/envpool.py b/envpool/python/envpool.py index 2a8d9097..7b4ada4e 100644 --- a/envpool/python/envpool.py +++ b/envpool/python/envpool.py @@ -19,7 +19,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np -import treevalue +import optree from dm_env import TimeStep from .protocol import EnvPool, EnvSpec @@ -59,9 +59,8 @@ def _from( ) -> List[np.ndarray]: """Convert action to C++-acceptable format.""" if isinstance(action, dict): - atree = treevalue.TreeValue(action) - alist = treevalue.flatten(atree) - adict = {".".join(k): v for k, v in alist} + paths, values, _ = optree.tree_flatten_with_path(action) + adict = {'.'.join(p): v for p, v in zip(paths, values)} else: # only 3 keys in action_keys if not hasattr(self, "_last_action_type"): self._last_action_type = self._spec._action_spec[-1][0] @@ -108,7 +107,8 @@ def seed( """Set the seed for all environments (abandoned).""" warnings.warn( "The `seed` function in envpool is abandoned. " - "You can set seed by envpool.make(..., seed=seed) instead." + "You can set seed by envpool.make(..., seed=seed) instead.", + stacklevel=2 ) def send( diff --git a/envpool/python/gym_envpool.py b/envpool/python/gym_envpool.py index 8fa93c80..ad915c78 100644 --- a/envpool/python/gym_envpool.py +++ b/envpool/python/gym_envpool.py @@ -18,7 +18,7 @@ import gym import numpy as np -import treevalue +import optree from packaging import version from .data import gym_structure @@ -52,6 +52,7 @@ def __new__(cls: Any, name: str, parents: Tuple, attrs: Dict) -> Any: base = parents[0] try: from .lax import XlaMixin + parents = (base, GymEnvPoolMixin, EnvPoolMixin, XlaMixin, gym.Env) except ImportError: @@ -68,31 +69,28 @@ def _xla(self: Any) -> None: check_key_duplication(name, "state", state_keys) check_key_duplication(name, "action", action_keys) - tree_pairs = gym_structure(state_keys) - state_idx = list(zip(*tree_pairs))[-1] + state_paths, state_idx, treepsec = gym_structure(state_keys) new_gym_api = version.parse(gym.__version__) >= version.parse("0.26.0") def _to_gym( self: Any, state_values: List[np.ndarray], reset: bool, return_info: bool ) -> Union[Any, Tuple[Any, Any], Tuple[Any, np.ndarray, np.ndarray, Any], - Tuple[Any, np.ndarray, np.ndarray, np.ndarray, Any]]: - values = map(lambda i: state_values[i], state_idx) - state = treevalue.unflatten( - [(path, vi) for (path, _), vi in zip(tree_pairs, values)] - ) + Tuple[Any, np.ndarray, np.ndarray, np.ndarray, Any],]: + values = (state_values[i] for i in state_idx) + state = optree.tree_unflatten(treepsec, values) if reset and not (return_info or new_gym_api): - return state.obs - info = treevalue.jsonify(state.info) + return state["obs"] + info = state["info"] if not new_gym_api: - info["TimeLimit.truncated"] = state.trunc - info["elapsed_step"] = state.elapsed_step + info["TimeLimit.truncated"] = state["trunc"] + info["elapsed_step"] = state["elapsed_step"] if reset: - return state.obs, info + return state["obs"], info if new_gym_api: - terminated = state.done & ~state.trunc - return state.obs, state.reward, terminated, state.trunc, info - return state.obs, state.reward, state.done, info + terminated = state["done"] & ~state["trunc"] + return state["obs"], state["reward"], terminated, state["trunc"], info + return state["obs"], state["reward"], state["done"], state["trunc"], info attrs["_to"] = _to_gym subcls = super().__new__(cls, name, parents, attrs) diff --git a/envpool/python/gymnasium_envpool.py b/envpool/python/gymnasium_envpool.py new file mode 100644 index 00000000..f7d72f8b --- /dev/null +++ b/envpool/python/gymnasium_envpool.py @@ -0,0 +1,97 @@ +# Copyright 2021 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""EnvPool meta class for gymnasium.Env API.""" + +from abc import ABC, ABCMeta +from typing import Any, Dict, List, Tuple, Union + +import gymnasium +import numpy as np +import optree + +from .data import gymnasium_structure +from .envpool import EnvPoolMixin +from .utils import check_key_duplication + + +class GymnasiumEnvPoolMixin(ABC): + """Special treatment for gymnasim API.""" + + @property + def observation_space(self: Any) -> Union[gymnasium.Space, Dict[str, Any]]: + """Observation space from EnvSpec.""" + if not hasattr(self, "_gym_observation_space"): + self._gym_observation_space = self.spec.gymnasium_observation_space + return self._gym_observation_space + + @property + def action_space(self: Any) -> Union[gymnasium.Space, Dict[str, Any]]: + """Action space from EnvSpec.""" + if not hasattr(self, "_gym_action_space"): + self._gym_action_space = self.spec.gymnasium_action_space + return self._gym_action_space + + +class GymnasiumEnvPoolMeta(ABCMeta, gymnasium.Env.__class__): + """Additional wrapper for EnvPool gymnasium.Env API.""" + + def __new__(cls: Any, name: str, parents: Tuple, attrs: Dict) -> Any: + """Check internal config and initialize data format convertion.""" + base = parents[0] + try: + from .lax import XlaMixin + + parents = ( + base, GymnasiumEnvPoolMixin, EnvPoolMixin, XlaMixin, gymnasium.Env + ) + except ImportError: + + def _xla(self: Any) -> None: + raise RuntimeError( + "XLA is disabled. To enable XLA please install jax." + ) + + attrs["xla"] = _xla + parents = (base, GymnasiumEnvPoolMixin, EnvPoolMixin, gymnasium.Env) + + state_keys = base._state_keys + action_keys = base._action_keys + check_key_duplication(name, "state", state_keys) + check_key_duplication(name, "action", action_keys) + + state_paths, state_idx, treepsec = gymnasium_structure(state_keys) + + def _to_gymnasium( + self: Any, state_values: List[np.ndarray], reset: bool, return_info: bool + ) -> Union[Any, Tuple[Any, Any], Tuple[Any, np.ndarray, np.ndarray, Any], + Tuple[Any, np.ndarray, np.ndarray, np.ndarray, Any],]: + values = (state_values[i] for i in state_idx) + state = optree.tree_unflatten(treepsec, values) + info = state["info"] + info["elapsed_step"] = state["elapsed_step"] + if reset: + return state["obs"], info + terminated = state["done"] & ~state["trunc"] + return state["obs"], state["reward"], terminated, state["trunc"], info + + attrs["_to"] = _to_gymnasium + subcls = super().__new__(cls, name, parents, attrs) + + def init(self: Any, spec: Any) -> None: + """Set self.spec to EnvSpecMeta.""" + super(subcls, self).__init__(spec) + self.spec = spec + + setattr(subcls, "__init__", init) # noqa: B010 + return subcls diff --git a/envpool/registration.py b/envpool/registration.py index e1757e30..af61b270 100644 --- a/envpool/registration.py +++ b/envpool/registration.py @@ -14,11 +14,14 @@ """Global env registry.""" import importlib +import os from typing import Any, Dict, List, Tuple import gym from packaging import version +base_path = os.path.abspath(os.path.dirname(__file__)) + class EnvRegistry: """A collection of available envs.""" @@ -30,14 +33,17 @@ def __init__(self) -> None: def register( self, task_id: str, import_path: str, spec_cls: str, dm_cls: str, - gym_cls: str, **kwargs: Any + gym_cls: str, gymnasium_cls: str, **kwargs: Any ) -> None: """Register EnvSpec and EnvPool in global EnvRegistry.""" assert task_id not in self.specs + if "base_path" not in kwargs: + kwargs["base_path"] = base_path self.specs[task_id] = (import_path, spec_cls, kwargs) self.envpools[task_id] = { "dm": (import_path, dm_cls), - "gym": (import_path, gym_cls) + "gym": (import_path, gym_cls), + "gymnasium": (import_path, gymnasium_cls) } def make(self, task_id: str, env_type: str, **kwargs: Any) -> Any: @@ -54,7 +60,7 @@ def make(self, task_id: str, env_type: str, **kwargs: Any) -> Any: assert task_id in self.specs, \ f"{task_id} is not supported, `envpool.list_all_envs()` may help." - assert env_type in ["dm", "gym"] + assert env_type in ["dm", "gym", "gymnasium"] spec = self.make_spec(task_id, **kwargs) import_path, envpool_cls = self.envpools[task_id][env_type] @@ -68,6 +74,10 @@ def make_gym(self, task_id: str, **kwargs: Any) -> Any: """Make gym.Env compatible envpool.""" return self.make(task_id, "gym", **kwargs) + def make_gymnasium(self, task_id: str, **kwargs: Any) -> Any: + """Make gymnasium.Env compatible envpool.""" + return self.make(task_id, "gymnasium", **kwargs) + def make_spec(self, task_id: str, **make_kwargs: Any) -> Any: """Make EnvSpec.""" import_path, spec_cls, kwargs = self.specs[task_id] @@ -100,5 +110,6 @@ def list_all_envs(self) -> List[str]: make = registry.make make_dm = registry.make_dm make_gym = registry.make_gym +make_gymnasium = registry.make_gymnasium make_spec = registry.make_spec list_all_envs = registry.list_all_envs diff --git a/envpool/toy_text/__init__.py b/envpool/toy_text/__init__.py index 8564981e..47b357f7 100644 --- a/envpool/toy_text/__init__.py +++ b/envpool/toy_text/__init__.py @@ -30,45 +30,53 @@ _TaxiEnvSpec, ) -CatchEnvSpec, CatchDMEnvPool, CatchGymEnvPool = py_env( - _CatchEnvSpec, _CatchEnvPool -) +(CatchEnvSpec, CatchDMEnvPool, CatchGymEnvPool, + CatchGymnasiumEnvPool) = py_env(_CatchEnvSpec, _CatchEnvPool) -FrozenLakeEnvSpec, FrozenLakeDMEnvPool, FrozenLakeGymEnvPool = py_env( - _FrozenLakeEnvSpec, _FrozenLakeEnvPool -) +( + FrozenLakeEnvSpec, FrozenLakeDMEnvPool, FrozenLakeGymEnvPool, + FrozenLakeGymnasiumEnvPool +) = py_env(_FrozenLakeEnvSpec, _FrozenLakeEnvPool) -TaxiEnvSpec, TaxiDMEnvPool, TaxiGymEnvPool = py_env(_TaxiEnvSpec, _TaxiEnvPool) +(TaxiEnvSpec, TaxiDMEnvPool, TaxiGymEnvPool, + TaxiGymnasiumEnvPool) = py_env(_TaxiEnvSpec, _TaxiEnvPool) -NChainEnvSpec, NChainDMEnvPool, NChainGymEnvPool = py_env( - _NChainEnvSpec, _NChainEnvPool -) +(NChainEnvSpec, NChainDMEnvPool, NChainGymEnvPool, + NChainGymnasiumEnvPool) = py_env(_NChainEnvSpec, _NChainEnvPool) -CliffWalkingEnvSpec, CliffWalkingDMEnvPool, CliffWalkingGymEnvPool = py_env( - _CliffWalkingEnvSpec, _CliffWalkingEnvPool -) +( + CliffWalkingEnvSpec, CliffWalkingDMEnvPool, CliffWalkingGymEnvPool, + CliffWalkingGymnasiumEnvPool +) = py_env(_CliffWalkingEnvSpec, _CliffWalkingEnvPool) -BlackjackEnvSpec, BlackjackDMEnvPool, BlackjackGymEnvPool = py_env( - _BlackjackEnvSpec, _BlackjackEnvPool -) +( + BlackjackEnvSpec, BlackjackDMEnvPool, BlackjackGymEnvPool, + BlackjackGymnasiumEnvPool +) = py_env(_BlackjackEnvSpec, _BlackjackEnvPool) __all__ = [ "CatchEnvSpec", "CatchDMEnvPool", "CatchGymEnvPool", + "CatchGymnasiumEnvPool", "FrozenLakeEnvSpec", "FrozenLakeDMEnvPool", "FrozenLakeGymEnvPool", + "FrozenLakeGymnasiumEnvPool", "TaxiEnvSpec", "TaxiDMEnvPool", "TaxiGymEnvPool", + "TaxiGymnasiumEnvPool", "NChainEnvSpec", "NChainDMEnvPool", "NChainGymEnvPool", + "NChainGymnasiumEnvPool", "CliffWalkingEnvSpec", "CliffWalkingDMEnvPool", "CliffWalkingGymEnvPool", + "CliffWalkingGymnasiumEnvPool", "BlackjackEnvSpec", "BlackjackDMEnvPool", "BlackjackGymEnvPool", + "BlackjackGymnasiumEnvPool", ] diff --git a/envpool/toy_text/blackjack.h b/envpool/toy_text/blackjack.h index fd7b1f65..db7461df 100644 --- a/envpool/toy_text/blackjack.h +++ b/envpool/toy_text/blackjack.h @@ -51,15 +51,14 @@ class BlackjackEnv : public Env { bool natural_, sab_; std::vector player_, dealer_; std::uniform_int_distribution<> dist_; - bool done_; + bool done_{true}; public: BlackjackEnv(const Spec& spec, int env_id) : Env(spec, env_id), natural_(spec.config["natural"_]), sab_(spec.config["sab"_]), - dist_(1, 13), - done_(true) {} + dist_(1, 13) {} bool IsDone() override { return done_; } diff --git a/envpool/toy_text/catch.h b/envpool/toy_text/catch.h index 83e8b860..97a84305 100644 --- a/envpool/toy_text/catch.h +++ b/envpool/toy_text/catch.h @@ -48,15 +48,14 @@ class CatchEnv : public Env { protected: int x_, y_, height_, width_, paddle_; std::uniform_int_distribution<> dist_; - bool done_; + bool done_{true}; public: CatchEnv(const Spec& spec, int env_id) : Env(spec, env_id), height_(spec.config["height"_]), width_(spec.config["width"_]), - dist_(0, width_ - 1), - done_(true) {} + dist_(0, width_ - 1) {} bool IsDone() override { return done_; } diff --git a/envpool/toy_text/cliffwalking.h b/envpool/toy_text/cliffwalking.h index 5cd18dec..be1dfc6e 100644 --- a/envpool/toy_text/cliffwalking.h +++ b/envpool/toy_text/cliffwalking.h @@ -47,11 +47,11 @@ using CliffWalkingEnvSpec = EnvSpec; class CliffWalkingEnv : public Env { protected: int x_, y_; - bool done_; + bool done_{true}; public: CliffWalkingEnv(const Spec& spec, int env_id) - : Env(spec, env_id), done_(true) {} + : Env(spec, env_id) {} bool IsDone() override { return done_; } diff --git a/envpool/toy_text/frozen_lake.h b/envpool/toy_text/frozen_lake.h index 1197e642..efa47dc8 100644 --- a/envpool/toy_text/frozen_lake.h +++ b/envpool/toy_text/frozen_lake.h @@ -51,7 +51,7 @@ class FrozenLakeEnv : public Env { protected: int x_, y_, size_, max_episode_steps_, elapsed_step_; std::uniform_int_distribution<> dist_; - bool done_; + bool done_{true}; std::vector map_; public: @@ -59,8 +59,7 @@ class FrozenLakeEnv : public Env { : Env(spec, env_id), size_(spec.config["size"_]), max_episode_steps_(spec.config["max_episode_steps"_]), - dist_(-1, 1), - done_(true) { + dist_(-1, 1) { if (size_ != 8) { map_ = std::vector({"SFFF", "FHFH", "FFFH", "HFFG"}); } else { diff --git a/envpool/toy_text/nchain.h b/envpool/toy_text/nchain.h index 0ac4732e..5ba2a991 100644 --- a/envpool/toy_text/nchain.h +++ b/envpool/toy_text/nchain.h @@ -48,14 +48,13 @@ class NChainEnv : public Env { protected: int s_, max_episode_steps_, elapsed_step_; std::uniform_real_distribution<> dist_; - bool done_; + bool done_{true}; public: NChainEnv(const Spec& spec, int env_id) : Env(spec, env_id), max_episode_steps_(spec.config["max_episode_steps"_]), - dist_(0, 1), - done_(true) {} + dist_(0, 1) {} bool IsDone() override { return done_; } diff --git a/envpool/toy_text/registration.py b/envpool/toy_text/registration.py index 206842d9..04188b65 100644 --- a/envpool/toy_text/registration.py +++ b/envpool/toy_text/registration.py @@ -21,6 +21,7 @@ spec_cls="CatchEnvSpec", dm_cls="CatchDMEnvPool", gym_cls="CatchGymEnvPool", + gymnasium_cls="CatchGymnasiumEnvPool", height=10, width=5, ) @@ -31,6 +32,7 @@ spec_cls="FrozenLakeEnvSpec", dm_cls="FrozenLakeDMEnvPool", gym_cls="FrozenLakeGymEnvPool", + gymnasium_cls="FrozenLakeGymnasiumEnvPool", size=4, max_episode_steps=100, reward_threshold=0.7, @@ -42,6 +44,7 @@ spec_cls="FrozenLakeEnvSpec", dm_cls="FrozenLakeDMEnvPool", gym_cls="FrozenLakeGymEnvPool", + gymnasium_cls="FrozenLakeGymnasiumEnvPool", size=8, max_episode_steps=200, reward_threshold=0.85, @@ -53,6 +56,7 @@ spec_cls="TaxiEnvSpec", dm_cls="TaxiDMEnvPool", gym_cls="TaxiGymEnvPool", + gymnasium_cls="TaxiGymnasiumEnvPool", max_episode_steps=200, reward_threshold=8.0, ) @@ -63,6 +67,7 @@ spec_cls="NChainEnvSpec", dm_cls="NChainDMEnvPool", gym_cls="NChainGymEnvPool", + gymnasium_cls="NChainGymnasiumEnvPool", max_episode_steps=1000, ) @@ -72,6 +77,7 @@ spec_cls="CliffWalkingEnvSpec", dm_cls="CliffWalkingDMEnvPool", gym_cls="CliffWalkingGymEnvPool", + gymnasium_cls="CliffWalkingGymnasiumEnvPool", ) register( @@ -80,6 +86,7 @@ spec_cls="BlackjackEnvSpec", dm_cls="BlackjackDMEnvPool", gym_cls="BlackjackGymEnvPool", + gymnasium_cls="BlackjackGymnasiumEnvPool", sab=True, natural=False, ) diff --git a/envpool/toy_text/taxi.h b/envpool/toy_text/taxi.h index d09fcf7e..3beca4d9 100644 --- a/envpool/toy_text/taxi.h +++ b/envpool/toy_text/taxi.h @@ -50,7 +50,7 @@ class TaxiEnv : public Env { protected: int x_, y_, s_, t_, max_episode_steps_, elapsed_step_; std::uniform_int_distribution<> dist_car_, dist_loc_; - bool done_; + bool done_{true}; std::vector> loc_; std::vector map_, loc_map_; @@ -60,7 +60,6 @@ class TaxiEnv : public Env { max_episode_steps_(spec.config["max_episode_steps"_]), dist_car_(0, 3), dist_loc_(0, 4), - done_(true), loc_({{0, 0}, {0, 4}, {4, 0}, {4, 3}}), map_({"|:|::|", "|:|::|", "|::::|", "||:|:|", "||:|:|"}), loc_map_({"0 1", " ", " ", " ", "2 3 "}) {} diff --git a/envpool/vizdoom/__init__.py b/envpool/vizdoom/__init__.py index 57468db8..fb913ce3 100644 --- a/envpool/vizdoom/__init__.py +++ b/envpool/vizdoom/__init__.py @@ -17,12 +17,12 @@ from .vizdoom_envpool import _VizdoomEnvPool, _VizdoomEnvSpec -VizdoomEnvSpec, VizdoomDMEnvPool, VizdoomGymEnvPool = py_env( - _VizdoomEnvSpec, _VizdoomEnvPool -) +(VizdoomEnvSpec, VizdoomDMEnvPool, VizdoomGymEnvPool, + VizdoomGymnasiumEnvPool) = py_env(_VizdoomEnvSpec, _VizdoomEnvPool) __all__ = [ "VizdoomEnvSpec", "VizdoomDMEnvPool", "VizdoomGymEnvPool", + "VizdoomGymnasiumEnvPool", ] diff --git a/envpool/vizdoom/registration.py b/envpool/vizdoom/registration.py index e07c1a18..2af2688c 100644 --- a/envpool/vizdoom/registration.py +++ b/envpool/vizdoom/registration.py @@ -16,9 +16,8 @@ import os from typing import List -from envpool.registration import register +from envpool.registration import base_path, register -base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) maps_path = os.path.join(base_path, "vizdoom", "maps") @@ -44,7 +43,7 @@ def _vizdoom_game_list() -> List[str]: spec_cls="VizdoomEnvSpec", dm_cls="VizdoomDMEnvPool", gym_cls="VizdoomGymEnvPool", - base_path=base_path, + gymnasium_cls="VizdoomGymnasiumEnvPool", cfg_path=cfg_path, wad_path=wad_path, max_episode_steps=525, diff --git a/envpool/vizdoom/vizdoom_env.h b/envpool/vizdoom/vizdoom_env.h index a579e782..8f4cb191 100644 --- a/envpool/vizdoom/vizdoom_env.h +++ b/envpool/vizdoom/vizdoom_env.h @@ -143,11 +143,11 @@ class VizdoomEnv : public Env { std::deque stack_buf_; std::string lmp_dir_; bool save_lmp_, episodic_life_, use_combined_action_, use_inter_area_resize_; - bool done_; + bool done_{true}; int max_episode_steps_, elapsed_step_, stack_num_, frame_skip_, - episode_count_, channel_; + episode_count_{0}, channel_; int deathcount_idx_, hitcount_idx_, damagecount_idx_; // bugged var - double last_deathcount_, last_hitcount_, last_damagecount_; + double last_deathcount_{0}, last_hitcount_{0}, last_damagecount_{0}; int selected_weapon_, selected_weapon_count_, weapon_duration_; std::vector action_set_; std::vector