From 88d1be2e555309b697e26dbc9eee9b74987eaf8f Mon Sep 17 00:00:00 2001 From: Bam4d Date: Tue, 17 Oct 2023 23:17:35 +0100 Subject: [PATCH] simplifying action interface code --- bindings/python.cpp | 2 - bindings/wrapper/GameProcess.cpp | 183 ++++++++++++++++++++++++------- bindings/wrapper/Player.cpp | 174 +---------------------------- deps/conanfile.txt | 2 +- python/griddly/__init__.py | 3 - python/griddly/gd/__init__.py | 2 +- python/griddly/gym.py | 68 +++--------- 7 files changed, 168 insertions(+), 266 deletions(-) diff --git a/bindings/python.cpp b/bindings/python.cpp index 7dc30b26..a3a42f98 100644 --- a/bindings/python.cpp +++ b/bindings/python.cpp @@ -107,8 +107,6 @@ PYBIND11_MODULE(python_griddly, m) { game_process.def("seed", &Py_GameProcess::seedRandomGenerator); py::class_> player(m, "Player"); - player.def("step", &Py_Player::stepSingle); - player.def("step_multi", &Py_Player::stepMulti); player.def("observe", &Py_Player::observe); player.def("get_observation_description", &Py_Player::getObservationDescription); diff --git a/bindings/wrapper/GameProcess.cpp b/bindings/wrapper/GameProcess.cpp index f20a18d8..38c7972d 100644 --- a/bindings/wrapper/GameProcess.cpp +++ b/bindings/wrapper/GameProcess.cpp @@ -29,6 +29,11 @@ class ValidActionNode { } }; +struct InfoAndTruncated { + py::dict info; + bool truncated; +}; + class Py_GameProcess { public: Py_GameProcess(std::string globalObserverName, std::shared_ptr gdyFactory) : gdyFactory_(gdyFactory) { @@ -194,11 +199,23 @@ class Py_GameProcess { spdlog::debug("Dims: {0}", stepArrayInfo.ndim); + if (stepArrayInfo.ndim != 3) { + auto error = fmt::format("Invalid number of dimensions {0}, must be 3. [P,N,A] (Num Players, Number of actions, Action Size)", stepArrayInfo.ndim); + spdlog::error(error); + throw std::invalid_argument(error); + } + auto playerStride = stepArrayInfo.strides[0] / sizeof(int32_t); - auto actionArrayStride = stepArrayInfo.strides[1] / sizeof(int32_t); + auto numActionStride = stepArrayInfo.strides[1] / sizeof(int32_t); + auto actionArrayStride = stepArrayInfo.strides[2] / sizeof(int32_t); + + spdlog::debug("Strides: {0}, {1}, {2}", playerStride, numActionStride, actionArrayStride); auto playerSize = stepArrayInfo.shape[0]; - auto actionSize = stepArrayInfo.shape[1]; + auto numActionSize = stepArrayInfo.shape[1]; + auto actionSize = stepArrayInfo.shape[2]; + + spdlog::debug("Shapes: {0}, {1}, {2}", playerSize, numActionSize, actionSize); if (playerSize != playerCount_) { auto error = fmt::format("The number of players {0} does not match the first dimension of the parallel action.", playerCount_); @@ -223,47 +240,60 @@ class Py_GameProcess { for (int i = 0; i < playerSize; i++) { auto p = playerIdx[i]; - std::string actionName; - std::vector actionArray; - auto pStr = (int32_t*)stepArrayInfo.ptr + p * playerStride; - + auto pStr = p * playerStride; bool lastPlayer = i == (playerSize - 1); + std::vector> actions{}; + for (int n = 0; n < numActionSize; n++) { + std::string actionName; + std::vector actionArray; + auto nStr = n * numActionStride; + auto offset = (int32_t*)stepArrayInfo.ptr + pStr + nStr; + + spdlog::debug("Player {0} action {1} offset {2}", p, n, pStr + nStr); + + switch (actionSize) { + case 1: + actionName = externalActionNames.at(0); + actionArray.push_back(*(offset + 0 * actionArrayStride)); + break; + case 2: + actionName = externalActionNames.at(*(offset + 0 * actionArrayStride)); + actionArray.push_back(*(offset + 1 * actionArrayStride)); + break; + case 3: + actionName = externalActionNames.at(0); + actionArray.push_back(*(offset + 0 * actionArrayStride)); + actionArray.push_back(*(offset + 1 * actionArrayStride)); + actionArray.push_back(*(offset + 2 * actionArrayStride)); + break; + case 4: + actionArray.push_back(*(offset + 0 * actionArrayStride)); + actionArray.push_back(*(offset + 1 * actionArrayStride)); + actionName = externalActionNames.at(*(offset + 2 * actionArrayStride)); + actionArray.push_back(*(offset + 3 * actionArrayStride)); + break; + default: { + auto error = fmt::format("Invalid action size, {0}", actionSize); + spdlog::error(error); + throw std::invalid_argument(error); + } + } - switch (actionSize) { - case 1: - actionName = externalActionNames.at(0); - actionArray.push_back(*(pStr + 0 * actionArrayStride)); - break; - case 2: - actionName = externalActionNames.at(*(pStr + 0 * actionArrayStride)); - actionArray.push_back(*(pStr + 1 * actionArrayStride)); - break; - case 3: - actionArray.push_back(*(pStr + 0 * actionArrayStride)); - actionArray.push_back(*(pStr + 1 * actionArrayStride)); - actionName = externalActionNames.at(0); - actionArray.push_back(*(pStr + 2 * actionArrayStride)); - break; - case 4: - actionArray.push_back(*(pStr + 0 * actionArrayStride)); - actionArray.push_back(*(pStr + 1 * actionArrayStride)); - actionName = externalActionNames.at(*(pStr + 2 * actionArrayStride)); - actionArray.push_back(*(pStr + 3 * actionArrayStride)); - break; - default: { - auto error = fmt::format("Invalid action size, {0}", actionSize); - spdlog::error(error); - throw std::invalid_argument(error); + spdlog::debug("Creating action for player {0} with name {1}", p, actionName); + auto action = buildAction(players_[p]->unwrapped(), actionName, actionArray); + if (action != nullptr) { + actions.push_back(action); } } + auto actionResult = players_[p]->performActions(actions, lastPlayer); + terminated = actionResult.terminated; + truncated = actionResult.truncated; - auto playerStepResult = players_[p]->stepSingle(actionName, actionArray, lastPlayer); - - // playerRewards.push_back(playerStepResult[0].cast()); if (lastPlayer) { - terminated = playerStepResult[0].cast(); - truncated = playerStepResult[1].cast(); - info = playerStepResult[2]; + spdlog::debug("Last player, updating ticks"); + auto info_and_truncated = buildInfo(actionResult); + info = info_and_truncated.info; + truncated = info_and_truncated.truncated; } } @@ -549,5 +579,84 @@ class Py_GameProcess { const std::shared_ptr gdyFactory_; uint32_t playerCount_ = 0; std::vector> players_; + + InfoAndTruncated buildInfo(ActionResult actionResult) { + py::dict py_info; + bool truncated = false; + + if (actionResult.terminated) { + py::dict py_playerResults; + + for (auto playerRes : actionResult.playerStates) { + std::string playerStatusString; + switch (playerRes.second) { + case TerminationState::WIN: + playerStatusString = "Win"; + break; + case TerminationState::LOSE: + playerStatusString = "Lose"; + break; + case TerminationState::NONE: + playerStatusString = "End"; + break; + case TerminationState::TRUNCATED: + truncated = true; + playerStatusString = "Truncated"; + break; + } + + if (playerStatusString.size() > 0) { + py_playerResults[std::to_string(playerRes.first).c_str()] = playerStatusString; + } + } + py_info["PlayerResults"] = py_playerResults; + } + + return {py_info, truncated}; + } + + std::shared_ptr buildAction(std::shared_ptr player, std::string actionName, std::vector actionArray) { + const auto& actionInputsDefinition = gdyFactory_->findActionInputsDefinition(actionName); + auto playerAvatar = player->getAvatar(); + auto playerId = player->getId(); + + const auto& inputMappings = actionInputsDefinition.inputMappings; + + if (playerAvatar != nullptr) { + auto actionId = actionArray[0]; + + if (inputMappings.find(actionId) == inputMappings.end()) { + return nullptr; + } + + const auto& mapping = inputMappings.at(actionId); + const auto& vectorToDest = mapping.vectorToDest; + const auto& orientationVector = mapping.orientationVector; + const auto& metaData = mapping.metaData; + const auto& action = std::make_shared(Action(gameProcess_->getGrid(), actionName, playerId, 0, metaData)); + action->init(playerAvatar, vectorToDest, orientationVector, actionInputsDefinition.relative); + + return action; + } else { + glm::ivec2 sourceLocation = {actionArray[0], actionArray[1]}; + + auto actionId = actionArray[2]; + + if (inputMappings.find(actionId) == inputMappings.end()) { + return nullptr; + } + + const auto& mapping = inputMappings.at(actionId); + const auto& vector = mapping.vectorToDest; + const auto& orientationVector = mapping.orientationVector; + const auto& metaData = mapping.metaData; + glm::ivec2 destinationLocation = sourceLocation + vector; + + auto action = std::make_shared(Action(gameProcess_->getGrid(), actionName, playerId, 0, metaData)); + action->init(sourceLocation, destinationLocation); + + return action; + } + } }; } // namespace griddly \ No newline at end of file diff --git a/bindings/wrapper/Player.cpp b/bindings/wrapper/Player.cpp index 9bf98f1d..9ea31352 100644 --- a/bindings/wrapper/Player.cpp +++ b/bindings/wrapper/Player.cpp @@ -36,184 +36,14 @@ class Py_Player { return wrapObservation(player_->getObserver()); } - py::tuple stepMulti(py::buffer stepArray, bool updateTicks) { - auto externalActionNames = gdyFactory_->getExternalActionNames(); - auto gameProcess = player_->getGameProcess(); - - if (gameProcess != nullptr && !gameProcess->isInitialized()) { - throw std::invalid_argument("Cannot send player commands when game has not been initialized."); - } - - auto stepArrayInfo = stepArray.request(); - if (stepArrayInfo.format != "l" && stepArrayInfo.format != "i") { - auto error = fmt::format("Invalid data type {0}, must be an integer.", stepArrayInfo.format); - spdlog::error(error); - throw std::invalid_argument(error); - } - - auto actionStride = stepArrayInfo.strides[0] / sizeof(int32_t); - auto actionArrayStride = stepArrayInfo.strides[1] / sizeof(int32_t); - - auto actionCount = stepArrayInfo.shape[0]; - auto actionSize = stepArrayInfo.shape[1]; - - spdlog::debug("action stride: {0}", actionStride); - spdlog::debug("action array stride: {0}", actionArrayStride); - spdlog::debug("action count: {0}", actionCount); - spdlog::debug("action size: {0}", actionSize); - - std::vector> actions; - for (int a = 0; a < actionCount; a++) { - std::string actionName; - std::vector actionArray; - auto pStr = (int32_t*)stepArrayInfo.ptr + a * actionStride; - - switch (actionSize) { - case 1: - actionName = externalActionNames.at(0); - actionArray.push_back(*(pStr + 0 * actionArrayStride)); - break; - case 2: - actionName = externalActionNames.at(*(pStr + 0 * actionArrayStride)); - actionArray.push_back(*(pStr + 1 * actionArrayStride)); - break; - case 3: - actionArray.push_back(*(pStr + 0 * actionArrayStride)); - actionArray.push_back(*(pStr + 1 * actionArrayStride)); - actionName = externalActionNames.at(0); - actionArray.push_back(*(pStr + 2 * actionArrayStride)); - break; - case 4: - actionArray.push_back(*(pStr + 0 * actionArrayStride)); - actionArray.push_back(*(pStr + 1 * actionArrayStride)); - actionName = externalActionNames.at(*(pStr + 2 * actionArrayStride)); - actionArray.push_back(*(pStr + 3 * actionArrayStride)); - break; - default: { - auto error = fmt::format("Invalid action size, {0}", actionSize); - spdlog::error(error); - throw std::invalid_argument(error); - } - } - - auto action = buildAction(actionName, actionArray); - if (action != nullptr) { - actions.push_back(action); - } - } - - auto actionResult = player_->performActions(actions, updateTicks); - auto info_and_truncated = buildInfo(actionResult); - auto info = info_and_truncated[0]; - auto truncated = info_and_truncated[1]; - auto rewards = gameProcess_->getAccumulatedRewards(player_->getId()); - return py::make_tuple(rewards, actionResult.terminated, truncated, info); - } - - py::tuple stepSingle(std::string actionName, std::vector actionArray, bool updateTicks) { - if (gameProcess_ != nullptr && !gameProcess_->isInitialized()) { - throw std::invalid_argument("Cannot send player commands when game has not been initialized."); - } - - auto action = buildAction(actionName, actionArray); - - ActionResult actionResult; - if (action != nullptr) { - actionResult = player_->performActions({action}, updateTicks); - } else { - actionResult = player_->performActions({}, updateTicks); - } - - auto info_and_truncated = buildInfo(actionResult); - auto info = info_and_truncated[0]; - auto truncated = info_and_truncated[1]; - - return py::make_tuple(actionResult.terminated, truncated, info); + ActionResult performActions(std::vector> actions, bool updateTicks) { + return player_->performActions(actions, updateTicks); } private: const std::shared_ptr player_; const std::shared_ptr gdyFactory_; const std::shared_ptr gameProcess_; - - py::tuple buildInfo(ActionResult actionResult) { - py::dict py_info; - bool truncated = false; - - if (actionResult.terminated) { - py::dict py_playerResults; - - for (auto playerRes : actionResult.playerStates) { - std::string playerStatusString; - switch (playerRes.second) { - case TerminationState::WIN: - playerStatusString = "Win"; - break; - case TerminationState::LOSE: - playerStatusString = "Lose"; - break; - case TerminationState::NONE: - playerStatusString = "End"; - break; - case TerminationState::TRUNCATED: - truncated = true; - playerStatusString = "Truncated"; - break; - } - - if (playerStatusString.size() > 0) { - py_playerResults[std::to_string(playerRes.first).c_str()] = playerStatusString; - } - } - py_info["PlayerResults"] = py_playerResults; - } - - return py::make_tuple(py_info, truncated); - } - - std::shared_ptr buildAction(std::string actionName, std::vector actionArray) { - const auto& actionInputsDefinition = gdyFactory_->findActionInputsDefinition(actionName); - auto playerAvatar = player_->getAvatar(); - auto playerId = player_->getId(); - - const auto& inputMappings = actionInputsDefinition.inputMappings; - - if (playerAvatar != nullptr) { - auto actionId = actionArray[0]; - - if (inputMappings.find(actionId) == inputMappings.end()) { - return nullptr; - } - - const auto& mapping = inputMappings.at(actionId); - const auto& vectorToDest = mapping.vectorToDest; - const auto& orientationVector = mapping.orientationVector; - const auto& metaData = mapping.metaData; - const auto& action = std::make_shared(Action(gameProcess_->getGrid(), actionName, playerId, 0, metaData)); - action->init(playerAvatar, vectorToDest, orientationVector, actionInputsDefinition.relative); - - return action; - } else { - glm::ivec2 sourceLocation = {actionArray[0], actionArray[1]}; - - auto actionId = actionArray[2]; - - if (inputMappings.find(actionId) == inputMappings.end()) { - return nullptr; - } - - const auto& mapping = inputMappings.at(actionId); - const auto& vector = mapping.vectorToDest; - const auto& orientationVector = mapping.orientationVector; - const auto& metaData = mapping.metaData; - glm::ivec2 destinationLocation = sourceLocation + vector; - - auto action = std::make_shared(Action(gameProcess_->getGrid(), actionName, playerId, 0, metaData)); - action->init(sourceLocation, destinationLocation); - - return action; - } - } }; } // namespace griddly \ No newline at end of file diff --git a/deps/conanfile.txt b/deps/conanfile.txt index ec574b05..576c146b 100644 --- a/deps/conanfile.txt +++ b/deps/conanfile.txt @@ -7,7 +7,7 @@ yaml-cpp/0.6.3 spdlog/1.9.2 stb/20200203 volk/1.3.243.0 -boost/1.82.0 +boost/1.83.0 [generators] diff --git a/python/griddly/__init__.py b/python/griddly/__init__.py index 9c76a660..802eb6bb 100644 --- a/python/griddly/__init__.py +++ b/python/griddly/__init__.py @@ -1,9 +1,6 @@ import os -from typing import Any, Dict - import yaml -from griddly import gd from griddly.gym import GymWrapperFactory diff --git a/python/griddly/gd/__init__.py b/python/griddly/gd/__init__.py index 989e1b9d..75de9899 100644 --- a/python/griddly/gd/__init__.py +++ b/python/griddly/gd/__init__.py @@ -8,7 +8,7 @@ sys.path.extend([libs_path]) -debug_path = os.path.join(module_path, "../../Debug/bin") +debug_path = os.path.join(module_path, "../../../Debug/bin") sys.path.extend([debug_path]) # Load the binary diff --git a/python/griddly/gym.py b/python/griddly/gym.py index a77ebc1b..cccef3f3 100644 --- a/python/griddly/gym.py +++ b/python/griddly/gym.py @@ -376,68 +376,36 @@ def step( # type: ignore Step for a particular player in the environment """ - player_id = 0 reward: Union[List[int], int] - if self.player_count == 1: - action = np.array(action, dtype=np.int32).reshape(1, -1, len(self.action_space_parts)) + ragged_actions = [] + max_num_actions = 1 - max_num_actions = 0 - for a in action: - if len(action) > max_num_actions: - max_num_actions = len(action) + if self.player_count == 1: + ragged_actions.append(np.array(action, dtype=np.int32).reshape(-1, len(self.action_space_parts))) + max_num_actions = ragged_actions[0].shape[0] + else: + for p in range(self.player_count): + if isinstance(action, list): + ragged_actions.append(np.array(action[p], dtype=np.int32).reshape(-1, len(self.action_space_parts))) + else: + ragged_actions.append(np.array(action, dtype=np.int32).reshape(-1, len(self.action_space_parts))) + + if ragged_actions[p].shape[0] > max_num_actions: + max_num_actions = ragged_actions[p].shape[0] action_data = np.zeros((self.player_count, max_num_actions, len(self.action_space_parts)), dtype=np.int32) for p in range(self.player_count): - for i, a in enumerate(action[p]): + for i, a in enumerate(ragged_actions[p]): action_data[p, i] = a reward, done, truncated, info = self.game.step_parallel(action_data) - # Simple agents executing single actions or multiple actions in a single time step - # if self.player_count == 1: - - # action_data = np.array(action, dtype=np.int32).reshape(-1, len(self.action_space_parts)) - - # reward, done, truncated, info = self._players[player_id].step_multi( - # action_data, True - # ) - - # else: - - # processed_actions = [] - # multi_action = False - - # # Replace any None actions with a zero action - # for a in action: - # processed_action = ( - # np.array(a, dtype=np.int32).reshape(-1, len(self.action_space_parts)) - # if a is not None - # else np.zeros((1, len(self.action_space_parts)), dtype=np.int32) - # ) - # processed_actions.append(processed_action) - # if len(processed_action.shape) > 1 and processed_action.shape[0] > 1: - # multi_action = True - - # if not self.has_avatar and multi_action: - # # Multiple agents that can perform multiple actions in parallel - # # Used in RTS games - # reward = [] - # for p in range(self.player_count): - # final = p == self.player_count - 1 - # rew, done, truncated, info = self._players[p].step_multi( - # action, final - # ) - # reward.append(rew) - - # # Multiple agents executing actions in parallel - # # Used in multi-agent environments - # else: - # action_data = np.array(processed_actions, dtype=np.int32) - # action_data = action_data.reshape(self.player_count, -1) - # reward, done, truncated, info = self.game.step_parallel(action_data) + # Compatibility with gymnasium + if self.player_count == 1: + reward = reward[0] # In the case where the environment is cloned, but no step has happened to replace the last obs, # we can do that here