Skip to content

Commit

Permalink
[C#] Handle chained calls and fields with conditional access and null…
Browse files Browse the repository at this point in the history
… forgiving operators (joernio#4369)

* handle chained calls with conditional access and null forgiving operators

* refactor

* refactor
  • Loading branch information
karan-batavia authored Mar 20, 2024
1 parent d407c93 commit 709c05d
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -495,40 +495,59 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
)
}

private def astForConditionalAccessExpression(condAccExpr: DotNetNodeInfo): Seq[Ast] = {
private def astForConditionalAccessExpression(
condAccExpr: DotNetNodeInfo,
baseType: Option[String] = None
): Seq[Ast] = {
val baseNode = createDotNetNodeInfo(condAccExpr.json(ParserKeys.Expression))
val baseAst = astForNode(baseNode)

val fieldIdentifier = fieldIdentifierNode(baseNode, baseNode.code, baseNode.code)

val baseTypeFullName = getTypeFullNameFromAstNode(baseAst)
val baseTypeFullName =
if (getTypeFullNameFromAstNode(baseAst).equals(Defines.Any)) baseType
else Option(getTypeFullNameFromAstNode(baseAst))

Try(createDotNetNodeInfo(condAccExpr.json(ParserKeys.WhenNotNull)(ParserKeys.Name))).toOption match {
Try(createDotNetNodeInfo(condAccExpr.json(ParserKeys.WhenNotNull))).toOption match {
case Some(node) =>
// Got a member access
val typ = scope
.tryResolveFieldAccess(node.code, Option(baseTypeFullName))
.map(_.typeName)
.orElse(Option(Defines.Any))

val identifier = newIdentifierNode(baseNode.code, baseTypeFullName)
val fieldAccessCode = s"${baseNode.code}?.${node.code}"
val fieldAccess = newOperatorCallNode(
Operators.fieldAccess,
fieldAccessCode,
typ,
condAccExpr.lineNumber,
condAccExpr.columnNumber
)
val fieldIdentifierAst = Ast(fieldIdentifier)

Seq(callAst(fieldAccess, baseAst ++ Seq(fieldIdentifierAst)))
case _ => astForInvocationExpression(createDotNetNodeInfo(condAccExpr.json(ParserKeys.WhenNotNull)))
node.node match {
case ConditionalAccessExpression =>
astForConditionalAccessExpression(node, baseTypeFullName)
case MemberBindingExpression => astForMemberBindingExpression(node, baseTypeFullName)
case InvocationExpression =>
astForInvocationExpression(node)
case _ => astForNode(node)
}
case None => Seq.empty[Ast]
}
}

private def astForSuppressNullableWarningExpression(suppressNullableExpr: DotNetNodeInfo): Seq[Ast] = {
val _identifierNode = createDotNetNodeInfo(suppressNullableExpr.json(ParserKeys.Operand))
Seq(astForIdentifier(_identifierNode))
}

private def astForMemberBindingExpression(
memberBindingExpr: DotNetNodeInfo,
baseTypeFullName: Option[String] = None
): Seq[Ast] = {
val typ = scope
.tryResolveFieldAccess(nameFromNode(memberBindingExpr), baseTypeFullName)
.map(_.typeName)
.map(f => scope.tryResolveTypeReference(f).map(_.name).orElse(Option(f)))
.getOrElse(Option(Defines.Any))

val fieldIdentifier = fieldIdentifierNode(memberBindingExpr, memberBindingExpr.code, memberBindingExpr.code)

val identifier = newIdentifierNode(memberBindingExpr.code, baseTypeFullName.getOrElse(Defines.Any))
val fieldAccess =
newOperatorCallNode(
Operators.fieldAccess,
memberBindingExpr.code,
typ,
memberBindingExpr.lineNumber,
memberBindingExpr.columnNumber
)
val fieldIdentifierAst = Ast(fieldIdentifier)

Seq(callAst(fieldAccess, Seq(Ast(identifier)) ++ Seq(fieldIdentifierAst)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,125 @@ class MemberAccessTests extends CSharpCode2CpgFixture {
}
}
}

"conditional method access expressions for chained calls" should {
val cpg = code("""
|namespace Foo {
| public class Baz {
| public int Qux() {}
| public Baz Fred(int a) {}
| }
| public class Bar {
| public static void Main() {
| var baz = new Baz();
| var b = baz?.Fred(1)?.Fred(2)?.Qux();
| }
| }
|}
|""".stripMargin)

"have correct types and attributes both on the LHS and RHS" in {
inside(cpg.assignment.l.sortBy(_.lineNumber).drop(1).l) {
case a :: Nil =>
inside(a.argument.l) {
case (lhs: Identifier) :: (rhs: Call) :: Nil =>
lhs.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.Int)
rhs.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.Int)
case _ => fail("Expected 2 arguments under the assignment call")
}
case _ => fail("Expected 1 assignment call.")
}
}
}

"chained calls for fields and methods together" should {
"combination of method access expression for chained fields" should {
val cpg = code("""
|namespace Foo {
| public class Baz {
| public Baz Qux {get;}
| public int Fred() {}
| }
| public class Bar {
| public static void Main() {
| var baz = new Baz();
| var b = baz.Qux.Qux.Qux.Fred();
| }
| }
|}
|""".stripMargin)

"have correct types and attributes both on the LHS and RHS" in {
inside(cpg.assignment.l.sortBy(_.lineNumber).drop(1).l) {
case a :: Nil =>
inside(a.argument.l) {
case (lhs: Identifier) :: (rhs: Call) :: Nil =>
lhs.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.Int)
rhs.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.Int)
case _ => fail("Expected 2 arguments under the assignment call")
}
case _ => fail("Expected 1 assignment call.")
}
}
}
}

"conditional method access expression for chained fields" should {
val cpg = code("""
|namespace Foo {
| public class Baz {
| public Baz Qux {get;}
| }
| public class Bar {
| public static void Main() {
| var baz = new Baz();
| var b = baz?.Qux?.Qux;
| }
| }
|}
|""".stripMargin)

"have correct types and attributes both on the LHS and RHS" in {
inside(cpg.assignment.l.sortBy(_.lineNumber).drop(1).l) {
case a :: Nil =>
inside(a.argument.l) {
case (lhs: Identifier) :: (rhs: Call) :: Nil =>
lhs.typeFullName shouldBe "Foo.Baz"
rhs.typeFullName shouldBe "Foo.Baz"
case _ => fail("Expected 2 arguments under the assignment call")
}
case _ => fail("Expected 1 assignment call.")
}
}
}

"combination of method access expression for chained fields" should {
val cpg = code("""
|namespace Foo {
| public class Baz {
| public Baz Qux {get;}
| public int Fred() {}
| }
| public class Bar {
| public static void Main() {
| var baz = new Baz();
| var b = baz.Qux?.Qux!.Qux.Fred();
| }
| }
|}
|""".stripMargin)

"have correct types and attributes both on the LHS and RHS" in {
inside(cpg.assignment.l.sortBy(_.lineNumber).drop(1).l) {
case a :: Nil =>
inside(a.argument.l) {
case (lhs: Identifier) :: (rhs: Call) :: Nil =>
lhs.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.Int)
rhs.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.Int)
case _ => fail("Expected 2 arguments under the assignment call")
}
case _ => fail("Expected 1 assignment call.")
}
}
}
}

0 comments on commit 709c05d

Please sign in to comment.