Skip to content

Commit

Permalink
creating entity observations in c++
Browse files Browse the repository at this point in the history
  • Loading branch information
Bam4d committed Dec 2, 2021
1 parent f47354a commit e2db046
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 39 deletions.
12 changes: 12 additions & 0 deletions bindings/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,24 @@ PYBIND11_MODULE(python_griddly, m) {
// Get list of possible variable names, ordered by ID
game_process.def("get_object_variable_names", &Py_GameWrapper::getObjectVariableNames);

// Get a mapping of objects to their variable names
game_process.def("get_object_variable_map", &Py_GameWrapper::getObjectVariableMap);

// Get a list of the global variable names
game_process.def("get_global_variable_names", &Py_GameWrapper::getGlobalVariableNames);

// Get a list of the events that have happened in the game up to this point
game_process.def("get_history", &Py_GameWrapper::getHistory, py::arg("purge")=true);

// Release resources for vulkan stuff
game_process.def("release", &Py_GameWrapper::release);

// Create an entity observer given a configuration of the entities and the custom variables that we want to view in the features
game_process.def("get_entity_observer", &Py_GameWrapper::createEntityObserver, py::arg("config")=py::dict());

py::class_<Py_EntityObserverWrapper, std::shared_ptr<Py_EntityObserverWrapper>> entityObserver(m, "EntityObserver");
entityObserver.def("observe", &Py_EntityObserverWrapper::observe);

py::class_<Py_StepPlayerWrapper, std::shared_ptr<Py_StepPlayerWrapper>> player(m, "Player");
player.def("step", &Py_StepPlayerWrapper::stepSingle);
player.def("step_multi", &Py_StepPlayerWrapper::stepMulti);
Expand Down
140 changes: 140 additions & 0 deletions bindings/wrapper/EntityObserverWrapper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#pragma once

#include <spdlog/spdlog.h>

#include "../../src/Griddly/Core/TurnBasedGameProcess.hpp"
#include "NumpyWrapper.cpp"
#include "StepPlayerWrapper.cpp"

namespace griddly {

class Py_EntityObserverWrapper {
public:
Py_EntityObserverWrapper(py::dict entityObserverConfig, std::shared_ptr<GDYFactory> gdyFactory, std::shared_ptr<GameProcess> gameProcess) : gameProcess_(gameProcess), gdyFactory_(gdyFactory) {
spdlog::debug("Created entity observer.");

if (entityObserverConfig.contains("VariableMapping")) {
entityVariableMapping_ = entityObserverConfig["variableMapping"].cast<std::unordered_map<std::string, std::vector<std::string>>>();
} else {
entityVariableMapping_ = gameProcess_->getGrid()->getObjectVariableMap();
}

for (auto entityVariables : entityVariableMapping_) {
for (auto variableName : entityVariables.second) {
spdlog::debug("Entity {0}, will include variable {1} in entity observations.", entityVariables.first, variableName);
}
}
}

py::dict observe(int playerId) {
py::dict observation;

auto entityObservationsAndIds = buildEntityObservations(playerId);
auto actionsAndMasks = buildEntityMasks(playerId);

observation["Entities"] = entityObservationsAndIds["Entities"];
observation["EntityIds"] = entityObservationsAndIds["EntityIds"];
observation["EntityMasks"] = actionsAndMasks;

return observation;
}

private:
// Build entity masks (for transformer architectures)
py::dict buildEntityMasks(int playerId) const {
std::map<std::string, std::vector<std::vector<int>>> entityMasks;
std::map<std::string, std::vector<size_t>> entityIds;

std::unordered_set<std::string> allAvailableActionNames;

py::dict entitiesAndMasks;

auto grid = gameProcess_->getGrid();

for (auto actionNamesAtLocation : gameProcess_->getAvailableActionNames(playerId)) {
auto location = actionNamesAtLocation.first;
auto actionNames = actionNamesAtLocation.second;

auto locationVec = glm::ivec2{location[0], location[1]};

for (auto actionName : actionNames) {
spdlog::debug("[{0}] available at location [{1}, {2}]", actionName, location.x, location.y);

auto actionInputsDefinitions = gdyFactory_->getActionInputsDefinitions();
std::vector<int> mask(actionInputsDefinitions[actionName].inputMappings.size() + 1);
mask[0] = 1; // NOP is always available

auto objectAtLocation = grid->getObject(location);
auto entityId = std::hash<std::shared_ptr<Object>>()(objectAtLocation);
auto actionIdsForName = gameProcess_->getAvailableActionIdsAtLocation(locationVec, actionName);

for (auto id : actionIdsForName) {
mask[id] = 1;
}

entityMasks[actionName].push_back(mask);
entityIds[actionName].push_back(entityId);

allAvailableActionNames.insert(actionName);
}
}

for (auto actionName : allAvailableActionNames) {
py::dict entitiesAndMasksForAction;
entitiesAndMasksForAction["EntityIds"] = entityIds[actionName];
entitiesAndMasksForAction["Masks"] = entityMasks[actionName];

entitiesAndMasks[actionName.c_str()] = entitiesAndMasksForAction;
}

return entitiesAndMasks;
}

// Build entity observations (for transformer architectures)
py::dict buildEntityObservations(int playerId) const {
py::dict entityObservationsAndIds;

std::map<std::string, std::vector<std::vector<float>>> entityObservations;
std::vector<size_t> entityIds;

auto grid = gameProcess_->getGrid();

for (auto object : grid->getObjects()) {
auto name = object->getObjectName();
auto location = object->getLocation();
auto orientation = object->getObjectOrientation().getUnitVector();
auto objectPlayerId = object->getPlayerId();
auto zIdx = object->getZIdx();

auto featureVariables = entityVariableMapping_.at(name);

auto numVariables = featureVariables.size();
auto numFeatures = 5 + numVariables;

std::vector<float> featureVector(numFeatures);
featureVector[0] = static_cast<float>(location[0]);
featureVector[1] = static_cast<float>(location[1]);
featureVector[2] = static_cast<float>(zIdx);
featureVector[3] = static_cast<float>(orientation[0] + 2 * orientation[1]);
featureVector[4] = static_cast<float>(objectPlayerId);
for (int32_t i = 0; i < numVariables; i++) {
auto variableValue = *object->getVariableValue(featureVariables[i]);
featureVector[5 + i] = static_cast<float>(variableValue);
}

entityObservations[name].push_back(featureVector);

entityIds.push_back(std::hash<std::shared_ptr<Object>>()(object));
}

entityObservationsAndIds["Entities"] = entityObservations;
entityObservationsAndIds["EntityIds"] = entityIds;

return entityObservationsAndIds;
}

std::unordered_map<std::string, std::vector<std::string>> entityVariableMapping_;
const std::shared_ptr<GDYFactory> gdyFactory_;
const std::shared_ptr<GameProcess> gameProcess_;
};
} // namespace griddly
57 changes: 20 additions & 37 deletions bindings/wrapper/GameWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "../../src/Griddly/Core/TurnBasedGameProcess.hpp"
#include "NumpyWrapper.cpp"
#include "StepPlayerWrapper.cpp"
#include "EntityObserverWrapper.cpp"

namespace griddly {

Expand Down Expand Up @@ -133,42 +134,6 @@ class Py_GameWrapper {
return valid_action_trees;
}

// Build entity masks (for transformer architectures)
py::dict buildEntityMasks(int playerId) const {

}

// Build entity observations (for transformer architectures)
py::dict buildEntityObservation(int playerId) const {

// entity_observation = defaultdict(list)
// entity_ids = set()
// for i, object in enumerate(self._current_g_state["Objects"]):
// name = object["Name"]
// location = object["Location"]
// variables = object["Variables"]

// # entity_ids.add(f'{location[0]},{location[1]}')
// # TODO: currently entity ids are a bit meaningless, but they have to be int or things break deeper down
// entity_ids.add(i)

// feature_vec = np.zeros(
// len(self._obs_space.entities[name].features), dtype=np.float32
// )
// feature_vec[0] = location[0]
// feature_vec[1] = location[1]
// feature_vec[2] = 0
// feature_vec[3] = orientation_feature(object["Orientation"])
// feature_vec[4] = object["PlayerId"]
// for i, variable_name in enumerate(self._env.variable_names):
// feature_vec[5 + i] = (
// variables[variable_name] if variable_name in variables else 0
// )

// entity_observation[name].append(feature_vec)

}

py::dict getAvailableActionNames(int playerId) const {
auto availableActionNames = gameProcess_->getAvailableActionNames(playerId);

Expand Down Expand Up @@ -393,6 +358,20 @@ class Py_GameWrapper {
return py_state;
}

std::vector<std::string> getGlobalVariableNames() const {
std::vector<std::string> globalVariableNames;
auto globalVariables = gameProcess_->getGrid()->getGlobalVariables();

for (auto globalVariableIt : globalVariables) {
globalVariableNames.push_back(globalVariableIt.first);
}
return globalVariableNames;
}

py::dict getObjectVariableMap() const {
return py::cast(gameProcess_->getGrid()->getObjectVariableMap());
}

py::dict getGlobalVariables(std::vector<std::string> variables) const {
py::dict py_globalVariables;
auto globalVariables = gameProcess_->getGrid()->getGlobalVariables();
Expand Down Expand Up @@ -455,7 +434,11 @@ class Py_GameWrapper {
}

std::vector<std::string> getObjectVariableNames() {
return gameProcess_->getGrid()->getObjectVariableNames();
return gameProcess_->getGrid()->getAllObjectVariableNames();
}

std::shared_ptr<Py_EntityObserverWrapper> createEntityObserver(py::dict entityObserverConfig) {
return std::make_shared<Py_EntityObserverWrapper>(Py_EntityObserverWrapper(entityObserverConfig, gdyFactory_, gameProcess_));
}

private:
Expand Down
8 changes: 7 additions & 1 deletion src/Griddly/Core/Grid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ const std::vector<std::string> Grid::getObjectNames() const {
return orderedNames;
}

const std::vector<std::string> Grid::getObjectVariableNames() const {
const std::vector<std::string> Grid::getAllObjectVariableNames() const {
auto namesCount = objectVariableIds_.size();
std::vector<std::string> orderedNames(namesCount);

Expand All @@ -476,12 +476,18 @@ const std::vector<std::string> Grid::getObjectVariableNames() const {
return orderedNames;
}

const std::unordered_map<std::string, std::vector<std::string>> Grid::getObjectVariableMap() const {
return objectVariableMap_;
}

void Grid::initObject(std::string objectName, std::vector<std::string> variableNames) {
objectIds_.insert({objectName, objectIds_.size()});

for (auto& variableName : variableNames) {
objectVariableIds_.insert({variableName, objectVariableIds_.size()});
}

objectVariableMap_[objectName] = variableNames;
}

std::unordered_map<uint32_t, std::shared_ptr<int32_t>> Grid::getObjectCounter(std::string objectName) {
Expand Down
8 changes: 7 additions & 1 deletion src/Griddly/Core/Grid.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,12 @@ class Grid : public std::enable_shared_from_this<Grid> {
/**
* Gets an ordered list of objectVariableNames
*/
virtual const std::vector<std::string> getObjectVariableNames() const;
virtual const std::vector<std::string> getAllObjectVariableNames() const;

/**
* Get a mapping of objects to their defined variables
*/
virtual const std::unordered_map<std::string, std::vector<std::string>> getObjectVariableMap() const;

/**
* Gets an ordered list of objectNames
Expand Down Expand Up @@ -180,6 +185,7 @@ class Grid : public std::enable_shared_from_this<Grid> {

std::unordered_map<std::string, uint32_t> objectIds_;
std::unordered_map<std::string, uint32_t> objectVariableIds_;
std::unordered_map<std::string, std::vector<std::string>> objectVariableMap_;
std::unordered_set<std::shared_ptr<Object>> objects_;
std::unordered_map<glm::ivec2, TileObjects> occupiedLocations_;
std::unordered_map<std::string, std::unordered_map<uint32_t, std::shared_ptr<int32_t>>> objectCounters_;
Expand Down

0 comments on commit e2db046

Please sign in to comment.