diff --git a/CMakeLists.txt b/CMakeLists.txt index b5a8a5c99..42b2f5d92 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -119,17 +119,11 @@ target_link_directories(nethack PUBLIC /usr/local/lib) target_link_libraries(nethack PUBLIC m fcontext bz2) -# dlopen wrapper library -add_library(nethackdl STATIC "sys/unix/nledl.c") -target_include_directories(nethackdl PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) -target_link_libraries(nethackdl PUBLIC dl) - -# rlmain C++ (test) binary -add_executable(rlmain "sys/unix/rlmain.cc") -set_target_properties(rlmain PROPERTIES CXX_STANDARD 11) -target_link_libraries(rlmain PUBLIC nethackdl) -target_include_directories(rlmain PUBLIC ${NLE_INC_GEN}) -add_dependencies(rlmain util) # For pm.h. +# rlmain C++ (test) binary add_executable(rlmain "sys/unix/rlmain.cc") +# set_target_properties(rlmain PROPERTIES CXX_STANDARD 11) +# target_link_libraries(rlmain PUBLIC nethackdl) +# target_include_directories(rlmain PUBLIC ${NLE_INC_GEN}) +# add_dependencies(rlmain util) # For pm.h. # pybind11 python library. add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third_party/pybind11) @@ -141,7 +135,8 @@ pybind11_add_module( src/drawing.c src/objects.c $) -target_link_libraries(_pynethack PUBLIC nethackdl) +target_link_libraries(_pynethack PUBLIC dl) set_target_properties(_pynethack PROPERTIES CXX_STANDARD 14) -target_include_directories(_pynethack PUBLIC ${NLE_INC_GEN}) +target_include_directories( + _pynethack PUBLIC ${NLE_INC_GEN} ${CMAKE_CURRENT_SOURCE_DIR}/include) add_dependencies(_pynethack util) # For pm.h. diff --git a/include/dloverride.h b/include/dloverride.h new file mode 100644 index 000000000..d9bc727a1 --- /dev/null +++ b/include/dloverride.h @@ -0,0 +1,234 @@ +/* Copyright (c) Facebook, Inc. and its affiliates. */ + +/* Mechanism to reset a loaded dynamic library. */ + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include + +#ifdef __linux__ +#include + +#define PAGE_SIZE 4096 +#define PAGE_MASK (PAGE_SIZE - 1) +#define PAGE_START(x) ((x) & ~PAGE_MASK) +#define PAGE_OFFSET(x) ((x) &PAGE_MASK) +#define PAGE_END(x) PAGE_START((x) + (PAGE_SIZE - 1)) + +#if __LP64__ +#define Elf_Ehdr Elf64_Ehdr +#define Elf_Phdr Elf64_Phdr + +#else /* __LP64__ */ +#define Elf_Ehdr Elf32_Ehdr +#define Elf_Phdr Elf32_Phdr +#endif /* __LP64__ */ + +#elif __APPLE__ +#include + +#if __LP64__ +#define LC_SEGMENT_COMMAND LC_SEGMENT_64 +#define MH_MAGIC_NUMBER MH_MAGIC_64 + +struct macho_header : public mach_header_64 { +}; +struct macho_segment_command : public segment_command_64 { +}; +struct macho_section : public section_64 { +}; +#else /* __LP64__ */ +#define LC_SEGMENT_COMMAND LC_SEGMENT +#define MH_MAGIC_NUMBER MH_MAGIC + +struct macho_header : public mach_header { +}; +struct macho_segment_command : public segment_command { +}; +struct macho_section : public section { +}; +#endif /* __LP64__ */ + +#endif /* __linux__, __APPLE__ */ + +class DL +{ + public: + DL(const char *filename, const char *symbol) + : handle_(dlopen(filename, RTLD_LAZY)) + { + if (!handle_) { + throw std::runtime_error(std::string("dlopen failed on ") + + filename + ": " + dlerror()); + } + + void *ptr = dlsym(handle_, symbol); + if (!ptr) { + throw std::runtime_error(dlerror()); + } + Dl_info dlinfo; + if (dladdr(ptr, &dlinfo) == 0) { + throw std::runtime_error("dladdr failed"); + } + if (!dlinfo.dli_sname) { + throw std::runtime_error("No matching addr found."); + } +#ifdef __linux__ + hdr_ = (Elf_Ehdr *) dlinfo.dli_fbase; + if (memcmp(hdr_->e_ident, ELFMAG, SELFMAG) != 0) { + throw std::runtime_error("Illegal elf header"); + } + + size_t offset = ~0; + + size_t phoff = hdr_->e_phoff; + for (size_t i = 0; i != hdr_->e_phnum; ++i) { + Elf_Phdr *ph = (Elf_Phdr *) ((uint8_t *) hdr_ + phoff); + + if (ph->p_type == PT_LOAD) { + offset = std::min(offset, (size_t) ph->p_vaddr); + } + segs_.push_back(ph); + + phoff += hdr_->e_phentsize; + } + + baseaddr_ = (uint8_t *) hdr_ - PAGE_START(offset); + +#elif __APPLE__ + hdr_ = (macho_header *) dlinfo.dli_fbase; + if (hdr_->magic != MH_MAGIC_NUMBER) { + throw std::runtime_error( + "Illegal magic integer " + std::to_string(hdr_->magic) + + ", expected " + std::to_string(MH_MAGIC_NUMBER)); + } + if (hdr_->filetype != MH_DYLIB) { + throw std::runtime_error( + std::string("Expected MH_DYLIB file type but got " + + std::to_string(hdr_->filetype))); + } + + const load_command *cmds = (load_command *) (hdr_ + 1); + const load_command *cmd = cmds; + + for (uint32_t i = 0; i < hdr_->ncmds; ++i) { + if (cmd->cmd != LC_SEGMENT_COMMAND) + continue; + + const auto *seg = (macho_segment_command *) cmd; + if (seg->nsects) + segs_.push_back(seg); + cmd = (const load_command *) (((uint8_t *) cmd) + cmd->cmdsize); + } + + baseaddr_ = (uint8_t *) hdr_ - segs_[0]->vmaddr; +#endif /* __linux__, __APPLE__ */ + } + + ~DL() + { + if (handle_) + dlclose(handle_); + } + + DL(DL &&dl) noexcept + { + *this = std::move(dl); + } + DL(const DL &) = delete; + DL &operator=(const DL &) = delete; + DL & + operator=(DL &&dl) noexcept + { + if (this == &dl) + return *this; + if (handle_) + dlclose(handle_); + handle_ = std::exchange(dl.handle_, nullptr); + segs_ = std::move(dl.segs_); + hdr_ = dl.hdr_; + baseaddr_ = dl.baseaddr_; + return *this; + } + +#ifdef __linux__ + bool + is_overridable(const Elf_Phdr *ph) const + { + return ph->p_type == PT_LOAD && ph->p_flags & PF_R + && ph->p_flags & PF_W; + } +#elif __APPLE__ + bool + is_overridable(const macho_segment_command *seg) const + { + return strcmp(seg->segname, SEG_DATA) == 0; + } +#endif + + template + void + for_changing_sections(F &&f) + { + for (const auto *seg : segs_) { + if (is_overridable(seg)) + f(seg); + } + } + + template + auto + func(const char *symbol) -> decltype((T(*)(Ts...)) nullptr) + { + void *ptr = dlsym(handle_, symbol); + if (!ptr) { + throw std::runtime_error(dlerror()); + } + return (T(*)(Ts...)) ptr; + } + +#ifdef __linux__ + uint8_t * + mem_addr(const Elf_Phdr *ph) const + { + size_t start = PAGE_END((size_t) (baseaddr_ + ph->p_vaddr)); + return (uint8_t *) start; + } + size_t + mem_size(const Elf_Phdr *ph) const + { + return ph->p_memsz; + } +#elif __APPLE__ + uint8_t * + mem_addr(const macho_segment_command *seg) const + { + return baseaddr_ + seg->vmaddr; + } + size_t + mem_size(const macho_segment_command *seg) const + { + return seg->vmsize; + } +#endif + + private: + void *handle_{ nullptr }; +#ifdef __linux__ + std::vector segs_; + const Elf_Ehdr *hdr_; +#elif __APPLE__ + std::vector segs_; + const macho_header *hdr_; +#endif + uint8_t *baseaddr_; +}; diff --git a/include/nledl.h b/include/nledl.h deleted file mode 100644 index b55bcac37..000000000 --- a/include/nledl.h +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Wrapper for dlopen. - */ - -#ifndef NLEDL_H -#define NLEDL_H - -#include - -#include "nleobs.h" - -typedef struct nledl_ctx { - char dlpath[1024]; - void *dlhandle; - void *nle_ctx; - void *(*step)(void *, nle_obs *); - FILE *ttyrec; -} nledl_ctx; - -nledl_ctx *nle_start(const char *, nle_obs *, FILE *, nle_seeds_init_t *, - nle_settings *); -nledl_ctx *nle_step(nledl_ctx *, nle_obs *); - -void nle_reset(nledl_ctx *, nle_obs *, FILE *, nle_seeds_init_t *, - nle_settings *); -void nle_end(nledl_ctx *); - -void nle_set_seed(nledl_ctx *, unsigned long, unsigned long, char); -void nle_get_seed(nledl_ctx *, unsigned long *, unsigned long *, char *); - -#endif /* NLEDL_H */ diff --git a/setup.py b/setup.py index e36ca12d1..e6d66b0df 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def build_extension(self, ext): hackdir_path = os.getenv("HACKDIR", output_path.joinpath("nethackdir")) os.makedirs(self.build_temp, exist_ok=True) - build_type = "Debug" if self.debug else "Release" + build_type = "Debug" # if self.debug else "Release" generator = "Ninja" if spawn.find_executable("ninja") else "Unix Makefiles" diff --git a/src/nle.c b/src/nle.c index 7d1f86af0..b372ca93c 100644 --- a/src/nle.c +++ b/src/nle.c @@ -419,12 +419,12 @@ nle_start(nle_obs *obs, FILE *ttyrec, nle_seeds_init_t *seed_init, nle_ctx_t *nle = init_nle(ttyrec, obs); nle_seeds_init = seed_init; + current_nle_ctx = nle; nle->stack = create_fcontext_stack(STACK_SIZE); nle->generatorcontext = make_fcontext(nle->stack.sptr, nle->stack.ssize, mainloop); - current_nle_ctx = nle; fcontext_transfer_t t = jump_fcontext(nle->generatorcontext, NULL); nle->generatorcontext = t.ctx; nle->done = (t.data == NULL); diff --git a/sys/unix/nledl.c b/sys/unix/nledl.c deleted file mode 100644 index 53de50263..000000000 --- a/sys/unix/nledl.c +++ /dev/null @@ -1,152 +0,0 @@ - -#include -#include -#include -#include - -#include "nledl.h" - -void -nledl_init(nledl_ctx *nledl, nle_obs *obs, nle_seeds_init_t *seed_init, - nle_settings *settings) -{ - void *handle = dlopen(nledl->dlpath, RTLD_LAZY | RTLD_NOLOAD); - if (handle) { - dlclose(handle); - fprintf(stderr, - "failure in nledl_init: library %s is already loaded\n", - nledl->dlpath); - exit(EXIT_FAILURE); - } - - nledl->dlhandle = dlopen(nledl->dlpath, RTLD_LAZY); - - if (!nledl->dlhandle) { - fprintf(stderr, "%s\n", dlerror()); - exit(EXIT_FAILURE); - } - - dlerror(); /* Clear any existing error */ - - void *(*start)(nle_obs *, FILE *, nle_seeds_init_t *, nle_settings *); - start = dlsym(nledl->dlhandle, "nle_start"); - nledl->nle_ctx = start(obs, nledl->ttyrec, seed_init, settings); - - char *error = dlerror(); - if (error != NULL) { - fprintf(stderr, "%s\n", error); - exit(EXIT_FAILURE); - } - - nledl->step = dlsym(nledl->dlhandle, "nle_step"); - - error = dlerror(); - if (error != NULL) { - fprintf(stderr, "%s\n", error); - exit(EXIT_FAILURE); - } -} - -void -nledl_close(nledl_ctx *nledl) -{ - void (*end)(void *); - - end = dlsym(nledl->dlhandle, "nle_end"); - end(nledl->nle_ctx); - - if (dlclose(nledl->dlhandle)) { - fprintf(stderr, "Error in dlclose: %s\n", dlerror()); - exit(EXIT_FAILURE); - } - - dlerror(); -} - -nledl_ctx * -nle_start(const char *dlpath, nle_obs *obs, FILE *ttyrec, - nle_seeds_init_t *seed_init, nle_settings *settings) -{ - /* TODO: Consider getting ttyrec path from caller? */ - struct nledl_ctx *nledl = malloc(sizeof(struct nledl_ctx)); - nledl->ttyrec = ttyrec; - strncpy(nledl->dlpath, dlpath, sizeof(nledl->dlpath)); - - nledl_init(nledl, obs, seed_init, settings); - return nledl; -}; - -nledl_ctx * -nle_step(nledl_ctx *nledl, nle_obs *obs) -{ - if (!nledl || !nledl->dlhandle || !nledl->nle_ctx) { - fprintf(stderr, "Illegal nledl_ctx\n"); - exit(EXIT_FAILURE); - } - - nledl->step(nledl->nle_ctx, obs); - - return nledl; -} - -/* TODO: For a standard reset, we don't need the full close in nle.c. - * E.g., we could re-use the stack buffer and the nledl_ctx. */ -void -nle_reset(nledl_ctx *nledl, nle_obs *obs, FILE *ttyrec, - nle_seeds_init_t *seed_init, nle_settings *settings) -{ - nledl_close(nledl); - /* Reset file only if not-NULL. */ - if (ttyrec) - nledl->ttyrec = ttyrec; - - // TODO: Consider refactoring nledl.h such that we expose this init - // function but drop reset. - nledl_init(nledl, obs, seed_init, settings); -} - -void -nle_end(nledl_ctx *nledl) -{ - nledl_close(nledl); - free(nledl); -} - -#ifdef NLE_ALLOW_SEEDING -void -nle_set_seed(nledl_ctx *nledl, unsigned long core, unsigned long disp, - char reseed) -{ - void (*set_seed)(void *, unsigned long, unsigned long, char); - - set_seed = dlsym(nledl->dlhandle, "nle_set_seed"); - - char *error = dlerror(); - if (error != NULL) { - fprintf(stderr, "%s\n", error); - exit(EXIT_FAILURE); - } - - set_seed(nledl->nle_ctx, core, disp, reseed); -} - -void -nle_get_seed(nledl_ctx *nledl, unsigned long *core, unsigned long *disp, - char *reseed) -{ - void (*get_seed)(void *, unsigned long *, unsigned long *, char *); - - get_seed = dlsym(nledl->dlhandle, "nle_get_seed"); - - char *error = dlerror(); - if (error != NULL) { - fprintf(stderr, "%s\n", error); - exit(EXIT_FAILURE); - } - - /* Careful here. NetHack has different ideas of what a boolean is - * than C++ (see global.h and SKIP_BOOLEAN). But one byte should be fine. - */ - get_seed(nledl->nle_ctx, core, disp, reseed); -} -#endif diff --git a/win/rl/pynethack.cc b/win/rl/pynethack.cc index bdb2c1990..8673aeb46 100644 --- a/win/rl/pynethack.cc +++ b/win/rl/pynethack.cc @@ -6,6 +6,8 @@ #include #include +#include "dloverride.h" + // "digit" is declared in both Python's longintrepr.h and NetHack's extern.h. #define digit nethack_digit @@ -17,7 +19,7 @@ extern "C" { } extern "C" { -#include "nledl.h" +#include "nleobs.h" } // Undef name clashes between NetHack and Python. @@ -55,6 +57,197 @@ on_level(d_level *lev1, d_level *lev2) /* End of copy from dungeon.c */ #endif +/*#define NLE_RESET_DLOPENCLOSE*/ /* Enables the old behaviour: reset by + dlclose & dl-re-open. */ + +#ifndef NLE_RESET_DLOPENCLOSE +class Instance +{ + public: + Instance(const std::string &dlpath, nle_obs *obs, FILE *ttyrec, + nle_seeds_init_t *seeds_init, nle_settings *settings) + : dl_(dlpath.c_str(), "nle_start") + { + dl_.for_changing_sections([&](const auto *seg) { + segments_.emplace_back(dl_.mem_size(seg)); + memcpy(&segments_.back()[0], dl_.mem_addr(seg), + dl_.mem_size(seg)); + }); + + start_ = dl_.func("nle_start"); + step_ = dl_.func("nle_step"); + end_ = dl_.func("nle_end"); + get_seed_ = + dl_.func( + "nle_get_seed"); + set_seed_ = + dl_.func( + "nle_set_seed"); + + nle_ctx_ = start_(obs, ttyrec, seeds_init, settings); + } + + ~Instance() + { + if (nle_ctx_) + close(); + } + + void + step(nle_obs *obs) + { + nle_ctx_ = step_(nle_ctx_, obs); + } + + void + reset(nle_obs *obs, FILE *ttyrec, nle_seeds_init_t *seeds_init, + nle_settings *settings) + { + end_(nle_ctx_); + + auto it = segments_.begin(); + dl_.for_changing_sections([&](const auto *seg) { + memcpy(dl_.mem_addr(seg), it->data(), dl_.mem_size(seg)); + ++it; + }); + nle_ctx_ = start_(obs, ttyrec, seeds_init, settings); + } + + void + close() + { + end_(nle_ctx_); + nle_ctx_ = nullptr; + } + + void + get_seed(unsigned long *core, unsigned long *disp, char *reseed) + { + get_seed_(nle_ctx_, core, disp, reseed); + } + + void + set_seed(unsigned long core, unsigned long disp, char reseed) + { + set_seed_(nle_ctx_, core, disp, reseed); + } + + private: + DL dl_; + void *nle_ctx_{ nullptr }; + + void *(*start_)(nle_obs *, FILE *, nle_seeds_init_t *, nle_settings *); + void *(*step_)(void *, nle_obs *); + void (*end_)(void *); + void (*get_seed_)(void *, unsigned long *, unsigned long *, char *); + void (*set_seed_)(void *, unsigned long, unsigned long, char); + + std::vector > segments_; +}; +#else /* NLE_RESET_DLOPENCLOSE*/ +class Instance +{ + public: + Instance(const std::string &dlpath, nle_obs *obs, FILE *ttyrec, + nle_seeds_init_t *seeds_init, nle_settings *settings) + : dlpath_(dlpath) + { + init(); + nle_ctx_ = start_(obs, ttyrec, seeds_init, settings); + } + + ~Instance() + { + close(); + } + + void + step(nle_obs *obs) + { + nle_ctx_ = step_(nle_ctx_, obs); + } + + void + reset(nle_obs *obs, FILE *ttyrec, nle_seeds_init_t *seeds_init, + nle_settings *settings) + { + close(); + init(); + nle_ctx_ = start_(obs, ttyrec, seeds_init, settings); + } + + void + close() + { + if (nle_ctx_) + end_(nle_ctx_); + nle_ctx_ = nullptr; + + if (handle_) + if (dlclose(handle_)) + throw std::runtime_error(dlerror()); + handle_ = nullptr; + } + + void + get_seed(unsigned long *core, unsigned long *disp, char *reseed) + { + get_seed_(nle_ctx_, core, disp, reseed); + } + + void + set_seed(unsigned long core, unsigned long disp, char reseed) + { + set_seed_(nle_ctx_, core, disp, reseed); + } + + private: + void + init() + { + void *handle = dlopen(dlpath_.c_str(), RTLD_LAZY | RTLD_NOLOAD); + if (handle) { + dlclose(handle); + throw std::runtime_error(dlpath_ + " is already loaded"); + } + handle_ = dlopen(dlpath_.c_str(), RTLD_LAZY); + if (!handle_) { + throw std::runtime_error(dlerror()); + } + + start_ = (decltype(start_)) get_sym("nle_start"); + step_ = (decltype(step_)) get_sym("nle_step"); + end_ = (decltype(end_)) get_sym("nle_end"); + get_seed_ = (decltype(get_seed_)) get_sym("nle_get_seed"); + set_seed_ = (decltype(set_seed_)) get_sym("nle_set_seed"); + } + + void * + get_sym(const char *sym) + { + dlerror(); /* Clear.*/ + void *result = dlsym(handle_, sym); + const char *error = dlerror(); + if (error) { + throw std::runtime_error(error); + } + return result; + } + + const std::string dlpath_; + + void *handle_{ nullptr }; + void *nle_ctx_{ nullptr }; + + void *(*start_)(nle_obs *, FILE *, nle_seeds_init_t *, nle_settings *); + void *(*step_)(void *, nle_obs *); + void (*end_)(void *); + void (*get_seed_)(void *, unsigned long *, unsigned long *, char *); + void (*set_seed_)(void *, unsigned long, unsigned long, char); +}; +#endif + namespace py = pybind11; using namespace py::literals; @@ -136,7 +329,8 @@ class Nethack if (obs_.done) throw std::runtime_error("Called step on finished NetHack"); obs_.action = action; - nle_ = nle_step(nle_, &obs_); + nle_->step(&obs_); + // nle_ = nle_step(nle_, &obs_); } bool @@ -234,8 +428,7 @@ class Nethack close() { if (nle_) { - nle_end(nle_); - nle_ = nullptr; + nle_.reset(nullptr); } } @@ -258,7 +451,7 @@ class Nethack #ifdef NLE_ALLOW_SEEDING if (!nle_) throw std::runtime_error("set_seed called without reset()"); - nle_set_seed(nle_, core, disp, reseed); + nle_->set_seed(core, disp, reseed); #else throw std::runtime_error("Seeding not enabled"); #endif @@ -273,8 +466,7 @@ class Nethack std::tuple result; char reseed; /* NetHack's booleans are not necessarily C++ bools ... */ - nle_get_seed(nle_, &std::get<0>(result), &std::get<1>(result), - &reseed); + nle_->get_seed(&std::get<0>(result), &std::get<1>(result), &reseed); std::get<2>(result) = reseed; return result; #else @@ -310,12 +502,12 @@ class Nethack py::gil_scoped_release gil; if (!nle_) { - nle_ = - nle_start(dlpath_.c_str(), &obs_, ttyrec ? ttyrec : ttyrec_, - use_seed_init ? &seed_init_ : nullptr, &settings_); + nle_ = std::make_unique( + dlpath_, &obs_, ttyrec ? ttyrec : ttyrec_, + use_seed_init ? &seed_init_ : nullptr, &settings_); } else - nle_reset(nle_, &obs_, ttyrec, - use_seed_init ? &seed_init_ : nullptr, &settings_); + nle_->reset(&obs_, ttyrec ? ttyrec : ttyrec_, + use_seed_init ? &seed_init_ : nullptr, &settings_); use_seed_init = false; if (obs_.done) @@ -327,7 +519,7 @@ class Nethack std::vector py_buffers_; nle_seeds_init_t seed_init_; bool use_seed_init = false; - nledl_ctx *nle_ = nullptr; + std::unique_ptr nle_; std::FILE *ttyrec_ = nullptr; nle_settings settings_; };