From 4d58f8fa34cd9ac2e339ac246dcd3d8235a1d2ec Mon Sep 17 00:00:00 2001 From: David Baker Effendi Date: Mon, 23 Oct 2023 15:55:45 +0200 Subject: [PATCH] [pysrc2cpg] Model Field-like Behaviour of Module Variables (#3750) * [pysrc2cpg] Model Field-like Behaviour of Module Variables The motivation for this PR is that module-level variables in Python can be accessed by other files by some implicit global scope. This behaviour is more akin to fields rather than identifiers, and should be modelled closer to that direction. Given the caveat that the data-flow engine sees fields a bit differently, especially if the flow goes across closures, this aims to be the least intrusive change to achieve this result. ### Main changes * Create a module-level identifier assigned to the module's type ref to be used as the base of module-level member field accesses. * At module-level identifier STOREs, create a block that writes to the module field, then aliases the result to the identifier, e.g., `x = 1` becomes ```python x = { .x = 1 tmp = .x tmp } ``` * Added helper methods, e.g., `isModuleContext` to `ContextStack` * Separated `VariableReference` class into `Reference` trait shared between `VariableReference` and `FieldReference` * During `createMemberLinks`, `FieldReference` calls will have a `REF` edge to their referencing member. TODO: Recover these `REF` edges during the type recovery pass. * ## Misc Changes * Fix `pysrc2cpg/ImportResolver` bug when an `__init__.py` entity is imported from a non-sibling module. * Added shortcut in `PythonTypeRecovery.visitStatementsInBlock` to detect these strong-update blocks and "re-sugar" them to skip propagating the type via member * Removed Python slicer tests, I would rather be intentional in how I support Python usage slicing at a later stage * Added member traversal handling for recovered base types * Fixed JSSRC import * Added test for data-flow, found some bug about how starting points are handled for members * Removed desugaring being part of the flows so that the flows are the same as before and give no fluctuations * Fixed assignment code from being `..x` --- .../language/ExtendedCfgNode.scala | 18 ++- .../queryengine/SourcesToStartingPoints.scala | 15 ++- .../passes/JavaScriptTypeRecovery.scala | 2 + .../io/joern/pysrc2cpg/ContextStack.scala | 98 +++++++++----- .../joern/pysrc2cpg/ImportResolverPass.scala | 17 ++- .../io/joern/pysrc2cpg/NodeBuilder.scala | 8 +- .../io/joern/pysrc2cpg/PythonAstVisitor.scala | 32 ++++- .../pysrc2cpg/PythonAstVisitorHelpers.scala | 69 ++++++++-- .../joern/pysrc2cpg/PythonTypeRecovery.scala | 30 ++++- .../joern/pysrc2cpg/cpg/AssignCpgTests.scala | 21 +-- .../io/joern/pysrc2cpg/cpg/CallCpgTests.scala | 3 +- .../cpg/ModuleFunctionCpgTests.scala | 10 +- .../cpg/VariableReferencingCpgTests.scala | 12 +- .../pysrc2cpg/dataflow/DataFlowTests.scala | 30 ++++- .../passes/TypeRecoveryPassTests.scala | 20 ++- .../pysrc2cpg/slicing/PyUsageSliceTests.scala | 51 ------- .../x2cpg/passes/frontend/XTypeRecovery.scala | 127 ++++++++++++------ 17 files changed, 381 insertions(+), 182 deletions(-) delete mode 100644 joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/slicing/PyUsageSliceTests.scala diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/ExtendedCfgNode.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/ExtendedCfgNode.scala index 70fb3db99e5c..525173f71a10 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/ExtendedCfgNode.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/ExtendedCfgNode.scala @@ -37,6 +37,18 @@ class ExtendedCfgNode(val traversal: Iterator[CfgNode]) extends AnyVal { reachedSources.cast[NodeType] } + /** Trims out path elements that represent module-level variable assignment desugaring + */ + private def trimModuleVariableDesugarBlock(path: Vector[PathElement]): Vector[PathElement] = { + val moduleKeywords = Seq("", ":program") + val codeMatcher = s"(${moduleKeywords.mkString("|")}).*" + path.filterNot { x => + x.node match + case block: Block => block.astChildren.code(codeMatcher).nonEmpty + case astNode => astNode.inAssignment.code(codeMatcher).nonEmpty + } + } + def reachableByFlows[A](sourceTrav: IterableOnce[A], sourceTravs: IterableOnce[A]*)(implicit context: EngineContext ): Iterator[Path] = { @@ -51,8 +63,10 @@ class ExtendedCfgNode(val traversal: Iterator[CfgNode]) extends AnyVal { if (first.isDefined && !first.get.visible && !startingPoints.contains(first.get.node)) { None } else { - val visiblePathElements = result.path.filter(x => startingPoints.contains(x.node) || x.visible) - Some(Path(removeConsecutiveDuplicates(visiblePathElements.map(_.node)))) + val visiblePathElements = result.path.filter(x => startingPoints.contains(x.node) || x.visible) + val trimmedOfModuleDesugaring = trimModuleVariableDesugarBlock(visiblePathElements) + val deDuplicated = removeConsecutiveDuplicates(trimmedOfModuleDesugaring.map(_.node)) + Some(Path(deDuplicated)) } } .filter(_.isDefined) diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/SourcesToStartingPoints.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/SourcesToStartingPoints.scala index e2e855abd3fa..6eca82e8f122 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/SourcesToStartingPoints.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/SourcesToStartingPoints.scala @@ -4,8 +4,8 @@ import io.joern.dataflowengineoss.globalFromLiteral import io.joern.x2cpg.Defines import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.allAssignmentTypes import io.shiftleft.semanticcpg.utils.MemberAccess.isFieldAccess import org.slf4j.LoggerFactory @@ -121,6 +121,8 @@ class SourceToStartingPoints(src: StoredNode) extends RecursiveTask[List[CfgNode x.argument(2).isFieldIdentifier.canonicalNameExact(identifier.name) case fieldIdentifier: FieldIdentifier => x.argument(2).isFieldIdentifier.canonicalNameExact(fieldIdentifier.canonicalName) + case member: Member => + x.argument(2).isFieldIdentifier.canonicalNameExact(member.name) case _ => Iterator.empty } } @@ -174,7 +176,9 @@ class SourceToStartingPoints(src: StoredNode) extends RecursiveTask[List[CfgNode .or( _.method.nameExact(Defines.StaticInitMethodName, Defines.ConstructorMethodName, "__init__"), // in language such as Python, where assignments for members can be directly under a type decl - _.method.typeDecl + _.method.typeDecl, + // for Python, we have moved to replacing strong updates of module-level variables with their members + _.target.isCall.nameExact(Operators.fieldAccess).argument(1).isIdentifier.name("") ) .target .flatMap { @@ -209,6 +213,11 @@ class SourceToStartingPoints(src: StoredNode) extends RecursiveTask[List[CfgNode private def targetsToClassIdentifierPair(targets: List[AstNode]): List[(TypeDecl, AstNode)] = { targets.flatMap { + case expr: FieldIdentifier => + expr.method.typeDecl.map { typeDecl => (typeDecl, expr) } ++ + expr.inCall.fieldAccess.referencedMember.flatMap { member => + member.typeDecl.map { typeDecl => (typeDecl, member) } + } case expr: Expression => expr.method.typeDecl.map { typeDecl => (typeDecl, expr) } case member: Member => diff --git a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/JavaScriptTypeRecovery.scala b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/JavaScriptTypeRecovery.scala index a5c1ac99d18f..59173e96c2b6 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/JavaScriptTypeRecovery.scala +++ b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/JavaScriptTypeRecovery.scala @@ -33,6 +33,8 @@ private class JavaScriptTypeRecovery(cpg: Cpg, state: XTypeRecoveryState) extend private class RecoverForJavaScriptFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, state: XTypeRecoveryState) extends RecoverForXCompilationUnit[File](cpg, cu, builder, state) { + import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromNodeExt + override protected val pathSep = ':' /** A heuristic method to determine if a call is a constructor or not. diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/ContextStack.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/ContextStack.scala index b1ec59b4181e..248d23bc7267 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/ContextStack.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/ContextStack.scala @@ -1,15 +1,14 @@ package io.joern.pysrc2cpg import io.joern.pysrc2cpg.ContextStack.transferLineColInfo -import io.joern.pysrc2cpg.memop._ +import io.joern.pysrc2cpg.memop.* import io.shiftleft.codepropertygraph.generated.nodes -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import org.slf4j.LoggerFactory import scala.collection.mutable object ContextStack { - private val logger = LoggerFactory.getLogger(getClass) def transferLineColInfo(src: NewIdentifier, tgt: NewLocal): Unit = { src.lineNumber match { @@ -23,7 +22,6 @@ object ContextStack { } class ContextStack { - import ContextStack.logger private trait Context { val astParent: nodes.NewNode @@ -70,6 +68,12 @@ class ContextStack { var lambdaCounter: Int = 0 ) extends Context {} + private trait Reference { + def identifier: nodes.NewIdentifier | nodes.NewFieldIdentifier + + def stack: List[Context] + } + private case class VariableReference( identifier: nodes.NewIdentifier, memOp: MemoryOperation, @@ -79,10 +83,13 @@ class ContextStack { // instances because the changes in the variable // maps need to be in sync. stack: List[Context] - ) + ) extends Reference + + private case class FieldReference(identifier: nodes.NewFieldIdentifier, fieldAccess: NewCall, stack: List[Context]) + extends Reference private var stack = List[Context]() - private val variableReferences = mutable.ArrayBuffer.empty[VariableReference] + private val variableReferences = mutable.ArrayBuffer.empty[Reference] private var moduleMethodContext = Option.empty[MethodContext] private var fileNamespaceBlock = Option.empty[nodes.NewNamespaceBlock] private val fileNamespaceBlockOrder = new AutoIncIndex(1) @@ -135,6 +142,10 @@ class ContextStack { variableReferences.append(VariableReference(identifier, memOp, stack)) } + def addFieldReference(identifier: nodes.NewFieldIdentifier, fieldAccess: NewCall): Unit = { + variableReferences.append(FieldReference(identifier, fieldAccess, stack)) + } + def getAndIncLambdaCounter(): Int = { val result = stack.head.lambdaCounter stack.head.lambdaCounter += 1 @@ -145,6 +156,9 @@ class ContextStack { contextStack.find(_.isInstanceOf[MethodContext]).get.asInstanceOf[MethodContext] } + def findEnclosingMethod(): Option[NewMethod] = + stack.find(_.isInstanceOf[MethodContext]).map(_.astParent).collect { case x: NewMethod => x } + def findEnclosingTypeDecl(): Option[NewNode] = { stack.find(_.isInstanceOf[ClassContext]) match { case Some(classContext: ClassContext) => @@ -160,13 +174,14 @@ class ContextStack { createRefEdge: (nodes.NewNode, nodes.NewNode) => Unit, createCaptureEdge: (nodes.NewNode, nodes.NewNode) => Unit ): Unit = { + val identifierReferences = variableReferences.collect { case x: VariableReference => x } // Before we do any linking, we iterate over all variable references and // create a variable in the module method context for each global variable // with a store operation on it. // This is necessary because there might be load/delete operations // referencing the global variable which are syntactically before the store // operations. - variableReferences.foreach { case VariableReference(identifier, memOp, contextStack) => + identifierReferences.foreach { case VariableReference(identifier, memOp, contextStack) => val name = identifier.name if ( memOp == Store && @@ -183,7 +198,7 @@ class ContextStack { // Variable references processing needs to be ordered by context depth in // order to make sure that variables captured into deeper nested contexts // are already created. - val sortedVariableRefs = variableReferences.sortBy(_.stack.size) + val sortedVariableRefs = identifierReferences.sortBy(_.stack.size) sortedVariableRefs.foreach { case VariableReference(identifier, memOp, contextStack) => val name = identifier.name // Store and delete operations look up variable only in method scope. @@ -252,40 +267,53 @@ class ContextStack { } } - /** Assignments to variables on the module-level may be exported to other modules and behave as inter-procedurally - * global variables. - * @param lhs - * the LHS node of an assignment - */ - def considerAsGlobalVariable(lhs: NewNode): Unit = { - lhs match { - case n: NewIdentifier if findEnclosingMethodContext(stack).scopeName.contains("") => - addGlobalVariable(n.name) - case _ => - } - } - /** For module-methods, the variables of this method can be imported into other modules which resembles behaviour much * like fields/members. This inter-procedural accessibility should be marked via the module's type decl node. */ - def createMemberLinks(moduleTypeDecl: NewTypeDecl, astEdgeLinker: (NewNode, NewNode, Int) => Unit): Unit = { - val globalVarsForEnclMethod = findEnclosingMethodContext(stack).globalVariables - variableReferences + def createMemberLinks( + moduleTypeDecl: NewTypeDecl, + astEdgeLinker: (NewNode, NewNode, Int) => Unit, + refEdgeLinker: (NewNode, NewNode) => Unit + ): Unit = { + val globalVariables = findEnclosingMethodContext(stack).globalVariables + val globalVarReferences = variableReferences + .filter(_.identifier match + case x: nodes.NewIdentifier => globalVariables.contains(x.name) + case x: nodes.NewFieldIdentifier => globalVariables.contains(x.canonicalName) + ) + val members = globalVarReferences .map(_.identifier) - .filter(i => globalVarsForEnclMethod.contains(i.name)) .sortBy(i => (i.lineNumber, i.columnNumber)) - .distinctBy(_.name) - .map(i => + .distinctBy { + case x: nodes.NewIdentifier => x.name + case x: nodes.NewFieldIdentifier => x.canonicalName + } + .map { i => + val name = i match + case x: nodes.NewIdentifier => x.name + case x: nodes.NewFieldIdentifier => x.canonicalName + + val dynamicTypeHintFullName = i match + case x: nodes.NewIdentifier => x.dynamicTypeHintFullName + case _: nodes.NewFieldIdentifier => Seq.empty + NewMember() - .name(i.name) + .name(name) .typeFullName(Constants.ANY) - .dynamicTypeHintFullName(i.dynamicTypeHintFullName) + .dynamicTypeHintFullName(dynamicTypeHintFullName) .lineNumber(i.lineNumber) .columnNumber(i.columnNumber) - .code(i.name) - ) + .code(name) + } .zipWithIndex - .foreach { case (m, idx) => astEdgeLinker(m, moduleTypeDecl, idx + 1) } + .map { case (m, idx) => + astEdgeLinker(m, moduleTypeDecl, idx + 1) + m + } + + globalVarReferences.collect { case x: FieldReference => x }.foreach { fi => + members.find(_.name == fi.identifier.canonicalName).foreach(member => refEdgeLinker(member, fi.fieldAccess)) + } } private def linkLocalOrCapturing( @@ -428,4 +456,10 @@ class ContextStack { }) } + def isModuleContext: Boolean = { + stack.headOption match + case Some(methodContext: MethodContext) => methodContext.scopeName.contains("") + case _ => false + } + } diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/ImportResolverPass.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/ImportResolverPass.scala index 98e3676c02ff..fed0e0035789 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/ImportResolverPass.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/ImportResolverPass.scala @@ -1,15 +1,14 @@ package io.joern.pysrc2cpg -import better.files.{File => BFile} -import io.joern.x2cpg.passes.frontend.ImportsPass._ +import better.files.File as BFile +import io.joern.x2cpg.passes.frontend.ImportsPass.* import io.joern.x2cpg.passes.frontend.XImportResolverPass import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* -import java.io.{File => JFile} +import java.io.File as JFile import java.util.regex.{Matcher, Pattern} -import better.files.File as BFile class ImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) { @@ -127,7 +126,11 @@ class ImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) { val pyFile = BFile(codeRoot) / s"$path.py" fileOrDir match { case f if f.isDirectory && !pyFile.exists => - Seq(s"${path.replaceAll("\\.", sep)}${java.io.File.separator}$expEntity.py:").toResolvedImport(cpg) + val namespace = path.replaceAll("\\.", sep) + val module = s"$expEntity.py:" + val initSubmodule = s"__init__.py:.$expEntity" + Seq(s"$namespace${JFile.separator}$module", s"$namespace${JFile.separator}$initSubmodule") + .toResolvedImport(cpg) case f if f.isDirectory && (f / s"$expEntity.py").exists => Seq(s"${(f / s"$expEntity.py").pathAsString.stripPrefix(codeRoot)}:").toResolvedImport(cpg) case _ => diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/NodeBuilder.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/NodeBuilder.scala index 03eec5b1ae20..16127346742d 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/NodeBuilder.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/NodeBuilder.scala @@ -198,12 +198,16 @@ class NodeBuilder(diffGraph: DiffGraphBuilder) { addNodeToDiff(returnNode) } - def identifierNode(name: String, lineAndColumn: LineAndColumn): nodes.NewIdentifier = { + def identifierNode( + name: String, + lineAndColumn: LineAndColumn, + typeFullName: String = Constants.ANY + ): nodes.NewIdentifier = { val identifierNode = nodes .NewIdentifier() .code(name) .name(name) - .typeFullName(Constants.ANY) + .typeFullName(typeFullName) .lineNumber(lineAndColumn.line) .columnNumber(lineAndColumn.column) addNodeToDiff(identifierNode) diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonAstVisitor.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonAstVisitor.scala index bbde1699d3cd..48bf976be3cb 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonAstVisitor.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonAstVisitor.scala @@ -5,7 +5,7 @@ import io.joern.pysrc2cpg.memop.* import io.joern.pythonparser.ast import io.joern.x2cpg.ValidationMode import io.shiftleft.codepropertygraph.generated.* -import io.shiftleft.codepropertygraph.generated.nodes.{NewNode, NewTypeDecl} +import io.shiftleft.codepropertygraph.generated.nodes.{NewMethod, NewNode, NewTypeDecl} import overflowdb.BatchedUpdate.DiffGraphBuilder import scala.collection.mutable @@ -32,7 +32,7 @@ class PythonAstVisitor(relFileName: String, protected val nodeToCode: NodeToCode protected val contextStack = new ContextStack() - private var memOpMap: AstNodeToMemoryOperationMap = _ + protected var memOpMap: AstNodeToMemoryOperationMap = _ private val members = mutable.Map.empty[NewTypeDecl, List[String]] @@ -97,7 +97,10 @@ class PythonAstVisitor(relFileName: String, protected val nodeToCode: NodeToCode methodFullName, Some(""), parameterProvider = () => MethodParameters.empty(), - bodyProvider = () => createBuiltinIdentifiers(memOpCalculator.names) ++ module.stmts.map(convert), + bodyProvider = () => + createBuiltinIdentifiers(memOpCalculator.names) + ++ createModuleIdentifier(methodFullName) + ++ module.stmts.map(convert), returns = None, isAsync = false, methodRefNode = None, @@ -167,6 +170,20 @@ class PythonAstVisitor(relFileName: String, protected val nodeToCode: NodeToCode result } + /** Creates the base for the implicit field accesses of module-level variable references. + * @param moduleFullName + * the target module. + * @return + * the assignment of the module identifier. + */ + private def createModuleIdentifier(moduleFullName: String): Iterable[nodes.NewNode] = { + val lineAndColumn = LineAndColumn(1, 1, 1, 1) + // Create implicit identifier + val moduleIdentifier = createIdentifierNode("", Store, lineAndColumn) + val moduleTypeRef = createTypeRef("", moduleFullName, lineAndColumn) + Seq(createAssignment(moduleIdentifier, moduleTypeRef, lineAndColumn)) + } + private def unhandled(node: ast.iast with ast.iattributes): NewNode = { val unhandledAsUnknown = true if (unhandledAsUnknown) { @@ -377,7 +394,7 @@ class PythonAstVisitor(relFileName: String, protected val nodeToCode: NodeToCode // For every method that is a module, the local variables can be imported by other modules. This behaviour is // much like fields so they are to be linked as fields to this method type - if (name == "") contextStack.createMemberLinks(typeDeclNode, edgeBuilder.astEdge) + if (name == "") contextStack.createMemberLinks(typeDeclNode, edgeBuilder.astEdge, edgeBuilder.refEdge) contextStack.pop() edgeBuilder.astEdge(typeDeclNode, contextStack.astParent, contextStack.order.getAndInc) @@ -1863,11 +1880,12 @@ class PythonAstVisitor(relFileName: String, protected val nodeToCode: NodeToCode def convert(name: ast.Name): nodes.NewNode = { val memoryOperation = memOpMap.get(name).get - val identifier = createIdentifierNode(name.id, memoryOperation, lineAndColOf(name)) if (contextStack.isClassContext && memoryOperation == Store) { - createAndRegisterMember(identifier.name, lineAndColOf(name)) + createAndRegisterMember(name.id, lineAndColOf(name)) + createIdentifierNode(name.id, memoryOperation, lineAndColOf(name)) + } else { + createIdentifierNode(name.id, memoryOperation, lineAndColOf(name)) } - identifier } // TODO test diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonAstVisitorHelpers.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonAstVisitorHelpers.scala index 4e6dcdc8233f..44d0fee91c29 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonAstVisitorHelpers.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonAstVisitorHelpers.scala @@ -2,7 +2,8 @@ package io.joern.pysrc2cpg import io.joern.pysrc2cpg.memop.{Load, MemoryOperation, Store} import io.joern.pythonparser.ast -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.joern.pythonparser.ast.Name +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators} import scala.collection.immutable.{::, Nil} @@ -102,9 +103,48 @@ trait PythonAstVisitorHelpers { this: PythonAstVisitor => ) { // No lowering or wrapping in a block is required if we have a single target and // no decomposition. - val targetNode = convert(targets.head) - - Iterable.single(createAssignment(targetNode, valueNode, lineAndColumn)) + val targetAstNode = targets.head + val memoryOperation = memOpMap.get(targetAstNode).get + + targetAstNode match + case name: Name if contextStack.isModuleContext && memoryOperation == Store => + // Module-level variables are implicitly interprocedurally accessible and should + // be treated as members. The code below ensures a strong update on the member and + // then aliases onto an identifier, e.g., `x = 1` becomes + /* + x = { + .x = 1 + tmp = .x + tmp + } + */ + contextStack.addGlobalVariable(name.id) + + val tmpVariableName = getUnusedName() + val stmts = Iterable( + createAssignment(generateModuleFieldAccess(name, Store), valueNode, lineAndColumn), + createAssignment( + createIdentifierNode(tmpVariableName, Store, lineAndColumn), + generateModuleFieldAccess(name, Load), + lineAndColumn + ), + createIdentifierNode(tmpVariableName, Load, lineAndColumn) + ).collect { + // Signal that these are module-related de-sugaring operations + case x: NewIdentifier => x.code(s".${x.code}") + case x: NewCall if !x.code.startsWith("") => x.code(s".${x.code}") + case x => x + } + val targetNode = convert(targetAstNode) + // Generate a block, and make it look like the "sugared" code that is being parsed + val block = createBlock(stmts, lineAndColumn) + .asInstanceOf[NewBlock] + .code(s"${codeOf(targetNode)} = ${codeOf(valueNode)}") + val desugaredModuleAssign = createAssignment(targetNode, block, lineAndColumn).asInstanceOf[NewCall] + // Get rid of code that looks like `targetNode = targetNode = value` + desugaredModuleAssign.code(desugaredModuleAssign.code.split("=").tail.mkString("=").trim) + Iterable.single(desugaredModuleAssign) + case _ => Iterable.single(createAssignment(convert(targetAstNode), valueNode, lineAndColumn)) } else { // Lowering of x, (y,z) = a = b = c: // Note: No surrounding block is created. This is the duty of the caller. @@ -169,6 +209,17 @@ trait PythonAstVisitorHelpers { this: PythonAstVisitor => result } + private def createModuleIdentifier(node: ast.iexpr, memoryOperation: MemoryOperation): NewNode = + contextStack + .findEnclosingMethod() + .map(m => createIdentifierNode("", memoryOperation, lineAndColOf(node), m.fullName)) + .getOrElse(createIdentifierNode("", memoryOperation, lineAndColOf(node), Constants.ANY)) + + private def generateModuleFieldAccess(name: Name, memoryOperation: MemoryOperation) = { + val moduleIdentifier = createModuleIdentifier(name, memoryOperation) + createFieldAccess(moduleIdentifier, name.id, lineAndColOf(name)) + } + protected def createComprehensionLowering( tmpVariableName: String, containerInitAssignNode: NewNode, @@ -437,9 +488,6 @@ trait PythonAstVisitorHelpers { this: PythonAstVisitor => val callNode = nodeBuilder.callNode(code, Operators.assignment, DispatchTypes.STATIC_DISPATCH, lineAndColumn) addAstChildrenAsArguments(callNode, 1, lhsNode, rhsNode) - // Do not include imports or function pointers - if (!codeOf(rhsNode).startsWith("import(") && codeOf(rhsNode) != s"def ${codeOf(lhsNode)}(...)") - contextStack.considerAsGlobalVariable(lhsNode) callNode } @@ -474,9 +522,10 @@ trait PythonAstVisitorHelpers { this: PythonAstVisitor => protected def createIdentifierNode( name: String, memOp: MemoryOperation, - lineAndColumn: LineAndColumn + lineAndColumn: LineAndColumn, + typeFullName: String = Constants.ANY ): NewIdentifier = { - val identifierNode = nodeBuilder.identifierNode(name, lineAndColumn) + val identifierNode = nodeBuilder.identifierNode(name, lineAndColumn, typeFullName) contextStack.addVariableReference(identifierNode, memOp) identifierNode } @@ -513,6 +562,8 @@ trait PythonAstVisitorHelpers { this: PythonAstVisitor => val code = codeOf(baseNode) + "." + codeOf(fieldIdNode) val callNode = nodeBuilder.callNode(code, Operators.fieldAccess, DispatchTypes.STATIC_DISPATCH, lineAndColumn) + contextStack.addFieldReference(fieldIdNode, callNode) + addAstChildrenAsArguments(callNode, 1, baseNode, fieldIdNode) callNode } diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonTypeRecovery.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonTypeRecovery.scala index a84a423baafa..a37d14191a3e 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonTypeRecovery.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonTypeRecovery.scala @@ -1,10 +1,10 @@ package io.joern.pysrc2cpg -import io.joern.x2cpg.passes.frontend._ +import io.joern.x2cpg.passes.frontend.* import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{Operators, PropertyNames} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.FieldAccess import overflowdb.BatchedUpdate.DiffGraphBuilder @@ -111,6 +111,16 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder } } + override def visitStatementsInBlock(b: Block, assignmentTarget: Option[Identifier]): Set[String] = { + if (b.inAssignment.nonEmpty && b.expressionDown.assignment.argument(1).fieldAccess.code(".*").nonEmpty) { + super.visitStatementsInBlock(b, assignmentTarget) + // Shortcut the actual value of the module access + visitAssignmentArguments(List(b.inAssignment.target.head, b.expressionDown.assignment.head.source)) + } else { + super.visitStatementsInBlock(b, assignmentTarget) + } + } + override def visitIdentifierAssignedToConstructor(i: Identifier, c: Call): Set[String] = { val constructorPaths = symbolTable.get(c).map(_.stripSuffix(s"${pathSep}__init__")) associateTypes(i, constructorPaths) @@ -182,7 +192,8 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder } } - override protected def postSetTypeInformation(): Unit = + override protected def postSetTypeInformation(): Unit = { + super.postSetTypeInformation() cu.typeDecl .map(t => t -> t.inheritsFromTypeFullName.partition(itf => symbolTable.contains(LocalVar(itf)))) .foreach { case (t, (identifierTypes, otherTypes)) => @@ -193,6 +204,7 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder builder.setNodeProperty(t, PropertyNames.INHERITS_FROM_TYPE_FULL_NAME, resolvedTypes) } } + } override def prepopulateSymbolTable(): Unit = { cu.ast.isMethodRef.where(_.astSiblings.isIdentifier.nameExact("classmethod")).referencedMethod.foreach { @@ -216,4 +228,14 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder .headOption .getOrElse(super.visitIdentifierAssignedToTypeRef(i, t, rec)) + override protected def handlePotentialFunctionPointer( + funcPtr: Expression, + baseTypes: Set[String], + funcName: String, + baseName: Option[String] + ): Unit = { + if (funcName != "") + super.handlePotentialFunctionPointer(funcPtr, baseTypes, funcName, baseName) + } + } diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/AssignCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/AssignCpgTests.scala index dca961f63eef..24cf0edf6828 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/AssignCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/AssignCpgTests.scala @@ -1,8 +1,8 @@ package io.joern.pysrc2cpg.cpg -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators, nodes} import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators, nodes} +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers @@ -11,8 +11,8 @@ class AssignCpgTests extends AnyFreeSpec with Matchers { lazy val cpg = Py2CpgTestContext.buildCpg("""x = 2""".stripMargin) "test assignment node properties" in { - val assignCall = cpg.call.methodFullName(Operators.assignment).head - assignCall.code shouldBe "x = 2" + val _ :: assignCall :: _ = cpg.call.methodFullName(Operators.assignment).l: @unchecked + assignCall.code shouldBe ".x = 2" assignCall.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH assignCall.lineNumber shouldBe Some(1) assignCall.columnNumber shouldBe Some(1) @@ -24,12 +24,15 @@ class AssignCpgTests extends AnyFreeSpec with Matchers { .astChildren .order(1) .isIdentifier - .head - .code shouldBe "x" + .code + .contains("x") shouldBe true cpg.call .methodFullName(Operators.assignment) .astChildren .order(2) + .astChildren + .order(1) + .astChildren .isLiteral .head .code shouldBe "2" @@ -38,9 +41,9 @@ class AssignCpgTests extends AnyFreeSpec with Matchers { "test assignment node arguments" in { cpg.call .methodFullName(Operators.assignment) + .last .argument .argumentIndex(1) - .isIdentifier .head .code shouldBe "x" cpg.call @@ -150,7 +153,7 @@ class AssignCpgTests extends AnyFreeSpec with Matchers { lazy val cpg = Py2CpgTestContext.buildCpg("""x: y = z""".stripMargin) "test assignment node properties" in { - val assignCall = cpg.call.methodFullName(Operators.assignment).head + val _ :: assignCall :: _ = cpg.call.methodFullName(Operators.assignment).l: @unchecked assignCall.code shouldBe "x = z" assignCall.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH assignCall.lineNumber shouldBe Some(1) @@ -160,6 +163,7 @@ class AssignCpgTests extends AnyFreeSpec with Matchers { "test assignment node ast children" in { cpg.call .methodFullName(Operators.assignment) + .last .astChildren .order(1) .isIdentifier @@ -177,6 +181,7 @@ class AssignCpgTests extends AnyFreeSpec with Matchers { "test assignment node arguments" in { cpg.call .methodFullName(Operators.assignment) + .last .argument .argumentIndex(1) .isIdentifier diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/CallCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/CallCpgTests.scala index 7ec13f219950..503becf26f54 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/CallCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/CallCpgTests.scala @@ -2,8 +2,7 @@ package io.joern.pysrc2cpg.cpg import io.joern.pysrc2cpg.PySrc2CpgFixture import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.shiftleft.semanticcpg.language._ -import overflowdb.traversal.NodeOps +import io.shiftleft.semanticcpg.language.* import java.io.File diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ModuleFunctionCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ModuleFunctionCpgTests.scala index 57c274c7502d..8dcc9f66ac69 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ModuleFunctionCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/ModuleFunctionCpgTests.scala @@ -1,7 +1,8 @@ package io.joern.pysrc2cpg.cpg import io.joern.pysrc2cpg.Py2CpgTestContext -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.nodes.Call +import io.shiftleft.semanticcpg.language.* import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers @@ -16,9 +17,8 @@ class ModuleFunctionCpgTests extends AnyFreeSpec with Matchers { } "test method body" in { - val topLevelExprs = cpg.method.fullName("test.py:").topLevelExpressions.l - topLevelExprs.size shouldBe 1 - topLevelExprs.isCall.head.code shouldBe "pass" - topLevelExprs.isCall.head.methodFullName shouldBe ".pass" + val _ :: (topLevelExprs: Call) :: Nil = cpg.method.fullName("test.py:").topLevelExpressions.l: @unchecked + topLevelExprs.code shouldBe "pass" + topLevelExprs.methodFullName shouldBe ".pass" } } diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/VariableReferencingCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/VariableReferencingCpgTests.scala index 9da932fb2023..eadf68b0023b 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/VariableReferencingCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/VariableReferencingCpgTests.scala @@ -81,8 +81,10 @@ class VariableReferencingCpgTests extends AnyFreeSpec with Matchers { } "test identifiers in line 1 and 3 reference to module method x variable" in { + val memberXNode = cpg.typeDecl.nameExact("").member.name("x").head + cpg.fieldAccess.where(_.fieldIdentifier.canonicalNameExact("x")).lineNumber(1).member.head shouldBe memberXNode + val localXNode = cpg.method.name("").block.local.name("x").head - cpg.identifier("x").lineNumber(1).refsTo.head shouldBe localXNode cpg.identifier("x").lineNumber(3).refsTo.head shouldBe localXNode } @@ -96,6 +98,10 @@ class VariableReferencingCpgTests extends AnyFreeSpec with Matchers { f.lineNumber shouldBe Some(3) f.columnNumber shouldBe Some(1) + val xField = cpg.fieldAccess.where(_.fieldIdentifier.canonicalNameExact("x")).head + xField.lineNumber shouldBe Some(1) + xField.lineNumber shouldBe Some(1) + val x = cpg.local("x").filterNot(_.definingBlock.astParent.isMethod.isEmpty).head val y = cpg.local("y").filterNot(_.definingBlock.astParent.isMethod.isEmpty).head x.lineNumber shouldBe Some(1) @@ -119,8 +125,10 @@ class VariableReferencingCpgTests extends AnyFreeSpec with Matchers { } "test identifiers in line 1 and 3 reference to module method x variable" in { + val memberXNode = cpg.typeDecl.nameExact("").member.name("x").head + cpg.fieldAccess.where(_.fieldIdentifier.canonicalNameExact("x")).lineNumber(1).member.head shouldBe memberXNode + val localXNode = cpg.method.name("").block.local.name("x").head - cpg.identifier("x").lineNumber(1).refsTo.head shouldBe localXNode cpg.identifier("x").lineNumber(3).refsTo.head shouldBe localXNode } diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/dataflow/DataFlowTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/dataflow/DataFlowTests.scala index 72e61c5a0b60..89b7a2603c4d 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/dataflow/DataFlowTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/dataflow/DataFlowTests.scala @@ -5,8 +5,7 @@ import io.joern.dataflowengineoss.semanticsloader.FlowSemantic import io.joern.pysrc2cpg.PySrc2CpgFixture import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{Literal, Member, Method} -import io.shiftleft.semanticcpg.language._ -import org.scalatest.Ignore +import io.shiftleft.semanticcpg.language.* import java.io.File @@ -302,7 +301,8 @@ class DataFlowTests extends PySrc2CpgFixture(withOssDataflow = true) { "models.py" ) - val List(method: Method) = cpg.identifier.name("foo").inAssignment.source.isCall.callee.l + val List(method: Method) = + cpg.identifier.name("foo").inAssignment.source.isBlock.ast.isCall.nameExact("Foo").callee.l method.fullName shouldBe "models.py:.Foo.__init__" val List(typeDeclFullName) = method.typeDecl.fullName.l typeDeclFullName shouldBe "models.py:.Foo" @@ -321,7 +321,8 @@ class DataFlowTests extends PySrc2CpgFixture(withOssDataflow = true) { "models.py" ) - val List(method: Method) = cpg.identifier.name("foo").inAssignment.source.isCall.callee.l + val List(method: Method) = + cpg.identifier.name("foo").inAssignment.source.isBlock.ast.isCall.nameExact("Foo").callee.l method.fullName shouldBe "models.py:.Foo.__init__" val List(typeDeclFullName) = method.typeDecl.fullName.l typeDeclFullName shouldBe "models.py:.Foo" @@ -589,4 +590,25 @@ class RegexDefinedFlowsDataFlowTests flows.size shouldBe 2 } + "flow across interprocedural module variables" in { + val cpg: Cpg = code( + """ + |a = 42 + |""".stripMargin, + "foo.py" + ) + .moreCode( + """ + |import foo + | + |print(foo.a) + |""".stripMargin, + "bar.py" + ) + + val source = cpg.literal("42").l + val sink = cpg.call.code("print.*").argument.l + sink.reachableByFlows(source).size shouldBe 1 + } + } diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/TypeRecoveryPassTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/TypeRecoveryPassTests.scala index 4416eb4d4572..ccb0fee59191 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/TypeRecoveryPassTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/TypeRecoveryPassTests.scala @@ -3,6 +3,7 @@ package io.joern.pysrc2cpg.passes import io.joern.pysrc2cpg.PySrc2CpgFixture import io.joern.x2cpg.passes.frontend.ImportsPass.* import io.joern.x2cpg.passes.frontend.{ImportsPass, XTypeHintCallLinker} +import io.shiftleft.codepropertygraph.generated.nodes.Local import io.shiftleft.semanticcpg.language.* import java.io.File @@ -214,6 +215,23 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { "bar.py" ).cpg + "be able to traverse from `foo.[x|y|db]` to its members" in { + val fields = cpg.fieldAccess.where(_.fieldIdentifier.canonicalName("x", "y", "db")).l + val List(mDB, mX, mY) = fields.referencedMember.dedup.sortBy(_.name).l + + mDB.name shouldBe "db" + mDB.typeFullName shouldBe "flask_sqlalchemy.py:.SQLAlchemy" + mDB.typeDecl.fullName shouldBe "foo.py:" + + mX.name shouldBe "x" + mX.typeFullName shouldBe "__builtin.int" + mX.typeDecl.fullName shouldBe "foo.py:" + + mY.name shouldBe "y" + mY.typeFullName shouldBe "__builtin.str" + mY.typeDecl.fullName shouldBe "foo.py:" + } + "resolve correct imports via tag nodes" in { val List(foo1: UnknownMethod, foo2: UnknownTypeDecl) = cpg.file(".*foo.py").ast.isCall.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked @@ -452,7 +470,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { val Some(selfFindFound) = cpg.typeDecl(".*InstallationsDAO.*").ast.isCall.name("find_one").headOption: @unchecked selfFindFound.dynamicTypeHintFullName shouldBe Seq( "__builtin.None.find_one", - "pymongo.py:.MongoClient.__init__...find_one" + "pymongo.py:.MongoClient.__init__....find_one" ) } diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/slicing/PyUsageSliceTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/slicing/PyUsageSliceTests.scala deleted file mode 100644 index 890679d66bdc..000000000000 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/slicing/PyUsageSliceTests.scala +++ /dev/null @@ -1,51 +0,0 @@ -package io.joern.pysrc2cpg.slicing - -import io.joern.dataflowengineoss.slicing.* -import io.joern.pysrc2cpg.PySrc2CpgFixture -import io.shiftleft.codepropertygraph.generated.Operators - -class PyUsageSliceTests extends PySrc2CpgFixture { - - private val config = UsagesConfig(excludeOperatorCalls = true).withParallelism(1) - - "extracting a usage slice from basic objects" should { - - lazy val cpg = code(""" - |from flask_sqlalchemy import SQLAlchemy - | - |x = 1 - |y = "test" - |db = SQLAlchemy() - | - |db.createTable() - |db.deleteTable() - |""".stripMargin) - - val programSlice = UsageSlicing.calculateUsageSlice(cpg, config.copy(excludeOperatorCalls = false)) - - "should successfully extract 'db' usages" in { - val slice = programSlice.objectSlices.head.slices.head - slice.targetObj shouldBe LocalDef("db", "flask_sqlalchemy.py:.SQLAlchemy", Some(6), Some(1)) - slice.definedBy shouldBe Option( - CallDef("SQLAlchemy", "ANY", Some("flask_sqlalchemy.py:.SQLAlchemy.__init__"), Some(6), Some(6)) - ) - - val inv1 = slice.invokedCalls.find(_.callName == "createTable").get - val inv2 = slice.invokedCalls.find(_.callName == "deleteTable").get - - inv1.resolvedMethod shouldBe Some("flask_sqlalchemy.py:.SQLAlchemy.createTable") - inv1.paramTypes shouldBe List.empty - inv1.returnType shouldBe "ANY" - - inv2.resolvedMethod shouldBe Some("flask_sqlalchemy.py:.SQLAlchemy.deleteTable") - inv2.paramTypes shouldBe List.empty - inv2.returnType shouldBe "ANY" - - val List(arg1) = slice.argToCalls - - arg1.callName shouldBe Operators.assignment - } - - } - -} diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeRecovery.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeRecovery.scala index 96f1fdc6eaff..a1568ec72dfe 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeRecovery.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeRecovery.scala @@ -18,6 +18,7 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.annotation.tailrec import scala.collection.concurrent.TrieMap import scala.collection.mutable +import scala.util.matching.Regex /** @param iterations * the number of iterations to run. @@ -66,6 +67,8 @@ abstract class XTypeRecoveryPass[CompilationUnitType <: AstNode]( config: XTypeRecoveryConfig = XTypeRecoveryConfig() ) extends CpgPass(cpg) { + import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromNodeExt + override def run(builder: BatchedUpdate.DiffGraphBuilder): Unit = if (config.iterations > 0) { val stopEarly = new AtomicBoolean(false) @@ -76,8 +79,11 @@ abstract class XTypeRecoveryPass[CompilationUnitType <: AstNode]( generateRecoveryPass(newState).createAndApply() } // If dummy values are enabled and we are stopping early, we need one more round to propagate these dummy values - if (stopEarly.get() && config.enabledDummyTypes) + if (stopEarly.get() && config.enabledDummyTypes) { generateRecoveryPass(state.copy(currentIteration = config.iterations - 1)).createAndApply() + } + + postTypeRecoveryAndPropagation(builder) } finally { state.clear() } @@ -85,6 +91,31 @@ abstract class XTypeRecoveryPass[CompilationUnitType <: AstNode]( protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[CompilationUnitType] + /** A hook for the end of the type recovery and propagation. + */ + protected def postTypeRecoveryAndPropagation(builder: DiffGraphBuilder): Unit = { + linkMembersToTheirRefs(builder) + } + + private def linkMembersToTheirRefs(builder: DiffGraphBuilder): Unit = { + import XTypeRecovery.unknownTypePattern + // Set all now-typed fieldAccess calls to their referencing members (if they exist) + cpg.fieldAccess + .where( + _.and( + _.not(_.referencedMember), + _.argument(1).isIdentifier.typeFullNameNot(unknownTypePattern.pattern.pattern()) + ) + ) + .foreach { fieldAccess => + cpg.typeDecl + .fullNameExact(fieldAccess.argument(1).getKnownTypes.toSeq: _*) + .member + .nameExact(fieldAccess.fieldIdentifier.canonicalName.toSeq: _*) + .foreach(builder.addEdge(fieldAccess, _, EdgeTypes.REF)) + } + } + } trait TypeRecoveryParserConfig[R <: X2CpgConfig[R]] { this: R => @@ -167,6 +198,8 @@ object XTypeRecovery { val DummyIndexAccess = "" private lazy val DummyTokens: Set[String] = Set(DummyReturnType, DummyMemberLoad, DummyIndexAccess) + val unknownTypePattern: Regex = s"(i?)(UNKNOWN|ANY|${Defines.UnresolvedNamespace}).*".r + def dummyMemberType(prefix: String, memberName: String, sep: Char = '.'): String = s"$prefix$sep$DummyMemberLoad($memberName)" @@ -201,6 +234,26 @@ object XTypeRecovery { ) } + // The below are convenience calls for accessing type properties, one day when this pass uses `Tag` nodes instead of + // the symbol table then perhaps this would work out better + implicit class AllNodeTypesFromNodeExt(x: StoredNode) { + def allTypes: Iterator[String] = (x.property(PropertyNames.TYPE_FULL_NAME, "ANY") +: x.property( + PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, + Seq.empty + )).iterator + + def getKnownTypes: Set[String] = { + x.allTypes.toSet.filterNot(unknownTypePattern.matches) + } + } + + implicit class AllNodeTypesFromIteratorExt(x: Iterator[StoredNode]) { + def allTypes: Iterator[String] = x.flatMap(_.allTypes) + + def getKnownTypes: Set[String] = + x.allTypes.toSet.filterNot(unknownTypePattern.matches) + } + } /** Performs type recovery from the root of a compilation unit level @@ -221,6 +274,8 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( state: XTypeRecoveryState ) extends RecursiveTask[Boolean] { + import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromNodeExt + protected val logger: Logger = LoggerFactory.getLogger(getClass) /** Stores type information for local structures that live within this compilation unit, e.g. local variables. @@ -309,6 +364,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( alias <- i.importedAs } { import io.joern.x2cpg.passes.frontend.ImportsPass.* + import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromIteratorExt ResolvedImport.tagToResolvedImport(resolvedImport).foreach { case ResolvedMethod(fullName, alias, receiver, _) => @@ -344,21 +400,21 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( * @param a * assignment call pointer. */ - protected def visitAssignments(a: Assignment): Set[String] = { - a.argumentOut.l match { - case List(i: Identifier, b: Block) => visitIdentifierAssignedToBlock(i, b) - case List(i: Identifier, c: Call) => visitIdentifierAssignedToCall(i, c) - case List(x: Identifier, y: Identifier) => visitIdentifierAssignedToIdentifier(x, y) - case List(i: Identifier, l: Literal) if state.isFirstIteration => visitIdentifierAssignedToLiteral(i, l) - case List(i: Identifier, m: MethodRef) => visitIdentifierAssignedToMethodRef(i, m) - case List(i: Identifier, t: TypeRef) => visitIdentifierAssignedToTypeRef(i, t) - case List(c: Call, i: Identifier) => visitCallAssignedToIdentifier(c, i) - case List(x: Call, y: Call) => visitCallAssignedToCall(x, y) - case List(c: Call, l: Literal) if state.isFirstIteration => visitCallAssignedToLiteral(c, l) - case List(c: Call, m: MethodRef) => visitCallAssignedToMethodRef(c, m) - case List(c: Call, b: Block) => visitCallAssignedToBlock(c, b) - case _ => Set.empty - } + protected def visitAssignments(a: Assignment): Set[String] = visitAssignmentArguments(a.argumentOut.l) + + protected def visitAssignmentArguments(args: List[AstNode]): Set[String] = args match { + case List(i: Identifier, b: Block) => visitIdentifierAssignedToBlock(i, b) + case List(i: Identifier, c: Call) => visitIdentifierAssignedToCall(i, c) + case List(x: Identifier, y: Identifier) => visitIdentifierAssignedToIdentifier(x, y) + case List(i: Identifier, l: Literal) if state.isFirstIteration => visitIdentifierAssignedToLiteral(i, l) + case List(i: Identifier, m: MethodRef) => visitIdentifierAssignedToMethodRef(i, m) + case List(i: Identifier, t: TypeRef) => visitIdentifierAssignedToTypeRef(i, t) + case List(c: Call, i: Identifier) => visitCallAssignedToIdentifier(c, i) + case List(x: Call, y: Call) => visitCallAssignedToCall(x, y) + case List(c: Call, l: Literal) if state.isFirstIteration => visitCallAssignedToLiteral(c, l) + case List(c: Call, m: MethodRef) => visitCallAssignedToMethodRef(c, m) + case List(c: Call, b: Block) => visitCallAssignedToBlock(c, b) + case _ => Set.empty } /** Visits an identifier being assigned to the result of some operation. @@ -426,13 +482,15 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( protected def setCallMethodFullNameFromBase(c: Call): Set[String] = { val recTypes = c.argument.headOption .map { - case x: Call if x.typeFullName != "ANY" => Set(x.typeFullName) + case x: Call if x.typeFullName != "ANY" => + Set(x.typeFullName) case x: Call => cpg.method.fullNameExact(c.methodFullName).methodReturn.typeFullNameNot("ANY").typeFullName.toSet match { case xs if xs.nonEmpty => xs case _ => symbolTable.get(x).map(t => Seq(t, XTypeRecovery.DummyReturnType).mkString(pathSep.toString)) } - case x => symbolTable.get(x) + case x => + symbolTable.get(x) } .getOrElse(Set.empty[String]) val callTypes = recTypes.map(_.concat(s"$pathSep${c.name}")) @@ -505,11 +563,14 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( // We have been able to resolve the type inter-procedurally associateTypes(i, globalTypes) } else if (baseTypes.nonEmpty) { + lazy val existingMembers = cpg.typeDecl.fullNameExact(baseTypes.toSeq: _*).member.nameExact(fieldName) if (baseTypes.equals(symbolTable.get(LocalVar(fieldFullName)))) { associateTypes(i, baseTypes) - } else { + } else if (existingMembers.isEmpty) { // If not available, use a dummy variable that can be useful for call matching associateTypes(i, baseTypes.map(t => XTypeRecovery.dummyMemberType(t, fieldName, pathSep))) + } else { + Set.empty } } else { // Assign dummy @@ -611,7 +672,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( */ protected def getTypesFromCall(c: Call): Set[String] = c.name match { case Operators.fieldAccess => symbolTable.get(LocalVar(getFieldName(new FieldAccess(c)))) - case _ if symbolTable.contains(c) => symbolTable.get(c) + case _ if symbolTable.contains(c) => methodReturnValues(symbolTable.get(c).toSeq) case Operators.indexAccess => getIndexAccessTypes(c) case n => logger.debug(s"Unknown RHS call type '$n' @ ${debugLocation(c)}") @@ -694,6 +755,8 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( // TODO: Handle this case better val callCode = if (c.code.contains("(")) c.code.substring(c.code.indexOf("(")) else c.code XTypeRecovery.dummyMemberType(callCode, f.canonicalName, pathSep) + case ::(_: TypeRef, ::(f: FieldIdentifier, _)) => + f.canonicalName case xs => logger.warn(s"Unhandled field structure ${xs.map(x => (x.label, x.code)).mkString(",")} @ ${debugLocation(fa)}") wrapName("") @@ -986,7 +1049,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( /** In the case this field access is a function pointer, we would want to make sure this has a method ref. */ - private def handlePotentialFunctionPointer( + protected def handlePotentialFunctionPointer( funcPtr: Expression, baseTypes: Set[String], funcName: String, @@ -1158,26 +1221,4 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( */ protected def postSetTypeInformation(): Unit = {} - private val unknownTypePattern = s"(i?)(UNKNOWN|ANY|${Defines.UnresolvedNamespace}).*".r - - // The below are convenience calls for accessing type properties, one day when this pass uses `Tag` nodes instead of - // the symbol table then perhaps this would work out better - implicit class AllNodeTypesFromNodeExt(x: StoredNode) { - def allTypes: Iterator[String] = (x.property(PropertyNames.TYPE_FULL_NAME, "ANY") +: x.property( - PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, - Seq.empty - )).iterator - - def getKnownTypes: Set[String] = { - x.allTypes.toSet.filterNot(unknownTypePattern.matches) - } - } - - implicit class AllNodeTypesFromIteratorExt(x: Iterator[StoredNode]) { - def allTypes: Iterator[String] = x.flatMap(_.allTypes) - - def getKnownTypes: Set[String] = - x.allTypes.toSet.filterNot(unknownTypePattern.matches) - } - }