Skip to content

Commit

Permalink
Improve output for MakePOMDPCanonic (#606)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexBork authored Sep 10, 2024
1 parent db61b80 commit 9fba97a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 17 deletions.
44 changes: 27 additions & 17 deletions src/storm-pomdp/transformer/MakePOMDPCanonic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,15 @@ std::string MakePOMDPCanonic<ValueType>::getStateInformation(uint64_t state) con
}
}

template<typename ValueType>
std::string MakePOMDPCanonic<ValueType>::getObservationInformation(uint32_t obs) const {
if (pomdp.hasObservationValuations()) {
return std::to_string(obs) + " " + pomdp.getObservationValuations().getStateInfo(obs);
} else {
return std::to_string(obs);
}
}

template<typename ValueType>
std::vector<uint64_t> MakePOMDPCanonic<ValueType>::computeCanonicalPermutation() const {
std::map<uint32_t, std::vector<detail::ActionIdentifier>> observationActionIdentifiers;
Expand All @@ -179,10 +188,6 @@ std::vector<uint64_t> MakePOMDPCanonic<ValueType>::computeCanonicalPermutation()
if (moreActionObservations.get(observation)) {
// We have seen this observation previously with multiple actions. Error!
// TODO provide more diagnostic information
std::string stateval = "";
if (pomdp.hasStateValuations()) {
stateval = " (" + pomdp.getStateValuations().getStateInfo(state) + ") ";
}
std::string actionval = "";
if (pomdp.hasChoiceLabeling()) {
auto labelsOfChoice = pomdp.getChoiceLabeling().getLabelsOfChoice(rowIndexFrom);
Expand All @@ -193,8 +198,8 @@ std::vector<uint64_t> MakePOMDPCanonic<ValueType>::computeCanonicalPermutation()
}
}
STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException,
"Observation " << observation << " sometimes provides multiple actions, but in state " << state << stateval
<< " provides only one action " << actionval << ".");
"Observation " << getObservationInformation(observation) << " sometimes provides multiple actions, but in state "
<< getStateInformation(state) << " provides only one action " << actionval << ".");
}
oneActionObservations.set(observation);

Expand All @@ -203,17 +208,13 @@ std::vector<uint64_t> MakePOMDPCanonic<ValueType>::computeCanonicalPermutation()
} else {
if (oneActionObservations.get(observation)) {
// We have seen this observation previously with one action. Error!
std::string stateval = "";
if (pomdp.hasStateValuations()) {
stateval = " (" + pomdp.getStateValuations().getStateInfo(state) + ") ";
}
// std::string actionval= "";
// if (pomdp.hasChoiceLabeling()) {
// actionval = *pomdp.getChoiceLabeling().getLabelsOfChoice(rowIndexFrom).begin();
// }
STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException,
"Observation " << observation << " sometimes provides one action, but in state " << state << stateval << " provides "
<< rowIndexTo - rowIndexFrom << " actions.");
"Observation " << getObservationInformation(observation) << " sometimes provides one action, but in state "
<< getStateInformation(state) << " provides " << rowIndexTo - rowIndexFrom << " actions.");
}
moreActionObservations.set(observation);
}
Expand Down Expand Up @@ -262,26 +263,35 @@ std::vector<uint64_t> MakePOMDPCanonic<ValueType>::computeCanonicalPermutation()
auto referenceStart = observationActionIdentifiers[observation].begin();
auto referenceEnd = observationActionIdentifiers[observation].end();
STORM_LOG_ASSERT(observationActionIdentifiers[observation].size() == pomdp.getNumberOfChoices(actionIdentifierDefinition[observation]),
"Number of actions recorded for state does not coinide with number of actions.");
"Number of actions recorded for state does not coincide with number of actions.");
if (observationActionIdentifiers[observation].size() != actionIdentifiers.size()) {
STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException,
"Number of actions in state '" << getStateInformation(state) << "' (nr actions:" << actionIdentifiers.size() << ") and state '"
<< getStateInformation(actionIdentifierDefinition[observation])
<< "' (actions: " << observationActionIdentifiers[observation].size()
<< " ), both having observation " << observation << " do not match.");
<< " ), both having observation " << getObservationInformation(observation) << " do not match.");
}
if (!detail::compatibleWith(referenceStart, referenceEnd, actionIdentifiers.begin(), actionIdentifiers.end())) {
std::cout << "Observation " << observation << ": ";
std::cout << "Observation " << getObservationInformation(observation) << ": \n";
detail::actionIdentifiersToStream(std::cout, observationActionIdentifiers[observation], labelStorage);
std::cout << " according to state " << actionIdentifierDefinition[observation] << ".\n";
std::cout << "Observation " << observation << ": ";
detail::actionIdentifiersToStream(std::cout, actionIdentifiers, labelStorage);
std::cout << " according to state " << state << ".\n";
std::cout << "(Actions are given as [id (label), originId])\n";

for (auto const& ai : actionIdentifiers) {
if (labelStorage.getLabel(ai.first.choiceLabelId) == "") {
for (auto const& ai2 : observationActionIdentifiers[observation]) {
STORM_LOG_WARN_COND(ai2.choiceLabelId != ai.first.choiceLabelId,
"Actions with empty label are only considered the same action if they originate from the same command.");
}
}
}

STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException,
"Actions identifiers do not align between states \n\t"
<< getStateInformation(state) << "\nand\n\t" << getStateInformation(actionIdentifierDefinition[observation])
<< "\nboth having observation " << observation << ". See output above for more information.");
<< "\nboth having observation " << getObservationInformation(observation) << ". See output above for more information.");
}
}

Expand Down
1 change: 1 addition & 0 deletions src/storm-pomdp/transformer/MakePOMDPCanonic.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class MakePOMDPCanonic {
std::vector<uint64_t> computeCanonicalPermutation() const;
std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> applyPermutationOnPomdp(std::vector<uint64_t> permutation) const;
std::string getStateInformation(uint64_t state) const;
std::string getObservationInformation(uint32_t obs) const;

storm::models::sparse::Pomdp<ValueType> const& pomdp;
};
Expand Down

0 comments on commit 9fba97a

Please sign in to comment.