Skip to content
This repository has been archived by the owner on May 6, 2024. It is now read-only.

Commit

Permalink
Fix rlmain.cc.
Browse files Browse the repository at this point in the history
  • Loading branch information
Heinrich Kuttler committed Feb 10, 2022
1 parent bf81c8f commit 205740f
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 212 deletions.
12 changes: 7 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,13 @@ target_link_directories(nethack PUBLIC /usr/local/lib)

target_link_libraries(nethack PUBLIC m fcontext bz2)

# rlmain C++ (test) binary add_executable(rlmain "sys/unix/rlmain.cc")
# set_target_properties(rlmain PROPERTIES CXX_STANDARD 11)
# target_link_libraries(rlmain PUBLIC nethackdl)
# target_include_directories(rlmain PUBLIC ${NLE_INC_GEN})
# add_dependencies(rlmain util) # For pm.h.
# rlmain C++ (test) binary
add_executable(rlmain "sys/unix/rlmain.cc")
set_target_properties(rlmain PROPERTIES CXX_STANDARD 14)
target_link_libraries(rlmain PUBLIC dl)
target_include_directories(rlmain PUBLIC ${NLE_INC_GEN}
${CMAKE_CURRENT_SOURCE_DIR}/include)
add_dependencies(rlmain util) # For pm.h.

# pybind11 python library.
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third_party/pybind11)
Expand Down
8 changes: 6 additions & 2 deletions include/dloverride.h
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
/* Copyright (c) Facebook, Inc. and its affiliates. */

/* Mechanism to reset a loaded dynamic library. */
/*
* Mechanism to reset a loaded dynamic library.
*/

#pragma once

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include <assert.h>
#include <stdio.h>

#include <dlfcn.h>
#include <errno.h>
#include <string.h>
#include <sys/mman.h>
#include <sys/stat.h>
Expand Down
206 changes: 206 additions & 0 deletions include/nleinstance.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
/* Copyright (c) Facebook, Inc. and its affiliates. */

#pragma once

/*#define NLE_RESET_DLOPENCLOSE*/ /* Enables the old behaviour: reset by
dlclose & dl-re-open. */

#include <utility>

#ifndef NLE_RESET_DLOPENCLOSE
#include "dloverride.h"
#else
#include <dlfcn.h>
#endif

extern "C" {
#include "nleobs.h"
}

#ifndef NLE_RESET_DLOPENCLOSE
class Instance
{
public:
Instance(const std::string &dlpath, nle_obs *obs, FILE *ttyrec,
nle_seeds_init_t *seeds_init, nle_settings *settings)
: dl_(dlpath.c_str(), "nle_start")
{
dl_.for_changing_sections([&](const auto *seg) {
segments_.emplace_back(dl_.mem_size(seg));
memcpy(&segments_.back()[0], dl_.mem_addr(seg),
dl_.mem_size(seg));
});

start_ = dl_.func<void *, nle_obs *, FILE *, nle_seeds_init_t *,
nle_settings *>("nle_start");
step_ = dl_.func<void *, void *, nle_obs *>("nle_step");
end_ = dl_.func<void, void *>("nle_end");
get_seed_ =
dl_.func<void, void *, unsigned long *, unsigned long *, char *>(
"nle_get_seed");
set_seed_ =
dl_.func<void, void *, unsigned long, unsigned long, char>(
"nle_set_seed");

nle_ctx_ = start_(obs, ttyrec, seeds_init, settings);
}

~Instance()
{
if (nle_ctx_)
close();
}

void
step(nle_obs *obs)
{
nle_ctx_ = step_(nle_ctx_, obs);
}

void
reset(nle_obs *obs, FILE *ttyrec, nle_seeds_init_t *seeds_init,
nle_settings *settings)
{
end_(nle_ctx_);

auto it = segments_.begin();
dl_.for_changing_sections([&](const auto *seg) {
memcpy(dl_.mem_addr(seg), it->data(), dl_.mem_size(seg));
++it;
});
nle_ctx_ = start_(obs, ttyrec, seeds_init, settings);
}

void
close()
{
end_(nle_ctx_);
nle_ctx_ = nullptr;
}

void
get_seed(unsigned long *core, unsigned long *disp, char *reseed)
{
get_seed_(nle_ctx_, core, disp, reseed);
}

void
set_seed(unsigned long core, unsigned long disp, char reseed)
{
set_seed_(nle_ctx_, core, disp, reseed);
}

private:
DL dl_;
void *nle_ctx_{ nullptr };

void *(*start_)(nle_obs *, FILE *, nle_seeds_init_t *, nle_settings *);
void *(*step_)(void *, nle_obs *);
void (*end_)(void *);
void (*get_seed_)(void *, unsigned long *, unsigned long *, char *);
void (*set_seed_)(void *, unsigned long, unsigned long, char);

std::vector<std::vector<uint8_t> > segments_;
};
#else /* NLE_RESET_DLOPENCLOSE*/
class Instance
{
public:
Instance(const std::string &dlpath, nle_obs *obs, FILE *ttyrec,
nle_seeds_init_t *seeds_init, nle_settings *settings)
: dlpath_(dlpath)
{
init();
nle_ctx_ = start_(obs, ttyrec, seeds_init, settings);
}

~Instance()
{
close();
}

void
step(nle_obs *obs)
{
nle_ctx_ = step_(nle_ctx_, obs);
}

void
reset(nle_obs *obs, FILE *ttyrec, nle_seeds_init_t *seeds_init,
nle_settings *settings)
{
close();
init();
nle_ctx_ = start_(obs, ttyrec, seeds_init, settings);
}

void
close()
{
if (nle_ctx_)
end_(nle_ctx_);
nle_ctx_ = nullptr;

if (handle_)
if (dlclose(handle_))
throw std::runtime_error(dlerror());
handle_ = nullptr;
}

void
get_seed(unsigned long *core, unsigned long *disp, char *reseed)
{
get_seed_(nle_ctx_, core, disp, reseed);
}

void
set_seed(unsigned long core, unsigned long disp, char reseed)
{
set_seed_(nle_ctx_, core, disp, reseed);
}

private:
void
init()
{
void *handle = dlopen(dlpath_.c_str(), RTLD_LAZY | RTLD_NOLOAD);
if (handle) {
dlclose(handle);
throw std::runtime_error(dlpath_ + " is already loaded");
}
handle_ = dlopen(dlpath_.c_str(), RTLD_LAZY);
if (!handle_) {
throw std::runtime_error(dlerror());
}

start_ = (decltype(start_)) get_sym("nle_start");
step_ = (decltype(step_)) get_sym("nle_step");
end_ = (decltype(end_)) get_sym("nle_end");
get_seed_ = (decltype(get_seed_)) get_sym("nle_get_seed");
set_seed_ = (decltype(set_seed_)) get_sym("nle_set_seed");
}

void *
get_sym(const char *sym)
{
dlerror(); /* Clear.*/
void *result = dlsym(handle_, sym);
const char *error = dlerror();
if (error) {
throw std::runtime_error(error);
}
return result;
}

const std::string dlpath_;

void *handle_{ nullptr };
void *nle_ctx_{ nullptr };

void *(*start_)(nle_obs *, FILE *, nle_seeds_init_t *, nle_settings *);
void *(*step_)(void *, nle_obs *);
void (*end_)(void *);
void (*get_seed_)(void *, unsigned long *, unsigned long *, char *);
void (*set_seed_)(void *, unsigned long, unsigned long, char);
};
#endif
25 changes: 13 additions & 12 deletions sys/unix/rlmain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
#include <termios.h>
#include <unistd.h>

