From e2db04627aebeba74d5ff77592f9abec306c33c0 Mon Sep 17 00:00:00 2001 From: Bam4d Date: Thu, 2 Dec 2021 18:28:01 +0000 Subject: [PATCH] creating entity observations in c++ --- bindings/python.cpp | 12 ++ bindings/wrapper/EntityObserverWrapper.cpp | 140 +++++++++++++++++++++ bindings/wrapper/GameWrapper.cpp | 57 +++------ src/Griddly/Core/Grid.cpp | 8 +- src/Griddly/Core/Grid.hpp | 8 +- 5 files changed, 186 insertions(+), 39 deletions(-) create mode 100644 bindings/wrapper/EntityObserverWrapper.cpp diff --git a/bindings/python.cpp b/bindings/python.cpp index a08b2ce72..8dcf79948 100644 --- a/bindings/python.cpp +++ b/bindings/python.cpp @@ -88,12 +88,24 @@ PYBIND11_MODULE(python_griddly, m) { // Get list of possible variable names, ordered by ID game_process.def("get_object_variable_names", &Py_GameWrapper::getObjectVariableNames); + // Get a mapping of objects to their variable names + game_process.def("get_object_variable_map", &Py_GameWrapper::getObjectVariableMap); + + // Get a list of the global variable names + game_process.def("get_global_variable_names", &Py_GameWrapper::getGlobalVariableNames); + // Get a list of the events that have happened in the game up to this point game_process.def("get_history", &Py_GameWrapper::getHistory, py::arg("purge")=true); // Release resources for vulkan stuff game_process.def("release", &Py_GameWrapper::release); + // Create an entity observer given a configuration of the entities and the custom variables that we want to view in the features + game_process.def("get_entity_observer", &Py_GameWrapper::createEntityObserver, py::arg("config")=py::dict()); + + py::class_> entityObserver(m, "EntityObserver"); + entityObserver.def("observe", &Py_EntityObserverWrapper::observe); + py::class_> player(m, "Player"); player.def("step", &Py_StepPlayerWrapper::stepSingle); player.def("step_multi", &Py_StepPlayerWrapper::stepMulti); diff --git a/bindings/wrapper/EntityObserverWrapper.cpp b/bindings/wrapper/EntityObserverWrapper.cpp new file mode 100644 index 000000000..e44cae56b --- /dev/null +++ b/bindings/wrapper/EntityObserverWrapper.cpp @@ -0,0 +1,140 @@ +#pragma once + +#include + +#include "../../src/Griddly/Core/TurnBasedGameProcess.hpp" +#include "NumpyWrapper.cpp" +#include "StepPlayerWrapper.cpp" + +namespace griddly { + +class Py_EntityObserverWrapper { + public: + Py_EntityObserverWrapper(py::dict entityObserverConfig, std::shared_ptr gdyFactory, std::shared_ptr gameProcess) : gameProcess_(gameProcess), gdyFactory_(gdyFactory) { + spdlog::debug("Created entity observer."); + + if (entityObserverConfig.contains("VariableMapping")) { + entityVariableMapping_ = entityObserverConfig["variableMapping"].cast>>(); + } else { + entityVariableMapping_ = gameProcess_->getGrid()->getObjectVariableMap(); + } + + for (auto entityVariables : entityVariableMapping_) { + for (auto variableName : entityVariables.second) { + spdlog::debug("Entity {0}, will include variable {1} in entity observations.", entityVariables.first, variableName); + } + } + } + + py::dict observe(int playerId) { + py::dict observation; + + auto entityObservationsAndIds = buildEntityObservations(playerId); + auto actionsAndMasks = buildEntityMasks(playerId); + + observation["Entities"] = entityObservationsAndIds["Entities"]; + observation["EntityIds"] = entityObservationsAndIds["EntityIds"]; + observation["EntityMasks"] = actionsAndMasks; + + return observation; + } + + private: + // Build entity masks (for transformer architectures) + py::dict buildEntityMasks(int playerId) const { + std::map>> entityMasks; + std::map> entityIds; + + std::unordered_set allAvailableActionNames; + + py::dict entitiesAndMasks; + + auto grid = gameProcess_->getGrid(); + + for (auto actionNamesAtLocation : gameProcess_->getAvailableActionNames(playerId)) { + auto location = actionNamesAtLocation.first; + auto actionNames = actionNamesAtLocation.second; + + auto locationVec = glm::ivec2{location[0], location[1]}; + + for (auto actionName : actionNames) { + spdlog::debug("[{0}] available at location [{1}, {2}]", actionName, location.x, location.y); + + auto actionInputsDefinitions = gdyFactory_->getActionInputsDefinitions(); + std::vector mask(actionInputsDefinitions[actionName].inputMappings.size() + 1); + mask[0] = 1; // NOP is always available + + auto objectAtLocation = grid->getObject(location); + auto entityId = std::hash>()(objectAtLocation); + auto actionIdsForName = gameProcess_->getAvailableActionIdsAtLocation(locationVec, actionName); + + for (auto id : actionIdsForName) { + mask[id] = 1; + } + + entityMasks[actionName].push_back(mask); + entityIds[actionName].push_back(entityId); + + allAvailableActionNames.insert(actionName); + } + } + + for (auto actionName : allAvailableActionNames) { + py::dict entitiesAndMasksForAction; + entitiesAndMasksForAction["EntityIds"] = entityIds[actionName]; + entitiesAndMasksForAction["Masks"] = entityMasks[actionName]; + + entitiesAndMasks[actionName.c_str()] = entitiesAndMasksForAction; + } + + return entitiesAndMasks; + } + + // Build entity observations (for transformer architectures) + py::dict buildEntityObservations(int playerId) const { + py::dict entityObservationsAndIds; + + std::map>> entityObservations; + std::vector entityIds; + + auto grid = gameProcess_->getGrid(); + + for (auto object : grid->getObjects()) { + auto name = object->getObjectName(); + auto location = object->getLocation(); + auto orientation = object->getObjectOrientation().getUnitVector(); + auto objectPlayerId = object->getPlayerId(); + auto zIdx = object->getZIdx(); + + auto featureVariables = entityVariableMapping_.at(name); + + auto numVariables = featureVariables.size(); + auto numFeatures = 5 + numVariables; + + std::vector featureVector(numFeatures); + featureVector[0] = static_cast(location[0]); + featureVector[1] = static_cast(location[1]); + featureVector[2] = static_cast(zIdx); + featureVector[3] = static_cast(orientation[0] + 2 * orientation[1]); + featureVector[4] = static_cast(objectPlayerId); + for (int32_t i = 0; i < numVariables; i++) { + auto variableValue = *object->getVariableValue(featureVariables[i]); + featureVector[5 + i] = static_cast(variableValue); + } + + entityObservations[name].push_back(featureVector); + + entityIds.push_back(std::hash>()(object)); + } + + entityObservationsAndIds["Entities"] = entityObservations; + entityObservationsAndIds["EntityIds"] = entityIds; + + return entityObservationsAndIds; + } + + std::unordered_map> entityVariableMapping_; + const std::shared_ptr gdyFactory_; + const std::shared_ptr gameProcess_; +}; +} // namespace griddly \ No newline at end of file diff --git a/bindings/wrapper/GameWrapper.cpp b/bindings/wrapper/GameWrapper.cpp index fe636db7a..e6991b47a 100644 --- a/bindings/wrapper/GameWrapper.cpp +++ b/bindings/wrapper/GameWrapper.cpp @@ -5,6 +5,7 @@ #include "../../src/Griddly/Core/TurnBasedGameProcess.hpp" #include "NumpyWrapper.cpp" #include "StepPlayerWrapper.cpp" +#include "EntityObserverWrapper.cpp" namespace griddly { @@ -133,42 +134,6 @@ class Py_GameWrapper { return valid_action_trees; } - // Build entity masks (for transformer architectures) - py::dict buildEntityMasks(int playerId) const { - - } - - // Build entity observations (for transformer architectures) - py::dict buildEntityObservation(int playerId) const { - - // entity_observation = defaultdict(list) - // entity_ids = set() - // for i, object in enumerate(self._current_g_state["Objects"]): - // name = object["Name"] - // location = object["Location"] - // variables = object["Variables"] - - // # entity_ids.add(f'{location[0]},{location[1]}') - // # TODO: currently entity ids are a bit meaningless, but they have to be int or things break deeper down - // entity_ids.add(i) - - // feature_vec = np.zeros( - // len(self._obs_space.entities[name].features), dtype=np.float32 - // ) - // feature_vec[0] = location[0] - // feature_vec[1] = location[1] - // feature_vec[2] = 0 - // feature_vec[3] = orientation_feature(object["Orientation"]) - // feature_vec[4] = object["PlayerId"] - // for i, variable_name in enumerate(self._env.variable_names): - // feature_vec[5 + i] = ( - // variables[variable_name] if variable_name in variables else 0 - // ) - - // entity_observation[name].append(feature_vec) - - } - py::dict getAvailableActionNames(int playerId) const { auto availableActionNames = gameProcess_->getAvailableActionNames(playerId); @@ -393,6 +358,20 @@ class Py_GameWrapper { return py_state; } + std::vector getGlobalVariableNames() const { + std::vector globalVariableNames; + auto globalVariables = gameProcess_->getGrid()->getGlobalVariables(); + + for (auto globalVariableIt : globalVariables) { + globalVariableNames.push_back(globalVariableIt.first); + } + return globalVariableNames; + } + + py::dict getObjectVariableMap() const { + return py::cast(gameProcess_->getGrid()->getObjectVariableMap()); + } + py::dict getGlobalVariables(std::vector variables) const { py::dict py_globalVariables; auto globalVariables = gameProcess_->getGrid()->getGlobalVariables(); @@ -455,7 +434,11 @@ class Py_GameWrapper { } std::vector getObjectVariableNames() { - return gameProcess_->getGrid()->getObjectVariableNames(); + return gameProcess_->getGrid()->getAllObjectVariableNames(); + } + + std::shared_ptr createEntityObserver(py::dict entityObserverConfig) { + return std::make_shared(Py_EntityObserverWrapper(entityObserverConfig, gdyFactory_, gameProcess_)); } private: diff --git a/src/Griddly/Core/Grid.cpp b/src/Griddly/Core/Grid.cpp index 1c9f508b1..1ef36e24b 100644 --- a/src/Griddly/Core/Grid.cpp +++ b/src/Griddly/Core/Grid.cpp @@ -463,7 +463,7 @@ const std::vector Grid::getObjectNames() const { return orderedNames; } -const std::vector Grid::getObjectVariableNames() const { +const std::vector Grid::getAllObjectVariableNames() const { auto namesCount = objectVariableIds_.size(); std::vector orderedNames(namesCount); @@ -476,12 +476,18 @@ const std::vector Grid::getObjectVariableNames() const { return orderedNames; } +const std::unordered_map> Grid::getObjectVariableMap() const { + return objectVariableMap_; +} + void Grid::initObject(std::string objectName, std::vector variableNames) { objectIds_.insert({objectName, objectIds_.size()}); for (auto& variableName : variableNames) { objectVariableIds_.insert({variableName, objectVariableIds_.size()}); } + + objectVariableMap_[objectName] = variableNames; } std::unordered_map> Grid::getObjectCounter(std::string objectName) { diff --git a/src/Griddly/Core/Grid.hpp b/src/Griddly/Core/Grid.hpp index 0ad6fc3cb..eaaefb8aa 100644 --- a/src/Griddly/Core/Grid.hpp +++ b/src/Griddly/Core/Grid.hpp @@ -133,7 +133,12 @@ class Grid : public std::enable_shared_from_this { /** * Gets an ordered list of objectVariableNames */ - virtual const std::vector getObjectVariableNames() const; + virtual const std::vector getAllObjectVariableNames() const; + + /** + * Get a mapping of objects to their defined variables + */ + virtual const std::unordered_map> getObjectVariableMap() const; /** * Gets an ordered list of objectNames @@ -180,6 +185,7 @@ class Grid : public std::enable_shared_from_this { std::unordered_map objectIds_; std::unordered_map objectVariableIds_; + std::unordered_map> objectVariableMap_; std::unordered_set> objects_; std::unordered_map occupiedLocations_; std::unordered_map>> objectCounters_;