From 884c577cf9cb22e7f30b792302ce388c26a94607 Mon Sep 17 00:00:00 2001 From: KuechA <31155350+KuechA@users.noreply.github.com> Date: Wed, 30 Oct 2024 18:27:13 +0100 Subject: [PATCH] Add Comprehensions of Lists, Sets and Maps and generator expressions (#1786) * Add comprehension expression * Initial python translation for listcomp * SetComp and DictComp in python frontend * First simple tests * test in main * Try to add DFG edges * Fix not implemented error * Fix more bugs * Also handle GeneratorExp, add some documentation. * Extract nested class to own file * Fix bug, aggregate predicates * Remove unnecessary changes * Specify idea for EOG * Fake higher test coverage * More testing * More tests * Fix error from renaming * Handle the comprehension expression in the control flow sensitive DFG * Adding alternatives to EOG for collection comprehension and fixing syntax error in comprehension expression * Adding alternative that properly depicts generator behavior * Small fix * Alternative for ComprehensionExpression * Fix * Adding EOG handling for ComprehensionExpression and CollectionComprehension * Add test and fix EOG pass implementation * Allow th addition to something that holds arguments and something that holds statements * Remove useless stuff from ControlflowSensitiveDFGPass * Make non-optional things non-optional * Fix test * Remove condition to reduce code which needs coverage * More tests * Update stuff * review * review * generator type --------- Co-authored-by: Konrad Weiss --- .../aisec/cpg/graph/ExpressionBuilder.kt | 19 ++ .../aisec/cpg/graph/builder/Fluent.kt | 41 +++ .../aisec/cpg/graph/edges/ast/AstEdge.kt | 4 +- .../edges/collections/UnwrappedEdgeList.kt | 7 +- .../expressions/CollectionComprehension.kt | 118 +++++++++ .../expressions/ComprehensionExpression.kt | 125 ++++++++++ .../cpg/passes/ControlFlowSensitiveDFGPass.kt | 35 +++ .../de/fraunhofer/aisec/cpg/passes/DFGPass.kt | 30 +++ .../cpg/passes/EvaluationOrderGraphPass.kt | 81 +++++- .../de/fraunhofer/aisec/cpg/GraphExamples.kt | 34 +++ .../aisec/cpg/graph/ExpressionBuilderTest.kt | 1 + .../fraunhofer/aisec/cpg/graph/FluentTest.kt | 170 +++++++++++++ .../collections/UnwrappedEdgeListTest.kt | 3 + .../passes/EvaluationOrderGraphPassTest.kt | 130 ++++++++++ .../cpg/frontends/python/ExpressionHandler.kt | 122 +++++++-- .../frontends/python/ExpressionHandlerTest.kt | 235 ++++++++++++++++++ .../test/resources/python/comprehension.py | 24 ++ docs/docs/CPG/specs/dfg.md | 38 +++ docs/docs/CPG/specs/eog.md | 49 ++++ 19 files changed, 1241 insertions(+), 25 deletions(-) create mode 100644 cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/statements/expressions/CollectionComprehension.kt create mode 100644 cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/statements/expressions/ComprehensionExpression.kt create mode 100644 cpg-language-python/src/test/resources/python/comprehension.py diff --git a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/ExpressionBuilder.kt b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/ExpressionBuilder.kt index af68b0f6e8..b49b683752 100644 --- a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/ExpressionBuilder.kt +++ b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/ExpressionBuilder.kt @@ -31,6 +31,7 @@ import de.fraunhofer.aisec.cpg.graph.NodeBuilder.log import de.fraunhofer.aisec.cpg.graph.edges.flows.ContextSensitiveDataflow import de.fraunhofer.aisec.cpg.graph.statements.expressions.* import de.fraunhofer.aisec.cpg.graph.statements.expressions.AssignExpression +import de.fraunhofer.aisec.cpg.graph.statements.expressions.CollectionComprehension import de.fraunhofer.aisec.cpg.graph.types.ProblemType import de.fraunhofer.aisec.cpg.graph.types.Type @@ -525,6 +526,24 @@ fun MetadataProvider.newInitializerListExpression( return node } +@JvmOverloads +fun MetadataProvider.newComprehensionExpression(rawNode: Any? = null): ComprehensionExpression { + val node = ComprehensionExpression() + node.applyMetadata(this, EMPTY_NAME, rawNode, true) + + log(node) + return node +} + +@JvmOverloads +fun MetadataProvider.newCollectionComprehension(rawNode: Any? = null): CollectionComprehension { + val node = CollectionComprehension() + node.applyMetadata(this, EMPTY_NAME, rawNode, true) + + log(node) + return node +} + /** * Creates a new [TypeExpression]. The [MetadataProvider] receiver will be used to fill different * meta-data using [Node.applyMetadata]. Calling this extension function outside of Kotlin requires diff --git a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/builder/Fluent.kt b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/builder/Fluent.kt index 1ea29ec589..7beb6cf634 100644 --- a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/builder/Fluent.kt +++ b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/builder/Fluent.kt @@ -32,6 +32,7 @@ import de.fraunhofer.aisec.cpg.graph.declarations.* import de.fraunhofer.aisec.cpg.graph.scopes.RecordScope import de.fraunhofer.aisec.cpg.graph.statements.* import de.fraunhofer.aisec.cpg.graph.statements.expressions.* +import de.fraunhofer.aisec.cpg.graph.statements.expressions.CollectionComprehension import de.fraunhofer.aisec.cpg.graph.types.FunctionType import de.fraunhofer.aisec.cpg.graph.types.Type import de.fraunhofer.aisec.cpg.graph.types.UnknownType @@ -325,6 +326,46 @@ fun LanguageFrontend<*, *>.subscriptExpr( return node } +context(Holder) +fun LanguageFrontend<*, *>.listComp( + init: (CollectionComprehension.() -> Unit)? = null +): CollectionComprehension { + val node = newCollectionComprehension() + + if (init != null) { + init(node) + } + + // Only add this to an argument holder if the nearest holder is an argument holder + val holder = this@Holder + if (holder is StatementHolder) { + holder += node + } else if (holder is ArgumentHolder) { + holder += node + } + + return node +} + +context(Holder) +fun LanguageFrontend<*, *>.compExpr( + init: (ComprehensionExpression.() -> Unit)? = null +): ComprehensionExpression { + val node = newComprehensionExpression() + + if (init != null) { + init(node) + } + + // Only add this to an argument holder if the nearest holder is an argument holder + val holder = this@Holder + if (holder is ArgumentHolder) { + holder += node + } + + return node +} + /** * Creates a new [DeclarationStatement] in the Fluent Node DSL and adds it to the * [StatementHolder.statements] of the nearest enclosing [StatementHolder]. The [init] block can be diff --git a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/edges/ast/AstEdge.kt b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/edges/ast/AstEdge.kt index adec108640..ebdc6b69e5 100644 --- a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/edges/ast/AstEdge.kt +++ b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/edges/ast/AstEdge.kt @@ -48,7 +48,7 @@ fun Node.astEdgesOf( } /** - * Creates an single optional [AstEdge] starting from this node (wrapped in a [EdgeSingletonList] + * Creates a single optional [AstEdge] starting from this node (wrapped in a [EdgeSingletonList] * container). */ fun Node.astOptionalEdgeOf( @@ -64,7 +64,7 @@ fun Node.astOptionalEdgeOf( } /** - * Creates an single [AstEdge] starting from this node (wrapped in a [EdgeSingletonList] container). + * Creates a single [AstEdge] starting from this node (wrapped in a [EdgeSingletonList] container). */ fun Node.astEdgeOf( of: NodeType, diff --git a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/edges/collections/UnwrappedEdgeList.kt b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/edges/collections/UnwrappedEdgeList.kt index e46f14c9b7..eeecc35138 100644 --- a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/edges/collections/UnwrappedEdgeList.kt +++ b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/edges/collections/UnwrappedEdgeList.kt @@ -70,7 +70,12 @@ class UnwrappedEdgeList>( } override fun subList(fromIndex: Int, toIndex: Int): MutableList { - TODO("Not yet implemented") + return if (list.outgoing) { + list.subList(fromIndex, toIndex).map { it.end }.toMutableList() + } else { + @Suppress("UNCHECKED_CAST") + list.subList(fromIndex, toIndex).map { it.start as NodeType }.toMutableList() + } } override fun get(index: Int): NodeType { diff --git a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/statements/expressions/CollectionComprehension.kt b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/statements/expressions/CollectionComprehension.kt new file mode 100644 index 0000000000..e51e306f3d --- /dev/null +++ b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/statements/expressions/CollectionComprehension.kt @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2024, Fraunhofer AISEC. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * $$$$$$\ $$$$$$$\ $$$$$$\ + * $$ __$$\ $$ __$$\ $$ __$$\ + * $$ / \__|$$ | $$ |$$ / \__| + * $$ | $$$$$$$ |$$ |$$$$\ + * $$ | $$ ____/ $$ |\_$$ | + * $$ | $$\ $$ | $$ | $$ | + * \$$$$$ |$$ | \$$$$$ | + * \______/ \__| \______/ + * + */ +package de.fraunhofer.aisec.cpg.graph.statements.expressions + +import de.fraunhofer.aisec.cpg.graph.* +import de.fraunhofer.aisec.cpg.graph.edges.ast.astEdgeOf +import de.fraunhofer.aisec.cpg.graph.edges.ast.astEdgesOf +import de.fraunhofer.aisec.cpg.graph.edges.unwrapping +import de.fraunhofer.aisec.cpg.graph.statements.Statement +import java.util.Objects +import org.apache.commons.lang3.builder.ToStringBuilder +import org.neo4j.ogm.annotation.Relationship + +/** + * Represent a list/set/map comprehension or similar expression. It contains four major components: + * The statement, the variable, the iterable and a predicate which are combined to something like + * `[statement(variable) : variable in iterable if predicate(variable)]`. + * + * Some languages provide a way to have multiple variables, iterables and predicates. For this + * reason, we represent the `variable, iterable and predicate in its own class + * [ComprehensionExpression]. + */ +class CollectionComprehension : Expression(), ArgumentHolder { + + @Relationship("COMPREHENSION_EXPRESSIONS") + var comprehensionExpressionEdges = astEdgesOf() + /** + * This field contains one or multiple [ComprehensionExpression]s. + * + * Note: Instead of having a list here, we could also enforce that the frontend nests the + * expressions in a meaningful way (in particular this would help us to satisfy dependencies + * between the comprehensions' variables). + */ + var comprehensionExpressions by + unwrapping(CollectionComprehension::comprehensionExpressionEdges) + + @Relationship("STATEMENT") + var statementEdge = + astEdgeOf( + ProblemExpression("No statement provided but is required in ${this::class}") + ) + /** + * This field contains the statement which is applied to each element of the input for which the + * predicate returned `true`. + */ + var statement by unwrapping(CollectionComprehension::statementEdge) + + override fun toString() = + ToStringBuilder(this, TO_STRING_STYLE) + .appendSuper(super.toString()) + .append("statement", statement) + .append("comprehensions", comprehensionExpressions) + .toString() + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is CollectionComprehension) return false + return super.equals(other) && + statement == other.statement && + comprehensionExpressions == other.comprehensionExpressions + } + + override fun hashCode() = Objects.hash(super.hashCode(), statement, comprehensionExpressions) + + override fun addArgument(expression: Expression) { + if (this.statement is ProblemExpression) { + this.statement = expression + } else if (expression is ComprehensionExpression) { + this.comprehensionExpressions += expression + } + } + + override fun replaceArgument(old: Expression, new: Expression): Boolean { + if (this.statement == old) { + this.statement = new + return true + } + if (new !is ComprehensionExpression) return false + var changedSomething = false + val newCompExp = + this.comprehensionExpressions.map { + if (it == old) { + changedSomething = true + new + } else it + } + this.comprehensionExpressions.clear() + this.comprehensionExpressions.addAll(newCompExp) + return changedSomething + } + + override fun hasArgument(expression: Expression): Boolean { + return this.statement == expression || expression in this.comprehensionExpressions + } +} diff --git a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/statements/expressions/ComprehensionExpression.kt b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/statements/expressions/ComprehensionExpression.kt new file mode 100644 index 0000000000..aeb24c0bf2 --- /dev/null +++ b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/statements/expressions/ComprehensionExpression.kt @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2024, Fraunhofer AISEC. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * $$$$$$\ $$$$$$$\ $$$$$$\ + * $$ __$$\ $$ __$$\ $$ __$$\ + * $$ / \__|$$ | $$ |$$ / \__| + * $$ | $$$$$$$ |$$ |$$$$\ + * $$ | $$ ____/ $$ |\_$$ | + * $$ | $$\ $$ | $$ | $$ | + * \$$$$$ |$$ | \$$$$$ | + * \______/ \__| \______/ + * + */ +package de.fraunhofer.aisec.cpg.graph.statements.expressions + +import de.fraunhofer.aisec.cpg.graph.AccessValues +import de.fraunhofer.aisec.cpg.graph.ArgumentHolder +import de.fraunhofer.aisec.cpg.graph.edges.ast.astEdgeOf +import de.fraunhofer.aisec.cpg.graph.edges.ast.astOptionalEdgeOf +import de.fraunhofer.aisec.cpg.graph.edges.unwrapping +import de.fraunhofer.aisec.cpg.graph.statements.Statement +import java.util.Objects +import org.apache.commons.lang3.builder.ToStringBuilder +import org.neo4j.ogm.annotation.Relationship + +/** This class holds the variable, iterable and predicate of the [CollectionComprehension]. */ +class ComprehensionExpression : Expression(), ArgumentHolder { + @Relationship("VARIABLE") + var variableEdge = + astEdgeOf( + of = ProblemExpression("Missing variableEdge in ${this::class}"), + onChanged = { _, new -> + val end = new?.end + if (end is Reference) { + end.access = AccessValues.WRITE + } + } + ) + + /** + * This field contains the iteration variable of the comprehension. It can be either a new + * variable declaration or a reference (probably to a new variable). + */ + var variable by unwrapping(ComprehensionExpression::variableEdge) + + @Relationship("ITERABLE") + var iterableEdge = + astEdgeOf(ProblemExpression("Missing iterable in ${this::class}")) + + /** This field contains the iteration subject of the loop. */ + var iterable by unwrapping(ComprehensionExpression::iterableEdge) + + @Relationship("PREDICATE") var predicateEdge = astOptionalEdgeOf() + + /** + * This field contains the predicate which has to hold to evaluate `statement(variable)` and + * include it in the result. + */ + var predicate by unwrapping(ComprehensionExpression::predicateEdge) + + override fun toString() = + ToStringBuilder(this, TO_STRING_STYLE) + .appendSuper(super.toString()) + .append("variable", variable) + .append("iterable", iterable) + .append("predicate", predicate) + .toString() + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is ComprehensionExpression) return false + return super.equals(other) && + variable == other.variable && + iterable == other.iterable && + predicate == other.predicate + } + + override fun hashCode() = Objects.hash(super.hashCode(), variable, iterable, predicate) + + override fun addArgument(expression: Expression) { + if (this.variable is ProblemExpression) { + this.variable = expression + } else if (this.iterable is ProblemExpression) { + this.iterable = expression + } else { + this.predicate = expression + } + } + + override fun replaceArgument(old: Expression, new: Expression): Boolean { + if (this.variable == old) { + this.variable = new + return true + } + + if (this.iterable == old) { + this.iterable = new + return true + } + + if (this.predicate == old) { + this.predicate = new + return true + } + return false + } + + override fun hasArgument(expression: Expression): Boolean { + return this.variable == expression || + this.iterable == expression || + expression == this.predicate + } +} diff --git a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/ControlFlowSensitiveDFGPass.kt b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/ControlFlowSensitiveDFGPass.kt index 6883252120..e9b06fc1ab 100644 --- a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/ControlFlowSensitiveDFGPass.kt +++ b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/ControlFlowSensitiveDFGPass.kt @@ -361,6 +361,41 @@ open class ControlFlowSensitiveDFGPass(ctx: TranslationContext) : EOGStarterPass // the other steps state.push(currentNode, it) } + } else if (currentNode is ComprehensionExpression) { + val writtenTo = + when (val variable = currentNode.variable) { + is DeclarationStatement -> { + if (variable.isSingleDeclaration()) { + variable.singleDeclaration + } else { + log.error( + "Cannot handle multiple declarations in the ComprehensionExpresdsion: Node $currentNode" + ) + null + } + } + else -> currentNode.variable + } + // We wrote something to this variable declaration + writtenTo?.let { + writtenDeclaration = + when (writtenTo) { + is Declaration -> writtenTo + is Reference -> writtenTo.refersTo + else -> { + log.error( + "The variable of type ${writtenTo.javaClass} is not yet supported in the ComprehensionExpression" + ) + null + } + } + + state.push(writtenTo, PowersetLattice(identitySetOf(currentNode.iterable))) + // Add the variable declaration (or the reference) to the list of previous + // write nodes in this path + state.declarationsState[writtenDeclaration] = + PowersetLattice(identitySetOf(writtenTo)) + } } else if (currentNode is ForEachStatement && currentNode.variable != null) { // The VariableDeclaration in the ForEachStatement doesn't have an initializer, so // the "normal" case won't work. We handle this case separately here... diff --git a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/DFGPass.kt b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/DFGPass.kt index da018184f7..7e0759d654 100644 --- a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/DFGPass.kt +++ b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/DFGPass.kt @@ -105,6 +105,8 @@ class DFGPass(ctx: TranslationContext) : ComponentPass(ctx) { ) { when (node) { // Expressions + is CollectionComprehension -> handleCollectionComprehension(node) + is ComprehensionExpression -> handleComprehensionExpression(node) is CallExpression -> handleCallExpression(node, inferDfgForUnresolvedSymbols) is CastExpression -> handleCastExpression(node) is BinaryOperator -> handleBinaryOp(node, parent) @@ -139,6 +141,34 @@ class DFGPass(ctx: TranslationContext) : ComponentPass(ctx) { } } + /** + * Handles a collection comprehension. The data flow from + * `comprehension.comprehensionExpressions[i]` to `comprehension.comprehensionExpressions[i+1]` + * and for the last `comprehension.comprehensionExpressions[i]`, it flows to the + * `comprehension.statement`. + */ + protected fun handleCollectionComprehension(comprehension: CollectionComprehension) { + if (comprehension.comprehensionExpressions.isNotEmpty()) { + comprehension.comprehensionExpressions + .subList(0, comprehension.comprehensionExpressions.size - 1) + .forEachIndexed { i, expr -> + expr.nextDFG += comprehension.comprehensionExpressions[i + 1] + } + comprehension.comprehensionExpressions.last().nextDFG += comprehension.statement + } + comprehension.prevDFG += comprehension.statement + } + + /** + * The iterable flows to the variable which flows into the whole expression together with the + * predicate(s). + */ + protected fun handleComprehensionExpression(comprehension: ComprehensionExpression) { + comprehension.iterable.nextDFG += comprehension.variable + comprehension.prevDFG += comprehension.variable + comprehension.predicate?.let { comprehension.prevDFG += it } + } + /** Handle a [ThrowStatement]. The exception and parent exception flow into the node. */ protected fun handleThrowStatement(node: ThrowStatement) { node.exception?.let { node.prevDFGEdges += it } diff --git a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/EvaluationOrderGraphPass.kt b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/EvaluationOrderGraphPass.kt index cccd4e3d6b..103a5ecc52 100644 --- a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/EvaluationOrderGraphPass.kt +++ b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/EvaluationOrderGraphPass.kt @@ -184,6 +184,12 @@ open class EvaluationOrderGraphPass(ctx: TranslationContext) : TranslationUnitPa map[DefaultStatement::class.java] = { handleDefault(it) } map[TypeIdExpression::class.java] = { handleDefault(it) } map[Reference::class.java] = { handleDefault(it) } + map[CollectionComprehension::class.java] = { + handleCollectionComprehension(it as CollectionComprehension) + } + map[ComprehensionExpression::class.java] = { + handleComprehensionExpression(it as ComprehensionExpression) + } map[LambdaExpression::class.java] = { handleLambdaExpression(it as LambdaExpression) } map[LookupScopeStatement::class.java] = { handleLookupScopeStatement(it as LookupScopeStatement) @@ -413,7 +419,7 @@ open class EvaluationOrderGraphPass(ctx: TranslationContext) : TranslationUnitPa // find out for java, but impossible for c++) // evaluate the call target first, optional base should be the callee or in its subtree - node.callee?.let { handleEOG(it) } + handleEOG(node.callee) // then the arguments for (arg in node.arguments) { @@ -856,7 +862,7 @@ open class EvaluationOrderGraphPass(ctx: TranslationContext) : TranslationUnitPa * Connects the current EOG leaf nodes to the last stacked node, e.g. loop head, and removes the * nodes. * - * @param loopScope the loop scope + * @param loopStatement the loop statement */ protected fun handleContainedBreaksAndContinues(loopStatement: LoopStatement) { // Breaks are connected to the NEXT EOG node and therefore temporarily stored after the loop @@ -888,11 +894,7 @@ open class EvaluationOrderGraphPass(ctx: TranslationContext) : TranslationUnitPa } /** - * Builds an EOG edge from prev to next. 'eogDirection' defines how the node instances save the - * references constituting the edge. 'FORWARD': only the nodes nextEOG member contains - * references, an points to the next nodes. 'BACKWARD': only the nodes prevEOG member contains - * references and points to the previous nodes. 'BIDIRECTIONAL': nextEOG and prevEOG contain - * references and point to the previous and the next nodes. + * Builds an EOG edge from prev to next. * * @param prev the previous node * @param next the next node @@ -943,6 +945,60 @@ open class EvaluationOrderGraphPass(ctx: TranslationContext) : TranslationUnitPa handleContainedBreaksAndContinues(node) } + private fun handleComprehensionExpression(node: ComprehensionExpression) { + handleEOG(node.iterable) + // When the iterable contains another element, the variable is evaluated with the + // nextElement. Therefore, we add a "true" edge. + nextEdgeBranch = true + handleEOG(node.variable) + handleEOG(node.predicate) + attachToEOG(node) + + // If the conditions evaluated to false, we need to retrieve the next element, therefore + // evaluating the iterable + drawEOGToEntriesOf(currentPredecessors, node.iterable, branchLabel = false) + + // If an element was found that fulfills the condition, we move forward + nextEdgeBranch = true + } + + private fun handleCollectionComprehension(node: CollectionComprehension) { + // Process the comprehension expressions from 0 to n and connect the EOG of i to i+1. + var prevComprehensionExpression: ComprehensionExpression? = null + var noMoreElementsInCollection = listOf() + node.comprehensionExpressions.forEach { + handleEOG(it) + + val noMoreElements = SubgraphWalker.getEOGPathEdges(it.iterable).exits + + // [ComprehensionExpression] yields no more elements => EOG:false + val prevComp = prevComprehensionExpression + if (prevComp == null) { + // We handle the EOG:false edges of the outermost comprehensionExpression later, + // they continue the + // path of execution when no more elements are yielded + noMoreElementsInCollection = noMoreElements + } else { + drawEOGToEntriesOf(noMoreElements, prevComp.iterable, branchLabel = false) + } + prevComprehensionExpression = it + + // [ComprehensionExpression] yields and element => EOG:true + nextEdgeBranch = true + } + + handleEOG(node.statement) + // After evaluating the statement we + node.comprehensionExpressions.last().let { + drawEOGToEntriesOf(currentPredecessors, it.iterable) + } + currentPredecessors.clear() + currentPredecessors.addAll(noMoreElementsInCollection) + nextEdgeBranch = + false // This path is followed when the comprehensions yield no more elements + attachToEOG(node) + } + protected fun handleForEachStatement(node: ForEachStatement) { handleEOG(node.iterable) handleEOG(node.variable) @@ -1172,7 +1228,7 @@ open class EvaluationOrderGraphPass(ctx: TranslationContext) : TranslationUnitPa else -> { LOGGER.error( "Currently the component {} does not have a defined loop start.", - this?.javaClass + this.javaClass ) ArrayList() } @@ -1218,4 +1274,13 @@ open class EvaluationOrderGraphPass(ctx: TranslationContext) : TranslationUnitPa else -> false } } + + fun drawEOGToEntriesOf(from: List, toEntriesOf: Node?, branchLabel: Boolean? = null) { + val tmpBranchLabel = nextEdgeBranch + branchLabel?.let { nextEdgeBranch = it } + SubgraphWalker.getEOGPathEdges(toEntriesOf).entries.forEach { entrance -> + addMultipleIncomingEOGEdges(from, entrance) + } + nextEdgeBranch = tmpBranchLabel + } } diff --git a/cpg-core/src/test/kotlin/de/fraunhofer/aisec/cpg/GraphExamples.kt b/cpg-core/src/test/kotlin/de/fraunhofer/aisec/cpg/GraphExamples.kt index db0679a64d..4ccb5af4f0 100644 --- a/cpg-core/src/test/kotlin/de/fraunhofer/aisec/cpg/GraphExamples.kt +++ b/cpg-core/src/test/kotlin/de/fraunhofer/aisec/cpg/GraphExamples.kt @@ -243,6 +243,40 @@ class GraphExamples { } } + fun getNestedComprehensionExpressions( + config: TranslationConfiguration = + TranslationConfiguration.builder() + .defaultPasses() + .registerLanguage(TestLanguage(".")) + .build() + ) = + testFrontend(config).build { + translationResult { + translationUnit("whileWithBreakAndElse.py") { + record("someRecord") { + method("func") { + body { + call("preComprehensions") + listComp { + ref("i") + compExpr { + ref("i") + ref("someIterable") + } + compExpr { + ref("j") + ref("i") + ref("j") gt literal(5, t("int")) + } + } + call("postComprehensions") + } + } + } + } + } + } + fun testFrontend(config: TranslationConfiguration): TestLanguageFrontend { val ctx = TranslationContext(config, ScopeManager(), TypeManager()) val language = config.languages.filterIsInstance().first() diff --git a/cpg-core/src/test/kotlin/de/fraunhofer/aisec/cpg/graph/ExpressionBuilderTest.kt b/cpg-core/src/test/kotlin/de/fraunhofer/aisec/cpg/graph/ExpressionBuilderTest.kt index 2756e92e7a..9ac9659d46 100644 --- a/cpg-core/src/test/kotlin/de/fraunhofer/aisec/cpg/graph/ExpressionBuilderTest.kt +++ b/cpg-core/src/test/kotlin/de/fraunhofer/aisec/cpg/graph/ExpressionBuilderTest.kt @@ -25,6 +25,7 @@ */ package de.fraunhofer.aisec.cpg.graph +import de.fraunhofer.aisec.cpg.graph.builder.plus import de.fraunhofer.aisec.cpg.graph.declarations.FieldDeclaration import de.fraunhofer.aisec.cpg.graph.edges.flows.CallingContextIn import de.fraunhofer.aisec.cpg.graph.edges.flows.ContextSensitiveDataflow diff --git a/cpg-core/src/test/kotlin/de/fraunhofer/aisec/cpg/graph/FluentTest.kt b/cpg-core/src/test/kotlin/de/fraunhofer/aisec/cpg/graph/FluentTest.kt index 28ffebd6a2..5bbb2fc77e 100644 --- a/cpg-core/src/test/kotlin/de/fraunhofer/aisec/cpg/graph/FluentTest.kt +++ b/cpg-core/src/test/kotlin/de/fraunhofer/aisec/cpg/graph/FluentTest.kt @@ -25,7 +25,9 @@ */ package de.fraunhofer.aisec.cpg.graph +import de.fraunhofer.aisec.cpg.frontends.TestLanguage import de.fraunhofer.aisec.cpg.frontends.TestLanguageFrontend +import de.fraunhofer.aisec.cpg.frontends.testFrontend import de.fraunhofer.aisec.cpg.graph.builder.* import de.fraunhofer.aisec.cpg.graph.declarations.VariableDeclaration import de.fraunhofer.aisec.cpg.graph.scopes.BlockScope @@ -35,8 +37,11 @@ import de.fraunhofer.aisec.cpg.graph.statements.DeclarationStatement import de.fraunhofer.aisec.cpg.graph.statements.IfStatement import de.fraunhofer.aisec.cpg.graph.statements.ReturnStatement import de.fraunhofer.aisec.cpg.graph.statements.expressions.* +import de.fraunhofer.aisec.cpg.graph.statements.expressions.CollectionComprehension +import de.fraunhofer.aisec.cpg.passes.ControlDependenceGraphPass import de.fraunhofer.aisec.cpg.passes.EvaluationOrderGraphPass import de.fraunhofer.aisec.cpg.passes.ImportResolver +import de.fraunhofer.aisec.cpg.passes.ProgramDependenceGraphPass import de.fraunhofer.aisec.cpg.passes.SymbolResolver import de.fraunhofer.aisec.cpg.test.* import kotlin.test.* @@ -179,4 +184,169 @@ class FluentTest { assertRefersTo(ref, variable) assertFullName("SomeClass::func", mce) } + + @Test + fun testCollectionComprehensions() { + val result = + testFrontend { + it.registerLanguage(TestLanguage(".")) + it.defaultPasses() + it.registerPass() + it.registerPass() + } + .build { + translationResult { + translationUnit("File") { + function("main", t("list")) { + param("argc", t("int")) + body { + declare { + variable("some") { + listComp { + ref("i") + compExpr { + ref("i") + ref("someIterable") + ref("i") gt literal(5, t("int")) + } + } + } + } + + returnStmt { ref("some") } + } + } + } + } + } + + val listComp = result.variables["some"]?.initializer + assertIs(listComp) + print(listComp.toString()) // This is only here to get a better test coverage + print( + listComp.comprehensionExpressions.firstOrNull()?.toString() + ) // This is only here to get a better test coverage + assertIs(listComp.statement) + assertLocalName("i", listComp.statement) + assertEquals(1, listComp.comprehensionExpressions.size) + val compExpr = listComp.comprehensionExpressions.single() + assertIs(compExpr) + assertIs(compExpr.variable) + assertLocalName("i", compExpr.variable) + assertIs(compExpr.iterable) + assertLocalName("someIterable", compExpr.iterable) + assertNotNull(compExpr.predicate) + } + + @Test + fun testCollectionComprehensionsWithDeclaration() { + val result = + testFrontend { + it.registerLanguage(TestLanguage(".")) + it.defaultPasses() + it.registerPass() + it.registerPass() + } + .build { + translationResult { + translationUnit("File") { + function("main", t("list")) { + param("argc", t("int")) + body { + declare { + variable("some") { + listComp { + ref("i") + compExpr { + this.variable = declare { variable("i") } + ref("someIterable") + ref("i") gt literal(5, t("int")) + } + } + } + } + + returnStmt { ref("some") } + } + } + } + } + } + + val listComp = result.variables["some"]?.initializer + assertIs(listComp) + print(listComp.toString()) // This is only here to get a better test coverage + print( + listComp.comprehensionExpressions.firstOrNull()?.toString() + ) // This is only here to get a better test coverage + assertIs(listComp.statement) + assertLocalName("i", listComp.statement) + assertEquals(1, listComp.comprehensionExpressions.size) + val compExpr = listComp.comprehensionExpressions.single() + assertIs(compExpr) + val variableDecl = compExpr.variable + assertIs(variableDecl) + assertLocalName("i", variableDecl.singleDeclaration) + assertIs(compExpr.iterable) + assertLocalName("someIterable", compExpr.iterable) + assertNotNull(compExpr.predicate) + } + + @Test + fun testCollectionComprehensionsWithTwoDeclarations() { + val result = + testFrontend { + it.registerLanguage(TestLanguage(".")) + it.defaultPasses() + it.registerPass() + it.registerPass() + } + .build { + translationResult { + translationUnit("File") { + function("main", t("list")) { + param("argc", t("int")) + body { + declare { + variable("some") { + listComp { + ref("i") + compExpr { + this.variable = declare { + variable("i") + variable("y") + } + ref("someIterable") + ref("i") gt literal(5, t("int")) + } + } + } + } + + returnStmt { ref("some") } + } + } + } + } + } + + val listComp = result.variables["some"]?.initializer + assertIs(listComp) + print(listComp.toString()) // This is only here to get a better test coverage + print( + listComp.comprehensionExpressions.firstOrNull()?.toString() + ) // This is only here to get a better test coverage + assertIs(listComp.statement) + assertLocalName("i", listComp.statement) + assertEquals(1, listComp.comprehensionExpressions.size) + val compExpr = listComp.comprehensionExpressions.single() + assertIs(compExpr) + val variableDecl = compExpr.variable + assertIs(variableDecl) + assertLocalName("i", variableDecl.declarations[0]) + assertLocalName("y", variableDecl.declarations[1]) + assertIs(compExpr.iterable) + assertLocalName("someIterable", compExpr.iterable) + assertNotNull(compExpr.predicate) + } } diff --git a/cpg-core/src/test/kotlin/de/fraunhofer/aisec/cpg/graph/edges/collections/UnwrappedEdgeListTest.kt b/cpg-core/src/test/kotlin/de/fraunhofer/aisec/cpg/graph/edges/collections/UnwrappedEdgeListTest.kt index 2d318ae5a6..9723546e78 100644 --- a/cpg-core/src/test/kotlin/de/fraunhofer/aisec/cpg/graph/edges/collections/UnwrappedEdgeListTest.kt +++ b/cpg-core/src/test/kotlin/de/fraunhofer/aisec/cpg/graph/edges/collections/UnwrappedEdgeListTest.kt @@ -53,6 +53,9 @@ class UnwrappedEdgeListTest { assertEquals(1, node2.prevEOGEdges.size) assertEquals(1, node3.prevEOGEdges.size) assertEquals(1, node3.prevEOG.size) + + assertEquals(listOf(node2, node3), node1.nextEOG.subList(0, 2)) + assertEquals(listOf(node1), node3.prevEOG.subList(0, 1)) } } diff --git a/cpg-core/src/test/kotlin/de/fraunhofer/aisec/cpg/passes/EvaluationOrderGraphPassTest.kt b/cpg-core/src/test/kotlin/de/fraunhofer/aisec/cpg/passes/EvaluationOrderGraphPassTest.kt index ba200df7eb..ae6ff36a0a 100644 --- a/cpg-core/src/test/kotlin/de/fraunhofer/aisec/cpg/passes/EvaluationOrderGraphPassTest.kt +++ b/cpg-core/src/test/kotlin/de/fraunhofer/aisec/cpg/passes/EvaluationOrderGraphPassTest.kt @@ -27,6 +27,7 @@ package de.fraunhofer.aisec.cpg.passes import de.fraunhofer.aisec.cpg.GraphExamples import de.fraunhofer.aisec.cpg.graph.* +import de.fraunhofer.aisec.cpg.graph.statements.expressions.CollectionComprehension import de.fraunhofer.aisec.cpg.helpers.Util import kotlin.test.Test import kotlin.test.assertNotNull @@ -217,4 +218,133 @@ class EvaluationOrderGraphPassTest { cr = Util.Connect.SUBTREE ) } + + @Test + fun testCollectionComprehensionStatement() { + val compExample = GraphExamples.getNestedComprehensionExpressions() + + val listComp = compExample.allChildren().first() + assertNotNull(listComp) + + val preCall = compExample.calls["preComprehensions"] + assertNotNull(preCall) + + val postCall = compExample.calls["postComprehensions"] + assertNotNull(postCall) + + assertTrue { listComp.comprehensionExpressions.size == 2 } + + val outerComprehensionExpression = listComp.comprehensionExpressions.first() + assertNotNull(outerComprehensionExpression) + + val innerComprehensionExpression = listComp.comprehensionExpressions.last() + assertNotNull(innerComprehensionExpression) + + assertTrue( + Util.eogConnect( + en = Util.Edge.EXITS, + n = preCall, + refs = listOf(listComp), + cr = Util.Connect.SUBTREE + ) + ) + assertTrue( + Util.eogConnect( + en = Util.Edge.EXITS, + n = listComp, + refs = listOf(postCall), + cr = Util.Connect.SUBTREE + ) + ) + assertTrue( + Util.eogConnect( + en = Util.Edge.EXITS, + n = outerComprehensionExpression, + refs = + listOf( + innerComprehensionExpression, + listComp, + outerComprehensionExpression.variable + ), + cr = Util.Connect.SUBTREE + ) + ) + assertTrue( + Util.eogConnect( + q = Util.Quantifier.ANY, + en = Util.Edge.EXITS, + n = outerComprehensionExpression, + refs = + listOf( + innerComprehensionExpression, + ), + cr = Util.Connect.SUBTREE, + predicate = { it.branch == true } + ) + ) + + assertTrue( + Util.eogConnect( + q = Util.Quantifier.ANY, + en = Util.Edge.EXITS, + n = outerComprehensionExpression, + refs = listOf(listComp), + cr = Util.Connect.SUBTREE, + predicate = { it.branch == false } + ) + ) + + assertTrue( + Util.eogConnect( + en = Util.Edge.EXITS, + n = innerComprehensionExpression, + refs = listOf(outerComprehensionExpression, listComp.statement), + cr = Util.Connect.SUBTREE + ) + ) + + assertTrue( + Util.eogConnect( + q = Util.Quantifier.ANY, + en = Util.Edge.EXITS, + n = innerComprehensionExpression, + refs = listOf(listComp.statement), + cr = Util.Connect.SUBTREE, + predicate = { it.branch == true } + ) + ) + + assertTrue( + Util.eogConnect( + q = Util.Quantifier.ANY, + en = Util.Edge.EXITS, + n = innerComprehensionExpression, + refs = listOf(outerComprehensionExpression), + cr = Util.Connect.SUBTREE, + predicate = { it.branch == false } + ) + ) + + assertTrue( + Util.eogConnect( + q = Util.Quantifier.ANY, + en = Util.Edge.EXITS, + n = outerComprehensionExpression.iterable, + refs = listOf(outerComprehensionExpression.variable), + cr = Util.Connect.SUBTREE, + predicate = { it.branch == true } + ) + ) + + assertTrue( + Util.eogConnect( + q = Util.Quantifier.ANY, + en = Util.Edge.EXITS, + n = innerComprehensionExpression.iterable, + refs = listOf(innerComprehensionExpression.variable), + cr = Util.Connect.SUBTREE, + predicate = { it.branch == true } + ) + ) + } } diff --git a/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/ExpressionHandler.kt b/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/ExpressionHandler.kt index 951d01cf62..20dd22abb3 100644 --- a/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/ExpressionHandler.kt +++ b/cpg-language-python/src/main/kotlin/de/fraunhofer/aisec/cpg/frontends/python/ExpressionHandler.kt @@ -29,6 +29,7 @@ import de.fraunhofer.aisec.cpg.graph.* import de.fraunhofer.aisec.cpg.graph.declarations.ImportDeclaration import de.fraunhofer.aisec.cpg.graph.declarations.MethodDeclaration import de.fraunhofer.aisec.cpg.graph.statements.expressions.* +import de.fraunhofer.aisec.cpg.graph.statements.expressions.CollectionComprehension import jep.python.PyObject class ExpressionHandler(frontend: PythonLanguageFrontend) : @@ -69,10 +70,10 @@ class ExpressionHandler(frontend: PythonLanguageFrontend) : is Python.AST.JoinedStr -> handleJoinedStr(node) is Python.AST.Starred -> handleStarred(node) is Python.AST.NamedExpr -> handleNamedExpr(node) - is Python.AST.GeneratorExp, - is Python.AST.ListComp, - is Python.AST.SetComp, - is Python.AST.DictComp, + is Python.AST.ListComp -> handleListComprehension(node) + is Python.AST.SetComp -> handleSetComprehension(node) + is Python.AST.DictComp -> handleDictComprehension(node) + is Python.AST.GeneratorExp -> handleGeneratorExp(node) is Python.AST.Await, is Python.AST.Yield, is Python.AST.YieldFrom -> @@ -82,6 +83,86 @@ class ExpressionHandler(frontend: PythonLanguageFrontend) : ) } } + /** + * Translates a Python + * [`comprehension`](https://docs.python.org/3/library/ast.html#ast.comprehension) into a + * [ComprehensionExpression]. + * + * Connects multiple predicates by `and`. + */ + private fun handleComprehension(node: Python.AST.comprehension): ComprehensionExpression { + return newComprehensionExpression(rawNode = node).apply { + variable = handle(node.target) + iterable = handle(node.iter) + val predicates = node.ifs.map { handle(it) } + if (predicates.size == 1) { + predicate = predicates.single() + } else if (predicates.size > 1) { + predicate = + joinListWithBinOp(operatorCode = "and", nodes = predicates, rawNode = node) + } + if (node.is_async != 0L) + additionalProblems += + newProblemExpression( + "Node marked as is_async but we don't support this yet", + rawNode = node + ) + } + } + + /** + * Translates a Python + * [`GeneratorExp`](https://docs.python.org/3/library/ast.html#ast.GeneratorExp) into a + * [CollectionComprehension]. + */ + private fun handleGeneratorExp(node: Python.AST.GeneratorExp): CollectionComprehension { + return newCollectionComprehension(rawNode = node).apply { + statement = handle(node.elt) + comprehensionExpressions += node.generators.map { handleComprehension(it) } + type = objectType("Generator") + } + } + + /** + * Translates a Python [`ListComp`](https://docs.python.org/3/library/ast.html#ast.ListComp) + * into a [CollectionComprehension]. + */ + private fun handleListComprehension(node: Python.AST.ListComp): CollectionComprehension { + return newCollectionComprehension(rawNode = node).apply { + statement = handle(node.elt) + comprehensionExpressions += node.generators.map { handleComprehension(it) } + type = objectType("list") // TODO: Replace this once we have dedicated types + } + } + + /** + * Translates a Python [`SetComp`](https://docs.python.org/3/library/ast.html#ast.SetComp) into + * a [CollectionComprehension]. + */ + private fun handleSetComprehension(node: Python.AST.SetComp): CollectionComprehension { + return newCollectionComprehension(rawNode = node).apply { + this.statement = handle(node.elt) + this.comprehensionExpressions += node.generators.map { handleComprehension(it) } + this.type = objectType("set") // TODO: Replace this once we have dedicated types + } + } + + /** + * Translates a Python [`DictComp`](https://docs.python.org/3/library/ast.html#ast.DictComp) + * into a [CollectionComprehension]. + */ + private fun handleDictComprehension(node: Python.AST.DictComp): CollectionComprehension { + return newCollectionComprehension(rawNode = node).apply { + this.statement = + newKeyValueExpression( + key = handle(node.key), + value = handle(node.value), + rawNode = node + ) + this.comprehensionExpressions += node.generators.map { handleComprehension(it) } + this.type = objectType("dict") // TODO: Replace this once we have dedicated types + } + } /** * Translates a Python [`NamedExpr`](https://docs.python.org/3/library/ast.html#ast.NamedExpr) @@ -147,15 +228,28 @@ class ExpressionHandler(frontend: PythonLanguageFrontend) : } else if (values.size == 1) { values.first() } else { - val lastTwo = newBinaryOperator("+", rawNode = node) - lastTwo.rhs = values.last() - lastTwo.lhs = values[values.size - 2] - values.subList(0, values.size - 2).foldRight(lastTwo) { newVal, start -> - val nextValue = newBinaryOperator("+") - nextValue.rhs = start - nextValue.lhs = newVal - nextValue - } + joinListWithBinOp(operatorCode = "+", nodes = values, rawNode = node) + } + } + + /** + * Joins the [nodes] with a [BinaryOperator] with the [operatorCode]. Nests the whole thing, + * where the first element in [nodes] is the lhs of the root of the tree of binary operators. + * The last operands are further down the tree. + */ + private fun joinListWithBinOp( + operatorCode: String, + nodes: List, + rawNode: Python.AST.AST? = null + ): BinaryOperator { + val lastTwo = newBinaryOperator(operatorCode, rawNode = rawNode) + lastTwo.rhs = nodes.last() + lastTwo.lhs = nodes[nodes.size - 2] + return nodes.subList(0, nodes.size - 2).foldRight(lastTwo) { newVal, start -> + val nextValue = newBinaryOperator(operatorCode) + nextValue.rhs = start + nextValue.lhs = newVal + nextValue } } @@ -405,7 +499,7 @@ class ExpressionHandler(frontend: PythonLanguageFrontend) : frontend.scopeManager.currentScope ?.lookupSymbol(name.localName, replaceImports = false) ?.filterIsInstance() - return decl?.isNotEmpty() ?: false + return decl?.isNotEmpty() == true } private fun handleName(node: Python.AST.Name): Expression { diff --git a/cpg-language-python/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/python/ExpressionHandlerTest.kt b/cpg-language-python/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/python/ExpressionHandlerTest.kt index 51da0d504d..2f81c21425 100644 --- a/cpg-language-python/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/python/ExpressionHandlerTest.kt +++ b/cpg-language-python/src/test/kotlin/de/fraunhofer/aisec/cpg/frontends/python/ExpressionHandlerTest.kt @@ -26,7 +26,13 @@ package de.fraunhofer.aisec.cpg.frontends.python import de.fraunhofer.aisec.cpg.graph.* +import de.fraunhofer.aisec.cpg.graph.statements.expressions.AssignExpression import de.fraunhofer.aisec.cpg.graph.statements.expressions.BinaryOperator +import de.fraunhofer.aisec.cpg.graph.statements.expressions.Block +import de.fraunhofer.aisec.cpg.graph.statements.expressions.CallExpression +import de.fraunhofer.aisec.cpg.graph.statements.expressions.CollectionComprehension +import de.fraunhofer.aisec.cpg.graph.statements.expressions.KeyValueExpression +import de.fraunhofer.aisec.cpg.graph.statements.expressions.Reference import de.fraunhofer.aisec.cpg.test.analyze import de.fraunhofer.aisec.cpg.test.assertLiteralValue import de.fraunhofer.aisec.cpg.test.assertLocalName @@ -34,6 +40,235 @@ import java.nio.file.Path import kotlin.test.* class ExpressionHandlerTest { + @Test + fun testListComprehensions() { + val topLevel = Path.of("src", "test", "resources", "python") + val result = + analyze(listOf(topLevel.resolve("comprehension.py").toFile()), topLevel, true) { + it.registerLanguage() + } + assertNotNull(result) + val listComp = result.functions["listComp"] + assertNotNull(listComp) + + val body = listComp.body + assertIs(body) + val singleWithIfAssignment = body.statements[0] + assertIs(singleWithIfAssignment) + val singleWithIf = singleWithIfAssignment.rhs[0] + assertIs(singleWithIf) + assertIs(singleWithIf.statement) + assertEquals(1, singleWithIf.comprehensionExpressions.size) + assertLocalName("i", singleWithIf.comprehensionExpressions[0].variable) + assertIs(singleWithIf.comprehensionExpressions[0].iterable) + assertLocalName("x", singleWithIf.comprehensionExpressions[0].iterable) + val ifPredicate = singleWithIf.comprehensionExpressions[0].predicate + assertIs(ifPredicate) + assertEquals("==", ifPredicate.operatorCode) + + val singleWithoutIfAssignment = body.statements[1] + assertIs(singleWithoutIfAssignment) + val singleWithoutIf = singleWithoutIfAssignment.rhs[0] + assertIs(singleWithoutIf) + assertIs(singleWithoutIf.statement) + assertEquals(1, singleWithoutIf.comprehensionExpressions.size) + assertLocalName("i", singleWithoutIf.comprehensionExpressions[0].variable) + assertIs(singleWithoutIf.comprehensionExpressions[0].iterable) + assertLocalName("x", singleWithoutIf.comprehensionExpressions[0].iterable) + assertNull(singleWithoutIf.comprehensionExpressions[0].predicate) + + val singleWithDoubleIfAssignment = body.statements[2] + assertIs(singleWithDoubleIfAssignment) + val singleWithDoubleIf = singleWithDoubleIfAssignment.rhs[0] + assertIs(singleWithDoubleIf) + assertIs(singleWithDoubleIf.statement) + assertEquals(1, singleWithDoubleIf.comprehensionExpressions.size) + assertLocalName("i", singleWithDoubleIf.comprehensionExpressions[0].variable) + assertIs(singleWithDoubleIf.comprehensionExpressions[0].iterable) + assertLocalName("x", singleWithDoubleIf.comprehensionExpressions[0].iterable) + val doubleIfPredicate = singleWithDoubleIf.comprehensionExpressions[0].predicate + assertIs(doubleIfPredicate) + assertEquals("and", doubleIfPredicate.operatorCode) + + val doubleAssignment = body.statements[3] as? AssignExpression + assertIs(doubleAssignment) + val double = doubleAssignment.rhs[0] as? CollectionComprehension + assertNotNull(double) + assertIs(double.statement) + assertEquals(2, double.comprehensionExpressions.size) + // TODO: Add tests on the comprehension expressions + } + + @Test + fun testSetComprehensions() { + val topLevel = Path.of("src", "test", "resources", "python") + val result = + analyze(listOf(topLevel.resolve("comprehension.py").toFile()), topLevel, true) { + it.registerLanguage() + } + assertNotNull(result) + val listComp = result.functions["setComp"] + assertNotNull(listComp) + + val body = listComp.body as? Block + assertNotNull(body) + val singleWithIfAssignment = body.statements[0] + assertIs(singleWithIfAssignment) + val singleWithIf = singleWithIfAssignment.rhs[0] + assertIs(singleWithIf) + assertIs(singleWithIf.statement) + assertEquals(1, singleWithIf.comprehensionExpressions.size) + assertLocalName("i", singleWithIf.comprehensionExpressions[0].variable) + assertIs(singleWithIf.comprehensionExpressions[0].iterable) + assertLocalName("x", singleWithIf.comprehensionExpressions[0].iterable) + val ifPredicate = singleWithIf.comprehensionExpressions[0].predicate + assertIs(ifPredicate) + assertEquals("==", ifPredicate.operatorCode) + + val singleWithoutIfAssignment = body.statements[1] + assertIs(singleWithoutIfAssignment) + val singleWithoutIf = singleWithoutIfAssignment.rhs[0] + assertIs(singleWithoutIf) + assertIs(singleWithoutIf.statement) + assertEquals(1, singleWithoutIf.comprehensionExpressions.size) + assertLocalName("i", singleWithoutIf.comprehensionExpressions[0].variable) + assertIs(singleWithoutIf.comprehensionExpressions[0].iterable) + assertLocalName("x", singleWithoutIf.comprehensionExpressions[0].iterable) + assertNull(singleWithoutIf.comprehensionExpressions[0].predicate) + + val singleWithDoubleIfAssignment = body.statements[2] + assertIs(singleWithDoubleIfAssignment) + val singleWithDoubleIf = singleWithDoubleIfAssignment.rhs[0] + assertIs(singleWithDoubleIf) + assertIs(singleWithDoubleIf.statement) + assertEquals(1, singleWithDoubleIf.comprehensionExpressions.size) + assertLocalName("i", singleWithDoubleIf.comprehensionExpressions[0].variable) + assertIs(singleWithDoubleIf.comprehensionExpressions[0].iterable) + assertLocalName("x", singleWithDoubleIf.comprehensionExpressions[0].iterable) + val doubleIfPredicate = singleWithDoubleIf.comprehensionExpressions[0].predicate + assertIs(doubleIfPredicate) + assertEquals("and", doubleIfPredicate.operatorCode) + + val doubleAssignment = body.statements[3] + assertIs(doubleAssignment) + val double = doubleAssignment.rhs[0] + assertIs(double) + assertIs(double.statement) + assertEquals(2, double.comprehensionExpressions.size) + } + + @Test + fun testDictComprehensions() { + val topLevel = Path.of("src", "test", "resources", "python") + val result = + analyze(listOf(topLevel.resolve("comprehension.py").toFile()), topLevel, true) { + it.registerLanguage() + } + assertNotNull(result) + val listComp = result.functions["dictComp"] + assertNotNull(listComp) + + val body = listComp.body as? Block + assertNotNull(body) + val singleWithIfAssignment = body.statements[0] + assertIs(singleWithIfAssignment) + val singleWithIf = singleWithIfAssignment.rhs[0] + assertIs(singleWithIf) + var statement = singleWithIf.statement + assertIs(statement) + assertIs(statement.key) + assertLocalName("i", statement.key) + assertIs(statement.value) + assertEquals(1, singleWithIf.comprehensionExpressions.size) + assertLocalName("i", singleWithIf.comprehensionExpressions[0].variable) + assertIs(singleWithIf.comprehensionExpressions[0].iterable) + assertLocalName("x", singleWithIf.comprehensionExpressions[0].iterable) + val ifPredicate = singleWithIf.comprehensionExpressions[0].predicate + assertIs(ifPredicate) + assertEquals("==", ifPredicate.operatorCode) + + val singleWithoutIfAssignment = body.statements[1] + assertIs(singleWithoutIfAssignment) + val singleWithoutIf = singleWithoutIfAssignment.rhs[0] + assertIs(singleWithoutIf) + statement = singleWithIf.statement + assertIs(statement) + assertIs(statement.key) + assertLocalName("i", statement.key) + assertIs(statement.value) + assertEquals(1, singleWithoutIf.comprehensionExpressions.size) + assertLocalName("i", singleWithoutIf.comprehensionExpressions[0].variable) + assertIs(singleWithoutIf.comprehensionExpressions[0].iterable) + assertLocalName("x", singleWithoutIf.comprehensionExpressions[0].iterable) + assertNull(singleWithoutIf.comprehensionExpressions[0].predicate) + + val singleWithDoubleIfAssignment = body.statements[2] + assertIs(singleWithDoubleIfAssignment) + val singleWithDoubleIf = singleWithDoubleIfAssignment.rhs[0] + assertIs(singleWithDoubleIf) + statement = singleWithIf.statement + assertIs(statement) + assertIs(statement.key) + assertLocalName("i", statement.key) + assertIs(statement.value) + assertEquals(1, singleWithDoubleIf.comprehensionExpressions.size) + assertLocalName("i", singleWithDoubleIf.comprehensionExpressions[0].variable) + assertIs(singleWithDoubleIf.comprehensionExpressions[0].iterable) + assertLocalName("x", singleWithDoubleIf.comprehensionExpressions[0].iterable) + val doubleIfPredicate = singleWithDoubleIf.comprehensionExpressions[0].predicate + assertIs(doubleIfPredicate) + assertEquals("and", doubleIfPredicate.operatorCode) + + val doubleAssignment = body.statements[3] as? AssignExpression + assertIs(doubleAssignment) + val double = doubleAssignment.rhs[0] as? CollectionComprehension + assertNotNull(double) + statement = singleWithIf.statement + assertIs(statement) + assertIs(statement.key) + assertLocalName("i", statement.key) + assertIs(statement.value) + assertEquals(2, double.comprehensionExpressions.size) + } + + @Test + fun testGeneratorExpr() { + val topLevel = Path.of("src", "test", "resources", "python") + val result = + analyze(listOf(topLevel.resolve("comprehension.py").toFile()), topLevel, true) { + it.registerLanguage() + } + assertNotNull(result) + val listComp = result.functions["generator"] + assertNotNull(listComp) + + val body = listComp.body as? Block + assertNotNull(body) + val singleWithIfAssignment = body.statements[0] + assertIs(singleWithIfAssignment) + val singleWithIf = singleWithIfAssignment.rhs[0] + assertIs(singleWithIf) + assertIs(singleWithIf.statement) + assertEquals(1, singleWithIf.comprehensionExpressions.size) + assertLocalName("i", singleWithIf.comprehensionExpressions[0].variable) + assertIs(singleWithIf.comprehensionExpressions[0].iterable) + assertLocalName("range", singleWithIf.comprehensionExpressions[0].iterable) + val ifPredicate = singleWithIf.comprehensionExpressions[0].predicate + assertIs(ifPredicate) + assertEquals("==", ifPredicate.operatorCode) + + val singleWithoutIfAssignment = body.statements[1] + assertIs(singleWithoutIfAssignment) + val singleWithoutIf = singleWithoutIfAssignment.rhs[0] + assertIs(singleWithoutIf) + assertIs(singleWithoutIf.statement) + assertEquals(1, singleWithoutIf.comprehensionExpressions.size) + assertLocalName("i", singleWithoutIf.comprehensionExpressions[0].variable) + assertIs(singleWithIf.comprehensionExpressions[0].iterable) + assertLocalName("range", singleWithIf.comprehensionExpressions[0].iterable) + assertNull(singleWithoutIf.comprehensionExpressions[0].predicate) + } + @Test fun testBoolOps() { val topLevel = Path.of("src", "test", "resources", "python") diff --git a/cpg-language-python/src/test/resources/python/comprehension.py b/cpg-language-python/src/test/resources/python/comprehension.py new file mode 100644 index 0000000000..7d6f2568ad --- /dev/null +++ b/cpg-language-python/src/test/resources/python/comprehension.py @@ -0,0 +1,24 @@ +def foo(arg): + return 7 + +def listComp(x, y): + a = [foo(i) for i in x if i == 10] + b = [foo(i) for i in x] + c = {foo(i) for i in x if i == 10 if i < 20} + d = [foo(i) for z in y if z in x for i in z if i == 10 ] + +def setComp(x, y): + a = {foo(i) for i in x if i == 10} + b = {foo(i) for i in x} + c = {foo(i) for i in x if i == 10 if i < 20} + d = {foo(i) for z in y if z in x for i in z if i == 10 } + +def dictComp(x, y): + a = {i: foo(i) for i in x if i == 10} + b = {i: foo(i) for i in x} + c = {i: foo(i) for i in x if i == 10 if i < 20} + d = {i: foo(i) for z in y if z in x for i in z if i == 10 } + +def generator(x, y): + a = (i**2 for i in range(10) if i == 10) + b = (i**2 for i in range(10)) \ No newline at end of file diff --git a/docs/docs/CPG/specs/dfg.md b/docs/docs/CPG/specs/dfg.md index d5db7894a7..00b25e56fe 100755 --- a/docs/docs/CPG/specs/dfg.md +++ b/docs/docs/CPG/specs/dfg.md @@ -177,6 +177,44 @@ Scheme: arrayExpression -.- node; ``` +## CollectionComprehension + +Interesting fields: + +* `comprehensionExpressions: List`: The list of expressions which are iterated over. +* `statement: Statement`: The statement which returns the data. + +The data of `comprehensionExpressions[i]` flow to `comprehensionExpressions[i+1]` and the last item in `comprehensionExpressions` flows to `statement`. + +Scheme: +```mermaid + flowchart LR + comp["for all 0 <= i < comprehensionExpressions-1: comprehensionExpressions[i]"] -- DFG --> comp1["comprehensionExpressions[i+1]"] -- DFG --> stmt["statement"] -- DFG --> node([CollectionComprehension]); + node -.- comp; + node -.- comp1; + node -.- stmt; +``` + +## CollectionComprehension.ComprehensionExpression + +Interesting fields: + +* `predicates: List`: A list of conditions which have to hold to process the variable in the result. +* `iterable: Statement`: The statement which iterates over something. +* `variable: Statement`: The variable which holds the individual elements in the iterable. + +The data of `iterable` flow to `variable` which flows to the whole node. Also, all `predicates` flow to the whole node. + +Scheme: +```mermaid + flowchart LR + pred["for all i: predicates[i]"] -- DFG --> stmt["statement"] -- DFG --> node([CollectionComprehension.ComprehensionExpression]); + iterable["iterable"] -- DFG --> var["variable"]; + var -- DFG --> node; + node -.- pred; + node -.- var; + node -.- iterable; +``` ## ConditionalExpression diff --git a/docs/docs/CPG/specs/eog.md b/docs/docs/CPG/specs/eog.md index d50463781e..84f341767b 100644 --- a/docs/docs/CPG/specs/eog.md +++ b/docs/docs/CPG/specs/eog.md @@ -645,6 +645,55 @@ flowchart LR parent -.-> child3 ``` +## CollectionComprehension +This node iterates through a collection of elements via `comprehensionExpression` and applies `statement` to the elements. + +Interesting fields: + +* `comprehensionExpressions: List`: The part which iterates through all elements of the collection and filter them. +* `statement: Statement`: The operation applied to each element iterated over. + +Scheme: +```mermaid +flowchart LR + classDef outer fill:#fff,stroke:#ddd,stroke-dasharray:5 5; + prev:::outer --EOG--> child1["comprehensionExpressions[0]"] + child1 --EOG:true--> child2["comprehensionExpressions[n]"] + child2 --EOG:true--> child3["statement"] + child2 --EOG:false--> child1["comprehensionExpressions[0]"] + child1 --EOG:false--> parent(["CollectionComprehension"]) + child3 --EOG--> child2 + parent --EOG--> next:::outer + parent -.-> child3 + parent -.-> child2 + parent -.-> child1 +``` + +## ComprehensionExpression +This node iterates through a collection of elements of `iterable`, keeps the element in `variable` and evaluates an optional `predicate`. + +Interesting fields: + +* `iterable: Statement`: The part which iterates through all elements of the collection (or similar). +* `variable: Statement`: The variable holding each element in the iterable. +* `predicate: Statement`: A condition which determines if we consider this variable further or if we fetch the next element. + +Scheme: +```mermaid +flowchart LR + classDef outer fill:#fff,stroke:#ddd,stroke-dasharray:5 5; + prev:::outer --EOG--> child1["iterable"] + child1 --EOG:true--> child2["variable"] + child2 --EOG--> child3["predicate"] + child3 --EOG--> parent(["ComprehensionExpression"]) + parent --EOG:true--> enter:::outer + parent --EOG:false--> child1 + child1 --EOG:false--> exit:::outer + parent -.-> child3 + parent -.-> child2 + parent -.-> child1 +``` + ## WhileStatement This is a classic while loop where the condition is evaluated before every loop iteration.