Skip to content

Commit

Permalink
Entity observer fixes (#191)
Browse files Browse the repository at this point in the history
* few fixes for entity observers

* fixing python tests
  • Loading branch information
Bam4d authored Apr 29, 2022
1 parent 593fdfc commit 7da7498
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 107 deletions.
34 changes: 14 additions & 20 deletions python/tests/entity_observer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,6 @@ def test_entity_observations(test_name):
env = build_test_env(test_name, "tests/gdy/test_entity_observer.yaml", global_observer_type=gd.ObserverType.NONE,
player_observer_type=gd.ObserverType.ENTITY)

global_variables = env.game.get_global_variable_names()
object_variable_map = env.game.get_object_variable_map()

assert global_variables == ["_steps", "test_global_variable"]
assert object_variable_map["entity_1"] == ["entity_1_variable"]
assert object_variable_map["entity_2"] == ["entity_2_variable"]

obs, reward, done, info = env.step(0)
entities = obs["Entities"]
entity_ids = obs["Ids"]
Expand All @@ -44,15 +37,16 @@ def test_entity_observations(test_name):
assert len(entity_1s) == 1
assert len(entity_1_ids) == 1
assert len(entity_1s[0]) == 3
assert entity_1s[0][0] == 2
assert entity_1s[0][1] == 2
assert entity_1s[0][0] == 1
assert entity_1s[0][1] == 1

entity_2s = entities["entity_2"]
entity_2_ids = entity_ids["entity_2"]
assert len(entity_2s) == 2
assert len(entity_2_ids) == 2
assert len(entity_2s) == 1
assert len(entity_2_ids) == 1
assert len(entity_2s[0]) == 3
assert len(entity_2s[1]) == 3
assert entity_2s[0][0] == 1
assert entity_2s[0][1] == 2

actor_masks = obs["ActorMasks"]
actor_ids = obs["ActorIds"]
Expand All @@ -75,29 +69,27 @@ def test_entity_observations_multi_agent(test_name):
global_observer_type=gd.ObserverType.NONE,
player_observer_type=["EntityObserverOne", "EntityObserverTwo"])

global_variables = env.game.get_global_variable_names()
object_variable_map = env.game.get_object_variable_map()

assert global_variables == ["_steps", "test_global_variable"]
assert object_variable_map["entity_1"] == ["entity_1_variable"]
assert object_variable_map["entity_2"] == ["entity_2_variable"]

player_1_space = env.player_observation_space[0].features
player_2_space = env.player_observation_space[1].features

assert player_1_space["entity_1"] == ["x", "y", "z", "playerId", "entity_1_variable"]
assert player_1_space["entity_2"] == ["x", "y", "z", "ox", "oy", "entity_2_variable"]
assert player_1_space["__global__"] == ["test_perplayer_variable", "test_global_variable"]

assert player_2_space["entity_1"] == ["x", "y", "z"]
assert player_2_space["entity_2"] == ["x", "y", "z"]
assert player_2_space["__global__"] == ["test_global_variable"]

obs, reward, done, info = env.step([0, 0])

player_1_obs = obs[0]

player_1_entities = player_1_obs["Entities"]
player_1_entity_ids = player_1_obs["Ids"]
player_1_locations = player_1_obs["Locations"]

p1_globals = player_1_entities["__global__"]
assert np.all(p1_globals[0] == [12.0, 0.0])

p1_entity_1s = player_1_entities["entity_1"]
p1_entity_1_ids = player_1_entity_ids["entity_1"]
Expand All @@ -117,7 +109,9 @@ def test_entity_observations_multi_agent(test_name):

player_2_entities = player_2_obs["Entities"]
player_2_entity_ids = player_2_obs["Ids"]
player_2_locations = player_1_obs["Locations"]

p2_globals = player_2_entities["__global__"]
assert np.all(p2_globals[0] == [0])

