Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding option to return ram state in info #297

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions envpool/atari/atari_env.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class AtariEnvFns {
"img_height"_.Bind(84), "img_width"_.Bind(84),
"task"_.Bind(std::string("pong")), "full_action_space"_.Bind(false),
"repeat_action_probability"_.Bind(0.0f),
"use_inter_area_resize"_.Bind(true), "gray_scale"_.Bind(true));
"use_inter_area_resize"_.Bind(true), "gray_scale"_.Bind(true), "expose_ram"_.Bind(false));
}
template <typename Config>
static decltype(auto) StateSpec(const Config& conf) {
Expand All @@ -66,7 +66,9 @@ class AtariEnvFns {
{0, 255})),
"info:lives"_.Bind(Spec<int>({-1})),
"info:reward"_.Bind(Spec<float>({-1})),
"info:terminated"_.Bind(Spec<int>({-1}, {0, 1})));
"info:terminated"_.Bind(Spec<int>({-1}, {0, 1})),
"info:ram"_.Bind(Spec<uint8_t>({128}, {0, 255}))
);
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
Expand Down Expand Up @@ -99,6 +101,7 @@ class AtariEnv : public Env<AtariEnvSpec> {
std::vector<Array> maxpool_buf_;
Array resize_img_;
std::uniform_int_distribution<> dist_noop_;
bool expose_ram_{false};
std::string rom_path_;

public:
Expand All @@ -121,6 +124,7 @@ class AtariEnv : public Env<AtariEnvSpec> {
spec.config["img_width"_]}),
resize_img_(resize_spec_),
dist_noop_(0, spec.config["noop_max"_] - 1),
expose_ram_(spec.config["expose_ram"_]),
rom_path_(GetRomPath(spec.config["base_path"_], spec.config["task"_])) {
env_->setFloat("repeat_action_probability",
spec.config["repeat_action_probability"_]);
Expand Down Expand Up @@ -247,6 +251,23 @@ class AtariEnv : public Env<AtariEnvSpec> {
.Slice(gray_scale_ ? i : i * 3, gray_scale_ ? i + 1 : (i + 1) * 3)
.Assign(stack_buf_[i]);
}
// Optionally add RAM state if expose_ram_ is true
if (expose_ram_) {
// const auto& ram = env_->getRAM(); // Get a reference to the RAM.
// const size_t ram_size = ram.size(); // Obtain the size of the RAM.
// const uint8_t* ram_data_ptr = ram.data();
// std::vector<uint8_t> ram_data(ram_data_ptr, ram_data_ptr + ram_size);
const size_t ram_size = env_->getRAM().size();
std::vector<uint8_t> ram_data(ram_size);

// Assuming getRAM().array() gives direct access to the RAM data
const uint8_t* ale_ram = env_->getRAM().array();
std::copy(ale_ram, ale_ram + ram_size, ram_data.begin());
state["info:ram"_].Assign(ale_ram, ram_size);
// for (size_t i = 0; i < ram_size; ++i) {
// state["ram"_].At(i) = ram[i]; // Directly write RAM data into state
// }
}
}

/**
Expand Down
Loading