Skip to content

Commit

Permalink
auth_logic AST operator== (#700)
Browse files Browse the repository at this point in the history
Closes #700

COPYBARA_INTEGRATE_REVIEW=#700 from google-research:ast-equality@aferr f7eb278
PiperOrigin-RevId: 474559929
  • Loading branch information
Andrew Ferraiuolo authored and arcs-c3po committed Sep 15, 2022
1 parent 642c151 commit 9706fa4
Show file tree
Hide file tree
Showing 7 changed files with 426 additions and 56 deletions.
9 changes: 9 additions & 0 deletions src/ir/auth_logic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ cc_library(
],
)

cc_test(
name = "ast_equality_test",
srcs = ["ast_equality_test.cc"],
deps = [
":ast",
"//src/common/testing:gtest",
],
)

cc_test(
name = "souffle_emitter_test",
srcs = ["souffle_emitter_test.cc"],
Expand Down
60 changes: 60 additions & 0 deletions src/ir/auth_logic/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class Principal {
// for debugging/testing only
std::string ToString() const { return name_; }

bool operator==(const Principal& rhs) const { return name_ == rhs.name_; }

bool operator!=(const Principal& rhs) const { return !(*this == rhs); }

private:
std::string name_;
};
Expand Down Expand Up @@ -81,6 +85,12 @@ class Attribute {
return absl::StrCat(principal_.name(), predicate_.ToString());
}

bool operator==(const Attribute& rhs) const {
return principal_ == rhs.principal_ && predicate_ == rhs.predicate_;
}

bool operator!=(const Attribute& rhs) const { return !(*this == rhs); }

private:
Principal principal_;
datalog::Predicate predicate_;
Expand Down Expand Up @@ -114,6 +124,13 @@ class CanActAs {
right_principal_.ToString());
}

bool operator==(const CanActAs& rhs) const {
return left_principal_ == rhs.left_principal_ &&
right_principal_ == rhs.right_principal_;
}

bool operator!=(const CanActAs& rhs) const { return !(*this == rhs); }

private:
Principal left_principal_;
Principal right_principal_;
Expand Down Expand Up @@ -157,6 +174,10 @@ class BaseFact {
")");
}

bool operator==(const BaseFact& rhs) const { return value_ == rhs.value_; }

bool operator!=(const BaseFact& rhs) const { return !(*this == rhs); }

private:
BaseFactVariantType value_;
};
Expand Down Expand Up @@ -197,6 +218,13 @@ class Fact {
base_fact_.ToString());
}

bool operator==(const Fact& rhs) const {
return delegation_chain_ == rhs.delegation_chain_ &&
base_fact_ == rhs.base_fact_;
}

bool operator!=(const Fact& rhs) const { return !(*this == rhs); }

private:
std::forward_list<Principal> delegation_chain_;
BaseFact base_fact_;
Expand Down Expand Up @@ -235,6 +263,14 @@ class ConditionalAssertion {
absl::StrJoin(rhs_strings, ", "));
}

bool operator==(const ConditionalAssertion& rhs) const {
return lhs_ == rhs.lhs_ && rhs_ == rhs.rhs_;
}

bool operator!=(const ConditionalAssertion& rhs) const {
return !(*this == rhs);
}

private:
Fact lhs_;
std::vector<BaseFact> rhs_;
Expand Down Expand Up @@ -272,6 +308,10 @@ class Assertion {
")");
}

bool operator==(const Assertion& rhs) const { return value_ == rhs.value_; }

bool operator!=(const Assertion& rhs) const { return !(*this == rhs); }

private:
AssertionVariantType value_;
};
Expand Down Expand Up @@ -307,6 +347,12 @@ class SaysAssertion {
absl::StrJoin(assertion_strings, "\n"), "}");
}

bool operator==(const SaysAssertion& rhs) const {
return principal_ == rhs.principal_ && assertions_ == rhs.assertions_;
}

bool operator!=(const SaysAssertion& rhs) const { return !(*this == rhs); }

private:
Principal principal_;
std::vector<Assertion> assertions_;
Expand Down Expand Up @@ -342,6 +388,13 @@ class Query {
fact_.ToString(), ")");
}

bool operator==(const Query& rhs) const {
return name_ == rhs.name_ && principal_ == rhs.principal_ &&
fact_ == rhs.fact_;
}