#include "nleinstance.h"

extern "C" {
#include "hack.h"
#include "nledl.h"
}

class ScopedTC
Expand All @@ -33,7 +34,7 @@ class ScopedTC
};

void
play(nledl_ctx *nle, nle_obs *obs, nle_settings *settings)
play(Instance &nle, nle_obs *obs, nle_settings *settings)
{
while (!obs->done) {
for (int r = 0; r < ROWNO; ++r) {
Expand All @@ -47,13 +48,13 @@ play(nledl_ctx *nle, nle_obs *obs, nle_settings *settings)
std::cout << std::endl;
read(STDIN_FILENO, &obs->action, 1);
if (obs->action == 'r')
nle_reset(nle, obs, nullptr, nullptr, settings);
nle = nle_step(nle, obs);
nle.reset(obs, nullptr, nullptr, settings);
nle.step(obs);
}
}

void
randplay(nledl_ctx *nle, nle_obs *obs)
randplay(Instance &nle, nle_obs *obs)
{
int actions[] = {
13, 107, 108, 106, 104, 117, 110, 98, 121,
Expand All @@ -63,7 +64,7 @@ randplay(nledl_ctx *nle, nle_obs *obs)

for (int i = 0; !obs->done && i < 10000; ++i) {
obs->action = actions[rand() % n];
nle = nle_step(nle, obs);
nle.step(obs);
}
if (!obs->done) {
std::cerr << "Episode didn't end after 10000 steps, aborting."
Expand All @@ -72,13 +73,13 @@ randplay(nledl_ctx *nle, nle_obs *obs)
}

void
randgame(nledl_ctx *nle, nle_obs *obs, const int no_episodes,
randgame(Instance &nle, nle_obs *obs, const int no_episodes,
nle_settings *settings)
{
for (int i = 0; i < no_episodes; ++i) {
randplay(nle, obs);
if (i < no_episodes - 1)
nle_reset(nle, obs, nullptr, nullptr, settings);
nle.reset(obs, nullptr, nullptr, settings);
}
}

Expand Down Expand Up @@ -118,14 +119,14 @@ main(int argc, char **argv)
strncpy(settings.hackdir, getenv("HACKDIR"), sizeof(settings.hackdir));

ScopedTC tc;
nledl_ctx *nle =
nle_start("libnethack.so", &obs, ttyrec.get(), nullptr, &settings);

Instance nle("libnethack.so", &obs, ttyrec.get(), nullptr, &settings);
if (argc > 1 && argv[1][0] == 'r') {
randgame(nle, &obs, 3, &settings);
} else {
play(nle, &obs, &settings);
nle_reset(nle, &obs, nullptr, nullptr, &settings);
nle.reset(&obs, nullptr, nullptr, &settings);
play(nle, &obs, &settings);
}
nle_end(nle);
nle.close();
}
Loading

0 comments on commit 205740f

Please sign in to comment.