Skip to content

Commit

Permalink
[c#] Hotfix for CSharp Call Resolution (joernio#4179)
Browse files Browse the repository at this point in the history
Hot and messy call resolution. Some technical debt here in `astForInvocationExpression` but calls are now largely being resolved.

Fixed some bugs and tests
  • Loading branch information
DavidBakerEffendi authored Feb 15, 2024
1 parent ed42be2 commit e779842
Show file tree
Hide file tree
Showing 10 changed files with 173 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
"returnType": "void",
"parameterTypes": [
["format", "object"]
]
],
"isStatic": true
},
{
"name": "WriteLine",
"returnType": "void",
"parameterTypes": [
["format", "object"]
]
],
"isStatic": true
}
],
"fields": []
Expand All @@ -30,12 +32,14 @@
"returnType": "void",
"parameterTypes": [
["count", "int"]
]
],
"isStatic": true
},
{
"name": "Clear",
"returnType": "void",
"parameterTypes": []
"parameterTypes": [],
"isStatic": true
}
],
"fields": []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,14 @@ package io.joern.csharpsrc2cpg
import com.typesafe.config.impl.*
import com.typesafe.config.{Config, ConfigFactory}
import io.joern.csharpsrc2cpg.astcreation.AstCreatorHelper
import io.joern.csharpsrc2cpg.parser.DotNetJsonAst.{
ClassDeclaration,
FieldDeclaration,
FileScopedNamespaceDeclaration,
InterfaceDeclaration,
MethodDeclaration,
NamespaceDeclaration,
RecordDeclaration,
StructDeclaration,
DeclarationExpr
}
import io.joern.csharpsrc2cpg.parser.DotNetJsonAst.*
import io.joern.csharpsrc2cpg.parser.{DotNetJsonAst, DotNetJsonParser, DotNetNodeInfo, ParserKeys}
import io.joern.x2cpg.Ast
import io.joern.x2cpg.astgen.AstGenRunner.AstGenRunnerResult
import io.joern.x2cpg.datastructures.Stack.Stack
import io.joern.x2cpg.utils.ConcurrentTaskUtil
import io.shiftleft.codepropertygraph.generated.nodes.NewNode
import io.shiftleft.codepropertygraph.generated.ModifierTypes
import io.shiftleft.codepropertygraph.generated.nodes.{NewModifier, NewNode}
import io.shiftleft.semanticcpg.language.*
import org.slf4j.LoggerFactory
import upickle.default.*
Expand Down Expand Up @@ -64,7 +56,10 @@ class TypeMap(astGenResult: AstGenRunnerResult, initialMappings: List[NamespaceT
val compilationUnit = AstCreatorHelper.createDotNetNodeInfo(parserResult.json(ParserKeys.AstRoot))
() => parseCompilationUnit(compilationUnit)
}.iterator
val typeMaps = ConcurrentTaskUtil.runUsingSpliterator(typeMapTasks).flatMap(_.toOption)
val typeMaps = ConcurrentTaskUtil.runUsingSpliterator(typeMapTasks).flatMap {
case Failure(exception) => logger.warn("Exception encountered during pre-parsing", exception); None
case Success(typeMap) => Option(typeMap)
}
(builtinTypes +: typeMaps ++: initialMappings).foldLeft(Map.empty[String, Set[CSharpType]])((a, b) => {
val accumulator = mutable.HashMap.from(a)
val allKeys = accumulator.keySet ++ b.keySet
Expand Down Expand Up @@ -114,14 +109,15 @@ class TypeMap(astGenResult: AstGenRunnerResult, initialMappings: List[NamespaceT
case None => Option(typesInNamespace)
}
}
case ClassDeclaration | InterfaceDeclaration => {
case _: TypeDeclaration => {
val globalClass = Set(parseClassDeclaration(parserNode, "global"))
namespaceTypeMap.updateWith("global") {
case Some(types) =>
Option(types ++ globalClass)
case None => Option(globalClass)
}
}
case _ =>
}

namespaceTypeMap.toMap
Expand Down Expand Up @@ -166,17 +162,17 @@ class TypeMap(astGenResult: AstGenRunnerResult, initialMappings: List[NamespaceT
.json(ParserKeys.ParameterList)
.obj(ParserKeys.Parameters)
.arr

val isStatic = methodDecl.json(ParserKeys.Modifiers).arr.exists(_(ParserKeys.Value).str == "static")
val methodReturn = AstCreatorHelper.createDotNetNodeInfo(methodDecl.json(ParserKeys.ReturnType)).code
val paramTypes = params
.map(param => AstCreatorHelper.createDotNetNodeInfo(param))
.map { param =>
val typ = param.json(ParserKeys.Type)(ParserKeys.Keyword)(ParserKeys.Value).str
val typ = AstCreatorHelper.createDotNetNodeInfo(param.json(ParserKeys.Type)).code
val name = param.json(ParserKeys.Identifier)(ParserKeys.Value).str
(name, typ)
}

List(CSharpMethod(AstCreatorHelper.nameFromNode(methodDecl), methodReturn, paramTypes.toList))
List(CSharpMethod(AstCreatorHelper.nameFromNode(methodDecl), methodReturn, paramTypes.toList, isStatic))
}

private def parseFieldDeclaration(fieldDecl: DotNetNodeInfo): List[CSharpField] = {
Expand All @@ -194,6 +190,7 @@ class TypeMap(astGenResult: AstGenRunnerResult, initialMappings: List[NamespaceT

case class CSharpField(name: String) derives ReadWriter

case class CSharpMethod(name: String, returnType: String, parameterTypes: List[(String, String)]) derives ReadWriter
case class CSharpMethod(name: String, returnType: String, parameterTypes: List[(String, String)], isStatic: Boolean)
derives ReadWriter

case class CSharpType(name: String, methods: List[CSharpMethod], fields: List[CSharpField]) derives ReadWriter
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package io.joern.csharpsrc2cpg.astcreation

import io.joern.csharpsrc2cpg.astcreation.AstCreatorHelper.nameFromIdentifier
import io.joern.csharpsrc2cpg.parser.DotNetJsonAst.*
import io.joern.csharpsrc2cpg.parser.{DotNetJsonAst, DotNetNodeInfo, ParserKeys}
import io.joern.csharpsrc2cpg.{Constants, astcreation}
Expand Down Expand Up @@ -155,6 +154,8 @@ object AstCreatorHelper {
case IdentifierName | Parameter | _: DeclarationExpr => nameFromIdentifier(node)
case QualifiedName => nameFromQualifiedName(node)
case SimpleMemberAccessExpression => nameFromIdentifier(createDotNetNodeInfo(node.json(ParserKeys.Name)))
case ObjectCreationExpression => nameFromNode(createDotNetNodeInfo(node.json(ParserKeys.Type)))
case ThisExpression => "this"
case _ => "<empty>"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import io.joern.csharpsrc2cpg.parser.DotNetJsonAst.*
import io.joern.csharpsrc2cpg.parser.{DotNetNodeInfo, ParserKeys}
import io.joern.x2cpg.utils.NodeBuilders.{newIdentifierNode, newOperatorCallNode}
import io.joern.x2cpg.{Ast, Defines, ValidationMode}
import io.shiftleft.codepropertygraph.generated.nodes.{NewFieldIdentifier, NewIdentifier, NewMethodParameterIn}
import io.shiftleft.codepropertygraph.generated.nodes.NewFieldIdentifier
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators}
trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>

Expand Down Expand Up @@ -123,11 +123,9 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
astForNode(rhsNode)
}

// TODO: This method is becoming a mess...
private def astForInvocationExpression(invocationExpr: DotNetNodeInfo): Seq[Ast] = {
val dispatchType = DispatchTypes.STATIC_DISPATCH // TODO
val arguments = astForArgumentList(createDotNetNodeInfo(invocationExpr.json(ParserKeys.ArgumentList)))
val argString =
s"${arguments.flatMap(_.root).collect { case x: NewMethodParameterIn => x.typeFullName }.mkString(",")}"
val arguments = astForArgumentList(createDotNetNodeInfo(invocationExpr.json(ParserKeys.ArgumentList)))

val expression = createDotNetNodeInfo(invocationExpr.json(ParserKeys.Expression))
val name = nameFromNode(expression)
Expand All @@ -137,35 +135,55 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
val baseNode = createDotNetNodeInfo(
createDotNetNodeInfo(invocationExpr.json(ParserKeys.Expression)).json(ParserKeys.Expression)
)
val baseNodeName = nameFromNode(baseNode)
val staticTypeRef = scope.tryResolveTypeReference(baseNodeName)
val baseIdentifier = astForIdentifier(baseNode, staticTypeRef.getOrElse(Defines.Any))
val _typeFullName = getTypeFullNameFromAstNode(Seq(baseIdentifier))

val baseIdentifier =
astForIdentifier(baseNode)
val _typeFullName = getTypeFullNameFromAstNode(Seq(baseIdentifier))

if (_typeFullName.isEmpty) {
if (_typeFullName != Defines.Any) {
val _identifierNode =
identifierNode(baseNode, nameFromNode(baseNode), code(baseNode), nodeTypeFullName(baseNode))
identifierNode(baseNode, nameFromNode(baseNode), code(baseNode), _typeFullName)
(Option(Ast(_identifierNode)), Option(_identifierNode.typeFullName))
} else if (staticTypeRef.isDefined) {
(Option(baseIdentifier), Option(staticTypeRef.get))
} else {
(Option(baseIdentifier), Option(_typeFullName))
}
case _ => (None, None)

val partialFullName = baseTypeFullName match
case Some(typeFullName) => s"$typeFullName.$name"
lazy val partialFullName = baseTypeFullName match
case Some(typeFullName) =>
s"$typeFullName.$name"
case _ =>
s"${Defines.UnresolvedNamespace}.$name"

val returnType =
scope.tryResolveMethodReturn(baseTypeFullName.getOrElse(scope.surroundingTypeDeclFullName.getOrElse("")), name);

val signature = scope
val parameterSignature = scope
.tryResolveMethodSignature(baseTypeFullName.getOrElse(scope.surroundingTypeDeclFullName.getOrElse("")), name)
.getOrElse(Defines.UnresolvedSignature)
val typeFullName = returnType.getOrElse(Defines.Any);

val methodFullName =
s"$partialFullName:${returnType.getOrElse(Defines.Unknown)}(${signature})"
val typeFullName = returnType.getOrElse(Defines.Any)

val methodSignature = s"${returnType.getOrElse(Defines.Unknown)}($parameterSignature)"
val defaultFullName = s"$partialFullName:$methodSignature"
val (methodFullName, dispatchType) = baseTypeFullName match {
case Some(baseFullName) if scope.tryResolveMethodInvocation(baseFullName, name).isDefined =>
val methodMetaData = scope.tryResolveMethodInvocation(baseFullName, name).get
s"$baseFullName.${methodMetaData.name}:$methodSignature" -> (if methodMetaData.isStatic then
DispatchTypes.STATIC_DISPATCH
else DispatchTypes.DYNAMIC_DISPATCH)
case None
if scope.surroundingTypeDeclFullName.isDefined && scope
.tryResolveMethodInvocation(scope.surroundingTypeDeclFullName.get, name)
.isDefined =>
val baseTypeFullName = scope.surroundingTypeDeclFullName.get
val methodMetaData = scope.tryResolveMethodInvocation(baseTypeFullName, name).get
s"$baseTypeFullName.${methodMetaData.name}:$methodSignature" -> (if methodMetaData.isStatic then
DispatchTypes.STATIC_DISPATCH
else DispatchTypes.DYNAMIC_DISPATCH)
case _ => defaultFullName -> DispatchTypes.STATIC_DISPATCH
}

val _callAst = callAst(
callNode(
Expand All @@ -174,7 +192,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
name,
methodFullName,
dispatchType,
Option(signature),
Option(methodSignature),
Option(typeFullName)
),
arguments,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ package io.joern.csharpsrc2cpg.astcreation

import io.joern.csharpsrc2cpg.parser.DotNetJsonAst.*
import io.joern.csharpsrc2cpg.parser.{DotNetNodeInfo, ParserKeys}
import io.joern.x2cpg.{Ast, ValidationMode}
import io.joern.x2cpg.{Ast, Defines, ValidationMode}
import io.shiftleft.codepropertygraph.generated.nodes.{NewCall, NewLocal}
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators}

trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>

protected def astForIdentifier(ident: DotNetNodeInfo, typeFullName: String = ""): Ast = {
protected def astForIdentifier(ident: DotNetNodeInfo, typeFullName: String = Defines.Any): Ast = {
val identifierName = nameFromNode(ident)
if identifierName != "_" then {
val variableOption = scope.lookupVariable(identifierName)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package io.joern.csharpsrc2cpg.datastructures

import io.joern.csharpsrc2cpg.astcreation.{AstCreatorHelper, BuiltinTypes}
import io.joern.csharpsrc2cpg.{CSharpType, TypeMap}
import io.joern.csharpsrc2cpg.astcreation.BuiltinTypes
import io.joern.csharpsrc2cpg.{CSharpMethod, CSharpType, TypeMap}
import io.joern.x2cpg.Defines
import io.joern.x2cpg.datastructures.{Scope, ScopeElement}
import io.shiftleft.codepropertygraph.generated.nodes.DeclarationNew

import scala.collection.mutable
import scala.util.boundary, boundary.break
import scala.util.boundary

class CSharpScope(typeMap: TypeMap) extends Scope[String, DeclarationNew, ScopeType] {

Expand Down Expand Up @@ -48,22 +48,23 @@ class CSharpScope(typeMap: TypeMap) extends Scope[String, DeclarationNew, ScopeT
.exists(x => x.scopeNode.isInstanceOf[MethodScope] || x.scopeNode.isInstanceOf[TypeLikeScope])

def tryResolveTypeReference(typeName: String): Option[String] = {
typesInScope
.find(_.name.endsWith(typeName))
.flatMap(typeMap.namespaceFor)
.map(n => {
// To avoid recursive type prefixing on assignment calls.
if (typeName.startsWith(n)) {
return Some(typeName)
} else {
return Some(s"$n.$typeName")
if (typeName == "this") {
surroundingTypeDeclFullName
} else {
typesInScope
.find(_.name.endsWith(typeName))
.flatMap(typeMap.namespaceFor)
.map {
// To avoid recursive type prefixing on assignment calls.
case n if typeName.startsWith(n) => typeName
case n => s"$n.$typeName"
}
})
}
}

def tryResolveMethodInvocation(typeFullName: String, callName: String): Option[String] = {
def tryResolveMethodInvocation(typeFullName: String, callName: String): Option[CSharpMethod] = {
typesInScope.find(_.name.endsWith(typeFullName)).flatMap { t =>
t.methods.find(_.name == callName).map { m => s"${t.name}.${m.name}" }
t.methods.find(_.name == callName)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,19 @@ object DotNetJsonAst {

sealed trait DeclarationExpr extends BaseExpr

object ClassDeclaration extends DeclarationExpr
sealed trait TypeDeclaration extends DeclarationExpr

object StructDeclaration extends DeclarationExpr
object ClassDeclaration extends TypeDeclaration

object RecordDeclaration extends DeclarationExpr
object StructDeclaration extends TypeDeclaration

object EnumDeclaration extends DeclarationExpr
object RecordDeclaration extends TypeDeclaration

object EnumDeclaration extends TypeDeclaration

object EnumMemberDeclaration extends DeclarationExpr

object InterfaceDeclaration extends DeclarationExpr
object InterfaceDeclaration extends TypeDeclaration

object MethodDeclaration extends DeclarationExpr

Expand Down
Loading

0 comments on commit e779842

Please sign in to comment.