bool operator!=(const Query& rhs) const { return !(*this == rhs); }

private:
std::string name_;
Principal principal_;
Expand Down Expand Up @@ -405,6 +458,13 @@ class Program {
absl::StrJoin(query_strings, "\n"), ")");
}

bool operator==(const Program& rhs) const {
return relation_declarations_ == rhs.relation_declarations_ &&
says_assertions_ == rhs.says_assertions_ && queries_ == rhs.queries_;
}

bool operator!=(const Program& rhs) const { return !(*this == rhs); }

private:
std::vector<datalog::RelationDeclaration> relation_declarations_;
std::vector<SaysAssertion> says_assertions_;
Expand Down
95 changes: 48 additions & 47 deletions src/ir/auth_logic/ast_construction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,28 +42,31 @@ namespace {
using auth_logic_cc_generator::AuthLogicLexer;
using auth_logic_cc_generator::AuthLogicParser;

std::string SanitizeMultilineStringForPredicateArgument(const std::string predicate_argument) {
return absl::StrReplaceAll(predicate_argument, {{"\n", ""}, {R"(""")", R"(")"}});
std::string SanitizeMultilineStringForPredicateArgument(
const std::string predicate_argument) {
return absl::StrReplaceAll(predicate_argument,
{{"\n", ""}, {R"(""")", R"(")"}});
}

static Principal ConstructPrincipal(
AuthLogicParser::PrincipalContext& principal_context) {
return Principal(
principal_context.STRING_LITERAL()->getText());
return Principal(principal_context.STRING_LITERAL()->getText());
}

static datalog::Predicate ConstructPredicate(
AuthLogicParser::PredicateContext& predicate_context) {
std::vector<std::string> args = utils::MapIter<std::string>(
predicate_context.pred_arg(), [](auto* pred_arg_context) {
if (auto* multi_line_predicate_argument_context =
dynamic_cast<AuthLogicParser::MultilineConstantContext*>(CHECK_NOTNULL(pred_arg_context)))
return SanitizeMultilineStringForPredicateArgument(multi_line_predicate_argument_context->getText());
dynamic_cast<AuthLogicParser::MultilineConstantContext*>(
CHECK_NOTNULL(pred_arg_context)))
return SanitizeMultilineStringForPredicateArgument(
multi_line_predicate_argument_context->getText());
return pred_arg_context->getText();
});
datalog::Sign sign = predicate_context.NEG() == nullptr
? datalog::Sign::kPositive
: datalog::Sign::kNegated;
? datalog::Sign::kPositive
: datalog::Sign::kNegated;
return datalog::Predicate(predicate_context.VARIABLE()->getText(),
std::move(args), sign);
}
Expand All @@ -74,16 +77,16 @@ static BaseFact ConstructVerbphrase(
if (auto* predphrase_context =
dynamic_cast<AuthLogicParser::PredphraseContext*>(
&verbphrase_context)) { // PredicateContext
return BaseFact(
Attribute(left_principal, ConstructPredicate(*CHECK_NOTNULL(
predphrase_context->predicate()))));
return BaseFact(Attribute(
left_principal,
ConstructPredicate(*CHECK_NOTNULL(predphrase_context->predicate()))));
}
auto& act_as_phrase_context =
*CHECK_NOTNULL(dynamic_cast<AuthLogicParser::ActsAsPhraseContext*>(
&verbphrase_context)); // ActsAsPhraseContext
return BaseFact(
CanActAs(left_principal, ConstructPrincipal(*CHECK_NOTNULL(
act_as_phrase_context.principal()))));
return BaseFact(CanActAs(
left_principal,
ConstructPrincipal(*CHECK_NOTNULL(act_as_phrase_context.principal()))));
}

static BaseFact ConstructFlatFact(
Expand All @@ -103,25 +106,26 @@ static BaseFact ConstructFlatFact(

static BaseFact ConstructRvalue(
AuthLogicParser::RvalueContext& rvalue_context) {
if (auto* flat_fact_rvalue_context =
dynamic_cast<AuthLogicParser::FlatFactRvalueContext*>(
&rvalue_context)) {
return ConstructFlatFact(
*CHECK_NOTNULL(flat_fact_rvalue_context->flatFact()));
}
auto& binop_rvalue_context = *CHECK_NOTNULL(
dynamic_cast<AuthLogicParser::BinopRvalueContext*>(
&rvalue_context));
std::vector<std::string> args = utils::MapIter<std::string>(
binop_rvalue_context.pred_arg(), [](auto* pred_arg_context) {
if (auto* multi_line_predicate_argument_context =
dynamic_cast<AuthLogicParser::MultilineConstantContext*>(CHECK_NOTNULL(pred_arg_context)))
return SanitizeMultilineStringForPredicateArgument(multi_line_predicate_argument_context->getText());
if (auto* flat_fact_rvalue_context =
dynamic_cast<AuthLogicParser::FlatFactRvalueContext*>(
&rvalue_context)) {
return ConstructFlatFact(
*CHECK_NOTNULL(flat_fact_rvalue_context->flatFact()));
}
auto& binop_rvalue_context = *CHECK_NOTNULL(
dynamic_cast<AuthLogicParser::BinopRvalueContext*>(&rvalue_context));
std::vector<std::string> args = utils::MapIter<std::string>(
binop_rvalue_context.pred_arg(), [](auto* pred_arg_context) {
if (auto* multi_line_predicate_argument_context =
dynamic_cast<AuthLogicParser::MultilineConstantContext*>(
CHECK_NOTNULL(pred_arg_context)))
return SanitizeMultilineStringForPredicateArgument(
multi_line_predicate_argument_context->getText());
return pred_arg_context->getText();
});
return BaseFact(
datalog::Predicate(binop_rvalue_context.binop()->getText(),
std::move(args), datalog::Sign::kPositive));
});
return BaseFact(datalog::Predicate(binop_rvalue_context.binop()->getText(),
std::move(args),
datalog::Sign::kPositive));
}

// Fact corresponds to a delagation chain of principal and a Basefact.
Expand All @@ -133,9 +137,9 @@ static Fact ConstructFact(std::forward_list<Principal> delegation_chain,
if (auto* flat_fact_fact_context =
dynamic_cast<AuthLogicParser::FlatFactFactContext*>(
&fact_context)) { // FlatFactFactContext
return Fact(std::move(delegation_chain),
ConstructFlatFact(
*CHECK_NOTNULL(flat_fact_fact_context->flatFact())));
return Fact(
std::move(delegation_chain),
ConstructFlatFact(*CHECK_NOTNULL(flat_fact_fact_context->flatFact())));
}
auto& can_say_fact_context = *CHECK_NOTNULL(
dynamic_cast<AuthLogicParser::CanSayFactContext*>(&fact_context));
Expand Down Expand Up @@ -213,8 +217,7 @@ static Assertion ConstructAssertion(
return ConstructRvalue(*CHECK_NOTNULL(rvalue_context));
});
return Assertion(ConditionalAssertion(
ConstructFact({},
*CHECK_NOTNULL(horn_clause_assertion_context.fact())),
ConstructFact({}, *CHECK_NOTNULL(horn_clause_assertion_context.fact())),
std::move(base_facts)));
}

Expand Down Expand Up @@ -245,19 +248,17 @@ static SaysAssertion ConstructSaysAssertion(

static Program ConstructProgram(
AuthLogicParser::ProgramContext& program_context) {
auto relation_declarations =
utils::MapIter<datalog::RelationDeclaration>(
program_context.relationDeclaration(),
[](AuthLogicParser::RelationDeclarationContext*
relation_declaration_context) {
return ConstructRelationDeclaration(
*CHECK_NOTNULL(relation_declaration_context));
});
auto relation_declarations = utils::MapIter<datalog::RelationDeclaration>(
program_context.relationDeclaration(),
[](AuthLogicParser::RelationDeclarationContext*
relation_declaration_context) {
return ConstructRelationDeclaration(
*CHECK_NOTNULL(relation_declaration_context));
});
std::vector<SaysAssertion> says_assertions = utils::MapIter<SaysAssertion>(
program_context.saysAssertion(),
[](AuthLogicParser::SaysAssertionContext* says_assertion_context) {
return ConstructSaysAssertion(
*CHECK_NOTNULL(says_assertion_context));
return ConstructSaysAssertion(*CHECK_NOTNULL(says_assertion_context));
});
std::vector<Query> queries = utils::MapIter<Query>(
program_context.query(),
Expand Down
Loading

0 comments on commit 9706fa4

Please sign in to comment.