p2_entity_1s = player_2_entities["entity_1"]
p2_entity_1_ids = player_2_entity_ids["entity_1"]
Expand Down
4 changes: 4 additions & 0 deletions python/tests/gdy/test_entity_observer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ Environment:
InitialValue: 12
Player:
AvatarObject: entity_1 # The player can only control a single avatar in the game
Observer:
Width: 3
Height: 3
TrackAvatar: true
Levels:
- |
. . . . .
Expand Down
8 changes: 7 additions & 1 deletion python/tests/gdy/test_entity_observer_multi_agent.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,29 @@ Environment:
Name: Test
Description: An environment only used for testing
Variables:
- Name: test_global_variable
- Name: test_perplayer_variable
InitialValue: 12
PerPlayer: true
- Name: test_global_variable
InitialValue: 0
Observers:
EntityObserverOne:
Type: ENTITY
IncludeMasks: true
IncludePlayerId: ["entity_1"]
IncludeRotation: ["entity_2"]
GlobalVariableMapping: ["test_perplayer_variable", "test_global_variable"]
VariableMapping:
entity_1: ["entity_1_variable"]
entity_2: ["entity_2_variable"]

EntityObserverTwo:
Type: ENTITY
GlobalVariableMapping: ["test_global_variable"]

Player:
Count: 2

