From 7da7498b1f920c4791942c3b73c1b3d67eb04075 Mon Sep 17 00:00:00 2001 From: Chris Bamford Date: Fri, 29 Apr 2022 09:57:09 +0100 Subject: [PATCH] Entity observer fixes (#191) * few fixes for entity observers * fixing python tests --- python/tests/entity_observer_test.py | 34 ++++---- python/tests/gdy/test_entity_observer.yaml | 4 + .../gdy/test_entity_observer_multi_agent.yaml | 8 +- src/Griddly/Core/GDY/GDYFactory.cpp | 16 ++++ src/Griddly/Core/GDY/Objects/Object.cpp | 81 ------------------- src/Griddly/Core/Observers/EntityObserver.cpp | 26 +++++- src/Griddly/Core/Observers/EntityObserver.hpp | 1 + src/Griddly/Core/Observers/Observer.cpp | 2 + .../Core/Observers/EntityObserverTest.cpp | 11 ++- 9 files changed, 76 insertions(+), 107 deletions(-) diff --git a/python/tests/entity_observer_test.py b/python/tests/entity_observer_test.py index 9f03db510..b15d0e740 100644 --- a/python/tests/entity_observer_test.py +++ b/python/tests/entity_observer_test.py @@ -27,13 +27,6 @@ def test_entity_observations(test_name): env = build_test_env(test_name, "tests/gdy/test_entity_observer.yaml", global_observer_type=gd.ObserverType.NONE, player_observer_type=gd.ObserverType.ENTITY) - global_variables = env.game.get_global_variable_names() - object_variable_map = env.game.get_object_variable_map() - - assert global_variables == ["_steps", "test_global_variable"] - assert object_variable_map["entity_1"] == ["entity_1_variable"] - assert object_variable_map["entity_2"] == ["entity_2_variable"] - obs, reward, done, info = env.step(0) entities = obs["Entities"] entity_ids = obs["Ids"] @@ -44,15 +37,16 @@ def test_entity_observations(test_name): assert len(entity_1s) == 1 assert len(entity_1_ids) == 1 assert len(entity_1s[0]) == 3 - assert entity_1s[0][0] == 2 - assert entity_1s[0][1] == 2 + assert entity_1s[0][0] == 1 + assert entity_1s[0][1] == 1 entity_2s = entities["entity_2"] entity_2_ids = entity_ids["entity_2"] - assert len(entity_2s) == 2 - assert len(entity_2_ids) == 2 + assert len(entity_2s) == 1 + assert len(entity_2_ids) == 1 assert len(entity_2s[0]) == 3 - assert len(entity_2s[1]) == 3 + assert entity_2s[0][0] == 1 + assert entity_2s[0][1] == 2 actor_masks = obs["ActorMasks"] actor_ids = obs["ActorIds"] @@ -75,21 +69,17 @@ def test_entity_observations_multi_agent(test_name): global_observer_type=gd.ObserverType.NONE, player_observer_type=["EntityObserverOne", "EntityObserverTwo"]) - global_variables = env.game.get_global_variable_names() - object_variable_map = env.game.get_object_variable_map() - - assert global_variables == ["_steps", "test_global_variable"] - assert object_variable_map["entity_1"] == ["entity_1_variable"] - assert object_variable_map["entity_2"] == ["entity_2_variable"] player_1_space = env.player_observation_space[0].features player_2_space = env.player_observation_space[1].features assert player_1_space["entity_1"] == ["x", "y", "z", "playerId", "entity_1_variable"] assert player_1_space["entity_2"] == ["x", "y", "z", "ox", "oy", "entity_2_variable"] + assert player_1_space["__global__"] == ["test_perplayer_variable", "test_global_variable"] assert player_2_space["entity_1"] == ["x", "y", "z"] assert player_2_space["entity_2"] == ["x", "y", "z"] + assert player_2_space["__global__"] == ["test_global_variable"] obs, reward, done, info = env.step([0, 0]) @@ -97,7 +87,9 @@ def test_entity_observations_multi_agent(test_name): player_1_entities = player_1_obs["Entities"] player_1_entity_ids = player_1_obs["Ids"] - player_1_locations = player_1_obs["Locations"] + + p1_globals = player_1_entities["__global__"] + assert np.all(p1_globals[0] == [12.0, 0.0]) p1_entity_1s = player_1_entities["entity_1"] p1_entity_1_ids = player_1_entity_ids["entity_1"] @@ -117,7 +109,9 @@ def test_entity_observations_multi_agent(test_name): player_2_entities = player_2_obs["Entities"] player_2_entity_ids = player_2_obs["Ids"] - player_2_locations = player_1_obs["Locations"] + + p2_globals = player_2_entities["__global__"] + assert np.all(p2_globals[0] == [0]) p2_entity_1s = player_2_entities["entity_1"] p2_entity_1_ids = player_2_entity_ids["entity_1"] diff --git a/python/tests/gdy/test_entity_observer.yaml b/python/tests/gdy/test_entity_observer.yaml index 797ca844b..e81756fa6 100644 --- a/python/tests/gdy/test_entity_observer.yaml +++ b/python/tests/gdy/test_entity_observer.yaml @@ -10,6 +10,10 @@ Environment: InitialValue: 12 Player: AvatarObject: entity_1 # The player can only control a single avatar in the game + Observer: + Width: 3 + Height: 3 + TrackAvatar: true Levels: - | . . . . . diff --git a/python/tests/gdy/test_entity_observer_multi_agent.yaml b/python/tests/gdy/test_entity_observer_multi_agent.yaml index 29913e08f..787f8b845 100644 --- a/python/tests/gdy/test_entity_observer_multi_agent.yaml +++ b/python/tests/gdy/test_entity_observer_multi_agent.yaml @@ -3,23 +3,29 @@ Environment: Name: Test Description: An environment only used for testing Variables: - - Name: test_global_variable + - Name: test_perplayer_variable InitialValue: 12 + PerPlayer: true + - Name: test_global_variable + InitialValue: 0 Observers: EntityObserverOne: Type: ENTITY IncludeMasks: true IncludePlayerId: ["entity_1"] IncludeRotation: ["entity_2"] + GlobalVariableMapping: ["test_perplayer_variable", "test_global_variable"] VariableMapping: entity_1: ["entity_1_variable"] entity_2: ["entity_2_variable"] EntityObserverTwo: Type: ENTITY + GlobalVariableMapping: ["test_global_variable"] Player: Count: 2 + Levels: - | . . E2 . . diff --git a/src/Griddly/Core/GDY/GDYFactory.cpp b/src/Griddly/Core/GDY/GDYFactory.cpp index 80022c209..1981d76ed 100644 --- a/src/Griddly/Core/GDY/GDYFactory.cpp +++ b/src/Griddly/Core/GDY/GDYFactory.cpp @@ -233,6 +233,22 @@ EntityObserverConfig GDYFactory::parseNamedEntityObserverConfig(std::string obse // Used to generate masks for entity obervers config.actionInputsDefinitions = getActionInputsDefinitions(); + auto globalVariableMappingNode = observerConfigNode["GlobalVariableMapping"]; + + if (globalVariableMappingNode.IsDefined()) { + const auto& globalEntityVariables = singleOrListNodeToList(globalVariableMappingNode); + + for (const auto& globalEntityVariable : globalEntityVariables) { + if (globalVariableDefinitions_.find(globalEntityVariable) == globalVariableDefinitions_.end()) { + std::string error = fmt::format("No global variable with name {0} in GlobalVariableMapping feature configuration.", globalEntityVariable); + spdlog::error(error); + throw std::invalid_argument(error); + } + } + + config.globalVariableMapping = globalEntityVariables; + } + auto variableMappingNodes = observerConfigNode["VariableMapping"]; if (variableMappingNodes.IsDefined()) { diff --git a/src/Griddly/Core/GDY/Objects/Object.cpp b/src/Griddly/Core/GDY/Objects/Object.cpp index 003e3d75b..775e40ce7 100644 --- a/src/Griddly/Core/GDY/Objects/Object.cpp +++ b/src/Griddly/Core/GDY/Objects/Object.cpp @@ -166,28 +166,6 @@ BehaviourCondition Object::resolveOR(const std::vector &cond }; } -// BehaviourCondition Object::instantiateCondition(std::string &commandName, YAML::Node &conditionNode) const { -// if (commandName == "eq") { -// return resolveConditionArguments([](int32_t a, int32_t b) { return a == b; }, conditionNode); -// } else if (commandName == "gt") { -// return resolveConditionArguments([](int32_t a, int32_t b) { return a > b; }, conditionNode); -// } else if (commandName == "gte") { -// return resolveConditionArguments([](int32_t a, int32_t b) { return a >= b; }, conditionNode); -// } else if (commandName == "lt") { -// return resolveConditionArguments([](int32_t a, int32_t b) { return a < b; }, conditionNode); -// } else if (commandName == "lte") { -// return resolveConditionArguments([](int32_t a, int32_t b) { return a <= b; }, conditionNode); -// } else if (commandName == "neq") { -// return resolveConditionArguments([](int32_t a, int32_t b) { return a != b; }, conditionNode); -// } else if (commandName == "and") { -// return processConditions(conditionNode, false, LogicOp::AND); -// } else if (commandName == "or") { -// return processConditions(conditionNode, false, LogicOp::OR); -// } else { -// throw std::invalid_argument(fmt::format("Unknown or badly defined condition command {0}.", commandName)); -// } -// } - BehaviourFunction Object::instantiateConditionalBehaviour(const std::string &commandName, CommandArguments &commandArguments, CommandList &subCommands) { if (subCommands.size() == 0) { return instantiateBehaviour(commandName, commandArguments); @@ -233,65 +211,6 @@ BehaviourFunction Object::instantiateConditionalBehaviour(const std::string &com }; } -// BehaviourCondition Object::processConditions(YAML::Node &conditionNodeList, bool isTopLevel, LogicOp op) const { -// // We should have a single item and not a list -// if (!conditionNodeList.IsDefined()) { -// auto line = conditionNodeList.Mark().line; -// auto errorString = fmt::format("Parse error line {0}. If statement is missing Conditions", line); -// spdlog::error(errorString); -// throw std::invalid_argument(errorString); -// } - -// if (conditionNodeList.IsMap()) { -// if (conditionNodeList.size() != 1) { -// auto line = conditionNodeList.Mark().line; -// auto errorString = fmt::format("Parse error line {0}. Conditions must contain a single top-level condition", line); -// spdlog::error(errorString); -// throw std::invalid_argument(errorString); -// } -// auto conditionNode = conditionNodeList.begin(); -// auto commandName = conditionNode->first.as(); -// return instantiateCondition(commandName, conditionNode->second); - -// } else if (conditionNodeList.IsSequence()) { -// std::vector conditionList; -// for (auto &&subConditionNode : conditionNodeList) { -// auto validatedNode = validateCommandPairNode(subConditionNode); -// auto commandName = validatedNode->first.as(); -// conditionList.push_back(instantiateCondition(commandName, validatedNode->second)); -// } -// switch (op) { -// case LogicOp::AND: { -// return resolveAnd() -// }; -// } -// break; -// case LogicOp::OR: { -// return [conditionList](const std::shared_ptr &action) -> bool { -// for (const auto &condition : conditionList) { -// if (condition(action)) { -// return true; -// } -// } -// return false; -// }; -// } -// default: { -// auto line = conditionNodeList.Mark().line; -// auto errorString = fmt::format("Parse error line {0}. A sequence of conditions must be within an AND or an OR operator.", line); -// spdlog::error(errorString); -// throw std::invalid_argument(errorString); -// } -// } -// } -// else { -// auto line = conditionNodeList.Mark().line; -// auto errorString = fmt::format("Conditions must be a map or a list", line); -// spdlog::error(errorString); -// throw std::invalid_argument(errorString); -// } -// } // namespace griddly - /** * @brief executes a list of behaviour functions and accumulates the rewards * diff --git a/src/Griddly/Core/Observers/EntityObserver.cpp b/src/Griddly/Core/Observers/EntityObserver.cpp index 41515ac70..46d68caa1 100644 --- a/src/Griddly/Core/Observers/EntityObserver.cpp +++ b/src/Griddly/Core/Observers/EntityObserver.cpp @@ -15,6 +15,10 @@ void EntityObserver::init(EntityObserverConfig& config) { } } + if(config.globalVariableMapping.size() > 0) { + entityFeatures_.insert({"__global__", config.globalVariableMapping}); + } + // Precalclate offsets for entity configurations for (const auto& objectName : config_.objectNames) { @@ -100,7 +104,6 @@ glm::ivec2 EntityObserver::resolveLocation(const glm::ivec2& location) const { auto resolvedLocation = location - glm::ivec2{observableGrid.left, observableGrid.bottom}; if (doTrackAvatar_) { - const auto& avatarLocation = avatarObject_->getLocation(); const auto& avatarDirection = avatarObject_->getObjectOrientation().getDirection(); if (config_.rotateWithAvatar) { @@ -132,6 +135,25 @@ void EntityObserver::buildObservations(EntityObservations& entityObservations) { const auto& observableGrid = getObservableGrid(); + // Build global entity + if(config_.globalVariableMapping.size() > 0) { + std::vector globalFeatureVector(config_.globalVariableMapping.size()); + uint32_t featureIdx = 0; + const auto& globalVariables = grid_->getGlobalVariables(); + for(const auto& globalVariableName : config_.globalVariableMapping) { + + const auto& globalVariableValues = globalVariables.at(globalVariableName); + if(globalVariableValues.size() == 1) { + globalFeatureVector[featureIdx++] = static_cast(*globalVariableValues.at(0)); + } else { + globalFeatureVector[featureIdx++] = static_cast(*globalVariableValues.at(config_.playerId)); + } + } + + entityObservations.observations["__global__"].push_back(globalFeatureVector); + entityObservations.ids["__global__"].push_back(0); + } + for (const auto& object : grid_->getObjects()) { const auto& name = object->getObjectName(); auto location = object->getLocation(); @@ -240,7 +262,7 @@ std::unordered_map> EntityObserver:: std::vector EntityObserver::getAvailableActionIdsAtLocation(glm::ivec2 location, std::string actionName) const { auto srcObject = grid_->getObject(location); - spdlog::debug("Getting available actionIds for action [{}] at location [{0},{1}]", actionName, location.x, location.y); + spdlog::debug("Getting available actionIds for action [{0}] at location [{1},{2}]", actionName, location.x, location.y); std::vector availableActionIds{}; if (srcObject) { diff --git a/src/Griddly/Core/Observers/EntityObserver.hpp b/src/Griddly/Core/Observers/EntityObserver.hpp index 15d23219b..f81b71013 100644 --- a/src/Griddly/Core/Observers/EntityObserver.hpp +++ b/src/Griddly/Core/Observers/EntityObserver.hpp @@ -20,6 +20,7 @@ struct EntityObserverConfig : public ObserverConfig { std::unordered_set includePlayerId{}; std::unordered_set includeRotation{}; std::unordered_map> entityVariableMapping{}; + std::vector globalVariableMapping{}; std::unordered_map actionInputsDefinitions{}; std::vector objectNames{}; bool includeMasks = false; diff --git a/src/Griddly/Core/Observers/Observer.cpp b/src/Griddly/Core/Observers/Observer.cpp index 86dc1d99b..69d0ecdea 100644 --- a/src/Griddly/Core/Observers/Observer.cpp +++ b/src/Griddly/Core/Observers/Observer.cpp @@ -28,6 +28,8 @@ void Observer::reset() { doTrackAvatar_ = avatarObject_ != nullptr && config_.trackAvatar; + spdlog::debug("Tracking avatar: {0}", doTrackAvatar_ ? "Yes":"No"); + // if the observer is "READY", then it has already been initialized once, so keep it in the ready state, we're just resetting it. observerState_ = observerState_ == ObserverState::READY ? ObserverState::READY : ObserverState::RESET; } diff --git a/tests/src/Griddly/Core/Observers/EntityObserverTest.cpp b/tests/src/Griddly/Core/Observers/EntityObserverTest.cpp index e0d4a8337..fb7008ac0 100644 --- a/tests/src/Griddly/Core/Observers/EntityObserverTest.cpp +++ b/tests/src/Griddly/Core/Observers/EntityObserverTest.cpp @@ -69,7 +69,6 @@ void runEntityObserverTest(EntityObserverConfig observerConfig, entityObserver->reset(); const auto& updateEntityObservations = entityObserver->update(); - const auto& entityVariableMapping = entityObserver->getEntityVariableMapping(); ASSERT_EQ(updateEntityObservations.observations.size(), expectedEntityObservervations.observations.size()); @@ -90,8 +89,10 @@ void runEntityObserverTest(EntityObserverConfig observerConfig, auto expectedEntityLocation = glm::ivec2{updateObservations[i][0], updateObservations[i][1]}; auto updateId = updateEntityObservations.ids.at(entityName)[i]; - ASSERT_EQ(updateEntityObservations.locations.at(updateId)[0], expectedEntityLocation[0]); - ASSERT_EQ(updateEntityObservations.locations.at(updateId)[1], expectedEntityLocation[1]); + if (entityName != "__global__") { + ASSERT_EQ(updateEntityObservations.locations.at(updateId)[0], expectedEntityLocation[0]); + ASSERT_EQ(updateEntityObservations.locations.at(updateId)[1], expectedEntityLocation[1]); + } } } @@ -1061,6 +1062,8 @@ TEST(EntityObserverTest, entityVariableMapping) { config.includeRotation = {"mo3"}; + config.globalVariableMapping = {"lightingR", "lightingG", "lightingB"}; + config.entityVariableMapping = { {"mo2", {"health", "max_health"}}, {"mo3", {"health", "max_health"}}}; @@ -1069,6 +1072,8 @@ TEST(EntityObserverTest, entityVariableMapping) { // "x", "y", "z", "ox", "oy", "player_id" expectedEntityObservervations.observations = { + {"__global__", + {{50, 100, 100}}}, {"avatar", {{2, 2, 0}}}, {"mo1",