diff --git a/include/arbitration_graphs/arbitrator.hpp b/include/arbitration_graphs/arbitrator.hpp index 5731b9ed..0213d01a 100644 --- a/include/arbitration_graphs/arbitrator.hpp +++ b/include/arbitration_graphs/arbitrator.hpp @@ -63,8 +63,16 @@ class Arbitrator : public Behavior { typename Behavior::Ptr behavior_; FlagsT flags_; + mutable util_caching::Cache command_; mutable util_caching::Cache verificationResult_; + SubCommandT getCommand(const Time& time) const { + if (!command_.cached(time)) { + command_.cache(time, behavior_->getCommand(time)); + } + return command_.cached(time).value(); + } + bool hasFlag(const FlagsT& flag_to_check) const { return flags_ & flag_to_check; } @@ -251,4 +259,4 @@ class Arbitrator : public Behavior { } // namespace arbitration_graphs #include "internal/arbitrator_impl.hpp" -#include "internal/arbitrator_io.hpp" \ No newline at end of file +#include "internal/arbitrator_io.hpp" diff --git a/include/arbitration_graphs/cost_arbitrator.hpp b/include/arbitration_graphs/cost_arbitrator.hpp index d5e2971a..add1a695 100644 --- a/include/arbitration_graphs/cost_arbitrator.hpp +++ b/include/arbitration_graphs/cost_arbitrator.hpp @@ -76,7 +76,7 @@ class CostArbitrator : public Arbitrator::Ptr& behavior, @@ -118,10 +118,10 @@ class CostArbitrator : public ArbitratorcostEstimator_->estimateCost(option->behavior_->getCommand(time), isActive); + cost = option->costEstimator_->estimateCost(option->getCommand(time), isActive); } else { option->behavior_->gainControl(time); - cost = option->costEstimator_->estimateCost(option->behavior_->getCommand(time), isActive); + cost = option->costEstimator_->estimateCost(option->getCommand(time), isActive); option->behavior_->loseControl(time); } option->last_estimated_cost_ = cost; @@ -139,4 +139,4 @@ class CostArbitrator : public Arbitrator Arbitrator::getAndVerifyCommand( const typename Option::Ptr& option, const Time& time) const { try { - const SubCommandT command = option->behavior_->getCommand(time); + const SubCommandT command = option->getCommand(time); const VerificationResultT verificationResult = verifier_.analyze(time, command); option->verificationResult_.cache(time, verificationResult); @@ -127,4 +127,4 @@ SubCommandT Arbitrator::g " applicable options passed the verification step!"); } -} // namespace arbitration_graphs \ No newline at end of file +} // namespace arbitration_graphs diff --git a/test/cost_arbitrator.cpp b/test/cost_arbitrator.cpp index 4cbacf70..4233b40c 100644 --- a/test/cost_arbitrator.cpp +++ b/test/cost_arbitrator.cpp @@ -116,6 +116,36 @@ TEST_F(CostArbitratorTest, BasicFunctionality) { EXPECT_EQ("high_cost", testCostArbitrator.getCommand(time)); } +TEST_F(CostArbitratorTest, CommandCaching) { + testCostArbitrator.addOption(testBehaviorLowCost, OptionFlags::NO_FLAGS, cost_estimator); + testCostArbitrator.addOption(testBehaviorLowCost, OptionFlags::NO_FLAGS, cost_estimator); + testCostArbitrator.addOption(testBehaviorHighCost, OptionFlags::NO_FLAGS, cost_estimator); + testCostArbitrator.addOption(testBehaviorMidCost, OptionFlags::NO_FLAGS, cost_estimator); + + EXPECT_TRUE(testCostArbitrator.checkInvocationCondition(time)); + EXPECT_FALSE(testCostArbitrator.checkCommitmentCondition(time)); + EXPECT_EQ(0, testBehaviorMidCost->getCommandCounter_); + + testCostArbitrator.gainControl(time); + + // Even though the cost arbitrator needs to compute the command to estimate the costs, the behaviors getCommand + // should only be called once since the result is cached + EXPECT_EQ("mid_cost", testCostArbitrator.getCommand(time)); + EXPECT_EQ(1, testBehaviorMidCost->getCommandCounter_); + EXPECT_EQ("mid_cost", testCostArbitrator.getCommand(time)); + // For a second call to getCommand, we can still use the cached command + EXPECT_EQ(1, testBehaviorMidCost->getCommandCounter_); + + time = time + Duration(1); + + // The cached command should be invalidated after the time has passed + // Therefore the behavior should be called again once for the new time + EXPECT_EQ("mid_cost", testCostArbitrator.getCommand(time)); + EXPECT_EQ(2, testBehaviorMidCost->getCommandCounter_); + EXPECT_EQ("mid_cost", testCostArbitrator.getCommand(time)); + EXPECT_EQ(2, testBehaviorMidCost->getCommandCounter_); +} + TEST_F(CostArbitratorTest, Printout) { testCostArbitrator.addOption(testBehaviorLowCost, OptionFlags::NO_FLAGS, cost_estimator); testCostArbitrator.addOption(testBehaviorLowCost, OptionFlags::NO_FLAGS, cost_estimator); diff --git a/test/dummy_types.hpp b/test/dummy_types.hpp index d0304d46..37126a5e 100644 --- a/test/dummy_types.hpp +++ b/test/dummy_types.hpp @@ -41,10 +41,10 @@ class DummyBehavior : public Behavior { using Ptr = std::shared_ptr; DummyBehavior(const bool invocation, const bool commitment, const std::string& name = "DummyBehavior") - : Behavior(name), invocationCondition_{invocation}, commitmentCondition_{commitment}, loseControlCounter_{ - 0} {}; + : Behavior(name), invocationCondition_{invocation}, commitmentCondition_{commitment} {}; DummyCommand getCommand(const Time& time) override { + getCommandCounter_++; return name_; } bool checkInvocationCondition(const Time& time) const override { @@ -59,7 +59,8 @@ class DummyBehavior : public Behavior { bool invocationCondition_; bool commitmentCondition_; - int loseControlCounter_; + int getCommandCounter_{0}; + int loseControlCounter_{0}; }; struct DummyResult { @@ -76,4 +77,4 @@ struct DummyResult { inline std::ostream& operator<<(std::ostream& out, const arbitration_graphs_tests::DummyResult& result) { out << (result.isOk() ? "is okay" : "is not okay"); return out; -} \ No newline at end of file +}