diff --git a/.github/workflows/lint_cc.yml b/.github/workflows/lint_cc.yml index f0c43209d..d0951eeaa 100644 --- a/.github/workflows/lint_cc.yml +++ b/.github/workflows/lint_cc.yml @@ -24,5 +24,5 @@ jobs: - uses: actions/checkout@v2 - uses: DoozyX/clang-format-lint-action@v0.12 with: - source: 'win/rl src/nle.c sys/unix/nledl.c include/nle.h include/nledl.h include/nleobs.h' + source: 'win/rl src/nle.c include/nle.h include/nleinstance.h include/dloverride.h include/nleobs.h sys/unix/rlmain.cc' clangFormatVersion: 12 diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 394fcede3..3a5a2afa4 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -48,8 +48,19 @@ jobs: submodules: true - name: Install from repo in test mode run: "pip install -e '.[dev]'" + - name: Run dyld_info + run: "dyldinfo `python -c 'from nle.nethack import nethack; print(nethack.DLPATH)'` || echo 0" + - name: Run otool + run: "otool -l `python -c 'from nle.nethack import nethack; print(nethack.DLPATH)'` || echo 0" + - name: Setup upterm session + uses: lhotari/action-upterm@v1 + with: + ## limits ssh access and adds the ssh public key for the user which triggered the workflow + limit-access-to-actor: true + ## limits ssh access and adds the ssh public keys of the listed GitHub users + limit-access-to-users: heiner,tscmoo - name: Run tests - run: "python -m pytest -svx nle/tests --basetemp=nle_test_data" + run: "DYLD_PRINT_SEGMENTS=1 python -m pytest -svx nle/tests --basetemp=nle_test_data" - name: Compress test output dir if: ${{ always() }} run: | diff --git a/CMakeLists.txt b/CMakeLists.txt index b5a8a5c99..6e61d7e3e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -119,16 +119,12 @@ target_link_directories(nethack PUBLIC /usr/local/lib) target_link_libraries(nethack PUBLIC m fcontext bz2) -# dlopen wrapper library -add_library(nethackdl STATIC "sys/unix/nledl.c") -target_include_directories(nethackdl PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) -target_link_libraries(nethackdl PUBLIC dl) - # rlmain C++ (test) binary add_executable(rlmain "sys/unix/rlmain.cc") -set_target_properties(rlmain PROPERTIES CXX_STANDARD 11) -target_link_libraries(rlmain PUBLIC nethackdl) -target_include_directories(rlmain PUBLIC ${NLE_INC_GEN}) +set_target_properties(rlmain PROPERTIES CXX_STANDARD 14) +target_link_libraries(rlmain PUBLIC dl) +target_include_directories(rlmain PUBLIC ${NLE_INC_GEN} + ${CMAKE_CURRENT_SOURCE_DIR}/include) add_dependencies(rlmain util) # For pm.h. # pybind11 python library. @@ -141,7 +137,8 @@ pybind11_add_module( src/drawing.c src/objects.c $) -target_link_libraries(_pynethack PUBLIC nethackdl) +target_link_libraries(_pynethack PUBLIC dl) set_target_properties(_pynethack PROPERTIES CXX_STANDARD 14) -target_include_directories(_pynethack PUBLIC ${NLE_INC_GEN}) +target_include_directories( + _pynethack PUBLIC ${NLE_INC_GEN} ${CMAKE_CURRENT_SOURCE_DIR}/include) add_dependencies(_pynethack util) # For pm.h. diff --git a/include/dloverride.h b/include/dloverride.h new file mode 100644 index 000000000..74104a4fd --- /dev/null +++ b/include/dloverride.h @@ -0,0 +1,376 @@ +/* Copyright (c) Facebook, Inc. and its affiliates. */ + +/* + * Mechanism to reset a loaded dynamic library. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +#ifdef __linux__ +#include + +#define PAGE_SIZE 4096 +#define PAGE_MASK (PAGE_SIZE - 1) +#define PAGE_START(x) ((x) & ~PAGE_MASK) +#define PAGE_OFFSET(x) ((x) &PAGE_MASK) +#define PAGE_END(x) PAGE_START((x) + (PAGE_SIZE - 1)) + +#if __LP64__ +#define Elf_Ehdr Elf64_Ehdr +#define Elf_Phdr Elf64_Phdr + +#else /* __LP64__ */ +#define Elf_Ehdr Elf32_Ehdr +#define Elf_Phdr Elf32_Phdr +#endif /* __LP64__ */ + +#elif __APPLE__ +#include + +#if __LP64__ +#define LC_SEGMENT_COMMAND LC_SEGMENT_64 +#define MH_MAGIC_NUMBER MH_MAGIC_64 + +struct macho_header : public mach_header_64 { +}; +struct macho_segment_command : public segment_command_64 { +}; +struct macho_section : public section_64 { +}; +#else /* __LP64__ */ +#define LC_SEGMENT_COMMAND LC_SEGMENT +#define MH_MAGIC_NUMBER MH_MAGIC + +struct macho_header : public mach_header { +}; +struct macho_segment_command : public segment_command { +}; +struct macho_section : public section { +}; +#endif /* __LP64__ */ + +#endif /* __linux__, __APPLE__ */ + +#ifdef __linux__ +struct Region { + uint8_t *data; + size_t size; + bool rw; + + uint8_t * + l() const + { + return data; + } + uint8_t * + r() const + { + return data + size; + } + + bool + intersects(const Region &s) const + { + return !(r() <= s.l() || s.r() <= l()); + } + + bool + operator<(const Region &s) const + { + return data < s.data; + } +}; + +std::vector +make_disjoint(const std::vector ®ions) +{ + std::vector starts; + std::vector ends; + + std::vector result; + + for (auto it = regions.rbegin(); it != regions.rend(); ++it) { + starts.push_back(it->l()); + ends.push_back(it->r()); + } + + std::sort(starts.begin(), starts.end(), std::greater<>()); + std::sort(ends.begin(), ends.end(), std::greater<>()); + + int overlap = 1; + uint8_t *start = starts.back(); + starts.pop_back(); + uint8_t *end; + bool active; + + while (!ends.empty()) { + active = overlap > 0; + + if (!starts.empty() && starts.back() <= ends.back()) { + ++overlap; + end = starts.back(); + starts.pop_back(); + } else { + --overlap; + end = ends.back(); + ends.pop_back(); + } + + if (active && start < end) { + result.push_back(Region{ start, (size_t) (end - start), false }); + } + start = end; + } + + if (!starts.empty()) { + throw std::runtime_error("Intervals required"); + } + + return result; +} +#endif + +class DL +{ + public: + DL(const char *filename, const char *symbol) + : handle_(dlopen(filename, RTLD_LAZY)) + { + if (!handle_) { + throw std::runtime_error(std::string("dlopen failed on ") + + filename + ": " + dlerror()); + } + + void *ptr = dlsym(handle_, symbol); + if (!ptr) { + throw std::runtime_error(dlerror()); + } + Dl_info dlinfo; + if (dladdr(ptr, &dlinfo) == 0) { + throw std::runtime_error("dladdr failed"); + } + if (!dlinfo.dli_sname) { + throw std::runtime_error("No matching addr found."); + } +#ifdef __linux__ + hdr_ = (Elf_Ehdr *) dlinfo.dli_fbase; + if (memcmp(hdr_->e_ident, ELFMAG, SELFMAG) != 0) { + throw std::runtime_error("Illegal elf header"); + } + + size_t offset = ~0; + + size_t phoff = hdr_->e_phoff; + for (size_t i = 0; i != hdr_->e_phnum; ++i) { + Elf_Phdr *ph = (Elf_Phdr *) ((uint8_t *) hdr_ + phoff); + + if (ph->p_type == PT_LOAD) { + offset = std::min(offset, (size_t) ph->p_vaddr); + } + phs_.push_back(ph); + phoff += hdr_->e_phentsize; + } + + baseaddr_ = (uint8_t *) hdr_ - offset; + + std::vector overlapping_regions; + + for (const Elf_Phdr *ph : phs_) { + overlapping_regions.push_back( + Region{ baseaddr_ + ph->p_vaddr, ph->p_memsz, false }); + } + + regions_ = make_disjoint(overlapping_regions); + + for (const Elf_Phdr *ph : phs_) { + Region region{ baseaddr_ + ph->p_vaddr, ph->p_memsz, + ph->p_flags & PF_W ? true : false }; + + for (Region &s : regions_) { + if (region.intersects(s)) { + s.rw = region.rw; + } else if (region.l() < s.r()) { + break; + } + } + } + +#elif __APPLE__ + hdr_ = (macho_header *) dlinfo.dli_fbase; + if (hdr_->magic != MH_MAGIC_NUMBER) { + throw std::runtime_error( + "Illegal magic integer " + std::to_string(hdr_->magic) + + ", expected " + std::to_string(MH_MAGIC_NUMBER)); + } + if (hdr_->filetype != MH_DYLIB) { + throw std::runtime_error( + std::string("Expected MH_DYLIB file type but got " + + std::to_string(hdr_->filetype))); + } + + const load_command *cmds = (load_command *) (hdr_ + 1); + const load_command *cmd = cmds; + + for (uint32_t i = 0; i < hdr_->ncmds; ++i) { + if (cmd->cmd != LC_SEGMENT_COMMAND) + continue; + + const auto *seg = (macho_segment_command *) cmd; + + if (seg->nsects) + segs_.push_back(seg); + cmd = (const load_command *) (((uint8_t *) cmd) + cmd->cmdsize); + } + + baseaddr_ = (uint8_t *) hdr_ - segs_[0]->vmaddr; + + for (const auto *seg : segs_) { + fprintf( + stderr, + "Found a segment '%s' with %u sections at offset %#010lx. " + "cmdsize: %u. vmaddr: %#010llx. vmsize: 0x%llx. " + "Should be at %p->%p\n", + seg->segname, seg->nsects, (uint8_t *) cmd - (uint8_t *) hdr_, + cmd->cmdsize, seg->vmaddr, seg->vmsize, mem_addr(seg), + mem_addr(seg) + mem_size(seg)); + } +#endif /* __linux__, __APPLE__ */ + } + + ~DL() + { + if (handle_) + dlclose(handle_); + } + + DL(DL &&dl) noexcept + { + *this = std::move(dl); + } + DL(const DL &) = delete; + DL &operator=(const DL &) = delete; + DL & + operator=(DL &&dl) noexcept + { + if (this == &dl) + return *this; + if (handle_) + dlclose(handle_); + handle_ = std::exchange(dl.handle_, nullptr); +#ifdef __linux__ + phs_ = std::move(dl.phs_); + regions_ = std::move(dl.regions_); +#elif __APPLE__ + segs_ = std::move(dl.segs_); +#endif + hdr_ = dl.hdr_; + baseaddr_ = dl.baseaddr_; + return *this; + } + +#ifdef __linux__ + bool + is_rw(const Elf_Phdr *ph) const + { + return ph->p_type == PT_LOAD && ph->p_flags & PF_R + && ph->p_flags & PF_W; + } + bool + is_rw(const Region ®ion) const + { + return region.rw; + } +#elif __APPLE__ + bool + is_rw(const macho_segment_command *seg) const + { + return strcmp(seg->segname, SEG_DATA) == 0; + } +#endif + + template + void + for_rw_regions(F &&f) + { +#ifdef __linux__ + for (const auto ®ion : regions_) { +#elif __APPLE__ + for (const auto ®ion : segs_) { +#endif + if (is_rw(region)) + f(region); + } + } + + template + auto + func(const char *symbol) -> decltype((T(*)(Ts...)) nullptr) + { + void *ptr = dlsym(handle_, symbol); + if (!ptr) { + throw std::runtime_error(dlerror()); + } + return (T(*)(Ts...)) ptr; + } + +#ifdef __linux__ + uint8_t * + mem_addr(const Elf_Phdr *ph) const + { + size_t start = (size_t) (baseaddr_ + ph->p_vaddr); + return (uint8_t *) start; + } + size_t + mem_size(const Elf_Phdr *ph) const + { + return ph->p_memsz; + } + uint8_t * + mem_addr(const Region ®ion) const + { + return region.data; + } + size_t + mem_size(const Region ®ion) const + { + return region.size; + } +#elif __APPLE__ + uint8_t * + mem_addr(const macho_segment_command *seg) const + { + return baseaddr_ + seg->vmaddr; + } + size_t + mem_size(const macho_segment_command *seg) const + { + return seg->vmsize; + } +#endif + + private: + void *handle_{ nullptr }; +#ifdef __linux__ + const Elf_Ehdr *hdr_; + std::vector phs_; + std::vector regions_; +#elif __APPLE__ + std::vector segs_; + const macho_header *hdr_; +#endif + uint8_t *baseaddr_; +}; diff --git a/include/nledl.h b/include/nledl.h deleted file mode 100644 index b55bcac37..000000000 --- a/include/nledl.h +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Wrapper for dlopen. - */ - -#ifndef NLEDL_H -#define NLEDL_H - -#include - -#include "nleobs.h" - -typedef struct nledl_ctx { - char dlpath[1024]; - void *dlhandle; - void *nle_ctx; - void *(*step)(void *, nle_obs *); - FILE *ttyrec; -} nledl_ctx; - -nledl_ctx *nle_start(const char *, nle_obs *, FILE *, nle_seeds_init_t *, - nle_settings *); -nledl_ctx *nle_step(nledl_ctx *, nle_obs *); - -void nle_reset(nledl_ctx *, nle_obs *, FILE *, nle_seeds_init_t *, - nle_settings *); -void nle_end(nledl_ctx *); - -void nle_set_seed(nledl_ctx *, unsigned long, unsigned long, char); -void nle_get_seed(nledl_ctx *, unsigned long *, unsigned long *, char *); - -#endif /* NLEDL_H */ diff --git a/include/nleinstance.h b/include/nleinstance.h new file mode 100644 index 000000000..571667a7d --- /dev/null +++ b/include/nleinstance.h @@ -0,0 +1,210 @@ +/* Copyright (c) Facebook, Inc. and its affiliates. */ + +#pragma once + +/* #define NLE_RESET_DLOPENCLOSE */ /* Enables the old behaviour: reset by + dlclose & dl-re-open. */ + +#ifndef NLE_RESET_DLOPENCLOSE +#include "dloverride.h" +#else +#include +#endif + +extern "C" { +#include "nleobs.h" +} + +#ifndef NLE_RESET_DLOPENCLOSE +class Instance +{ + public: + Instance(const std::string &dlpath, nle_obs *obs, FILE *ttyrec, + nle_seeds_init_t *seeds_init, nle_settings *settings) + : dl_(dlpath.c_str(), "nle_start") + { + dl_.for_rw_regions([&](const auto ®) { + fprintf(stderr, "memcpy out of [%p, %p)\n", dl_.mem_addr(reg), + dl_.mem_addr(reg) + dl_.mem_size(reg)); + + regions_.emplace_back(dl_.mem_size(reg)); + memcpy(®ions_.back()[0], dl_.mem_addr(reg), dl_.mem_size(reg)); + fprintf(stderr, "1 memcpy done\n"); + }); + + start_ = dl_.func("nle_start"); + step_ = dl_.func("nle_step"); + end_ = dl_.func("nle_end"); + get_seed_ = + dl_.func( + "nle_get_seed"); + set_seed_ = + dl_.func( + "nle_set_seed"); + + nle_ctx_ = start_(obs, ttyrec, seeds_init, settings); + } + + ~Instance() + { + if (nle_ctx_) + close(); + } + + void + step(nle_obs *obs) + { + nle_ctx_ = step_(nle_ctx_, obs); + } + + void + reset(nle_obs *obs, FILE *ttyrec, nle_seeds_init_t *seeds_init, + nle_settings *settings) + { + end_(nle_ctx_); + + auto it = regions_.begin(); + dl_.for_rw_regions([&](const auto ®) { + fprintf(stderr, "memcpy into [%p, %p)\n", dl_.mem_addr(reg), + dl_.mem_addr(reg) + dl_.mem_size(reg)); + + memcpy(dl_.mem_addr(reg), it->data(), dl_.mem_size(reg)); + ++it; + }); + nle_ctx_ = start_(obs, ttyrec, seeds_init, settings); + } + + void + close() + { + end_(nle_ctx_); + nle_ctx_ = nullptr; + } + + void + get_seed(unsigned long *core, unsigned long *disp, char *reseed) + { + get_seed_(nle_ctx_, core, disp, reseed); + } + + void + set_seed(unsigned long core, unsigned long disp, char reseed) + { + set_seed_(nle_ctx_, core, disp, reseed); + } + + private: + DL dl_; + void *nle_ctx_{ nullptr }; + + void *(*start_)(nle_obs *, FILE *, nle_seeds_init_t *, nle_settings *); + void *(*step_)(void *, nle_obs *); + void (*end_)(void *); + void (*get_seed_)(void *, unsigned long *, unsigned long *, char *); + void (*set_seed_)(void *, unsigned long, unsigned long, char); + + std::vector > regions_; +}; +#else /* NLE_RESET_DLOPENCLOSE */ +class Instance +{ + public: + Instance(const std::string &dlpath, nle_obs *obs, FILE *ttyrec, + nle_seeds_init_t *seeds_init, nle_settings *settings) + : dlpath_(dlpath) + { + init(); + nle_ctx_ = start_(obs, ttyrec, seeds_init, settings); + } + + ~Instance() + { + close(); + } + + void + step(nle_obs *obs) + { + nle_ctx_ = step_(nle_ctx_, obs); + } + + void + reset(nle_obs *obs, FILE *ttyrec, nle_seeds_init_t *seeds_init, + nle_settings *settings) + { + close(); + init(); + nle_ctx_ = start_(obs, ttyrec, seeds_init, settings); + } + + void + close() + { + if (nle_ctx_) + end_(nle_ctx_); + nle_ctx_ = nullptr; + + if (handle_) + if (dlclose(handle_)) + throw std::runtime_error(dlerror()); + handle_ = nullptr; + } + + void + get_seed(unsigned long *core, unsigned long *disp, char *reseed) + { + get_seed_(nle_ctx_, core, disp, reseed); + } + + void + set_seed(unsigned long core, unsigned long disp, char reseed) + { + set_seed_(nle_ctx_, core, disp, reseed); + } + + private: + void + init() + { + void *handle = dlopen(dlpath_.c_str(), RTLD_LAZY | RTLD_NOLOAD); + if (handle) { + dlclose(handle); + throw std::runtime_error(dlpath_ + " is already loaded"); + } + handle_ = dlopen(dlpath_.c_str(), RTLD_LAZY); + if (!handle_) { + throw std::runtime_error(dlerror()); + } + + start_ = (decltype(start_)) get_sym("nle_start"); + step_ = (decltype(step_)) get_sym("nle_step"); + end_ = (decltype(end_)) get_sym("nle_end"); + get_seed_ = (decltype(get_seed_)) get_sym("nle_get_seed"); + set_seed_ = (decltype(set_seed_)) get_sym("nle_set_seed"); + } + + void * + get_sym(const char *sym) + { + dlerror(); /* Clear.*/ + void *result = dlsym(handle_, sym); + const char *error = dlerror(); + if (error) { + throw std::runtime_error(error); + } + return result; + } + + const std::string dlpath_; + + void *handle_{ nullptr }; + void *nle_ctx_{ nullptr }; + + void *(*start_)(nle_obs *, FILE *, nle_seeds_init_t *, nle_settings *); + void *(*step_)(void *, nle_obs *); + void (*end_)(void *); + void (*get_seed_)(void *, unsigned long *, unsigned long *, char *); + void (*set_seed_)(void *, unsigned long, unsigned long, char); +}; +#endif diff --git a/setup.py b/setup.py index e36ca12d1..e6d66b0df 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def build_extension(self, ext): hackdir_path = os.getenv("HACKDIR", output_path.joinpath("nethackdir")) os.makedirs(self.build_temp, exist_ok=True) - build_type = "Debug" if self.debug else "Release" + build_type = "Debug" # if self.debug else "Release" generator = "Ninja" if spawn.find_executable("ninja") else "Unix Makefiles" diff --git a/src/nle.c b/src/nle.c index 7d1f86af0..b372ca93c 100644 --- a/src/nle.c +++ b/src/nle.c @@ -419,12 +419,12 @@ nle_start(nle_obs *obs, FILE *ttyrec, nle_seeds_init_t *seed_init, nle_ctx_t *nle = init_nle(ttyrec, obs); nle_seeds_init = seed_init; + current_nle_ctx = nle; nle->stack = create_fcontext_stack(STACK_SIZE); nle->generatorcontext = make_fcontext(nle->stack.sptr, nle->stack.ssize, mainloop); - current_nle_ctx = nle; fcontext_transfer_t t = jump_fcontext(nle->generatorcontext, NULL); nle->generatorcontext = t.ctx; nle->done = (t.data == NULL); diff --git a/sys/unix/nledl.c b/sys/unix/nledl.c deleted file mode 100644 index 53de50263..000000000 --- a/sys/unix/nledl.c +++ /dev/null @@ -1,152 +0,0 @@ - -#include -#include -#include -#include - -#include "nledl.h" - -void -nledl_init(nledl_ctx *nledl, nle_obs *obs, nle_seeds_init_t *seed_init, - nle_settings *settings) -{ - void *handle = dlopen(nledl->dlpath, RTLD_LAZY | RTLD_NOLOAD); - if (handle) { - dlclose(handle); - fprintf(stderr, - "failure in nledl_init: library %s is already loaded\n", - nledl->dlpath); - exit(EXIT_FAILURE); - } - - nledl->dlhandle = dlopen(nledl->dlpath, RTLD_LAZY); - - if (!nledl->dlhandle) { - fprintf(stderr, "%s\n", dlerror()); - exit(EXIT_FAILURE); - } - - dlerror(); /* Clear any existing error */ - - void *(*start)(nle_obs *, FILE *, nle_seeds_init_t *, nle_settings *); - start = dlsym(nledl->dlhandle, "nle_start"); - nledl->nle_ctx = start(obs, nledl->ttyrec, seed_init, settings); - - char *error = dlerror(); - if (error != NULL) { - fprintf(stderr, "%s\n", error); - exit(EXIT_FAILURE); - } - - nledl->step = dlsym(nledl->dlhandle, "nle_step"); - - error = dlerror(); - if (error != NULL) { - fprintf(stderr, "%s\n", error); - exit(EXIT_FAILURE); - } -} - -void -nledl_close(nledl_ctx *nledl) -{ - void (*end)(void *); - - end = dlsym(nledl->dlhandle, "nle_end"); - end(nledl->nle_ctx); - - if (dlclose(nledl->dlhandle)) { - fprintf(stderr, "Error in dlclose: %s\n", dlerror()); - exit(EXIT_FAILURE); - } - - dlerror(); -} - -nledl_ctx * -nle_start(const char *dlpath, nle_obs *obs, FILE *ttyrec, - nle_seeds_init_t *seed_init, nle_settings *settings) -{ - /* TODO: Consider getting ttyrec path from caller? */ - struct nledl_ctx *nledl = malloc(sizeof(struct nledl_ctx)); - nledl->ttyrec = ttyrec; - strncpy(nledl->dlpath, dlpath, sizeof(nledl->dlpath)); - - nledl_init(nledl, obs, seed_init, settings); - return nledl; -}; - -nledl_ctx * -nle_step(nledl_ctx *nledl, nle_obs *obs) -{ - if (!nledl || !nledl->dlhandle || !nledl->nle_ctx) { - fprintf(stderr, "Illegal nledl_ctx\n"); - exit(EXIT_FAILURE); - } - - nledl->step(nledl->nle_ctx, obs); - - return nledl; -} - -/* TODO: For a standard reset, we don't need the full close in nle.c. - * E.g., we could re-use the stack buffer and the nledl_ctx. */ -void -nle_reset(nledl_ctx *nledl, nle_obs *obs, FILE *ttyrec, - nle_seeds_init_t *seed_init, nle_settings *settings) -{ - nledl_close(nledl); - /* Reset file only if not-NULL. */ - if (ttyrec) - nledl->ttyrec = ttyrec; - - // TODO: Consider refactoring nledl.h such that we expose this init - // function but drop reset. - nledl_init(nledl, obs, seed_init, settings); -} - -void -nle_end(nledl_ctx *nledl) -{ - nledl_close(nledl); - free(nledl); -} - -#ifdef NLE_ALLOW_SEEDING -void -nle_set_seed(nledl_ctx *nledl, unsigned long core, unsigned long disp, - char reseed) -{ - void (*set_seed)(void *, unsigned long, unsigned long, char); - - set_seed = dlsym(nledl->dlhandle, "nle_set_seed"); - - char *error = dlerror(); - if (error != NULL) { - fprintf(stderr, "%s\n", error); - exit(EXIT_FAILURE); - } - - set_seed(nledl->nle_ctx, core, disp, reseed); -} - -void -nle_get_seed(nledl_ctx *nledl, unsigned long *core, unsigned long *disp, - char *reseed) -{ - void (*get_seed)(void *, unsigned long *, unsigned long *, char *); - - get_seed = dlsym(nledl->dlhandle, "nle_get_seed"); - - char *error = dlerror(); - if (error != NULL) { - fprintf(stderr, "%s\n", error); - exit(EXIT_FAILURE); - } - - /* Careful here. NetHack has different ideas of what a boolean is - * than C++ (see global.h and SKIP_BOOLEAN). But one byte should be fine. - */ - get_seed(nledl->nle_ctx, core, disp, reseed); -} -#endif diff --git a/sys/unix/rlmain.cc b/sys/unix/rlmain.cc index c4efcf65f..f94ddcbb2 100644 --- a/sys/unix/rlmain.cc +++ b/sys/unix/rlmain.cc @@ -6,9 +6,10 @@ #include #include +#include "nleinstance.h" + extern "C" { #include "hack.h" -#include "nledl.h" } class ScopedTC @@ -33,7 +34,7 @@ class ScopedTC }; void -play(nledl_ctx *nle, nle_obs *obs, nle_settings *settings) +play(Instance &nle, nle_obs *obs, nle_settings *settings) { while (!obs->done) { for (int r = 0; r < ROWNO; ++r) { @@ -47,13 +48,13 @@ play(nledl_ctx *nle, nle_obs *obs, nle_settings *settings) std::cout << std::endl; read(STDIN_FILENO, &obs->action, 1); if (obs->action == 'r') - nle_reset(nle, obs, nullptr, nullptr, settings); - nle = nle_step(nle, obs); + nle.reset(obs, nullptr, nullptr, settings); + nle.step(obs); } } void -randplay(nledl_ctx *nle, nle_obs *obs) +randplay(Instance &nle, nle_obs *obs) { int actions[] = { 13, 107, 108, 106, 104, 117, 110, 98, 121, @@ -63,7 +64,7 @@ randplay(nledl_ctx *nle, nle_obs *obs) for (int i = 0; !obs->done && i < 10000; ++i) { obs->action = actions[rand() % n]; - nle = nle_step(nle, obs); + nle.step(obs); } if (!obs->done) { std::cerr << "Episode didn't end after 10000 steps, aborting." @@ -72,13 +73,13 @@ randplay(nledl_ctx *nle, nle_obs *obs) } void -randgame(nledl_ctx *nle, nle_obs *obs, const int no_episodes, +randgame(Instance &nle, nle_obs *obs, const int no_episodes, nle_settings *settings) { for (int i = 0; i < no_episodes; ++i) { randplay(nle, obs); if (i < no_episodes - 1) - nle_reset(nle, obs, nullptr, nullptr, settings); + nle.reset(obs, nullptr, nullptr, settings); } } @@ -114,18 +115,34 @@ main(int argc, char **argv) std::unique_ptr ttyrec( fopen("nle.ttyrec.bz2", "a"), fclose); - nle_settings settings; - strncpy(settings.hackdir, getenv("HACKDIR"), sizeof(settings.hackdir)); + nle_settings settings{}; + char *hackdir = getenv("HACKDIR"); + + if (hackdir) + strncpy(settings.hackdir, hackdir, sizeof(settings.hackdir)); + else + throw std::runtime_error( + "Set HACKDIR env variable to NetHack installation"); ScopedTC tc; - nledl_ctx *nle = - nle_start("libnethack.so", &obs, ttyrec.get(), nullptr, &settings); + + char *cwd = getcwd(nullptr, 0); /* glibc's allocating version. */ + std::string path(cwd); + free(cwd); + + if (path.back() != '/') + path.append("/"); + + path.append("libnethack.so"); + + Instance nle(path, &obs, ttyrec.get(), nullptr, &settings); if (argc > 1 && argv[1][0] == 'r') { randgame(nle, &obs, 3, &settings); } else { - play(nle, &obs, &settings); - nle_reset(nle, &obs, nullptr, nullptr, &settings); - play(nle, &obs, &settings); + for (int i = 0; i < 10; ++i) { + play(nle, &obs, &settings); + nle.reset(&obs, nullptr, nullptr, &settings); + } } - nle_end(nle); + nle.close(); } diff --git a/win/rl/pynethack.cc b/win/rl/pynethack.cc index bdb2c1990..b52c9fd9f 100644 --- a/win/rl/pynethack.cc +++ b/win/rl/pynethack.cc @@ -6,6 +6,8 @@ #include #include +#include "nleinstance.h" + // "digit" is declared in both Python's longintrepr.h and NetHack's extern.h. #define digit nethack_digit @@ -17,7 +19,7 @@ extern "C" { } extern "C" { -#include "nledl.h" +#include "nleobs.h" } // Undef name clashes between NetHack and Python. @@ -77,7 +79,7 @@ checked_conversion(py::handle h, const std::vector &shape) py::buffer_info buf = array.request(); - if (buf.ndim != shape.size()) + if ((size_t) buf.ndim != shape.size()) throw std::runtime_error("array has wrong number of dims"); if (!std::equal(shape.begin(), shape.end(), buf.shape.begin())) throw std::runtime_error("Array has wrong shape"); @@ -136,7 +138,7 @@ class Nethack if (obs_.done) throw std::runtime_error("Called step on finished NetHack"); obs_.action = action; - nle_ = nle_step(nle_, &obs_); + nle_->step(&obs_); } bool @@ -234,8 +236,7 @@ class Nethack close() { if (nle_) { - nle_end(nle_); - nle_ = nullptr; + nle_.reset(nullptr); } } @@ -258,7 +259,7 @@ class Nethack #ifdef NLE_ALLOW_SEEDING if (!nle_) throw std::runtime_error("set_seed called without reset()"); - nle_set_seed(nle_, core, disp, reseed); + nle_->set_seed(core, disp, reseed); #else throw std::runtime_error("Seeding not enabled"); #endif @@ -273,8 +274,7 @@ class Nethack std::tuple result; char reseed; /* NetHack's booleans are not necessarily C++ bools ... */ - nle_get_seed(nle_, &std::get<0>(result), &std::get<1>(result), - &reseed); + nle_->get_seed(&std::get<0>(result), &std::get<1>(result), &reseed); std::get<2>(result) = reseed; return result; #else @@ -310,12 +310,12 @@ class Nethack py::gil_scoped_release gil; if (!nle_) { - nle_ = - nle_start(dlpath_.c_str(), &obs_, ttyrec ? ttyrec : ttyrec_, - use_seed_init ? &seed_init_ : nullptr, &settings_); + nle_ = std::make_unique( + dlpath_, &obs_, ttyrec ? ttyrec : ttyrec_, + use_seed_init ? &seed_init_ : nullptr, &settings_); } else - nle_reset(nle_, &obs_, ttyrec, - use_seed_init ? &seed_init_ : nullptr, &settings_); + nle_->reset(&obs_, ttyrec ? ttyrec : ttyrec_, + use_seed_init ? &seed_init_ : nullptr, &settings_); use_seed_init = false; if (obs_.done) @@ -327,7 +327,7 @@ class Nethack std::vector py_buffers_; nle_seeds_init_t seed_init_; bool use_seed_init = false; - nledl_ctx *nle_ = nullptr; + std::unique_ptr nle_; std::FILE *ttyrec_ = nullptr; nle_settings settings_; }; @@ -638,7 +638,7 @@ PYBIND11_MODULE(_pynethack, m) "Argument should be between 0 and MAXMCLASSES (" + std::to_string(MAXMCLASSES) + ") but got " + std::to_string(let)); - return &def_monsyms[let]; + return &def_monsyms[(size_t) let]; }, py::return_value_policy::reference) .def_static( @@ -649,7 +649,7 @@ PYBIND11_MODULE(_pynethack, m) "Argument should be between 0 and MAXOCLASSES (" + std::to_string(MAXOCLASSES) + ") but got " + std::to_string(olet)); - return &def_oc_syms[olet]; + return &def_oc_syms[(size_t) olet]; }, py::return_value_policy::reference) .def_readonly("sym", &class_sym::sym)