Skip to content

Commit

Permalink
simplifying action interface code
Browse files Browse the repository at this point in the history
  • Loading branch information
Bam4d committed Oct 17, 2023
1 parent e239a98 commit 88d1be2
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 266 deletions.
2 changes: 0 additions & 2 deletions bindings/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ PYBIND11_MODULE(python_griddly, m) {
game_process.def("seed", &Py_GameProcess::seedRandomGenerator);

py::class_<Py_Player, std::shared_ptr<Py_Player>> player(m, "Player");
player.def("step", &Py_Player::stepSingle);
player.def("step_multi", &Py_Player::stepMulti);
player.def("observe", &Py_Player::observe);
player.def("get_observation_description", &Py_Player::getObservationDescription);

Expand Down
183 changes: 146 additions & 37 deletions bindings/wrapper/GameProcess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ class ValidActionNode {
}
};

struct InfoAndTruncated {
py::dict info;
bool truncated;
};

class Py_GameProcess {
public:
Py_GameProcess(std::string globalObserverName, std::shared_ptr<GDYFactory> gdyFactory) : gdyFactory_(gdyFactory) {
Expand Down Expand Up @@ -194,11 +199,23 @@ class Py_GameProcess {

spdlog::debug("Dims: {0}", stepArrayInfo.ndim);

if (stepArrayInfo.ndim != 3) {
auto error = fmt::format("Invalid number of dimensions {0}, must be 3. [P,N,A] (Num Players, Number of actions, Action Size)", stepArrayInfo.ndim);
spdlog::error(error);
throw std::invalid_argument(error);
}

auto playerStride = stepArrayInfo.strides[0] / sizeof(int32_t);
auto actionArrayStride = stepArrayInfo.strides[1] / sizeof(int32_t);
auto numActionStride = stepArrayInfo.strides[1] / sizeof(int32_t);
auto actionArrayStride = stepArrayInfo.strides[2] / sizeof(int32_t);

spdlog::debug("Strides: {0}, {1}, {2}", playerStride, numActionStride, actionArrayStride);

auto playerSize = stepArrayInfo.shape[0];
auto actionSize = stepArrayInfo.shape[1];
auto numActionSize = stepArrayInfo.shape[1];
auto actionSize = stepArrayInfo.shape[2];

spdlog::debug("Shapes: {0}, {1}, {2}", playerSize, numActionSize, actionSize);

if (playerSize != playerCount_) {
auto error = fmt::format("The number of players {0} does not match the first dimension of the parallel action.", playerCount_);
Expand All @@ -223,47 +240,60 @@ class Py_GameProcess {

for (int i = 0; i < playerSize; i++) {
auto p = playerIdx[i];
std::string actionName;
std::vector<int32_t> actionArray;
auto pStr = (int32_t*)stepArrayInfo.ptr + p * playerStride;

auto pStr = p * playerStride;
bool lastPlayer = i == (playerSize - 1);
std::vector<std::shared_ptr<Action>> actions{};
for (int n = 0; n < numActionSize; n++) {
std::string actionName;
std::vector<int32_t> actionArray;
auto nStr = n * numActionStride;
auto offset = (int32_t*)stepArrayInfo.ptr + pStr + nStr;

spdlog::debug("Player {0} action {1} offset {2}", p, n, pStr + nStr);

switch (actionSize) {
case 1:
actionName = externalActionNames.at(0);
actionArray.push_back(*(offset + 0 * actionArrayStride));
break;
case 2:
actionName = externalActionNames.at(*(offset + 0 * actionArrayStride));
actionArray.push_back(*(offset + 1 * actionArrayStride));
break;
case 3:
actionName = externalActionNames.at(0);
actionArray.push_back(*(offset + 0 * actionArrayStride));
actionArray.push_back(*(offset + 1 * actionArrayStride));
actionArray.push_back(*(offset + 2 * actionArrayStride));
break;
case 4:
actionArray.push_back(*(offset + 0 * actionArrayStride));
actionArray.push_back(*(offset + 1 * actionArrayStride));
actionName = externalActionNames.at(*(offset + 2 * actionArrayStride));
actionArray.push_back(*(offset + 3 * actionArrayStride));
break;
default: {
auto error = fmt::format("Invalid action size, {0}", actionSize);
spdlog::error(error);
throw std::invalid_argument(error);
}
}

switch (actionSize) {
case 1:
actionName = externalActionNames.at(0);
actionArray.push_back(*(pStr + 0 * actionArrayStride));
break;
case 2:
actionName = externalActionNames.at(*(pStr + 0 * actionArrayStride));
actionArray.push_back(*(pStr + 1 * actionArrayStride));
break;
case 3:
actionArray.push_back(*(pStr + 0 * actionArrayStride));
actionArray.push_back(*(pStr + 1 * actionArrayStride));
actionName = externalActionNames.at(0);
actionArray.push_back(*(pStr + 2 * actionArrayStride));
break;
case 4:
actionArray.push_back(*(pStr + 0 * actionArrayStride));
actionArray.push_back(*(pStr + 1 * actionArrayStride));
actionName = externalActionNames.at(*(pStr + 2 * actionArrayStride));
actionArray.push_back(*(pStr + 3 * actionArrayStride));
break;
default: {
auto error = fmt::format("Invalid action size, {0}", actionSize);
spdlog::error(error);
throw std::invalid_argument(error);
spdlog::debug("Creating action for player {0} with name {1}", p, actionName);
auto action = buildAction(players_[p]->unwrapped(), actionName, actionArray);
if (action != nullptr) {
actions.push_back(action);
}
}
auto actionResult = players_[p]->performActions(actions, lastPlayer);
terminated = actionResult.terminated;
truncated = actionResult.truncated;

auto playerStepResult = players_[p]->stepSingle(actionName, actionArray, lastPlayer);

// playerRewards.push_back(playerStepResult[0].cast<int32_t>());
if (lastPlayer) {
terminated = playerStepResult[0].cast<bool>();
truncated = playerStepResult[1].cast<bool>();
info = playerStepResult[2];
spdlog::debug("Last player, updating ticks");
auto info_and_truncated = buildInfo(actionResult);
info = info_and_truncated.info;
truncated = info_and_truncated.truncated;
}
}

Expand Down Expand Up @@ -549,5 +579,84 @@ class Py_GameProcess {
const std::shared_ptr<GDYFactory> gdyFactory_;
uint32_t playerCount_ = 0;
std::vector<std::shared_ptr<Py_Player>> players_;

InfoAndTruncated buildInfo(ActionResult actionResult) {
py::dict py_info;
bool truncated = false;

if (actionResult.terminated) {
py::dict py_playerResults;

for (auto playerRes : actionResult.playerStates) {
std::string playerStatusString;
switch (playerRes.second) {
case TerminationState::WIN:
playerStatusString = "Win";
break;
case TerminationState::LOSE:
playerStatusString = "Lose";
break;
case TerminationState::NONE:
playerStatusString = "End";
break;
case TerminationState::TRUNCATED:
truncated = true;
playerStatusString = "Truncated";
break;
}

if (playerStatusString.size() > 0) {
py_playerResults[std::to_string(playerRes.first).c_str()] = playerStatusString;
}
}
py_info["PlayerResults"] = py_playerResults;
}

return {py_info, truncated};
}

std::shared_ptr<Action> buildAction(std::shared_ptr<Player> player, std::string actionName, std::vector<int32_t> actionArray) {
const auto& actionInputsDefinition = gdyFactory_->findActionInputsDefinition(actionName);
auto playerAvatar = player->getAvatar();
auto playerId = player->getId();

const auto& inputMappings = actionInputsDefinition.inputMappings;

if (playerAvatar != nullptr) {
auto actionId = actionArray[0];

if (inputMappings.find(actionId) == inputMappings.end()) {
return nullptr;
}

const auto& mapping = inputMappings.at(actionId);
const auto& vectorToDest = mapping.vectorToDest;
const auto& orientationVector = mapping.orientationVector;
const auto& metaData = mapping.metaData;
const auto& action = std::make_shared<Action>(Action(gameProcess_->getGrid(), actionName, playerId, 0, metaData));
action->init(playerAvatar, vectorToDest, orientationVector, actionInputsDefinition.relative);

return action;
} else {
glm::ivec2 sourceLocation = {actionArray[0], actionArray[1]};

auto actionId = actionArray[2];

if (inputMappings.find(actionId) == inputMappings.end()) {
return nullptr;
}

const auto& mapping = inputMappings.at(actionId);
const auto& vector = mapping.vectorToDest;
const auto& orientationVector = mapping.orientationVector;
const auto& metaData = mapping.metaData;
glm::ivec2 destinationLocation = sourceLocation + vector;

auto action = std::make_shared<Action>(Action(gameProcess_->getGrid(), actionName, playerId, 0, metaData));
action->init(sourceLocation, destinationLocation);

return action;
}
}
};
} // namespace griddly
Loading

0 comments on commit 88d1be2

Please sign in to comment.