Levels:
- |
. . E2 . .
Expand Down
16 changes: 16 additions & 0 deletions src/Griddly/Core/GDY/GDYFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,22 @@ EntityObserverConfig GDYFactory::parseNamedEntityObserverConfig(std::string obse
// Used to generate masks for entity obervers
config.actionInputsDefinitions = getActionInputsDefinitions();

auto globalVariableMappingNode = observerConfigNode["GlobalVariableMapping"];

if (globalVariableMappingNode.IsDefined()) {
const auto& globalEntityVariables = singleOrListNodeToList(globalVariableMappingNode);

for (const auto& globalEntityVariable : globalEntityVariables) {
if (globalVariableDefinitions_.find(globalEntityVariable) == globalVariableDefinitions_.end()) {
std::string error = fmt::format("No global variable with name {0} in GlobalVariableMapping feature configuration.", globalEntityVariable);
spdlog::error(error);
throw std::invalid_argument(error);
}
}

config.globalVariableMapping = globalEntityVariables;
}

auto variableMappingNodes = observerConfigNode["VariableMapping"];

if (variableMappingNodes.IsDefined()) {
Expand Down
81 changes: 0 additions & 81 deletions src/Griddly/Core/GDY/Objects/Object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,28 +166,6 @@ BehaviourCondition Object::resolveOR(const std::vector<BehaviourCondition> &cond
};
}

// BehaviourCondition Object::instantiateCondition(std::string &commandName, YAML::Node &conditionNode) const {
// if (commandName == "eq") {
// return resolveConditionArguments([](int32_t a, int32_t b) { return a == b; }, conditionNode);
// } else if (commandName == "gt") {
// return resolveConditionArguments([](int32_t a, int32_t b) { return a > b; }, conditionNode);
// } else if (commandName == "gte") {
// return resolveConditionArguments([](int32_t a, int32_t b) { return a >= b; }, conditionNode);
// } else if (commandName == "lt") {
// return resolveConditionArguments([](int32_t a, int32_t b) { return a < b; }, conditionNode);
// } else if (commandName == "lte") {
// return resolveConditionArguments([](int32_t a, int32_t b) { return a <= b; }, conditionNode);
// } else if (commandName == "neq") {
// return resolveConditionArguments([](int32_t a, int32_t b) { return a != b; }, conditionNode);
// } else if (commandName == "and") {
// return processConditions(conditionNode, false, LogicOp::AND);
// } else if (commandName == "or") {
// return processConditions(conditionNode, false, LogicOp::OR);
// } else {
// throw std::invalid_argument(fmt::format("Unknown or badly defined condition command {0}.", commandName));
// }
// }

BehaviourFunction Object::instantiateConditionalBehaviour(const std::string &commandName, CommandArguments &commandArguments, CommandList &subCommands) {
if (subCommands.size() == 0) {
return instantiateBehaviour(commandName, commandArguments);
Expand Down Expand Up @@ -233,65 +211,6 @@ BehaviourFunction Object::instantiateConditionalBehaviour(const std::string &com
};
}

// BehaviourCondition Object::processConditions(YAML::Node &conditionNodeList, bool isTopLevel, LogicOp op) const {
// // We should have a single item and not a list
// if (!conditionNodeList.IsDefined()) {
// auto line = conditionNodeList.Mark().line;
// auto errorString = fmt::format("Parse error line {0}. If statement is missing Conditions", line);
// spdlog::error(errorString);
// throw std::invalid_argument(errorString);
// }

// if (conditionNodeList.IsMap()) {
// if (conditionNodeList.size() != 1) {
// auto line = conditionNodeList.Mark().line;
// auto errorString = fmt::format("Parse error line {0}. Conditions must contain a single top-level condition", line);
// spdlog::error(errorString);
// throw std::invalid_argument(errorString);
// }
// auto conditionNode = conditionNodeList.begin();
// auto commandName = conditionNode->first.as<std::string>();
// return instantiateCondition(commandName, conditionNode->second);

// } else if (conditionNodeList.IsSequence()) {
// std::vector<BehaviourCondition> conditionList;
// for (auto &&subConditionNode : conditionNodeList) {
// auto validatedNode = validateCommandPairNode(subConditionNode);
// auto commandName = validatedNode->first.as<std::string>();
// conditionList.push_back(instantiateCondition(commandName, validatedNode->second));
// }
// switch (op) {
// case LogicOp::AND: {
// return resolveAnd()
// };
// }
// break;
// case LogicOp::OR: {
// return [conditionList](const std::shared_ptr<Action> &action) -> bool {
// for (const auto &condition : conditionList) {
// if (condition(action)) {
// return true;
// }
// }
// return false;
// };
// }
// default: {
// auto line = conditionNodeList.Mark().line;
// auto errorString = fmt::format("Parse error line {0}. A sequence of conditions must be within an AND or an OR operator.", line);
// spdlog::error(errorString);
// throw std::invalid_argument(errorString);
// }
// }
// }
// else {
// auto line = conditionNodeList.Mark().line;
// auto errorString = fmt::format("Conditions must be a map or a list", line);
// spdlog::error(errorString);
// throw std::invalid_argument(errorString);
// }
// } // namespace griddly

/**
* @brief executes a list of behaviour functions and accumulates the rewards
*
Expand Down
26 changes: 24 additions & 2 deletions src/Griddly/Core/Observers/EntityObserver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ void EntityObserver::init(EntityObserverConfig& config) {
}
}

if(config.globalVariableMapping.size() > 0) {
entityFeatures_.insert({"__global__", config.globalVariableMapping});
}

// Precalclate offsets for entity configurations
for (const auto& objectName : config_.objectNames) {

Expand Down Expand Up @@ -100,7 +104,6 @@ glm::ivec2 EntityObserver::resolveLocation(const glm::ivec2& location) const {
auto resolvedLocation = location - glm::ivec2{observableGrid.left, observableGrid.bottom};

if (doTrackAvatar_) {
const auto& avatarLocation = avatarObject_->getLocation();
const auto& avatarDirection = avatarObject_->getObjectOrientation().getDirection();

if (config_.rotateWithAvatar) {
Expand Down Expand Up @@ -132,6 +135,25 @@ void EntityObserver::buildObservations(EntityObservations& entityObservations) {

const auto& observableGrid = getObservableGrid();

// Build global entity
if(config_.globalVariableMapping.size() > 0) {
std::vector<float> globalFeatureVector(config_.globalVariableMapping.size());
uint32_t featureIdx = 0;
const auto& globalVariables = grid_->getGlobalVariables();
for(const auto& globalVariableName : config_.globalVariableMapping) {

const auto& globalVariableValues = globalVariables.at(globalVariableName);
if(globalVariableValues.size() == 1) {
globalFeatureVector[featureIdx++] = static_cast<float>(*globalVariableValues.at(0));
} else {
globalFeatureVector[featureIdx++] = static_cast<float>(*globalVariableValues.at(config_.playerId));
}
}

entityObservations.observations["__global__"].push_back(globalFeatureVector);
entityObservations.ids["__global__"].push_back(0);
}

for (const auto& object : grid_->getObjects()) {
const auto& name = object->getObjectName();
auto location = object->getLocation();
Expand Down Expand Up @@ -240,7 +262,7 @@ std::unordered_map<glm::ivec2, std::unordered_set<std::string>> EntityObserver::
std::vector<uint32_t> EntityObserver::getAvailableActionIdsAtLocation(glm::ivec2 location, std::string actionName) const {
auto srcObject = grid_->getObject(location);

spdlog::debug("Getting available actionIds for action [{}] at location [{0},{1}]", actionName, location.x, location.y);
spdlog::debug("Getting available actionIds for action [{0}] at location [{1},{2}]", actionName, location.x, location.y);

std::vector<uint32_t> availableActionIds{};
if (srcObject) {
Expand Down
1 change: 1 addition & 0 deletions src/Griddly/Core/Observers/EntityObserver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct EntityObserverConfig : public ObserverConfig {
std::unordered_set<std::string> includePlayerId{};
std::unordered_set<std::string> includeRotation{};
std::unordered_map<std::string, std::vector<std::string>> entityVariableMapping{};
std::vector<std::string> globalVariableMapping{};
std::unordered_map<std::string, ActionInputsDefinition> actionInputsDefinitions{};
std::vector<std::string> objectNames{};
bool includeMasks = false;
Expand Down
2 changes: 2 additions & 0 deletions src/Griddly/Core/Observers/Observer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ void Observer::reset() {

doTrackAvatar_ = avatarObject_ != nullptr && config_.trackAvatar;

spdlog::debug("Tracking avatar: {0}", doTrackAvatar_ ? "Yes":"No");

// if the observer is "READY", then it has already been initialized once, so keep it in the ready state, we're just resetting it.
observerState_ = observerState_ == ObserverState::READY ? ObserverState::READY : ObserverState::RESET;
}
Expand Down
11 changes: 8 additions & 3 deletions tests/src/Griddly/Core/Observers/EntityObserverTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ void runEntityObserverTest(EntityObserverConfig observerConfig,
entityObserver->reset();

const auto& updateEntityObservations = entityObserver->update();
const auto& entityVariableMapping = entityObserver->getEntityVariableMapping();

ASSERT_EQ(updateEntityObservations.observations.size(), expectedEntityObservervations.observations.size());

Expand All @@ -90,8 +89,10 @@ void runEntityObserverTest(EntityObserverConfig observerConfig,
auto expectedEntityLocation = glm::ivec2{updateObservations[i][0], updateObservations[i][1]};
auto updateId = updateEntityObservations.ids.at(entityName)[i];

ASSERT_EQ(updateEntityObservations.locations.at(updateId)[0], expectedEntityLocation[0]);
ASSERT_EQ(updateEntityObservations.locations.at(updateId)[1], expectedEntityLocation[1]);
if (entityName != "__global__") {
ASSERT_EQ(updateEntityObservations.locations.at(updateId)[0], expectedEntityLocation[0]);
ASSERT_EQ(updateEntityObservations.locations.at(updateId)[1], expectedEntityLocation[1]);
}
}
}

Expand Down Expand Up @@ -1061,6 +1062,8 @@ TEST(EntityObserverTest, entityVariableMapping) {

config.includeRotation = {"mo3"};

config.globalVariableMapping = {"lightingR", "lightingG", "lightingB"};

config.entityVariableMapping = {
{"mo2", {"health", "max_health"}},
{"mo3", {"health", "max_health"}}};
Expand All @@ -1069,6 +1072,8 @@ TEST(EntityObserverTest, entityVariableMapping) {

// "x", "y", "z", "ox", "oy", "player_id"
expectedEntityObservervations.observations = {
{"__global__",
{{50, 100, 100}}},
{"avatar",
{{2, 2, 0}}},
{"mo1",
Expand Down

0 comments on commit 7da7498

Please sign in to comment.