From 9644ce813cde192c5dbbff8bfd7b9bcd5f83caca Mon Sep 17 00:00:00 2001 From: KhemrajSingh Rathore Date: Wed, 11 Oct 2023 13:17:18 +0530 Subject: [PATCH] [rubysrc2cpg] Revert Refactoring Commits to Preserve Stability (#3732) * Revert "[rubysrc2cpg] Do-Block Function as Conditional (#3729)" This reverts commit c726bb528c217441c783c95fb99e50383ceb3a26. * Revert "[rubysrc2cpg] Chained Higher Order Methods (#3727)" This reverts commit 39dd417d5ba074edeaaaa408484e0c4915c43cad. * Revert "[rubysrc2cpg] Fixed Duplicate Do-Blocks Due to Side-Effect (#3720)" This reverts commit 3f0e745f715bbbaf2a6294d69e5c8a8df18938ed. * Revert "[rubysrc2cpg] Comprehensive follow up to #3708's fix (#3714)" This reverts commit 662bd23ea0fe5e78c2cb518d522fd5fc8bb4b7ce. * Revert "[rubysrc2cpg] Fixed Bug with Higher-Order Functions (#3708)" This reverts commit 3e9ea9775b651a579b759099c737b7e90da8c1af. * Revert "[rubysrc2cpg] General Do-Block Function Fixes (#3676)" This reverts commit 4ef5cdc23ffecf7fc279bd6ca34c21c03251a364. --- .../rubysrc2cpg/astcreation/AstCreator.scala | 192 +++++++++--------- .../astcreation/AstCreatorHelper.scala | 31 --- .../AstForControlStructuresCreator.scala | 51 ----- .../AstForExpressionsCreator.scala | 166 +++++++++------ .../astcreation/AstForFunctionsCreator.scala | 101 ++------- .../astcreation/AstForStatementsCreator.scala | 71 ++++--- .../rubysrc2cpg/astcreation/RubyScope.scala | 51 +---- .../io/joern/rubysrc2cpg/passes/Defines.scala | 16 -- .../rubysrc2cpg/passes/ast/DoBlockTest.scala | 98 --------- .../passes/ast/MethodTwoTests.scala | 58 ++++-- .../ast/SimpleAstCreationPassTest.scala | 2 +- .../querying/ControlStructureTests.scala | 6 +- .../rubysrc2cpg/querying/MiscTests.scala | 2 +- 13 files changed, 294 insertions(+), 551 deletions(-) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala index e35f05e8730a..c3f9852e185f 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala @@ -106,11 +106,7 @@ class AstCreator( programCtx.compoundStatement() != null && programCtx.compoundStatement().statements() != null ) { - val stmts = astForStatements(programCtx.compoundStatement().statements(), false, false) - val methodNodes = stmts.flatMap(_.nodes).collect { case x: NewMethod => x }.toSet - // Block methods is largely unnecessary, but will keep it for AST functions that still populate it instead of - // attaching it to AST - stmts ++ blockMethods.filterNot(_.root.collect { case x: NewMethod => x }.exists(methodNodes.contains)) + astForStatements(programCtx.compoundStatement().statements(), false, false) ++ blockMethods } else { logger.error(s"File $filename has no compound statement. Needs to be examined") List[Ast](Ast()) @@ -143,7 +139,23 @@ class AstCreator( .filterNot(_.astParentType == NodeTypes.TYPE_DECL) .map { methodNode => // Create a methodRefNode and assign it to the identifier version of the method, which will help in type propagation to resolve calls - methodRefAssignmentFromMethod(methodNode, Option(lineColNum), Option(lineColNum)) + val methodRefNode = NewMethodRef() + .code("def " + methodNode.name + "(...)") + .methodFullName(methodNode.fullName) + .typeFullName(methodNode.fullName) + .lineNumber(lineColNum) + .columnNumber(lineColNum) + + val methodNameIdentifier = NewIdentifier() + .code(methodNode.name) + .name(methodNode.name) + .typeFullName(Defines.Any) + .lineNumber(lineColNum) + .columnNumber(lineColNum) + scope.addToScope(methodNode.name, methodNameIdentifier) + val methodRefAssignmentAst = + astForAssignment(methodNameIdentifier, methodRefNode, methodNode.lineNumber, methodNode.columnNumber) + methodRefAssignmentAst } .toList @@ -199,13 +211,13 @@ class AstCreator( case ctx: MethodDefinitionPrimaryContext => astForMethodDefinitionContext(ctx.methodDefinition()) case ctx: ProcDefinitionPrimaryContext => astForProcDefinitionContext(ctx.procDefinition()) case ctx: YieldWithOptionalArgumentPrimaryContext => - astForYieldCall(ctx, Option(ctx.yieldWithOptionalArgument().arguments())) - case ctx: IfExpressionPrimaryContext => astForIfExpression(ctx.ifExpression()) - case ctx: UnlessExpressionPrimaryContext => astForUnlessExpression(ctx.unlessExpression()) + Seq(astForYieldCall(ctx, Option(ctx.yieldWithOptionalArgument().arguments()))) + case ctx: IfExpressionPrimaryContext => Seq(astForIfExpression(ctx.ifExpression())) + case ctx: UnlessExpressionPrimaryContext => Seq(astForUnlessExpression(ctx.unlessExpression())) case ctx: CaseExpressionPrimaryContext => astForCaseExpressionPrimaryContext(ctx) - case ctx: WhileExpressionPrimaryContext => astForWhileExpression(ctx.whileExpression()) - case ctx: UntilExpressionPrimaryContext => astForUntilExpression(ctx.untilExpression()) - case ctx: ForExpressionPrimaryContext => astForForExpression(ctx.forExpression()) + case ctx: WhileExpressionPrimaryContext => Seq(astForWhileExpression(ctx.whileExpression())) + case ctx: UntilExpressionPrimaryContext => Seq(astForUntilExpression(ctx.untilExpression())) + case ctx: ForExpressionPrimaryContext => Seq(astForForExpression(ctx.forExpression())) case ctx: ReturnWithParenthesesPrimaryContext => Seq(returnAst(returnNode(ctx, text(ctx)), astForArgumentsWithParenthesesContext(ctx.argumentsWithParentheses()))) case ctx: JumpExpressionPrimaryContext => astForJumpExpressionPrimaryContext(ctx) @@ -224,8 +236,8 @@ class AstCreator( case ctx: RegexInterpolationPrimaryContext => astForRegexInterpolationPrimaryContext(ctx.regexInterpolation) case ctx: QuotedRegexInterpolationPrimaryContext => astForQuotedRegexInterpolation(ctx.quotedRegexInterpolation) - case ctx: IsDefinedPrimaryContext => astForIsDefinedPrimaryExpression(ctx) - case ctx: SuperExpressionPrimaryContext => astForSuperExpression(ctx) + case ctx: IsDefinedPrimaryContext => Seq(astForIsDefinedPrimaryExpression(ctx)) + case ctx: SuperExpressionPrimaryContext => Seq(astForSuperExpression(ctx)) case ctx: IndexingExpressionPrimaryContext => astForIndexingExpressionPrimaryContext(ctx) case ctx: MethodOnlyIdentifierPrimaryContext => astForMethodOnlyIdentifier(ctx.methodOnlyIdentifier()) case ctx: InvocationWithBlockOnlyPrimaryContext => astForInvocationWithBlockOnlyPrimaryContext(ctx) @@ -240,23 +252,23 @@ class AstCreator( def astForExpressionContext(ctx: ExpressionContext): Seq[Ast] = ctx match { case ctx: PrimaryExpressionContext => astForPrimaryContext(ctx.primary()) - case ctx: UnaryExpressionContext => astForUnaryExpression(ctx) - case ctx: PowerExpressionContext => astForPowerExpression(ctx) - case ctx: UnaryMinusExpressionContext => astForUnaryMinusExpression(ctx) - case ctx: MultiplicativeExpressionContext => astForMultiplicativeExpression(ctx) - case ctx: AdditiveExpressionContext => astForAdditiveExpression(ctx) - case ctx: BitwiseShiftExpressionContext => astForBitwiseShiftExpression(ctx) - case ctx: BitwiseAndExpressionContext => astForBitwiseAndExpression(ctx) - case ctx: BitwiseOrExpressionContext => astForBitwiseOrExpression(ctx) - case ctx: RelationalExpressionContext => astForRelationalExpression(ctx) - case ctx: EqualityExpressionContext => astForEqualityExpression(ctx) - case ctx: OperatorAndExpressionContext => astForAndExpression(ctx) - case ctx: OperatorOrExpressionContext => astForOrExpression(ctx) + case ctx: UnaryExpressionContext => Seq(astForUnaryExpression(ctx)) + case ctx: PowerExpressionContext => Seq(astForPowerExpression(ctx)) + case ctx: UnaryMinusExpressionContext => Seq(astForUnaryMinusExpression(ctx)) + case ctx: MultiplicativeExpressionContext => Seq(astForMultiplicativeExpression(ctx)) + case ctx: AdditiveExpressionContext => Seq(astForAdditiveExpression(ctx)) + case ctx: BitwiseShiftExpressionContext => Seq(astForBitwiseShiftExpression(ctx)) + case ctx: BitwiseAndExpressionContext => Seq(astForBitwiseAndExpression(ctx)) + case ctx: BitwiseOrExpressionContext => Seq(astForBitwiseOrExpression(ctx)) + case ctx: RelationalExpressionContext => Seq(astForRelationalExpression(ctx)) + case ctx: EqualityExpressionContext => Seq(astForEqualityExpression(ctx)) + case ctx: OperatorAndExpressionContext => Seq(astForAndExpression(ctx)) + case ctx: OperatorOrExpressionContext => Seq(astForOrExpression(ctx)) case ctx: RangeExpressionContext => astForRangeExpressionContext(ctx) case ctx: ConditionalOperatorExpressionContext => Seq(astForTernaryConditionalOperator(ctx)) case ctx: SingleAssignmentExpressionContext => astForSingleAssignmentExpressionContext(ctx) case ctx: MultipleAssignmentExpressionContext => astForMultipleAssignmentExpressionContext(ctx) - case ctx: IsDefinedExpressionContext => astForIsDefinedExpression(ctx) + case ctx: IsDefinedExpressionContext => Seq(astForIsDefinedExpression(ctx)) case _ => logger.error(s"astForExpressionContext() $relativeFilename, ${text(ctx)} All contexts mismatched.") Seq(Ast()) @@ -279,10 +291,9 @@ class AstCreator( .asScala .flatMap(astForExpressionContext) .toSeq - val splatAsts = astForExpressionOrCommand(ctx.splattingArgument().expressionOrCommand()) - val callNode = createOpCall(ctx.COMMA, Operators.arrayInitializer, text(ctx)) - val (argAsts, otherAsts) = (expAsts ++ splatAsts).partitionExprAst - otherAsts :+ callAst(callNode, argAsts) + val splatAsts = astForExpressionOrCommand(ctx.splattingArgument().expressionOrCommand()) + val callNode = createOpCall(ctx.COMMA, Operators.arrayInitializer, text(ctx)) + Seq(callAst(callNode, expAsts ++ splatAsts)) case ctx: AssociationsOnlyIndexingArgumentsContext => astForAssociationsContext(ctx.associations()) case ctx: RubyParser.SplattingOnlyIndexingArgumentsContext => @@ -310,11 +321,11 @@ class AstCreator( ctx.AMPDOT() } - val (argsAst, otherAst) = (if (ctx.argumentsWithParentheses() != null) { - astForArgumentsWithParenthesesContext(ctx.argumentsWithParentheses()) - } else { - Seq() - }).partitionExprAst + val argsAst = if (ctx.argumentsWithParentheses() != null) { + astForArgumentsWithParenthesesContext(ctx.argumentsWithParentheses()) + } else { + Seq() + } if (hasBlockStmt) { val blockName = methodNameAst.head.nodes.head @@ -331,35 +342,29 @@ class AstCreator( lineEnd(ctx).head, columnEnd(ctx).head ) + val blockMethodNode = + blockMethodAsts.head.nodes.head + .asInstanceOf[NewMethod] - blockMethodAsts - ++ blockMethodAsts - .flatMap(_.nodes) - .collectFirst { case methodRefNode: NewMethodRef => - val callNode = NewCall() - .name(blockName) - .methodFullName(methodRefNode.methodFullName) - .typeFullName(Defines.Any) - .code(methodRefNode.code) - .lineNumber(methodRefNode.lineNumber) - .columnNumber(methodRefNode.columnNumber) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - // TODO: primaryAst.headOption is broken when primaryAst is an array - val (exprAst, otherAst) = primaryAst.partitionExprAst - - exprAst.lastOption match - case Some(receiverCallAst: Ast) - if exprAst.size > 1 && receiverCallAst.root.exists(_.isInstanceOf[NewCall]) => - val siblings = exprAst.take(exprAst.length - 2) - (otherAst ++ siblings) :+ callAst( - callNode, - argsAst ++ Seq(Ast(methodRefNode.copy)), - Option(receiverCallAst) - ) - case _ => - otherAst :+ callAst(callNode, argsAst ++ Seq(Ast(methodRefNode.copy)), exprAst.headOption) - } - .getOrElse(Seq.empty) + blockMethods.addOne(blockMethodAsts.head) + + val callNode = NewCall() + .name(blockName) + .methodFullName(blockMethodNode.fullName) + .typeFullName(Defines.Any) + .code(blockMethodNode.code) + .lineNumber(blockMethodNode.lineNumber) + .columnNumber(blockMethodNode.columnNumber) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + + val methodRefNode = NewMethodRef() + .methodFullName(blockMethodNode.fullName) + .typeFullName(Defines.Any) + .code(blockMethodNode.code) + .lineNumber(blockMethodNode.lineNumber) + .columnNumber(blockMethodNode.columnNumber) + + Seq(callAst(callNode, argsAst ++ Seq(Ast(methodRefNode)), primaryAst.headOption)) } else { val callNode = methodNameAst.head.nodes .filter(node => node.isInstanceOf[NewCall]) @@ -369,7 +374,7 @@ class AstCreator( if (callNode.name == "call" && ctx.primary().isInstanceOf[ProcDefinitionPrimaryContext]) { // this is a proc.call val baseCallNode = primaryAst.head.nodes.head.asInstanceOf[NewCall] - otherAst :+ callAst(baseCallNode, argsAst) + Seq(callAst(baseCallNode, argsAst)) } else { callNode .code(text(ctx)) @@ -383,9 +388,9 @@ class AstCreator( .methodFullName(methodNode.fullName) .typeFullName(Defines.Any) blockMethods.addOne(primaryAst.head) - otherAst :+ callAst(callNode, Seq(Ast(methodRefNode)) ++ argsAst) + Seq(callAst(callNode, Seq(Ast(methodRefNode)) ++ argsAst)) case _ => - otherAst :+ callAst(callNode, argsAst, primaryAst.headOption) + Seq(callAst(callNode, argsAst, primaryAst.headOption)) } } } @@ -417,17 +422,16 @@ class AstCreator( val baseAst = astForPrimaryContext(ctx.primary()) val blocksAst = if (ctx.block() != null) { - astForBlock(ctx.block()) + Seq(astForBlock(ctx.block())) } else { Seq() } val callNode = methodNameAst.head.nodes.filter(node => node.isInstanceOf[NewCall]).head.asInstanceOf[NewCall] callNode .code(text(ctx)) - .lineNumber(ctx.COLON2.lineNumber) - .columnNumber(ctx.COLON2.columnNumber) - val (argsAst, otherAst) = (baseAst ++ blocksAst).partitionExprAst - otherAst :+ callAst(callNode, argsAst) + .lineNumber(ctx.COLON2().getSymbol().getLine()) + .columnNumber(ctx.COLON2().getSymbol().getCharPositionInLine()) + Seq(callAst(callNode, baseAst ++ blocksAst)) } private def astForChainedScopedConstantReferencePrimaryContext( @@ -530,8 +534,8 @@ class AstCreator( .asInstanceOf[NewCall] .name - val isYieldMethod = if (blockName.endsWith(Defines.YIELD_SUFFIX)) { - val lookupMethodName = blockName.take(blockName.length - Defines.YIELD_SUFFIX.length) + val isYieldMethod = if (blockName.endsWith(YIELD_SUFFIX)) { + val lookupMethodName = blockName.take(blockName.length - YIELD_SUFFIX.length) methodNamesWithYield.contains(lookupMethodName) } else { false @@ -551,7 +555,7 @@ class AstCreator( columnEnd(ctx).head ) } else { - val blockAst = astForBlock(ctx.block()) + val blockAst = Seq(astForBlock(ctx.block())) // this is expected to be a call node val callNode = methodIdAst.head.nodes.head.asInstanceOf[NewCall] Seq(callAst(callNode, blockAst)) @@ -565,26 +569,18 @@ class AstCreator( callNode.name(resolveAlias(callNode.name)) if (ctx.block() != null) { - val isYieldMethod = if (callNode.name.endsWith(Defines.YIELD_SUFFIX)) { - val lookupMethodName = callNode.name.take(callNode.name.length - Defines.YIELD_SUFFIX.length) + val isYieldMethod = if (callNode.name.endsWith(YIELD_SUFFIX)) { + val lookupMethodName = callNode.name.take(callNode.name.length - YIELD_SUFFIX.length) methodNamesWithYield.contains(lookupMethodName) } else { false } if (isYieldMethod) { val methAst = astForBlock(ctx.block(), Some(callNode.name)) - methAst - .collectFirst { case x: Ast if x.root.isDefined && x.root.get.isInstanceOf[NewMethod] => x } - .foreach(blockMethods.addOne) + blockMethods.addOne(methAst) Seq(callAst(callNode, parenAst)) - } else if (callNode.name == Defines.DEFINE_METHOD) { - parenAst.headOption - .flatMap(_.root) - .collect { case x: AstNodeNew => stripQuotes(x.code).stripPrefix(":") } - .map(methodName => astForBlock(ctx.block(), Option(methodName))) - .getOrElse(Seq.empty) } else { - val blockAst = astForBlock(ctx.block()) + val blockAst = Seq(astForBlock(ctx.block())) Seq(callAst(callNode, parenAst ++ blockAst)) } } else @@ -593,7 +589,7 @@ class AstCreator( def astForCallNode(ctx: ParserRuleContext, code: String, isYieldBlock: Boolean = false): Ast = { val name = if (isYieldBlock) { - s"${resolveAlias(text(ctx))}${Defines.YIELD_SUFFIX}" + s"${resolveAlias(text(ctx))}$YIELD_SUFFIX" } else { val calleeName = resolveAlias(text(ctx)) // Add the call name to the global builtIn callNames set @@ -668,27 +664,27 @@ class AstCreator( private def astForCommandWithDoBlockContext(ctx: CommandWithDoBlockContext): Seq[Ast] = ctx match { case ctx: ArgsAndDoBlockCommandWithDoBlockContext => val argsAsts = astForArguments(ctx.argumentsWithoutParentheses().arguments()) - val doBlockAst = astForDoBlock(ctx.doBlock()) + val doBlockAst = Seq(astForDoBlock(ctx.doBlock())) argsAsts ++ doBlockAst case ctx: RubyParser.ArgsAndDoBlockAndMethodIdCommandWithDoBlockContext => val methodIdAsts = astForMethodIdentifierContext(ctx.methodIdentifier(), text(ctx)) methodIdAsts.headOption.flatMap(_.root) match - case Some(methodIdRoot: NewCall) if methodIdRoot.name == Defines.DEFINE_METHOD => + case Some(methodIdRoot: NewCall) if methodIdRoot.name == "define_method" => ctx.argumentsWithoutParentheses.arguments.argument.asScala.headOption .map { methodArg => // TODO: methodArg will name the method, but this could be an identifier or even a string concatenation // which is not assumed below - val methodName = stripQuotes(methodArg.getText).stripPrefix(":") - astForDoBlock(ctx.doBlock(), Option(methodName)) + val methodName = stripQuotes(methodArg.getText) + Seq(astForDoBlock(ctx.doBlock(), Option(methodName))) } .getOrElse(Seq.empty) case _ => val argsAsts = astForArguments(ctx.argumentsWithoutParentheses().arguments()) - val doBlockAsts = astForDoBlock(ctx.doBlock()) + val doBlockAsts = Seq(astForDoBlock(ctx.doBlock())) methodIdAsts ++ argsAsts ++ doBlockAsts case ctx: RubyParser.PrimaryMethodArgsDoBlockCommandWithDoBlockContext => val argsAsts = astForArguments(ctx.argumentsWithoutParentheses().arguments()) - val doBlockAsts = astForDoBlock(ctx.doBlock()) + val doBlockAsts = Seq(astForDoBlock(ctx.doBlock())) val methodNameAsts = astForMethodNameContext(ctx.methodName()) val primaryAsts = astForPrimaryContext(ctx.primary()) primaryAsts ++ methodNameAsts ++ argsAsts ++ doBlockAsts @@ -745,8 +741,8 @@ class AstCreator( val operatorText = getOperatorName(terminalNode.getSymbol) val expressions = ctx.expression.asScala - val (callArgs, otherAst) = - (Option(ctx.keyword) match { + val callArgs = + Option(ctx.keyword) match { case Some(ctxKeyword) => val expr1Ast = astForCallNode(ctx, ctxKeyword.getText) val expr2Asts = astForExpressionContext(expressions.head) @@ -755,10 +751,10 @@ class AstCreator( val expr1Asts = astForExpressionContext(expressions.head) val expr2Asts = expressions.lift(1).flatMap(astForExpressionContext) expr1Asts ++ expr2Asts - }).partitionExprAst + } val callNode = createOpCall(terminalNode, operatorText, text(ctx)) - otherAst ++ Seq(callAst(callNode, callArgs)) + Seq(callAst(callNode, callArgs)) } private def astForAssociationsContext(ctx: AssociationsContext): Seq[Ast] = { diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreatorHelper.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreatorHelper.scala index 231bb6d6ad4b..42a056f81609 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreatorHelper.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreatorHelper.scala @@ -209,37 +209,6 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As } - implicit class AstIterExt(a: Iterable[Ast]) { - - /** Partitions a sequence of Ast objects into those with roots start with an expression, and those that don't. - * @return - * a tuple of sequences where the first has expression roots and the second does not. - */ - def partitionExprAst: (Seq[Ast], Seq[Ast]) = { - val (as, bs) = a.partition(_.root match - case Some(_: ExpressionNew) => true - case _ => false - ) - (as.toSeq, bs.toSeq) - } - - /** Partitions a sequence of Ast objects into the boilerplate for do-block functions and the call node at the end. - * - * @return - * a tuple where the first element is the closure boilerplate and the latter is the last expression. - */ - def partitionClosureFromExpr: (Seq[Ast], Option[Ast]) = { - val (as, bs) = a.partition(_.root match - case Some(_: NewMethod) => true - case Some(_: NewTypeDecl) => true - case Some(x: NewCall) if x.name.startsWith(Operators.assignment) => true - case _ => false - ) - (as.toSeq, bs.lastOption) - } - - } - } object RubyOperators { diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala index 84d9ed6baf2b..717235dc7fa1 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala @@ -132,55 +132,4 @@ trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMo tryCatchAst(tryNode, tryBodyAst, catchAsts, finallyAst) } - protected def astForUntilExpression(ctx: UntilExpressionContext): Seq[Ast] = { - val (boilerplate, exprAst) = astForExpressionOrCommand(ctx.expressionOrCommand()).partitionClosureFromExpr - val bodyAst = astForCompoundStatement(ctx.doClause().compoundStatement()) - // TODO: testAst should be negated if it's going to be modelled as a while stmt. - boilerplate :+ whileAst(exprAst, bodyAst, Some(text(ctx)), line(ctx), column(ctx)) - } - - protected def astForForExpression(ctx: ForExpressionContext): Seq[Ast] = { - val forVarAst = astForForVariableContext(ctx.forVariable()) - val (boilerplate, forExprAst) = astForExpressionOrCommand(ctx.expressionOrCommand()).partitionClosureFromExpr - val forBodyAst = astForCompoundStatement(ctx.doClause().compoundStatement()) - // TODO: for X in Y is not properly modelled by while Y - val forRootAst = whileAst(forExprAst, forBodyAst, Some(text(ctx)), line(ctx), column(ctx)) - boilerplate :+ forVarAst.headOption.map(forRootAst.withChild).getOrElse(forRootAst) - } - - private def astForForVariableContext(ctx: ForVariableContext): Seq[Ast] = { - if (ctx.singleLeftHandSide() != null) { - astForSingleLeftHandSideContext(ctx.singleLeftHandSide()) - } else if (ctx.multipleLeftHandSide() != null) { - astForMultipleLeftHandSideContext(ctx.multipleLeftHandSide()) - } else { - Seq(Ast()) - } - } - - protected def astForWhileExpression(ctx: WhileExpressionContext): Seq[Ast] = { - val (boilerplate, exprAst) = astForExpressionOrCommand(ctx.expressionOrCommand()).partitionClosureFromExpr - val bodyAst = astForCompoundStatement(ctx.doClause().compoundStatement()) - boilerplate :+ whileAst(exprAst, bodyAst, Some(text(ctx)), line(ctx), column(ctx)) - } - - protected def astForIfExpression(ctx: IfExpressionContext): Seq[Ast] = { - val (boilerplate, exprAst) = astForExpressionOrCommand(ctx.expressionOrCommand()).partitionClosureFromExpr - val thenAst = astForCompoundStatement(ctx.thenClause().compoundStatement()) - val elsifAsts = Option(ctx.elsifClause).map(_.asScala).getOrElse(Seq()).flatMap(astForElsifClause) - val elseAst = Option(ctx.elseClause()).map(ctx => astForCompoundStatement(ctx.compoundStatement())).getOrElse(Seq()) - val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, text(ctx)) - boilerplate :+ controlStructureAst(ifNode, exprAst) - .withChildren(thenAst) - .withChildren(elsifAsts.toSeq) - .withChildren(elseAst) - } - - private def astForElsifClause(ctx: ElsifClauseContext): Seq[Ast] = { - val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, text(ctx)) - val (boilerplate, exprAst) = astForExpressionOrCommand(ctx.expressionOrCommand()).partitionClosureFromExpr - val bodyAst = astForCompoundStatement(ctx.thenClause().compoundStatement()) - boilerplate :+ controlStructureAst(ifNode, exprAst, bodyAst) - } - } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala index 87d670eeb3c3..44d2c8bc4e09 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -4,7 +4,7 @@ import io.joern.rubysrc2cpg.parser.RubyParser.* import io.joern.rubysrc2cpg.passes.Defines import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType import io.joern.x2cpg.{Ast, ValidationMode} -import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.nodes.{AstNodeNew, NewCall, NewIdentifier} import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, ModifierTypes, Operators} import org.antlr.v4.runtime.ParserRuleContext import org.slf4j.LoggerFactory @@ -17,45 +17,45 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { private val logger = LoggerFactory.getLogger(this.getClass) protected var lastModifier: Option[String] = None - protected def astForPowerExpression(ctx: PowerExpressionContext): Seq[Ast] = + protected def astForPowerExpression(ctx: PowerExpressionContext): Ast = astForBinaryOperatorExpression(ctx, Operators.exponentiation, ctx.expression().asScala) - protected def astForOrExpression(ctx: OperatorOrExpressionContext): Seq[Ast] = + protected def astForOrExpression(ctx: OperatorOrExpressionContext): Ast = astForBinaryOperatorExpression(ctx, Operators.or, ctx.expression().asScala) - protected def astForAndExpression(ctx: OperatorAndExpressionContext): Seq[Ast] = + protected def astForAndExpression(ctx: OperatorAndExpressionContext): Ast = astForBinaryOperatorExpression(ctx, Operators.and, ctx.expression().asScala) - protected def astForUnaryExpression(ctx: UnaryExpressionContext): Seq[Ast] = ctx.op.getType match { + protected def astForUnaryExpression(ctx: UnaryExpressionContext): Ast = ctx.op.getType match { case TILDE => astForBinaryOperatorExpression(ctx, Operators.not, Seq(ctx.expression())) case PLUS => astForBinaryOperatorExpression(ctx, Operators.plus, Seq(ctx.expression())) case EMARK => astForBinaryOperatorExpression(ctx, Operators.not, Seq(ctx.expression())) } - protected def astForUnaryMinusExpression(ctx: UnaryMinusExpressionContext): Seq[Ast] = + protected def astForUnaryMinusExpression(ctx: UnaryMinusExpressionContext): Ast = astForBinaryOperatorExpression(ctx, Operators.minus, Seq(ctx.expression())) - protected def astForAdditiveExpression(ctx: AdditiveExpressionContext): Seq[Ast] = ctx.op.getType match { + protected def astForAdditiveExpression(ctx: AdditiveExpressionContext): Ast = ctx.op.getType match { case PLUS => astForBinaryOperatorExpression(ctx, Operators.addition, ctx.expression().asScala) case MINUS => astForBinaryOperatorExpression(ctx, Operators.subtraction, ctx.expression().asScala) } - protected def astForMultiplicativeExpression(ctx: MultiplicativeExpressionContext): Seq[Ast] = ctx.op.getType match { + protected def astForMultiplicativeExpression(ctx: MultiplicativeExpressionContext): Ast = ctx.op.getType match { case STAR => astForMultiplicativeStarExpression(ctx) case SLASH => astForMultiplicativeSlashExpression(ctx) case PERCENT => astForMultiplicativePercentExpression(ctx) } - protected def astForMultiplicativeStarExpression(ctx: MultiplicativeExpressionContext): Seq[Ast] = + protected def astForMultiplicativeStarExpression(ctx: MultiplicativeExpressionContext): Ast = astForBinaryOperatorExpression(ctx, Operators.multiplication, ctx.expression().asScala) - protected def astForMultiplicativeSlashExpression(ctx: MultiplicativeExpressionContext): Seq[Ast] = + protected def astForMultiplicativeSlashExpression(ctx: MultiplicativeExpressionContext): Ast = astForBinaryOperatorExpression(ctx, Operators.division, ctx.expression().asScala) - protected def astForMultiplicativePercentExpression(ctx: MultiplicativeExpressionContext): Seq[Ast] = + protected def astForMultiplicativePercentExpression(ctx: MultiplicativeExpressionContext): Ast = astForBinaryOperatorExpression(ctx, Operators.modulo, ctx.expression().asScala) - protected def astForEqualityExpression(ctx: EqualityExpressionContext): Seq[Ast] = ctx.op.getType match { + protected def astForEqualityExpression(ctx: EqualityExpressionContext): Ast = ctx.op.getType match { case LTEQGT => astForBinaryOperatorExpression(ctx, Operators.compare, ctx.expression().asScala) case EQ2 => astForBinaryOperatorExpression(ctx, Operators.equals, ctx.expression().asScala) case EQ3 => astForBinaryOperatorExpression(ctx, Operators.is, ctx.expression().asScala) @@ -64,22 +64,22 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { case EMARKTILDE => astForBinaryOperatorExpression(ctx, RubyOperators.notPatternMatch, ctx.expression().asScala) } - protected def astForRelationalExpression(ctx: RelationalExpressionContext): Seq[Ast] = ctx.op.getType match { + protected def astForRelationalExpression(ctx: RelationalExpressionContext): Ast = ctx.op.getType match { case GT => astForBinaryOperatorExpression(ctx, Operators.greaterThan, ctx.expression().asScala) case GTEQ => astForBinaryOperatorExpression(ctx, Operators.greaterEqualsThan, ctx.expression().asScala) case LT => astForBinaryOperatorExpression(ctx, Operators.lessThan, ctx.expression().asScala) case LTEQ => astForBinaryOperatorExpression(ctx, Operators.lessEqualsThan, ctx.expression().asScala) } - protected def astForBitwiseOrExpression(ctx: BitwiseOrExpressionContext): Seq[Ast] = ctx.op.getType match { + protected def astForBitwiseOrExpression(ctx: BitwiseOrExpressionContext): Ast = ctx.op.getType match { case BAR => astForBinaryOperatorExpression(ctx, Operators.logicalOr, ctx.expression().asScala) case CARET => astForBinaryOperatorExpression(ctx, Operators.logicalOr, ctx.expression().asScala) } - protected def astForBitwiseAndExpression(ctx: BitwiseAndExpressionContext): Seq[Ast] = + protected def astForBitwiseAndExpression(ctx: BitwiseAndExpressionContext): Ast = astForBinaryOperatorExpression(ctx, Operators.logicalAnd, ctx.expression().asScala) - protected def astForBitwiseShiftExpression(ctx: BitwiseShiftExpressionContext): Seq[Ast] = ctx.op.getType match { + protected def astForBitwiseShiftExpression(ctx: BitwiseShiftExpressionContext): Ast = ctx.op.getType match { case LT2 => astForBinaryOperatorExpression(ctx, Operators.shiftLeft, ctx.expression().asScala) case GT2 => astForBinaryOperatorExpression(ctx, Operators.logicalShiftRight, ctx.expression().asScala) } @@ -88,22 +88,20 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { ctx: ParserRuleContext, name: String, arguments: Iterable[ExpressionContext] - ): Seq[Ast] = { - val (argsAst, otherAst) = arguments - .flatMap(astForExpressionContext) - .partitionExprAst - val call = callNode(ctx, text(ctx), name, name, DispatchTypes.STATIC_DISPATCH) - otherAst :+ callAst(call, argsAst.toList) + ): Ast = { + val argsAst = arguments.flatMap(astForExpressionContext) + val call = callNode(ctx, text(ctx), name, name, DispatchTypes.STATIC_DISPATCH) + callAst(call, argsAst.toList) } - protected def astForIsDefinedExpression(ctx: IsDefinedExpressionContext): Seq[Ast] = + protected def astForIsDefinedExpression(ctx: IsDefinedExpressionContext): Ast = astForBinaryOperatorExpression(ctx, RubyOperators.defined, Seq(ctx.expression())) // TODO: Maybe merge (in RubyParser.g4) isDefinedExpression with isDefinedPrimaryExpression? - protected def astForIsDefinedPrimaryExpression(ctx: IsDefinedPrimaryContext): Seq[Ast] = { - val (argsAst, otherAst) = astForExpressionOrCommand(ctx.expressionOrCommand()).partitionExprAst - val call = callNode(ctx, text(ctx), RubyOperators.defined, RubyOperators.defined, DispatchTypes.STATIC_DISPATCH) - otherAst :+ callAst(call, argsAst.toList) + protected def astForIsDefinedPrimaryExpression(ctx: IsDefinedPrimaryContext): Ast = { + val argsAst = astForExpressionOrCommand(ctx.expressionOrCommand()) + val call = callNode(ctx, text(ctx), RubyOperators.defined, RubyOperators.defined, DispatchTypes.STATIC_DISPATCH) + callAst(call, argsAst.toList) } protected def astForLiteralPrimaryExpression(ctx: LiteralPrimaryContext): Seq[Ast] = ctx.literal() match { @@ -158,10 +156,10 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { case ctx: VariableIdentifierOnlySingleLeftHandSideContext => Seq(astForVariableIdentifierHelper(ctx.variableIdentifier, true)) case ctx: PrimaryInsideBracketsSingleLeftHandSideContext => - val primaryAsts = astForPrimaryContext(ctx.primary) - val (argsAsts, otherAst) = astForArguments(ctx.arguments).partitionExprAst - val indexAccessCall = createOpCall(ctx.LBRACK, Operators.indexAccess, text(ctx)) - otherAst :+ callAst(indexAccessCall, primaryAsts ++ argsAsts) + val primaryAsts = astForPrimaryContext(ctx.primary) + val argsAsts = astForArguments(ctx.arguments) + val indexAccessCall = createOpCall(ctx.LBRACK, Operators.indexAccess, text(ctx)) + Seq(callAst(indexAccessCall, primaryAsts ++ argsAsts)) case ctx: XdotySingleLeftHandSideContext => // TODO handle obj.foo=arg being interpreted as obj.foo(arg) here. val xAsts = astForPrimaryContext(ctx.primary) @@ -198,21 +196,12 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { .lineNumber(ctx.op.getLine) .columnNumber(ctx.op.getCharPositionInLine) if (leftAst.size == 1 && rightAst.size > 1) { - if (rightAst.headOption.flatMap(_.root).exists(_.isInstanceOf[NewMethod])) { - /* - * Here we expect to be assigned the result of some dynamically defined function's application to some variable - */ - val lastAst = rightAst.takeRight(1) - rightAst.filterNot(_ == lastAst.head) ++ Seq(callAst(opCallNode, leftAst ++ lastAst)) - } else { - /* - * This is multiple RHS packed into a single LHS. That is, packing left hand side. - * This is as good as multiple RHS packed into an array and put into a single LHS - */ - val packedRHS = getPackedRHS(rightAst, wrapInBrackets = true) - val (argsAst, otherAst) = (leftAst ++ packedRHS).partitionExprAst - otherAst :+ callAst(opCallNode, argsAst) - } + /* + * This is multiple RHS packed into a single LHS. That is, packing left hand side. + * This is as good as multiple RHS packed into an array and put into a single LHS + */ + val packedRHS = getPackedRHS(rightAst, wrapInBrackets = true) + Seq(callAst(opCallNode, leftAst ++ packedRHS)) } else { Seq(callAst(opCallNode, leftAst ++ rightAst)) } @@ -261,9 +250,8 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { case Some(node) if node.name == "Array" => Operators.arrayInitializer case _ => Operators.indexAccess - val callNode = createOpCall(ctx.LBRACK, operator, text(ctx)) - val (argsAst, otherAst) = (lhsExpressionAst ++ rhsExpressionAst).partitionExprAst - otherAst :+ callAst(callNode, argsAst) + val callNode = createOpCall(ctx.LBRACK, operator, text(ctx)) + Seq(callAst(callNode, lhsExpressionAst ++ rhsExpressionAst)) } @@ -279,8 +267,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { .typeFullName(Defines.Any) .dispatchType(DispatchTypes.STATIC_DISPATCH) .code(if (wrapInBrackets) s"[$code]" else code) - val (argsAst, otherAst) = astsToConcat.partitionExprAst - otherAst :+ callAst(callNode, argsAst) + Seq(callAst(callNode, astsToConcat)) } def astForStringInterpolationContext(ctx: InterpolatedStringExpressionContext): Seq[Ast] = { @@ -367,9 +354,9 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } def astForRangeExpressionContext(ctx: RangeExpressionContext): Seq[Ast] = - astForBinaryOperatorExpression(ctx, Operators.range, ctx.expression().asScala) + Seq(astForBinaryOperatorExpression(ctx, Operators.range, ctx.expression().asScala)) - protected def astForSuperExpression(ctx: SuperExpressionPrimaryContext): Seq[Ast] = { + protected def astForSuperExpression(ctx: SuperExpressionPrimaryContext): Ast = { val argsAst = Option(ctx.argumentsWithParentheses()) match case Some(ctxArgs) => astForArgumentsWithParenthesesContext(ctxArgs) case None => Seq() @@ -379,20 +366,67 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { // TODO: Handle the optional block. // NOTE: `super` is quite complicated semantically speaking. We'll need // to revisit how to represent them. - protected def astForSuperCall(ctx: ParserRuleContext, arguments: Seq[Ast]): Seq[Ast] = { + protected def astForSuperCall(ctx: ParserRuleContext, arguments: Seq[Ast]): Ast = { val call = callNode(ctx, text(ctx), RubyOperators.superKeyword, RubyOperators.superKeyword, DispatchTypes.STATIC_DISPATCH) - - val (argsAst, otherAst) = arguments.partitionExprAst - otherAst :+ callAst(call, argsAst) + callAst(call, arguments.toList) } - protected def astForYieldCall(ctx: ParserRuleContext, argumentsCtx: Option[ArgumentsContext]): Seq[Ast] = { + protected def astForYieldCall(ctx: ParserRuleContext, argumentsCtx: Option[ArgumentsContext]): Ast = { val args = argumentsCtx.map(astForArguments).getOrElse(Seq()) - val call = - callNode(ctx, text(ctx), Defines.UNRESOLVED_YIELD, Defines.UNRESOLVED_YIELD, DispatchTypes.STATIC_DISPATCH) - val (argsAst, otherAst) = args.partitionExprAst - otherAst :+ callAst(call, argsAst) + val call = callNode(ctx, text(ctx), UNRESOLVED_YIELD, UNRESOLVED_YIELD, DispatchTypes.STATIC_DISPATCH) + callAst(call, args) + } + + protected def astForUntilExpression(ctx: UntilExpressionContext): Ast = { + val testAst = astForExpressionOrCommand(ctx.expressionOrCommand()).headOption + val bodyAst = astForCompoundStatement(ctx.doClause().compoundStatement()) + // TODO: testAst should be negated if it's going to be modelled as a while stmt. + whileAst(testAst, bodyAst, Some(text(ctx)), line(ctx), column(ctx)) + } + + protected def astForForExpression(ctx: ForExpressionContext): Ast = { + val forVarAst = astForForVariableContext(ctx.forVariable()) + val forExprAst = astForExpressionOrCommand(ctx.expressionOrCommand()) + val forBodyAst = astForCompoundStatement(ctx.doClause().compoundStatement()) + // TODO: for X in Y is not properly modelled by while Y + val forRootAst = whileAst(forExprAst.headOption, forBodyAst, Some(text(ctx)), line(ctx), column(ctx)) + forVarAst.headOption.map(forRootAst.withChild).getOrElse(forRootAst) + } + + private def astForForVariableContext(ctx: ForVariableContext): Seq[Ast] = { + if (ctx.singleLeftHandSide() != null) { + astForSingleLeftHandSideContext(ctx.singleLeftHandSide()) + } else if (ctx.multipleLeftHandSide() != null) { + astForMultipleLeftHandSideContext(ctx.multipleLeftHandSide()) + } else { + Seq(Ast()) + } + } + + protected def astForWhileExpression(ctx: WhileExpressionContext): Ast = { + val testAst = astForExpressionOrCommand(ctx.expressionOrCommand()) + val bodyAst = astForCompoundStatement(ctx.doClause().compoundStatement()) + whileAst(testAst.headOption, bodyAst, Some(text(ctx)), line(ctx), column(ctx)) + } + + protected def astForIfExpression(ctx: IfExpressionContext): Ast = { + val testAst = astForExpressionOrCommand(ctx.expressionOrCommand()) + val thenAst = astForCompoundStatement(ctx.thenClause().compoundStatement()) + val elsifAsts = Option(ctx.elsifClause).map(_.asScala).getOrElse(Seq()).map(astForElsifClause) + val elseAst = Option(ctx.elseClause()).map(ctx => astForCompoundStatement(ctx.compoundStatement())).getOrElse(Seq()) + val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, text(ctx)) + controlStructureAst(ifNode, testAst.headOption) + .withChildren(thenAst) + .withChildren(elsifAsts.toSeq) + .withChildren(elseAst) + } + + private def astForElsifClause(ctx: ElsifClauseContext): Ast = { + val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, text(ctx)) + val testAst = astForExpressionOrCommand(ctx.expressionOrCommand()) + val bodyAst = astForCompoundStatement(ctx.thenClause().compoundStatement()) + controlStructureAst(ifNode, testAst.headOption, bodyAst) } protected def astForVariableReference(ctx: VariableReferenceContext): Ast = ctx match { @@ -450,13 +484,13 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } } - protected def astForUnlessExpression(ctx: UnlessExpressionContext): Seq[Ast] = { - val (exprAst, otherAst) = astForExpressionOrCommand(ctx.expressionOrCommand()).partitionExprAst - val thenAst = astForCompoundStatement(ctx.thenClause().compoundStatement()) + protected def astForUnlessExpression(ctx: UnlessExpressionContext): Ast = { + val testAst = astForExpressionOrCommand(ctx.expressionOrCommand()) + val thenAst = astForCompoundStatement(ctx.thenClause().compoundStatement()) val elseAst = Option(ctx.elseClause()).map(_.compoundStatement()).map(st => astForCompoundStatement(st)).getOrElse(Seq()) val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, text(ctx)) - otherAst :+ controlStructureAst(ifNode, exprAst.headOption, thenAst ++ elseAst) + controlStructureAst(ifNode, testAst.headOption, thenAst ++ elseAst) } protected def astForQuotedStringExpression(ctx: QuotedStringExpressionContext): Seq[Ast] = ctx match diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala index 137c87a20474..5e595a231378 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala @@ -19,6 +19,17 @@ trait AstForFunctionsCreator(packageContext: PackageContext)(implicit withSchema private val logger = LoggerFactory.getLogger(getClass) + /* + *Fake methods created from yield blocks and their yield calls will have this suffix in their names + */ + protected val YIELD_SUFFIX = "_yield" + + /* + * This is used to mark call nodes created due to yield calls. This is set in their names at creation. + * The appropriate name wrt the names of their actual methods is set later in them. + */ + protected val UNRESOLVED_YIELD = "unresolved_yield" + /* * Stack of variable identifiers incorrectly identified as method identifiers * Each AST contains exactly one call or identifier node @@ -90,12 +101,12 @@ trait AstForFunctionsCreator(packageContext: PackageContext)(implicit withSchema // process yield calls. astBody - .flatMap(_.nodes.collect { case x: NewCall => x }.filter(_.name == Defines.UNRESOLVED_YIELD)) + .flatMap(_.nodes.collect { case x: NewCall => x }.filter(_.name == UNRESOLVED_YIELD)) .foreach { yieldCallNode => val name = newMethodNode.name val methodFullName = classStack.reverse :+ callNode.name mkString pathSep - yieldCallNode.name(name + Defines.YIELD_SUFFIX) - yieldCallNode.methodFullName(methodFullName + Defines.YIELD_SUFFIX) + yieldCallNode.name(name + YIELD_SUFFIX) + yieldCallNode.methodFullName(methodFullName + YIELD_SUFFIX) methodNamesWithYield.add(newMethodNode.name) /* * These are calls to the yield block of this method. @@ -337,8 +348,6 @@ trait AstForFunctionsCreator(packageContext: PackageContext)(implicit withSchema } } - /** Creates a method, methodRef, and type decl binding for this block method. - */ protected def astForBlockFunction( ctxStmt: StatementsContext, ctxParam: Option[BlockParameterContext], @@ -354,17 +363,17 @@ trait AstForFunctionsCreator(packageContext: PackageContext)(implicit withSchema val methodFullName = classStack.reverse :+ blockMethodName mkString pathSep val newMethodNode = methodNode(ctxStmt, blockMethodName, text(ctxStmt), methodFullName, None, relativeFilename) .lineNumber(lineStart) - .lineNumberEnd(lineEnd + 1) // this requires a +1 due to the `end` token + .lineNumberEnd(lineEnd) .columnNumber(colStart) .columnNumberEnd(colEnd) scope.pushNewScope(newMethodNode) val astMethodParam = ctxParam.map(astForBlockParameterContext).getOrElse(Seq()) + val publicModifier = NewModifier().modifierType(ModifierTypes.PUBLIC) val paramSeq = astMethodParam.flatMap(_.root).map { /* In majority of cases, node will be an identifier */ case identifierNode: NewIdentifier => - scope.removeFromScope(identifierNode) val param = NewMethodParameterIn() .name(identifierNode.name) .code(identifierNode.code) @@ -380,28 +389,19 @@ trait AstForFunctionsCreator(packageContext: PackageContext)(implicit withSchema case _ => Ast() } - val paramNames = paramSeq + val paramNames = (astMethodParam ++ paramSeq) .flatMap(_.root) .collect { case x: NewMethodParameterIn => x.name case x: NewIdentifier => x.name } .toSet - - val astBody = astForStatements(ctxStmt, true) - val locals = scope.createAndLinkLocalNodes(diffGraph, paramNames).map(Ast.apply) - paramSeq.flatMap(_.root).collect { case x: NewMethodParameterIn => x }.foreach(scope.linkParamNode(diffGraph, _)) + val astBody = astForStatements(ctxStmt, true) + val locals = scope.createAndLinkLocalNodes(diffGraph, paramNames).map(Ast.apply) val methodRetNode = NewMethodReturn().typeFullName(Defines.Any) scope.popScope() - // Create a method ref & type binding for this node - val methodRefAssignmentAst = methodRefAssignmentFromMethod(newMethodNode) - val binding = NewBinding() - .name(blockMethodName) - .methodFullName(methodFullName) - val typeDecl = typeDeclFromMethod(newMethodNode) - Seq( methodAst( newMethodNode, @@ -409,69 +409,8 @@ trait AstForFunctionsCreator(packageContext: PackageContext)(implicit withSchema blockAst(blockNode(ctxStmt), locals ++ astBody.toList), methodRetNode, Seq(publicModifier) - ), - methodRefAssignmentAst, - Ast(typeDecl).withBindsEdge(typeDecl, binding).withRefEdge(binding, newMethodNode) + ) ) } - private def methodPositionWithFallback( - method: NewMethod, - lineNum: Option[Integer] = None, - colNum: Option[Integer] = None - ): (Option[Integer], Option[Integer]) = { - val lineNumber = lineNum match - case Some(x) => Some(x) - case None if method.lineNumber.isDefined => method.lineNumber - case _ => None - val columnNumber = colNum match - case Some(x) => Some(x) - case None if method.columnNumber.isDefined => method.columnNumber - case _ => None - - (lineNumber, columnNumber) - } - - /** Creates a method ref node assigned to an identifier of the same name from a method and adds the identifier to the - * scope. - */ - protected def methodRefAssignmentFromMethod( - method: NewMethod, - lineNum: Option[Integer] = None, - colNum: Option[Integer] = None - ): Ast = { - val (lineNumber, columnNumber) = methodPositionWithFallback(method, lineNum, colNum) - val methodRefNode = NewMethodRef() - .code("def " + method.name + "(...)") - .methodFullName(method.fullName) - .typeFullName(method.fullName) - .lineNumber(lineNumber) - .columnNumber(columnNumber) - - val methodNameIdentifier = NewIdentifier() - .code(method.name) - .name(method.name) - .typeFullName(Defines.Any) - .lineNumber(lineNumber) - .columnNumber(columnNumber) - scope.addToScope(method.name, methodNameIdentifier) - val methodRefAssignmentAst = - astForAssignment(methodNameIdentifier, methodRefNode, lineNumber, columnNumber) - methodRefAssignmentAst - } - - protected def typeDeclFromMethod( - method: NewMethod, - lineNum: Option[Integer] = None, - colNum: Option[Integer] = None - ): NewTypeDecl = { - val (lineNumber, columnNumber) = methodPositionWithFallback(method, lineNum, colNum) - NewTypeDecl() - .code("def " + method.name + "(...)") - .name(method.name) - .fullName(method.fullName) - .lineNumber(lineNumber) - .columnNumber(columnNumber) - } - } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala index dcd1fcd99d4a..432b4944f4af 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala @@ -208,54 +208,53 @@ trait AstForStatementsCreator(filename: String)(implicit withSchemaValidation: V // TODO: return Ast instead of Seq[Ast] protected def astForExpressionOrCommand(ctx: ExpressionOrCommandContext): Seq[Ast] = ctx match { case ctx: InvocationExpressionOrCommandContext => astForInvocationExpressionOrCommandContext(ctx) - case ctx: NotExpressionOrCommandContext => astForNotKeywordExpressionOrCommand(ctx) - case ctx: OrAndExpressionOrCommandContext => astForOrAndExpressionOrCommand(ctx) + case ctx: NotExpressionOrCommandContext => Seq(astForNotKeywordExpressionOrCommand(ctx)) + case ctx: OrAndExpressionOrCommandContext => Seq(astForOrAndExpressionOrCommand(ctx)) case ctx: ExpressionExpressionOrCommandContext => astForExpressionContext(ctx.expression()) case _ => logger.error(s"astForExpressionOrCommand() $relativeFilename, ${text(ctx)} All contexts mismatched.") Seq(Ast()) } - private def astForNotKeywordExpressionOrCommand(ctx: NotExpressionOrCommandContext): Seq[Ast] = { - val exprOrCommandAst = astForExpressionOrCommand(ctx.expressionOrCommand()) - val call = callNode(ctx, text(ctx), Operators.not, Operators.not, DispatchTypes.STATIC_DISPATCH) - val (argsAst, otherAst) = exprOrCommandAst.partitionExprAst - otherAst :+ callAst(call, argsAst) + private def astForNotKeywordExpressionOrCommand(ctx: NotExpressionOrCommandContext): Ast = { + val exprOrCommandAst = astForExpressionOrCommand(ctx.expressionOrCommand()) + val call = callNode(ctx, text(ctx), Operators.not, Operators.not, DispatchTypes.STATIC_DISPATCH) + callAst(call, exprOrCommandAst) } - private def astForOrAndExpressionOrCommand(ctx: OrAndExpressionOrCommandContext): Seq[Ast] = ctx.op.getType match { + private def astForOrAndExpressionOrCommand(ctx: OrAndExpressionOrCommandContext): Ast = ctx.op.getType match { case OR => astForOrExpressionOrCommand(ctx) case AND => astForAndExpressionOrCommand(ctx) } - private def astForOrExpressionOrCommand(ctx: OrAndExpressionOrCommandContext): Seq[Ast] = { - val args = ctx.expressionOrCommand().asScala.flatMap(astForExpressionOrCommand) - val call = callNode(ctx, text(ctx), Operators.or, Operators.or, DispatchTypes.STATIC_DISPATCH) - val (argsAst, otherAst) = args.partitionExprAst - otherAst :+ callAst(call, argsAst) + private def astForOrExpressionOrCommand(ctx: OrAndExpressionOrCommandContext): Ast = { + val argsAst = ctx.expressionOrCommand().asScala.flatMap(astForExpressionOrCommand) + val call = callNode(ctx, text(ctx), Operators.or, Operators.or, DispatchTypes.STATIC_DISPATCH) + callAst(call, argsAst.toList) } - private def astForAndExpressionOrCommand(ctx: OrAndExpressionOrCommandContext): Seq[Ast] = { - val args = ctx.expressionOrCommand().asScala.flatMap(astForExpressionOrCommand) - val call = callNode(ctx, text(ctx), Operators.and, Operators.and, DispatchTypes.STATIC_DISPATCH) - val (argsAst, otherAst) = args.partitionExprAst - otherAst :+ callAst(call, argsAst) + private def astForAndExpressionOrCommand(ctx: OrAndExpressionOrCommandContext): Ast = { + val argsAst = ctx.expressionOrCommand().asScala.flatMap(astForExpressionOrCommand) + val call = callNode(ctx, text(ctx), Operators.and, Operators.and, DispatchTypes.STATIC_DISPATCH) + callAst(call, argsAst.toList) } - private def astForSuperCommand(ctx: SuperCommandContext): Seq[Ast] = + private def astForSuperCommand(ctx: SuperCommandContext): Ast = astForSuperCall(ctx, astForArguments(ctx.argumentsWithoutParentheses().arguments())) - private def astForYieldCommand(ctx: YieldCommandContext): Seq[Ast] = + private def astForYieldCommand(ctx: YieldCommandContext): Ast = astForYieldCall(ctx, Option(ctx.argumentsWithoutParentheses().arguments())) private def astForSimpleMethodCommand(ctx: SimpleMethodCommandContext): Seq[Ast] = { val methodIdentifierAsts = astForMethodIdentifierContext(ctx.methodIdentifier(), text(ctx)) methodIdentifierAsts.headOption.foreach(methodNameAsIdentifierStack.push) - val args = astForArguments(ctx.argumentsWithoutParentheses().arguments()) - val (argsAst, _) = args.partitionExprAst + val argsAsts = astForArguments(ctx.argumentsWithoutParentheses().arguments()) + + /* get args without the method def in it */ + val argAstsWithoutMethods = argsAsts.filterNot(_.root.exists(_.isInstanceOf[NewMethod])) /* isolate methods from the original args and create identifier ASTs from it */ - val methodDefAsts = args.filter(_.root.exists(_.isInstanceOf[NewMethod])) + val methodDefAsts = argsAsts.filter(_.root.exists(_.isInstanceOf[NewMethod])) val methodToIdentifierAsts = methodDefAsts.flatMap { _.nodes.collectFirst { case methodNode: NewMethod => Ast( @@ -278,17 +277,17 @@ trait AstForStatementsCreator(filename: String)(implicit withSchemaValidation: V if (callNodes.size == 1) { val callNode = callNodes.head if (callNode.name == "require" || callNode.name == "load") { - resolveRequireOrLoadPath(args, callNode) + resolveRequireOrLoadPath(argsAsts, callNode) } else if (callNode.name == "require_relative") { - resolveRelativePath(filename, args, callNode) + resolveRelativePath(filename, argsAsts, callNode) } else if (prefixMethods.contains(callNode.name)) { /* we remove the method definition AST from argument and add its corresponding identifier form */ - Seq(callAst(callNode, argsAst ++ methodToIdentifierAsts)) + Seq(callAst(callNode, argAstsWithoutMethods ++ methodToIdentifierAsts)) } else { - Seq(callAst(callNode, argsAst)) + Seq(callAst(callNode, argsAsts)) } } else { - args + argsAsts } } @@ -326,8 +325,8 @@ trait AstForStatementsCreator(filename: String)(implicit withSchemaValidation: V methodRefNode(ctx, s"def ${methodNode.name}(...)", methodNode.fullName, Defines.Any) protected def astForCommand(ctx: CommandContext): Seq[Ast] = ctx match { - case ctx: YieldCommandContext => astForYieldCommand(ctx) - case ctx: SuperCommandContext => astForSuperCommand(ctx) + case ctx: YieldCommandContext => Seq(astForYieldCommand(ctx)) + case ctx: SuperCommandContext => Seq(astForSuperCommand(ctx)) case ctx: SimpleMethodCommandContext => astForSimpleMethodCommand(ctx) case ctx: MemberAccessCommandContext => astForMemberAccessCommand(ctx) } @@ -370,7 +369,7 @@ trait AstForStatementsCreator(filename: String)(implicit withSchemaValidation: V } } - protected def astForBlock(ctx: BlockContext, blockMethodName: Option[String] = None): Seq[Ast] = ctx match + protected def astForBlock(ctx: BlockContext, blockMethodName: Option[String] = None): Ast = ctx match case ctx: DoBlockBlockContext => astForDoBlock(ctx.doBlock(), blockMethodName) case ctx: BraceBlockBlockContext => astForBraceBlock(ctx.braceBlock(), blockMethodName) @@ -379,7 +378,7 @@ trait AstForStatementsCreator(filename: String)(implicit withSchemaValidation: V blockParamCtx: Option[BlockParameterContext], compoundStmtCtx: CompoundStatementContext, blockMethodName: Option[String] = None - ): Seq[Ast] = { + ) = { blockMethodName match { case Some(blockMethodName) => astForBlockFunction( @@ -390,20 +389,20 @@ trait AstForStatementsCreator(filename: String)(implicit withSchemaValidation: V lineEnd(compoundStmtCtx).head, column(compoundStmtCtx).head, columnEnd(compoundStmtCtx).head - ) + ).head case None => val blockNode_ = blockNode(ctx, text(ctx), Defines.Any) val blockBodyAst = astForCompoundStatement(compoundStmtCtx) val blockParamAst = blockParamCtx.flatMap(astForBlockParameterContext) - Seq(blockAst(blockNode_, blockBodyAst.toList ++ blockParamAst)) + blockAst(blockNode_, blockBodyAst.toList ++ blockParamAst) } } - protected def astForDoBlock(ctx: DoBlockContext, blockMethodName: Option[String] = None): Seq[Ast] = { + protected def astForDoBlock(ctx: DoBlockContext, blockMethodName: Option[String] = None): Ast = { astForBlockHelper(ctx, Option(ctx.blockParameter), ctx.bodyStatement().compoundStatement(), blockMethodName) } - private def astForBraceBlock(ctx: BraceBlockContext, blockMethodName: Option[String] = None): Seq[Ast] = { + private def astForBraceBlock(ctx: BraceBlockContext, blockMethodName: Option[String] = None): Ast = { astForBlockHelper(ctx, Option(ctx.blockParameter), ctx.bodyStatement().compoundStatement(), blockMethodName) } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyScope.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyScope.scala index 5871be927e8f..393a860cbc44 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyScope.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyScope.scala @@ -1,9 +1,8 @@ package io.joern.rubysrc2cpg.astcreation -import io.joern.x2cpg.Ast import io.joern.x2cpg.datastructures.Scope import io.shiftleft.codepropertygraph.generated.EdgeTypes -import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.nodes.{DeclarationNew, NewIdentifier, NewLocal, NewNode} import overflowdb.BatchedUpdate import scala.collection.mutable @@ -31,10 +30,6 @@ class RubyScope extends Scope[String, NewIdentifier, NewNode] { scopeNode } - def removeFromScope(variable: NewIdentifier): Unit = { - stack.headOption.foreach(head => scopeToVarMap.removeIdentifierFromVarGroup(head.scopeNode, variable)) - } - override def popScope(): Option[NewNode] = { stack.headOption.map(_.scopeNode).foreach(scopeToVarMap.remove) super.popScope() @@ -47,18 +42,9 @@ class RubyScope extends Scope[String, NewIdentifier, NewNode] { def createAndLinkLocalNodes( diffGraph: BatchedUpdate.DiffGraphBuilder, paramNames: Set[String] = Set.empty - ): List[DeclarationNew] = { - stack.headOption match - case Some(top) => scopeToVarMap.buildVariableGroupings(top.scopeNode, paramNames ++ Set("this"), diffGraph) - case None => List.empty[DeclarationNew] - } - - /** Links the parameter node to the referenced identifiers in this scope. - */ - def linkParamNode(diffGraph: BatchedUpdate.DiffGraphBuilder, param: NewMethodParameterIn): Unit = - stack.headOption match - case Some(top) => scopeToVarMap.buildParameterGrouping(top.scopeNode, param, diffGraph) - case None => List.empty[DeclarationNew] + ): List[DeclarationNew] = stack.headOption match + case Some(top) => scopeToVarMap.buildVariableGroupings(top.scopeNode, paramNames ++ Set("this"), diffGraph) + case None => List.empty[DeclarationNew] private implicit class IdentifierExt(node: NewIdentifier) { @@ -91,18 +77,6 @@ class RubyScope extends Scope[String, NewIdentifier, NewNode] { Some(Map(identifier.name -> identifier.toNewVarGroup)) } - /** Removes an identifier from the var group. - */ - def removeIdentifierFromVarGroup(key: ScopeNodeType, identifier: NewIdentifier): Unit = - scopeMap.updateWith(key) { - case Some(varMap: VarMap) => - Some(varMap.updatedWith(identifier.name) { - case Some(varGroup: VarGroup) => Some(varGroup.copy(ids = varGroup.ids.filterNot(_ == identifier))) - case None => None - }) - case None => None - } - /** Will persist the variable groupings that do not represent parameter nodes and link them with REF edges. * @return * the list of persisted local nodes. @@ -122,23 +96,6 @@ class RubyScope extends Scope[String, NewIdentifier, NewNode] { } .toList case None => List.empty[DeclarationNew] - - /** Will persist a REF edge between the given parameter and its corresponding identifiers. - */ - def buildParameterGrouping( - key: ScopeNodeType, - param: NewMethodParameterIn, - diffGraph: BatchedUpdate.DiffGraphBuilder - ): Unit = { - scopeMap - .get(key) - .map(_.values) - .foreach(_.filter { case VarGroup(local, _) => local.name == param.name } - .foreach { case VarGroup(_, ids) => - ids.foreach(id => diffGraph.addEdge(id, param, EdgeTypes.REF)) - }) - } - } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/Defines.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/Defines.scala index 80edb4858ffa..42b32a2d4f69 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/Defines.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/Defines.scala @@ -35,21 +35,5 @@ object Defines { // Constructor method val Initialize = "initialize" - /* - * Fake methods created from yield blocks and their yield calls will have this suffix in their names - */ - val YIELD_SUFFIX = "_yield" - - /* - * This is used to mark call nodes created due to yield calls. This is set in their names at creation. - * The appropriate name wrt the names of their actual methods is set later in them. - */ - val UNRESOLVED_YIELD = "unresolved_yield" - - /* - * Ruby provides a dynamic method declaration via its metaprogramming keyword `define_method`. - */ - val DEFINE_METHOD = "define_method" - def getBuiltInType(typeInString: String) = s"${GlobalTypes.builtinPrefix}.$typeInString" } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/DoBlockTest.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/DoBlockTest.scala index ac378c86f80c..4361d21718c0 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/DoBlockTest.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/DoBlockTest.scala @@ -1,9 +1,7 @@ package io.joern.rubysrc2cpg.passes.ast import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.nodes.{Call, ControlStructure, Identifier, MethodRef} import io.shiftleft.semanticcpg.language.* -import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, NodeTypes} class DoBlockTest extends RubyCode2CpgFixture { @@ -33,100 +31,4 @@ class DoBlockTest extends RubyCode2CpgFixture { } } - "a do-block function used as a higher-order function" should { - val cpg = code("""class TransactionsController < ApplicationController - | def permitted_column_name(column_name) - | %w[trx_date description amount].find { |permitted| column_name == permitted } || 'trx_date' - | end - |end - | - |""".stripMargin) - - "create a do-block method named from the surrounding function" in { - val findMethod :: _ = cpg.method.name("find.*").l: @unchecked - findMethod.name should startWith("find") - findMethod.parameter.size shouldBe 1 - val permitParam :: _ = findMethod.parameter.l: @unchecked - permitParam.name shouldBe "permitted" - } - - } - - "a do-block function wrapped within an active record association" should { - val cpg = code(""" - |refunds = [] - |attrs = { - | refunds: refunds.sort_by! { |r| r["created"] } - |} - |""".stripMargin) - - "create a do-block method named from the surrounding function" in { - val findMethod :: _ = cpg.method.name("sort_by.*").l: @unchecked - findMethod.name should startWith("sort_by") - findMethod.parameter.size shouldBe 1 - val permitParam :: _ = findMethod.parameter.l: @unchecked - permitParam.name shouldBe "r" - } - } - - "a do-block function wrapped within a chained invocation inside of a call argument" should { - val cpg = code("OpenStruct.new(obj.map { |key, val| [key, to_recursive_ostruct(val)] }.to_h)") - - "create a do-block method named from the surrounding function" in { - val mapMethod :: _ = cpg.method.name("map.*").l: @unchecked - mapMethod.name should startWith("map") - mapMethod.parameter.size shouldBe 2 - val k :: v :: _ = mapMethod.parameter.l: @unchecked - k.name shouldBe "key" - v.name shouldBe "val" - } - } - - "chained higher-order functions as do-block functions" should { - val cpg = code(""" - |xs - | .select { |x| x.foo } - | .each { |y| puts(y) } - |""".stripMargin) - - "chain the two methods calls where one is the argument of the other" in { - val selectCall = cpg.call.nameExact("select").head - val eachCall = cpg.call.nameExact("each").head - selectCall.astParent shouldBe eachCall - - selectCall.astChildren.l match - case ::(xs: Identifier, ::(selectRef: MethodRef, Nil)) => - xs.name shouldBe "xs" - selectRef.referencedMethod.name should startWith("select") - case _ => fail("'select' call children are not what is expected.") - - eachCall.astChildren.l match - case ::(sCall: Call, ::(eachRef: MethodRef, Nil)) => - sCall shouldBe selectCall - eachRef.referencedMethod.name should startWith("each") - case _ => fail("'each' call children are not what is expected.") - } - } - - "a boolean do-block function as a conditional argument" should { - val cpg = code(""" - |if @items.any? { |x| x > 1 } - | puts "foo" - |else - | puts "bar" - |end - |""".stripMargin) - - "be defined outside of the control structure" in { - val anyMethod = cpg.method.name("any.*").head - anyMethod.astParent.label shouldBe NodeTypes.BLOCK - } - - "have the call to the method ref as the conditional argument" in { - val ifStmt = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).head - val ::(anyCall: Call, _) = ifStmt.condition.l: @unchecked - anyCall.name should startWith("any") - } - } - } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/MethodTwoTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/MethodTwoTests.scala index 8f8605204d4a..86bb049f5cf4 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/MethodTwoTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/MethodTwoTests.scala @@ -1,7 +1,9 @@ package io.joern.rubysrc2cpg.passes.ast import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture +import io.shiftleft.codepropertygraph.generated.{EvaluationStrategies, NodeTypes} import io.shiftleft.semanticcpg.language.* +import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal class MethodTwoTests extends RubyCode2CpgFixture { @@ -12,76 +14,88 @@ class MethodTwoTests extends RubyCode2CpgFixture { |end |""".stripMargin) - "should contain exactly one method node with correct fields" in { + // TODO: This test cases needs to be fixed. + "should contain exactly one method node with correct fields" ignore { inside(cpg.method.name("foo").l) { case List(x) => x.name shouldBe "foo" x.isExternal shouldBe false - x.fullName shouldBe "Test0.rb::program.foo" - x.code should startWith("return \"\"") + x.fullName shouldBe "Test0.rb::program:foo" + x.code should startWith("def foo(a, b)") x.isExternal shouldBe false - x.order shouldBe 2 + x.order shouldBe 1 x.filename endsWith "Test0.rb" x.lineNumber shouldBe Option(2) x.lineNumberEnd shouldBe Option(4) } } - "should return correct number of lines" in { + // TODO: This test cases needs to be fixed. + "should return correct number of lines" ignore { cpg.method.name("foo").numberOfLines.l shouldBe List(3) } - "should allow traversing to parameters" in { + // TODO: This test cases needs to be fixed. + "should allow traversing to parameters" ignore { cpg.method.name("foo").parameter.name.toSetMutable shouldBe Set("a", "b") } - "should allow traversing to methodReturn" in { + // TODO: This test cases needs to be fixed. + "should allow traversing to methodReturn" ignore { cpg.method.name("foo").methodReturn.l.size shouldBe 1 cpg.method.name("foo").methodReturn.typeFullName.head shouldBe "ANY" } - "should allow traversing to method" in { - cpg.methodReturn.method.isExternal(false).name.l shouldBe List("foo", ":program") + // TODO: This test cases needs to be fixed. + "should allow traversing to method" ignore { + cpg.methodReturn.method.name.l shouldBe List("foo", ":program") } - "should allow traversing to file" in { + // TODO: This test cases needs to be fixed. + "should allow traversing to file" ignore { cpg.method.name("foo").file.name.l should not be empty } - "test function method ref" in { + // TODO: Need to be fixed + "test function method ref" ignore { cpg.methodRef("foo").referencedMethod.fullName.l should not be empty - cpg.methodRef("foo").referencedMethod.fullName.head shouldBe "Test0.rb::program.foo" + cpg.methodRef("foo").referencedMethod.fullName.head shouldBe + "Test0.rb::program:foo" } - "test existence of local variable in module function" in { + // TODO: Need to be fixed. + "test existence of local variable in module function" ignore { cpg.method.fullName("Test0.rb::program").local.name.l should contain("foo") } - "test corresponding type, typeDecl and binding" in { - cpg.method.fullName("Test0.rb::program.foo").referencingBinding.bindingTypeDecl.l should not be empty + // TODO: need to be fixed. + "test corresponding type, typeDecl and binding" ignore { + cpg.method.fullName("Test0.rb::program:foo").referencingBinding.bindingTypeDecl.l should not be empty val bindingTypeDecl = - cpg.method.fullName("Test0.rb::program.foo").referencingBinding.bindingTypeDecl.head + cpg.method.fullName("Test0.rb::program:foo").referencingBinding.bindingTypeDecl.head bindingTypeDecl.name shouldBe "foo" - bindingTypeDecl.fullName shouldBe "Test0.rb::program.foo" + bindingTypeDecl.fullName shouldBe "Test0.rb::program:foo" bindingTypeDecl.referencingType.name.head shouldBe "foo" - bindingTypeDecl.referencingType.fullName.head shouldBe "Test0.rb::program.foo" + bindingTypeDecl.referencingType.fullName.head shouldBe "Test0.rb::program:foo" } - "test method parameter nodes" in { + // TODO: Need to be fixed + "test method parameter nodes" ignore { cpg.method.name("foo").parameter.name.l.size shouldBe 2 - val parameter1 = cpg.method.fullName("Test0.rb::program.foo").parameter.order(1).head + val parameter1 = cpg.method.fullName("Test0.rb::program:foo").parameter.order(1).head parameter1.name shouldBe "a" parameter1.index shouldBe 1 parameter1.typeFullName shouldBe "ANY" - val parameter2 = cpg.method.fullName("Test0.rb::program.foo").parameter.order(2).head + val parameter2 = cpg.method.fullName("Test0.rb::program:foo").parameter.order(2).head parameter2.name shouldBe "b" parameter2.index shouldBe 2 parameter2.typeFullName shouldBe "ANY" } - "should allow traversing from parameter to method" in { + // TODO: Need to be fixed + "should allow traversing from parameter to method" ignore { cpg.parameter.name("a").method.name.l shouldBe List("foo") cpg.parameter.name("b").method.name.l shouldBe List("foo") } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/SimpleAstCreationPassTest.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/SimpleAstCreationPassTest.scala index 97909ef2c446..9e5ccfe1d5ef 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/SimpleAstCreationPassTest.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/SimpleAstCreationPassTest.scala @@ -733,7 +733,7 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture { val cpg = code("object::foo do\nputs \"right here\"\nend") val List(callNode1) = cpg.call.name("foo").l - callNode1.code shouldBe "def foo1(...)" + callNode1.code shouldBe "puts \"right here\"" callNode1.lineNumber shouldBe Some(1) callNode1.columnNumber shouldBe Some(3) diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala index 18c7cfd4625e..48e197ddcbd5 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala @@ -4,7 +4,6 @@ import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import io.shiftleft.codepropertygraph.generated.ControlStructureTypes import io.shiftleft.codepropertygraph.generated.nodes.{Block, ControlStructure} import io.shiftleft.semanticcpg.language.* - class ControlStructureTests extends RubyCode2CpgFixture { "CPG for code with doBlock iterating over a constant array" should { @@ -16,7 +15,7 @@ class ControlStructureTests extends RubyCode2CpgFixture { "recognise all identifier nodes" in { cpg.identifier.name("n").size shouldBe 1 - cpg.identifier.size shouldBe 3 // 1 identifier node is for `puts = typeDef(__builtin.puts)` and similarly for `each2` + cpg.identifier.size shouldBe 2 // 1 identifier node is for `puts = typeDef(__builtin.puts)` } "recognize all call nodes" in { @@ -57,7 +56,8 @@ class ControlStructureTests extends RubyCode2CpgFixture { "recognise all identifier nodes" in { cpg.identifier.name("n").size shouldBe 2 cpg.identifier.name("m").size shouldBe 1 - cpg.identifier.size shouldBe 6 // includes each2 = def each2(...) + cpg.identifier.size shouldBe 5 + cpg.method.name("fakeName").dotAst.l } "recognize all call nodes" in { diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MiscTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MiscTests.scala index d030f7f3a01a..76997217302a 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MiscTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MiscTests.scala @@ -54,7 +54,7 @@ class MiscTests extends RubyCode2CpgFixture { cpg.identifier.name("Formatter").size shouldBe 1 cpg.identifier.name("Logger").size shouldBe 1 cpg.identifier.name("log_formatter").size shouldBe 1 - cpg.identifier.size shouldBe 6 + cpg.identifier.size shouldBe 5 } }