Skip to content

Commit

Permalink
Merge pull request #149 from Bam4d/entity_obs
Browse files Browse the repository at this point in the history
Entity obs
  • Loading branch information
Bam4d authored Dec 2, 2021
2 parents e146261 + e2db046 commit ad6ebda
Show file tree
Hide file tree
Showing 12 changed files with 200 additions and 14 deletions.
12 changes: 12 additions & 0 deletions bindings/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,24 @@ PYBIND11_MODULE(python_griddly, m) {
// Get list of possible variable names, ordered by ID
game_process.def("get_object_variable_names", &Py_GameWrapper::getObjectVariableNames);

// Get a mapping of objects to their variable names
game_process.def("get_object_variable_map", &Py_GameWrapper::getObjectVariableMap);

// Get a list of the global variable names
game_process.def("get_global_variable_names", &Py_GameWrapper::getGlobalVariableNames);

// Get a list of the events that have happened in the game up to this point
game_process.def("get_history", &Py_GameWrapper::getHistory, py::arg("purge")=true);

// Release resources for vulkan stuff
game_process.def("release", &Py_GameWrapper::release);

// Create an entity observer given a configuration of the entities and the custom variables that we want to view in the features
game_process.def("get_entity_observer", &Py_GameWrapper::createEntityObserver, py::arg("config")=py::dict());

py::class_<Py_EntityObserverWrapper, std::shared_ptr<Py_EntityObserverWrapper>> entityObserver(m, "EntityObserver");
entityObserver.def("observe", &Py_EntityObserverWrapper::observe);

py::class_<Py_StepPlayerWrapper, std::shared_ptr<Py_StepPlayerWrapper>> player(m, "Player");
player.def("step", &Py_StepPlayerWrapper::stepSingle);
player.def("step_multi", &Py_StepPlayerWrapper::stepMulti);
Expand Down
140 changes: 140 additions & 0 deletions bindings/wrapper/EntityObserverWrapper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#pragma once

#include <spdlog/spdlog.h>

#include "../../src/Griddly/Core/TurnBasedGameProcess.hpp"
#include "NumpyWrapper.cpp"
#include "StepPlayerWrapper.cpp"

namespace griddly {

class Py_EntityObserverWrapper {
public:
Py_EntityObserverWrapper(py::dict entityObserverConfig, std::shared_ptr<GDYFactory> gdyFactory, std::shared_ptr<GameProcess> gameProcess) : gameProcess_(gameProcess), gdyFactory_(gdyFactory) {
spdlog::debug("Created entity observer.");

if (entityObserverConfig.contains("VariableMapping")) {
entityVariableMapping_ = entityObserverConfig["variableMapping"].cast<std::unordered_map<std::string, std::vector<std::string>>>();
} else {
entityVariableMapping_ = gameProcess_->getGrid()->getObjectVariableMap();
}

for (auto entityVariables : entityVariableMapping_) {
for (auto variableName : entityVariables.second) {
spdlog::debug("Entity {0}, will include variable {1} in entity observations.", entityVariables.first, variableName);
}
}
}

py::dict observe(int playerId) {
py::dict observation;

auto entityObservationsAndIds = buildEntityObservations(playerId);
auto actionsAndMasks = buildEntityMasks(playerId);

observation["Entities"] = entityObservationsAndIds["Entities"];
observation["EntityIds"] = entityObservationsAndIds["EntityIds"];
observation["EntityMasks"] = actionsAndMasks;

return observation;
}

private:
// Build entity masks (for transformer architectures)
py::dict buildEntityMasks(int playerId) const {
std::map<std::string, std::vector<std::vector<int>>> entityMasks;
std::map<std::string, std::vector<size_t>> entityIds;

std::unordered_set<std::string> allAvailableActionNames;

py::dict entitiesAndMasks;

auto grid = gameProcess_->getGrid();

for (auto actionNamesAtLocation : gameProcess_->getAvailableActionNames(playerId)) {
auto location = actionNamesAtLocation.first;
auto actionNames = actionNamesAtLocation.second;

auto locationVec = glm::ivec2{location[0], location[1]};

for (auto actionName : actionNames) {
spdlog::debug("[{0}] available at location [{1}, {2}]", actionName, location.x, location.y);

auto actionInputsDefinitions = gdyFactory_->getActionInputsDefinitions();
std::vector<int> mask(actionInputsDefinitions[actionName].inputMappings.size() + 1);
mask[0] = 1; // NOP is always available

auto objectAtLocation = grid->getObject(location);
auto entityId = std::hash<std::shared_ptr<Object>>()(objectAtLocation);
auto actionIdsForName = gameProcess_->getAvailableActionIdsAtLocation(locationVec, actionName);

for (auto id : actionIdsForName) {
mask[id] = 1;
}

entityMasks[actionName].push_back(mask);
entityIds[actionName].push_back(entityId);

allAvailableActionNames.insert(actionName);
}
}

for (auto actionName : allAvailableActionNames) {
py::dict entitiesAndMasksForAction;
entitiesAndMasksForAction["EntityIds"] = entityIds[actionName];
entitiesAndMasksForAction["Masks"] = entityMasks[actionName];

entitiesAndMasks[actionName.c_str()] = entitiesAndMasksForAction;
}

return entitiesAndMasks;
}

// Build entity observations (for transformer architectures)
py::dict buildEntityObservations(int playerId) const {
py::dict entityObservationsAndIds;

std::map<std::string, std::vector<std::vector<float>>> entityObservations;
std::vector<size_t> entityIds;

auto grid = gameProcess_->getGrid();

for (auto object : grid->getObjects()) {
auto name = object->getObjectName();
auto location = object->getLocation();
auto orientation = object->getObjectOrientation().getUnitVector();
auto objectPlayerId = object->getPlayerId();
auto zIdx = object->getZIdx();

auto featureVariables = entityVariableMapping_.at(name);

auto numVariables = featureVariables.size();
auto numFeatures = 5 + numVariables;

std::vector<float> featureVector(numFeatures);
featureVector[0] = static_cast<float>(location[0]);
featureVector[1] = static_cast<float>(location[1]);
featureVector[2] = static_cast<float>(zIdx);
featureVector[3] = static_cast<float>(orientation[0] + 2 * orientation[1]);
featureVector[4] = static_cast<float>(objectPlayerId);
for (int32_t i = 0; i < numVariables; i++) {
auto variableValue = *object->getVariableValue(featureVariables[i]);
featureVector[5 + i] = static_cast<float>(variableValue);
}

entityObservations[name].push_back(featureVector);

entityIds.push_back(std::hash<std::shared_ptr<Object>>()(object));
}

entityObservationsAndIds["Entities"] = entityObservations;
entityObservationsAndIds["EntityIds"] = entityIds;

return entityObservationsAndIds;
}

std::unordered_map<std::string, std::vector<std::string>> entityVariableMapping_;
const std::shared_ptr<GDYFactory> gdyFactory_;
const std::shared_ptr<GameProcess> gameProcess_;
};
} // namespace griddly
21 changes: 20 additions & 1 deletion bindings/wrapper/GameWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "../../src/Griddly/Core/TurnBasedGameProcess.hpp"
#include "NumpyWrapper.cpp"
#include "StepPlayerWrapper.cpp"
#include "EntityObserverWrapper.cpp"

namespace griddly {

Expand Down Expand Up @@ -357,6 +358,20 @@ class Py_GameWrapper {
return py_state;
}

std::vector<std::string> getGlobalVariableNames() const {
std::vector<std::string> globalVariableNames;
auto globalVariables = gameProcess_->getGrid()->getGlobalVariables();

for (auto globalVariableIt : globalVariables) {
globalVariableNames.push_back(globalVariableIt.first);
}
return globalVariableNames;
}

py::dict getObjectVariableMap() const {
return py::cast(gameProcess_->getGrid()->getObjectVariableMap());
}

py::dict getGlobalVariables(std::vector<std::string> variables) const {
py::dict py_globalVariables;
auto globalVariables = gameProcess_->getGrid()->getGlobalVariables();
Expand Down Expand Up @@ -419,7 +434,11 @@ class Py_GameWrapper {
}

std::vector<std::string> getObjectVariableNames() {
return gameProcess_->getGrid()->getObjectVariableNames();
return gameProcess_->getGrid()->getAllObjectVariableNames();
}

std::shared_ptr<Py_EntityObserverWrapper> createEntityObserver(py::dict entityObserverConfig) {
return std::make_shared<Py_EntityObserverWrapper>(Py_EntityObserverWrapper(entityObserverConfig, gdyFactory_, gameProcess_));
}

private:
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/Proximity Triggers/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Step 1 - Create the lava, water and spider objects
**********************************************
Step 3 - Set up the proximity trigger for lava
Step 2 - Set up the proximity trigger for lava
**********************************************

For the lava, we want the spider to be catch fire instantly if it is next to the lava, but have a small chance of catching fire if it is close, but not right next to it.
Expand Down Expand Up @@ -113,7 +113,7 @@ Additionally you can set a ``Probability`` for an action to set how likely the a
***********************************************
Step 4 - Set up the proximity trigger for water
Step 3 - Set up the proximity trigger for water
***********************************************


Expand Down
3 changes: 2 additions & 1 deletion src/Griddly/Core/GDY/Objects/ObjectGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ std::unordered_map<std::string, float> ObjectGenerator::getActionProbabilities()
return actionProbabilities_;
}

std::unordered_map<std::string, std::shared_ptr<ObjectDefinition>> ObjectGenerator::getObjectDefinitions() const {
// The order of object definitions needs to be consistent across levels and maps, so we have to make sure this is ordered here.
std::map<std::string, std::shared_ptr<ObjectDefinition>> ObjectGenerator::getObjectDefinitions() const {
return objectDefinitions_;
}

Expand Down
6 changes: 4 additions & 2 deletions src/Griddly/Core/GDY/Objects/ObjectGenerator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,13 @@ class ObjectGenerator : public std::enable_shared_from_this<ObjectGenerator> {
virtual std::unordered_map<std::string, ActionTriggerDefinition> getActionTriggerDefinitions() const;
virtual std::unordered_map<std::string, float> getActionProbabilities() const;

virtual std::unordered_map<std::string, std::shared_ptr<ObjectDefinition>> getObjectDefinitions() const;
virtual std::map<std::string, std::shared_ptr<ObjectDefinition>> getObjectDefinitions() const;

private:
std::unordered_map<char, std::string> objectChars_;
std::unordered_map<std::string, std::shared_ptr<ObjectDefinition>> objectDefinitions_;

// This needs to be ordered, so object types are always in a consistent order across multiple instantiations of games.
std::map<std::string, std::shared_ptr<ObjectDefinition>> objectDefinitions_;

std::string avatarObject_;

Expand Down
8 changes: 7 additions & 1 deletion src/Griddly/Core/Grid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ const std::vector<std::string> Grid::getObjectNames() const {
return orderedNames;
}

const std::vector<std::string> Grid::getObjectVariableNames() const {
const std::vector<std::string> Grid::getAllObjectVariableNames() const {
auto namesCount = objectVariableIds_.size();
std::vector<std::string> orderedNames(namesCount);

Expand All @@ -476,12 +476,18 @@ const std::vector<std::string> Grid::getObjectVariableNames() const {
return orderedNames;
}

const std::unordered_map<std::string, std::vector<std::string>> Grid::getObjectVariableMap() const {
return objectVariableMap_;
}

void Grid::initObject(std::string objectName, std::vector<std::string> variableNames) {
objectIds_.insert({objectName, objectIds_.size()});

for (auto& variableName : variableNames) {
objectVariableIds_.insert({variableName, objectVariableIds_.size()});
}

objectVariableMap_[objectName] = variableNames;
}

std::unordered_map<uint32_t, std::shared_ptr<int32_t>> Grid::getObjectCounter(std::string objectName) {
Expand Down
8 changes: 7 additions & 1 deletion src/Griddly/Core/Grid.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,12 @@ class Grid : public std::enable_shared_from_this<Grid> {
/**
* Gets an ordered list of objectVariableNames
*/
virtual const std::vector<std::string> getObjectVariableNames() const;
virtual const std::vector<std::string> getAllObjectVariableNames() const;

/**
* Get a mapping of objects to their defined variables
*/
virtual const std::unordered_map<std::string, std::vector<std::string>> getObjectVariableMap() const;

/**
* Gets an ordered list of objectNames
Expand Down Expand Up @@ -180,6 +185,7 @@ class Grid : public std::enable_shared_from_this<Grid> {

std::unordered_map<std::string, uint32_t> objectIds_;
std::unordered_map<std::string, uint32_t> objectVariableIds_;
std::unordered_map<std::string, std::vector<std::string>> objectVariableMap_;
std::unordered_set<std::shared_ptr<Object>> objects_;
std::unordered_map<glm::ivec2, TileObjects> occupiedLocations_;
std::unordered_map<std::string, std::unordered_map<uint32_t, std::shared_ptr<int32_t>>> objectCounters_;
Expand Down
4 changes: 2 additions & 2 deletions tests/src/Griddly/Core/GDY/GDYFactoryTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -817,8 +817,8 @@ TEST(GDYFactoryTest, loadAction_src_missing) {
testBehaviourDefinition(yamlString, expectedBehaviourDefinition, true);
}

std::unordered_map<std::string, std::shared_ptr<ObjectDefinition>> mockObjectDefs(std::vector<std::string> objectNames) {
std::unordered_map<std::string, std::shared_ptr<ObjectDefinition>> mockObjectDefinitions;
std::map<std::string, std::shared_ptr<ObjectDefinition>> mockObjectDefs(std::vector<std::string> objectNames) {
std::map<std::string, std::shared_ptr<ObjectDefinition>> mockObjectDefinitions;
for (auto name : objectNames) {
ObjectDefinition objectDefinition = {
name};
Expand Down
2 changes: 1 addition & 1 deletion tests/src/Griddly/Core/GameProcessTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@ TEST(GameProcessTest, clone) {
EXPECT_CALL(*mockGridPtr, getGlobalVariables())
.WillRepeatedly(ReturnRef(globalVariables));

std::unordered_map<std::string, std::shared_ptr<ObjectDefinition>> mockObjectDefinitions = {
std::map<std::string, std::shared_ptr<ObjectDefinition>> mockObjectDefinitions = {
{"object1", std::make_shared<ObjectDefinition>(ObjectDefinition{"object1", 'a'})},
{"object2", std::make_shared<ObjectDefinition>(ObjectDefinition{"object2", 'b'})},
{"object3", std::make_shared<ObjectDefinition>(ObjectDefinition{"object3", 'c'})},
Expand Down
4 changes: 2 additions & 2 deletions tests/src/Griddly/Core/LevelGenerator/MapReaderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ using ::testing::ReturnRef;

namespace griddly {

std::unordered_map<std::string, std::shared_ptr<ObjectDefinition>> mockObjectDefinitions(std::vector<std::string> objectNames) {
std::unordered_map<std::string, std::shared_ptr<ObjectDefinition>> mockObjectDefinitions;
std::map<std::string, std::shared_ptr<ObjectDefinition>> mockObjectDefinitions(std::vector<std::string> objectNames) {
std::map<std::string, std::shared_ptr<ObjectDefinition>> mockObjectDefinitions;
for (auto name : objectNames) {
ObjectDefinition objectDefinition = {
name};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class MockObjectGenerator : public ObjectGenerator {
MOCK_METHOD(std::shared_ptr<Object>, cloneInstance, (std::shared_ptr<Object>, std::shared_ptr<Grid> grid), ());

MOCK_METHOD(std::string&, getObjectNameFromMapChar, (char character), ());
MOCK_METHOD((std::unordered_map<std::string, std::shared_ptr<ObjectDefinition>>), getObjectDefinitions, (), (const));
MOCK_METHOD((std::map<std::string, std::shared_ptr<ObjectDefinition>>), getObjectDefinitions, (), (const));

MOCK_METHOD(void, setAvatarObject, (std::string objectName), ());
};
Expand Down

0 comments on commit ad6ebda

Please sign in to comment.