Skip to content

Commit

Permalink
[pysrc2cpg] Model Field-like Behaviour of Module Variables (joernio#3750
Browse files Browse the repository at this point in the history
)

* [pysrc2cpg] Model Field-like Behaviour of Module Variables
The motivation for this PR is that module-level variables in Python can be accessed by other files by some implicit global scope. This behaviour is more akin to fields rather than identifiers, and should be modelled closer to that direction.

Given the caveat that the data-flow engine sees fields a bit differently, especially if the flow goes across closures, this aims to be the least intrusive change to achieve this result.

### Main changes
* Create a module-level identifier assigned to the module's type ref to be used as the base of module-level member field accesses.
* At module-level identifier STOREs, create a block that writes to the module field, then aliases the result to the identifier, e.g., `x = 1` becomes
```python
x = {
      <module>.x = 1
      tmp = <module>.x
      tmp
    }
```
* Added helper methods, e.g., `isModuleContext` to `ContextStack`
* Separated `VariableReference` class into `Reference` trait shared between `VariableReference` and `FieldReference`
* During `createMemberLinks`, `FieldReference` calls will have a `REF` edge to their referencing member. TODO: Recover these `REF` edges during the type recovery pass.
*

## Misc Changes
* Fix `pysrc2cpg/ImportResolver` bug when an `__init__.py` entity is imported from a non-sibling module.
* Added shortcut in `PythonTypeRecovery.visitStatementsInBlock` to detect these strong-update blocks and "re-sugar" them to skip propagating the type via member
* Removed Python slicer tests, I would rather be intentional in how I support Python usage slicing at a later stage

* Added member traversal handling for recovered base types

* Fixed JSSRC import

* Added test for data-flow, found some bug about how starting points are handled for members

* Removed desugaring being part of the flows so that the flows are the same as before and give no fluctuations

* Fixed assignment code from being `<module>.<module>.x`
  • Loading branch information
DavidBakerEffendi authored Oct 23, 2023
1 parent 9717b71 commit 4d58f8f
Show file tree
Hide file tree
Showing 17 changed files with 381 additions and 182 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ class ExtendedCfgNode(val traversal: Iterator[CfgNode]) extends AnyVal {
reachedSources.cast[NodeType]
}

/** Trims out path elements that represent module-level variable assignment desugaring
*/
private def trimModuleVariableDesugarBlock(path: Vector[PathElement]): Vector[PathElement] = {
val moduleKeywords = Seq("<module>", ":program")
val codeMatcher = s"(${moduleKeywords.mkString("|")}).*"
path.filterNot { x =>
x.node match
case block: Block => block.astChildren.code(codeMatcher).nonEmpty
case astNode => astNode.inAssignment.code(codeMatcher).nonEmpty
}
}

def reachableByFlows[A](sourceTrav: IterableOnce[A], sourceTravs: IterableOnce[A]*)(implicit
context: EngineContext
): Iterator[Path] = {
Expand All @@ -51,8 +63,10 @@ class ExtendedCfgNode(val traversal: Iterator[CfgNode]) extends AnyVal {
if (first.isDefined && !first.get.visible && !startingPoints.contains(first.get.node)) {
None
} else {
val visiblePathElements = result.path.filter(x => startingPoints.contains(x.node) || x.visible)
Some(Path(removeConsecutiveDuplicates(visiblePathElements.map(_.node))))
val visiblePathElements = result.path.filter(x => startingPoints.contains(x.node) || x.visible)
val trimmedOfModuleDesugaring = trimModuleVariableDesugarBlock(visiblePathElements)
val deDuplicated = removeConsecutiveDuplicates(trimmedOfModuleDesugaring.map(_.node))
Some(Path(deDuplicated))
}
}
.filter(_.isDefined)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import io.joern.dataflowengineoss.globalFromLiteral
import io.joern.x2cpg.Defines
import io.shiftleft.codepropertygraph.Cpg
import io.shiftleft.codepropertygraph.generated.Operators
import io.shiftleft.codepropertygraph.generated.nodes._
import io.shiftleft.semanticcpg.language._
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.shiftleft.semanticcpg.language.*
import io.shiftleft.semanticcpg.language.operatorextension.allAssignmentTypes
import io.shiftleft.semanticcpg.utils.MemberAccess.isFieldAccess
import org.slf4j.LoggerFactory
Expand Down Expand Up @@ -121,6 +121,8 @@ class SourceToStartingPoints(src: StoredNode) extends RecursiveTask[List[CfgNode
x.argument(2).isFieldIdentifier.canonicalNameExact(identifier.name)
case fieldIdentifier: FieldIdentifier =>
x.argument(2).isFieldIdentifier.canonicalNameExact(fieldIdentifier.canonicalName)
case member: Member =>
x.argument(2).isFieldIdentifier.canonicalNameExact(member.name)
case _ => Iterator.empty
}
}
Expand Down Expand Up @@ -174,7 +176,9 @@ class SourceToStartingPoints(src: StoredNode) extends RecursiveTask[List[CfgNode
.or(
_.method.nameExact(Defines.StaticInitMethodName, Defines.ConstructorMethodName, "__init__"),
// in language such as Python, where assignments for members can be directly under a type decl
_.method.typeDecl
_.method.typeDecl,
// for Python, we have moved to replacing strong updates of module-level variables with their members
_.target.isCall.nameExact(Operators.fieldAccess).argument(1).isIdentifier.name("<module>")
)
.target
.flatMap {
Expand Down Expand Up @@ -209,6 +213,11 @@ class SourceToStartingPoints(src: StoredNode) extends RecursiveTask[List[CfgNode

private def targetsToClassIdentifierPair(targets: List[AstNode]): List[(TypeDecl, AstNode)] = {
targets.flatMap {
case expr: FieldIdentifier =>
expr.method.typeDecl.map { typeDecl => (typeDecl, expr) } ++
expr.inCall.fieldAccess.referencedMember.flatMap { member =>
member.typeDecl.map { typeDecl => (typeDecl, member) }
}
case expr: Expression =>
expr.method.typeDecl.map { typeDecl => (typeDecl, expr) }
case member: Member =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ private class JavaScriptTypeRecovery(cpg: Cpg, state: XTypeRecoveryState) extend
private class RecoverForJavaScriptFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, state: XTypeRecoveryState)
extends RecoverForXCompilationUnit[File](cpg, cu, builder, state) {

import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromNodeExt

override protected val pathSep = ':'

/** A heuristic method to determine if a call is a constructor or not.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package io.joern.pysrc2cpg

import io.joern.pysrc2cpg.ContextStack.transferLineColInfo
import io.joern.pysrc2cpg.memop._
import io.joern.pysrc2cpg.memop.*
import io.shiftleft.codepropertygraph.generated.nodes
import io.shiftleft.codepropertygraph.generated.nodes._
import io.shiftleft.codepropertygraph.generated.nodes.*
import org.slf4j.LoggerFactory

import scala.collection.mutable

object ContextStack {
private val logger = LoggerFactory.getLogger(getClass)

def transferLineColInfo(src: NewIdentifier, tgt: NewLocal): Unit = {
src.lineNumber match {
Expand All @@ -23,7 +22,6 @@ object ContextStack {
}

class ContextStack {
import ContextStack.logger

private trait Context {
val astParent: nodes.NewNode
Expand Down Expand Up @@ -70,6 +68,12 @@ class ContextStack {
var lambdaCounter: Int = 0
) extends Context {}

private trait Reference {
def identifier: nodes.NewIdentifier | nodes.NewFieldIdentifier

def stack: List[Context]
}

private case class VariableReference(
identifier: nodes.NewIdentifier,
memOp: MemoryOperation,
Expand All @@ -79,10 +83,13 @@ class ContextStack {
// instances because the changes in the variable
// maps need to be in sync.
stack: List[Context]
)
) extends Reference

private case class FieldReference(identifier: nodes.NewFieldIdentifier, fieldAccess: NewCall, stack: List[Context])
extends Reference

private var stack = List[Context]()
private val variableReferences = mutable.ArrayBuffer.empty[VariableReference]
private val variableReferences = mutable.ArrayBuffer.empty[Reference]
private var moduleMethodContext = Option.empty[MethodContext]
private var fileNamespaceBlock = Option.empty[nodes.NewNamespaceBlock]
private val fileNamespaceBlockOrder = new AutoIncIndex(1)
Expand Down Expand Up @@ -135,6 +142,10 @@ class ContextStack {
variableReferences.append(VariableReference(identifier, memOp, stack))
}

def addFieldReference(identifier: nodes.NewFieldIdentifier, fieldAccess: NewCall): Unit = {
variableReferences.append(FieldReference(identifier, fieldAccess, stack))
}

def getAndIncLambdaCounter(): Int = {
val result = stack.head.lambdaCounter
stack.head.lambdaCounter += 1
Expand All @@ -145,6 +156,9 @@ class ContextStack {
contextStack.find(_.isInstanceOf[MethodContext]).get.asInstanceOf[MethodContext]
}

def findEnclosingMethod(): Option[NewMethod] =
stack.find(_.isInstanceOf[MethodContext]).map(_.astParent).collect { case x: NewMethod => x }

def findEnclosingTypeDecl(): Option[NewNode] = {
stack.find(_.isInstanceOf[ClassContext]) match {
case Some(classContext: ClassContext) =>
Expand All @@ -160,13 +174,14 @@ class ContextStack {
createRefEdge: (nodes.NewNode, nodes.NewNode) => Unit,
createCaptureEdge: (nodes.NewNode, nodes.NewNode) => Unit
): Unit = {
val identifierReferences = variableReferences.collect { case x: VariableReference => x }
// Before we do any linking, we iterate over all variable references and
// create a variable in the module method context for each global variable
// with a store operation on it.
// This is necessary because there might be load/delete operations
// referencing the global variable which are syntactically before the store
// operations.
variableReferences.foreach { case VariableReference(identifier, memOp, contextStack) =>
identifierReferences.foreach { case VariableReference(identifier, memOp, contextStack) =>
val name = identifier.name
if (
memOp == Store &&
Expand All @@ -183,7 +198,7 @@ class ContextStack {
// Variable references processing needs to be ordered by context depth in
// order to make sure that variables captured into deeper nested contexts
// are already created.
val sortedVariableRefs = variableReferences.sortBy(_.stack.size)
val sortedVariableRefs = identifierReferences.sortBy(_.stack.size)
sortedVariableRefs.foreach { case VariableReference(identifier, memOp, contextStack) =>
val name = identifier.name
// Store and delete operations look up variable only in method scope.
Expand Down Expand Up @@ -252,40 +267,53 @@ class ContextStack {
}
}

/** Assignments to variables on the module-level may be exported to other modules and behave as inter-procedurally
* global variables.
* @param lhs
* the LHS node of an assignment
*/
def considerAsGlobalVariable(lhs: NewNode): Unit = {
lhs match {
case n: NewIdentifier if findEnclosingMethodContext(stack).scopeName.contains("<module>") =>
addGlobalVariable(n.name)
case _ =>
}
}

/** For module-methods, the variables of this method can be imported into other modules which resembles behaviour much
* like fields/members. This inter-procedural accessibility should be marked via the module's type decl node.
*/
def createMemberLinks(moduleTypeDecl: NewTypeDecl, astEdgeLinker: (NewNode, NewNode, Int) => Unit): Unit = {
val globalVarsForEnclMethod = findEnclosingMethodContext(stack).globalVariables
variableReferences
def createMemberLinks(
moduleTypeDecl: NewTypeDecl,
astEdgeLinker: (NewNode, NewNode, Int) => Unit,
refEdgeLinker: (NewNode, NewNode) => Unit
): Unit = {
val globalVariables = findEnclosingMethodContext(stack).globalVariables
val globalVarReferences = variableReferences
.filter(_.identifier match
case x: nodes.NewIdentifier => globalVariables.contains(x.name)
case x: nodes.NewFieldIdentifier => globalVariables.contains(x.canonicalName)
)
val members = globalVarReferences
.map(_.identifier)
.filter(i => globalVarsForEnclMethod.contains(i.name))
.sortBy(i => (i.lineNumber, i.columnNumber))
.distinctBy(_.name)
.map(i =>
.distinctBy {
case x: nodes.NewIdentifier => x.name
case x: nodes.NewFieldIdentifier => x.canonicalName
}
.map { i =>
val name = i match
case x: nodes.NewIdentifier => x.name
case x: nodes.NewFieldIdentifier => x.canonicalName

val dynamicTypeHintFullName = i match
case x: nodes.NewIdentifier => x.dynamicTypeHintFullName
case _: nodes.NewFieldIdentifier => Seq.empty

NewMember()
.name(i.name)
.name(name)
.typeFullName(Constants.ANY)
.dynamicTypeHintFullName(i.dynamicTypeHintFullName)
.dynamicTypeHintFullName(dynamicTypeHintFullName)
.lineNumber(i.lineNumber)
.columnNumber(i.columnNumber)
.code(i.name)
)
.code(name)
}
.zipWithIndex
.foreach { case (m, idx) => astEdgeLinker(m, moduleTypeDecl, idx + 1) }
.map { case (m, idx) =>
astEdgeLinker(m, moduleTypeDecl, idx + 1)
m
}

globalVarReferences.collect { case x: FieldReference => x }.foreach { fi =>
members.find(_.name == fi.identifier.canonicalName).foreach(member => refEdgeLinker(member, fi.fieldAccess))
}
}

private def linkLocalOrCapturing(
Expand Down Expand Up @@ -428,4 +456,10 @@ class ContextStack {
})
}

def isModuleContext: Boolean = {
stack.headOption match
case Some(methodContext: MethodContext) => methodContext.scopeName.contains("<module>")
case _ => false
}

}
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package io.joern.pysrc2cpg

import better.files.{File => BFile}
import io.joern.x2cpg.passes.frontend.ImportsPass._
import better.files.File as BFile
import io.joern.x2cpg.passes.frontend.ImportsPass.*
import io.joern.x2cpg.passes.frontend.XImportResolverPass
import io.shiftleft.codepropertygraph.Cpg
import io.shiftleft.codepropertygraph.generated.nodes._
import io.shiftleft.semanticcpg.language._
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.shiftleft.semanticcpg.language.*

import java.io.{File => JFile}
import java.io.File as JFile
import java.util.regex.{Matcher, Pattern}
import better.files.File as BFile

class ImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) {

Expand Down Expand Up @@ -127,7 +126,11 @@ class ImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) {
val pyFile = BFile(codeRoot) / s"$path.py"
fileOrDir match {
case f if f.isDirectory && !pyFile.exists =>
Seq(s"${path.replaceAll("\\.", sep)}${java.io.File.separator}$expEntity.py:<module>").toResolvedImport(cpg)
val namespace = path.replaceAll("\\.", sep)
val module = s"$expEntity.py:<module>"
val initSubmodule = s"__init__.py:<module>.$expEntity"
Seq(s"$namespace${JFile.separator}$module", s"$namespace${JFile.separator}$initSubmodule")
.toResolvedImport(cpg)
case f if f.isDirectory && (f / s"$expEntity.py").exists =>
Seq(s"${(f / s"$expEntity.py").pathAsString.stripPrefix(codeRoot)}:<module>").toResolvedImport(cpg)
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,16 @@ class NodeBuilder(diffGraph: DiffGraphBuilder) {
addNodeToDiff(returnNode)
}

def identifierNode(name: String, lineAndColumn: LineAndColumn): nodes.NewIdentifier = {
def identifierNode(
name: String,
lineAndColumn: LineAndColumn,
typeFullName: String = Constants.ANY
): nodes.NewIdentifier = {
val identifierNode = nodes
.NewIdentifier()
.code(name)
.name(name)
.typeFullName(Constants.ANY)
.typeFullName(typeFullName)
.lineNumber(lineAndColumn.line)
.columnNumber(lineAndColumn.column)
addNodeToDiff(identifierNode)
Expand Down
Loading

0 comments on commit 4d58f8f

Please sign in to comment.