From 309893e36cdfa2fad6d3cd11798b2a4a5c2861e8 Mon Sep 17 00:00:00 2001 From: KhemrajSingh Rathore Date: Mon, 22 Apr 2024 16:23:14 +0530 Subject: [PATCH] Support for Inference via endpoint variable in API (#1075) * initial work * revert - setRule name change * first working version * added a todo * minor refactor --- .../scala/ai/privado/cache/RuleCache.scala | 9 ++ .../sink/api/JavaAPIRetrofitTagger.scala | 14 +- ...inkByParameterMarkByAnnotationTagger.scala | 3 +- .../api/JavaAPISinkByParameterTagger.scala | 3 +- ...APISinkEndpointMapperByNonInitMethod.scala | 38 ++---- .../java/tagger/sink/api/JavaAPITagger.scala | 4 + .../java/tagger/sink/api/Utility.scala | 44 ++---- .../php/tagger/sink/APITagger.scala | 1 + .../python/tagger/sink/PythonAPITagger.scala | 1 + .../ruby/tagger/sink/APITagger.scala | 1 + .../scala/ai/privado/model/PrivadoTag.scala | 2 + .../ai/privado/tagger/sink/APITagger.scala | 1 + .../sink/api/InferenceAPIEndpointTagger.scala | 14 +- .../tagger/utility/APITaggerUtility.scala | 125 +++++++++++++----- .../api/InferenceAPIEndpointTaggerTest.scala | 100 +++++++++++++- 15 files changed, 238 insertions(+), 122 deletions(-) diff --git a/src/main/scala/ai/privado/cache/RuleCache.scala b/src/main/scala/ai/privado/cache/RuleCache.scala index 271e4337e..f5970d3d5 100644 --- a/src/main/scala/ai/privado/cache/RuleCache.scala +++ b/src/main/scala/ai/privado/cache/RuleCache.scala @@ -64,6 +64,15 @@ class RuleCache { } } + def addThirdPartyRuleInfo(thirdPartyAPIRuleInfo: RuleInfo, domain: String): String = { + val newRuleIdToUse = s"${Constants.thirdPartiesAPIRuleId}.$domain" + this.setRuleInfo( + thirdPartyAPIRuleInfo + .copy(id = newRuleIdToUse, name = s"${thirdPartyAPIRuleInfo.name} $domain", isGenerated = true) + ) + newRuleIdToUse + } + def addStorageRuleInfo(ruleInfo: RuleInfo): Unit = storageRuleInfo.addOne(ruleInfo) def getStorageRuleInfo(): List[RuleInfo] = storageRuleInfo.toList diff --git a/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPIRetrofitTagger.scala b/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPIRetrofitTagger.scala index a724030f1..25ca85496 100644 --- a/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPIRetrofitTagger.scala +++ b/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPIRetrofitTagger.scala @@ -16,8 +16,6 @@ class JavaAPIRetrofitTagger(cpg: Cpg, ruleCache: RuleCache) extends PrivadoParal private val apiMatchingRegex = ruleCache.getAllRuleInfo.filter(_.nodeType == NodeType.API).map(_.combinedRulePattern).mkString("(", "|", ")") - private val thirdPartyRuleInfo = ruleCache.getRuleInfo(Constants.thirdPartiesAPIRuleId) - private val logger = getLogger(this.getClass) override def generateParts(): Array[(Call, Call)] = { cpg @@ -46,19 +44,11 @@ class JavaAPIRetrofitTagger(cpg: Cpg, ruleCache: RuleCache) extends PrivadoParal Try(baseUrlCall.argument.last).toOption match case Some(lit: Literal) => // Mark the nodes as API sink - tagAPICallByItsUrlMethod(cpg, builder, lit, sinkCalls, apiMatchingRegex, thirdPartyRuleInfo, ruleCache) + tagAPICallByItsUrlMethod(cpg, builder, lit, sinkCalls, apiMatchingRegex, ruleCache) case _ => Try { val methodNode = createCall.method - tagAPICallByItsUrlMethod( - cpg, - builder, - methodNode, - sinkCalls, - apiMatchingRegex, - thirdPartyRuleInfo, - ruleCache - ) + tagAPICallByItsUrlMethod(cpg, builder, methodNode, sinkCalls, apiMatchingRegex, ruleCache) } match case Failure(e) => logger.debug(s"Failed to get to a method node for retrofit create call") case _ => diff --git a/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkByParameterMarkByAnnotationTagger.scala b/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkByParameterMarkByAnnotationTagger.scala index e3712bd81..b339f4de3 100644 --- a/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkByParameterMarkByAnnotationTagger.scala +++ b/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkByParameterMarkByAnnotationTagger.scala @@ -17,7 +17,6 @@ class JavaAPISinkByParameterMarkByAnnotationTagger(cpg: Cpg, ruleCache: RuleCach private val apiMatchingRegex = ruleCache.getAllRuleInfo.filter(_.nodeType == NodeType.API).map(_.combinedRulePattern).mkString("(", "|", ")") - private val thirdPartyRuleInfo = ruleCache.getRuleInfo(Constants.thirdPartiesAPIRuleId) override def generateParts(): Array[(Method, List[String])] = { val methodWithAnnotationCode = cpg.parameter @@ -70,7 +69,7 @@ class JavaAPISinkByParameterMarkByAnnotationTagger(cpg: Cpg, ruleCache: RuleCach .l // Mark the nodes as API sink - tagAPICallByItsUrlMethod(cpg, builder, methodNode, sinkCalls, apiMatchingRegex, thirdPartyRuleInfo, ruleCache) + tagAPICallByItsUrlMethod(cpg, builder, methodNode, sinkCalls, apiMatchingRegex, ruleCache) } diff --git a/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkByParameterTagger.scala b/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkByParameterTagger.scala index f424150d6..48b98dccb 100644 --- a/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkByParameterTagger.scala +++ b/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkByParameterTagger.scala @@ -19,7 +19,6 @@ class JavaAPISinkByParameterTagger(cpg: Cpg, ruleCache: RuleCache) private val apiMatchingRegex = ruleCache.getAllRuleInfo.filter(_.nodeType == NodeType.API).map(_.combinedRulePattern).mkString("(", "|", ")") - private val thirdPartyRuleInfo = ruleCache.getRuleInfo(Constants.thirdPartiesAPIRuleId) override def generateParts(): Array[(Method, String)] = { /* Below query looks for methods whose parameter names ends with `url|endpoint`, @@ -96,7 +95,7 @@ class JavaAPISinkByParameterTagger(cpg: Cpg, ruleCache: RuleCache) .l // Mark the nodes as API sink - tagAPICallByItsUrlMethod(cpg, builder, methodNode, sinkCalls, apiMatchingRegex, thirdPartyRuleInfo, ruleCache) + tagAPICallByItsUrlMethod(cpg, builder, methodNode, sinkCalls, apiMatchingRegex, ruleCache) } } diff --git a/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkEndpointMapperByNonInitMethod.scala b/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkEndpointMapperByNonInitMethod.scala index 68e0e5034..780170460 100644 --- a/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkEndpointMapperByNonInitMethod.scala +++ b/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkEndpointMapperByNonInitMethod.scala @@ -10,7 +10,7 @@ import ai.privado.languageEngine.java.tagger.sink.api.Utility.tagAPICallByItsUrl import ai.privado.tagger.utility.APITaggerUtility.{ getLiteralCode, resolveDomainFromSource, - tagAPIWithDomainAndUpdateRuleCache + tagThirdPartyAPIWithDomainAndUpdateRuleCache } import ai.privado.utility.Utilities.{addRuleTags, getDomainFromString, storeForTag} import io.shiftleft.codepropertygraph.generated.nodes.{AstNode, Call, Method} @@ -23,7 +23,6 @@ class JavaAPISinkEndpointMapperByNonInitMethod(cpg: Cpg, ruleCache: RuleCache) private val apiMatchingRegex = ruleCache.getAllRuleInfo.filter(_.nodeType == NodeType.API).map(_.combinedRulePattern).mkString("(", "|", ")") - private val thirdPartyRuleInfo = ruleCache.getRuleInfo(Constants.thirdPartiesAPIRuleId) override def generateParts(): Array[String] = { /* General assumption - there is a function which creates a client, and the usage of the client and binding @@ -32,20 +31,17 @@ class JavaAPISinkEndpointMapperByNonInitMethod(cpg: Cpg, ruleCache: RuleCache) we can say the client uses the following endpoint */ - if (thirdPartyRuleInfo.isDefined) { - cpg.call - .where(_.tag.nameExact(InternalTag.API_SINK_MARKED.toString)) - .methodFullName - .map(_.split(methodFullNameSplitter).headOption.getOrElse("")) - .filter(_.nonEmpty) - .map { methodNamespace => - val parts = methodNamespace.split("[.]") - if parts.nonEmpty then parts.dropRight(1).mkString(".") else "" - } - .dedup - .toArray - } else - Array[String]() + cpg.call + .where(_.tag.nameExact(InternalTag.API_SINK_MARKED.toString)) + .methodFullName + .map(_.split(methodFullNameSplitter).headOption.getOrElse("")) + .filter(_.nonEmpty) + .map { methodNamespace => + val parts = methodNamespace.split("[.]") + if parts.nonEmpty then parts.dropRight(1).mkString(".") else "" + } + .dedup + .toArray } override def runOnPart(builder: DiffGraphBuilder, typeFullName: String): Unit = { @@ -57,15 +53,7 @@ class JavaAPISinkEndpointMapperByNonInitMethod(cpg: Cpg, ruleCache: RuleCache) .where(_.tag.nameExact(InternalTag.API_SINK_MARKED.toString)) .l - tagAPICallByItsUrlMethod( - cpg, - builder, - clientReturningMethod, - impactedApiCalls, - apiMatchingRegex, - thirdPartyRuleInfo, - ruleCache - ) + tagAPICallByItsUrlMethod(cpg, builder, clientReturningMethod, impactedApiCalls, apiMatchingRegex, ruleCache) } } } diff --git a/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPITagger.scala b/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPITagger.scala index c759a7d4a..93db155fd 100644 --- a/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPITagger.scala +++ b/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPITagger.scala @@ -165,6 +165,7 @@ class JavaAPITagger(cpg: Cpg, ruleCache: RuleCache, privadoInputConfig: PrivadoI logger.debug("Using brute API Tagger to find API sinks") println(s"${Calendar.getInstance().getTime} - --API TAGGER V1 invoked...") sinkTagger( + cpg, apiInternalSources ++ propertySources ++ identifierSource ++ serviceSource, apis, builder, @@ -174,6 +175,7 @@ class JavaAPITagger(cpg: Cpg, ruleCache: RuleCache, privadoInputConfig: PrivadoI privadoInputConfig.enableAPIDisplay ) sinkTagger( + cpg, apiInternalSources ++ propertySources ++ identifierSource ++ serviceSource, feignAPISinks ++ grpcSinks ++ soapSinks ++ markedAPISinks, builder, @@ -185,6 +187,7 @@ class JavaAPITagger(cpg: Cpg, ruleCache: RuleCache, privadoInputConfig: PrivadoI logger.debug("Using Enhanced API tagger to find API sinks") println(s"${Calendar.getInstance().getTime} - --API TAGGER V2 invoked...") sinkTagger( + cpg, apiInternalSources ++ propertySources ++ identifierSource ++ serviceSource, apis.methodFullName(commonHttpPackages).l ++ feignAPISinks ++ grpcSinks ++ soapSinks ++ markedAPISinks, builder, @@ -196,6 +199,7 @@ class JavaAPITagger(cpg: Cpg, ruleCache: RuleCache, privadoInputConfig: PrivadoI logger.debug("Skipping API Tagger because valid match not found, only applying Feign client") println(s"${Calendar.getInstance().getTime} - --API TAGGER SKIPPED, applying Feign client API...") sinkTagger( + cpg, apiInternalSources ++ propertySources ++ identifierSource ++ serviceSource, feignAPISinks ++ grpcSinks ++ soapSinks ++ markedAPISinks, builder, diff --git a/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/Utility.scala b/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/Utility.scala index 3bbdc2755..9c6d90531 100644 --- a/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/Utility.scala +++ b/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/Utility.scala @@ -2,7 +2,10 @@ package ai.privado.languageEngine.java.tagger.sink.api import ai.privado.cache.RuleCache import ai.privado.model.{CatLevelOne, Constants, InternalTag, RuleInfo} -import ai.privado.tagger.utility.APITaggerUtility.{resolveDomainFromSource, tagAPIWithDomainAndUpdateRuleCache} +import ai.privado.tagger.utility.APITaggerUtility.{ + resolveDomainFromSource, + tagThirdPartyAPIWithDomainAndUpdateRuleCache +} import ai.privado.utility.Utilities.{getDomainFromString, storeForTag} import io.shiftleft.codepropertygraph.generated.nodes.{AstNode, Call, Literal, Method} import overflowdb.BatchedUpdate.DiffGraphBuilder @@ -20,7 +23,6 @@ object Utility { nodeToComputeUrl: AstNode, apiCalls: List[Call], apiMatchingRegex: String, - thirdPartyRuleInfo: Option[RuleInfo], ruleCache: RuleCache ): Unit = { @@ -33,17 +35,9 @@ object Utility { nodeToComputeUrl match { case lit: Literal => - extractUrlFromLiteralAndTag(cpg, lit, builder, impactedApiCalls, thirdPartyRuleInfo, ruleCache) + extractUrlFromLiteralAndTag(cpg, lit, builder, impactedApiCalls, ruleCache) case method: Method => - extractUrlFromMethodAndTag( - cpg, - method, - builder, - impactedApiCalls, - apiMatchingRegex, - thirdPartyRuleInfo, - ruleCache - ) + extractUrlFromMethodAndTag(cpg, method, builder, impactedApiCalls, apiMatchingRegex, ruleCache) } } @@ -54,13 +48,12 @@ object Utility { literalNode: Literal, builder: DiffGraphBuilder, impactedApiCalls: List[Call], - thirdPartyRuleInfo: Option[RuleInfo], ruleCache: RuleCache ): Unit = { val domain = resolveDomainFromSource(literalNode) impactedApiCalls.foreach { apiCall => - tagAPIWithDomainAndUpdateRuleCache(builder, thirdPartyRuleInfo.get, ruleCache, domain, apiCall, literalNode) + tagThirdPartyAPIWithDomainAndUpdateRuleCache(builder, cpg, ruleCache, domain, apiCall, literalNode) storeForTag(builder, apiCall, ruleCache)(InternalTag.API_SINK_MARKED.toString) storeForTag(builder, apiCall, ruleCache)(InternalTag.API_URL_MARKED.toString) } @@ -72,7 +65,6 @@ object Utility { builder: DiffGraphBuilder, impactedApiCalls: List[Call], apiMatchingRegex: String, - thirdPartyRuleInfo: Option[RuleInfo], ruleCache: RuleCache ): Unit = { /* @@ -83,7 +75,7 @@ object Utility { matchingProperties.foreach { propertyNode => val domain = getDomainFromString(propertyNode.value) impactedApiCalls.foreach { apiCall => - tagAPIWithDomainAndUpdateRuleCache(builder, thirdPartyRuleInfo.get, ruleCache, domain, apiCall, propertyNode) + tagThirdPartyAPIWithDomainAndUpdateRuleCache(builder, cpg, ruleCache, domain, apiCall, propertyNode) storeForTag(builder, apiCall, ruleCache)(InternalTag.API_SINK_MARKED.toString) storeForTag(builder, apiCall, ruleCache)(InternalTag.API_URL_MARKED.toString) } @@ -144,14 +136,7 @@ object Utility { if (endpointNode.isDefined) { val domain = resolveDomainFromSource(endpointNode.get) impactedApiCalls.foreach { apiCall => - tagAPIWithDomainAndUpdateRuleCache( - builder, - thirdPartyRuleInfo.get, - ruleCache, - domain, - apiCall, - endpointNode.get - ) + tagThirdPartyAPIWithDomainAndUpdateRuleCache(builder, cpg, ruleCache, domain, apiCall, endpointNode.get) storeForTag(builder, apiCall, ruleCache)(InternalTag.API_SINK_MARKED.toString) storeForTag(builder, apiCall, ruleCache)(InternalTag.API_URL_MARKED.toString) } @@ -165,7 +150,7 @@ object Utility { matchingParameters.head // Pick only the first parameter as we don't want to tag same sink with multiple API's val domain = resolveDomainFromSource(parameter) impactedApiCalls.foreach { apiCall => - tagAPIWithDomainAndUpdateRuleCache(builder, thirdPartyRuleInfo.get, ruleCache, domain, apiCall, parameter) + tagThirdPartyAPIWithDomainAndUpdateRuleCache(builder, cpg, ruleCache, domain, apiCall, parameter) storeForTag(builder, apiCall, ruleCache)(InternalTag.API_SINK_MARKED.toString) storeForTag(builder, apiCall, ruleCache)(InternalTag.API_URL_MARKED.toString) } @@ -177,14 +162,7 @@ object Utility { matchingIdentifiers.head // Pick only the first identifier as we don't want to tag same sink with multiple API's val domain = resolveDomainFromSource(identifier) impactedApiCalls.foreach { apiCall => - tagAPIWithDomainAndUpdateRuleCache( - builder, - thirdPartyRuleInfo.get, - ruleCache, - domain, - apiCall, - identifier - ) + tagThirdPartyAPIWithDomainAndUpdateRuleCache(builder, cpg, ruleCache, domain, apiCall, identifier) storeForTag(builder, apiCall, ruleCache)(InternalTag.API_SINK_MARKED.toString) storeForTag(builder, apiCall, ruleCache)(InternalTag.API_URL_MARKED.toString) } diff --git a/src/main/scala/ai/privado/languageEngine/php/tagger/sink/APITagger.scala b/src/main/scala/ai/privado/languageEngine/php/tagger/sink/APITagger.scala index d5b69b4ed..b1508d2fd 100644 --- a/src/main/scala/ai/privado/languageEngine/php/tagger/sink/APITagger.scala +++ b/src/main/scala/ai/privado/languageEngine/php/tagger/sink/APITagger.scala @@ -61,6 +61,7 @@ class APITagger(cpg: Cpg, ruleCache: RuleCache, privadoInput: PrivadoInput, appC logger.debug("Using Enhanced API tagger to find API sinks") sinkTagger( + cpg, apiInternalSources ++ propertySources ++ identifierSource, httpApis.distinct, builder, diff --git a/src/main/scala/ai/privado/languageEngine/python/tagger/sink/PythonAPITagger.scala b/src/main/scala/ai/privado/languageEngine/python/tagger/sink/PythonAPITagger.scala index 126d84486..a923c8191 100644 --- a/src/main/scala/ai/privado/languageEngine/python/tagger/sink/PythonAPITagger.scala +++ b/src/main/scala/ai/privado/languageEngine/python/tagger/sink/PythonAPITagger.scala @@ -75,6 +75,7 @@ class PythonAPITagger(cpg: Cpg, ruleCache: RuleCache, privadoInput: PrivadoInput logger.debug("Using Enhanced API tagger to find API sinks") println(s"${Calendar.getInstance().getTime} - --API TAGGER Common HTTP Libraries Used...") sinkTagger( + cpg, apiInternalSources ++ propertySources ++ identifierSource, apis.methodFullName(commonHttpPackages).l, builder, diff --git a/src/main/scala/ai/privado/languageEngine/ruby/tagger/sink/APITagger.scala b/src/main/scala/ai/privado/languageEngine/ruby/tagger/sink/APITagger.scala index 999ca2f3b..ec3f7f19b 100644 --- a/src/main/scala/ai/privado/languageEngine/ruby/tagger/sink/APITagger.scala +++ b/src/main/scala/ai/privado/languageEngine/ruby/tagger/sink/APITagger.scala @@ -65,6 +65,7 @@ class APITagger(cpg: Cpg, ruleCache: RuleCache, privadoInput: PrivadoInput, appC logger.debug("Using Enhanced API tagger to find API sinks") println(s"${Calendar.getInstance().getTime} - --API TAGGER Common HTTP Libraries Used...") sinkTagger( + cpg, apiInternalSources ++ propertySources ++ identifierSource, (httpApis ++ clientLikeApis).distinct, builder, diff --git a/src/main/scala/ai/privado/model/PrivadoTag.scala b/src/main/scala/ai/privado/model/PrivadoTag.scala index 147b17c8e..792bad3ca 100644 --- a/src/main/scala/ai/privado/model/PrivadoTag.scala +++ b/src/main/scala/ai/privado/model/PrivadoTag.scala @@ -217,6 +217,8 @@ object FilterProperty extends Enumeration { // For Inference API Endpoint mapping val METHOD_FULL_NAME_WITH_LITERAL: model.FilterProperty.Value = Value("method_full_name_with_literal") val METHOD_FULL_NAME_WITH_PROPERTY_NAME: model.FilterProperty.Value = Value("method_full_name_with_property_name") + val ENDPOINT_DOMAIN_WITH_LITERAL: model.FilterProperty.Value = Value("endpoint_domain_with_literal") + val ENDPOINT_DOMAIN_WITH_PROPERTY_NAME: model.FilterProperty.Value = Value("endpoint_domain_with_property_name") def withNameWithDefault(name: String): Value = { try { diff --git a/src/main/scala/ai/privado/tagger/sink/APITagger.scala b/src/main/scala/ai/privado/tagger/sink/APITagger.scala index 259917a74..de928bed7 100644 --- a/src/main/scala/ai/privado/tagger/sink/APITagger.scala +++ b/src/main/scala/ai/privado/tagger/sink/APITagger.scala @@ -78,6 +78,7 @@ class APITagger(cpg: Cpg, ruleCache: RuleCache, privadoInput: PrivadoInput, appC List() } sinkTagger( + cpg, apiInternalSources ++ propertySources ++ identifierSource ++ serviceSource, apis, builder, diff --git a/src/main/scala/ai/privado/tagger/sink/api/InferenceAPIEndpointTagger.scala b/src/main/scala/ai/privado/tagger/sink/api/InferenceAPIEndpointTagger.scala index c886461a2..c34daf758 100644 --- a/src/main/scala/ai/privado/tagger/sink/api/InferenceAPIEndpointTagger.scala +++ b/src/main/scala/ai/privado/tagger/sink/api/InferenceAPIEndpointTagger.scala @@ -5,7 +5,7 @@ import ai.privado.languageEngine.java.language.* import ai.privado.model.FilterProperty.* import ai.privado.model.{Constants, InternalTag, RuleInfo} import ai.privado.tagger.PrivadoParallelCpgPass -import ai.privado.tagger.utility.APITaggerUtility.tagAPIWithDomainAndUpdateRuleCache +import ai.privado.tagger.utility.APITaggerUtility.tagThirdPartyAPIWithDomainAndUpdateRuleCache import ai.privado.utility.Utilities.{getDomainFromString, storeForTag} import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.nodes.AstNode @@ -18,8 +18,7 @@ import org.slf4j.LoggerFactory */ class InferenceAPIEndpointTagger(cpg: Cpg, ruleCache: RuleCache) extends PrivadoParallelCpgPass[RuleInfo](cpg) { - private val logger = LoggerFactory.getLogger(getClass) - private val thirdPartyRuleInfo = ruleCache.getRuleInfo(Constants.thirdPartiesAPIRuleId) + private val logger = LoggerFactory.getLogger(getClass) override def generateParts(): Array[RuleInfo] = ruleCache.getRule.inferences.filter(_.catLevelTwo.equals(Constants.apiEndpoint)).toArray @@ -58,14 +57,7 @@ class InferenceAPIEndpointTagger(cpg: Cpg, ruleCache: RuleCache) extends Privado } private def tagNode(builder: DiffGraphBuilder, apiCall: AstNode, apiUrl: String) = { - tagAPIWithDomainAndUpdateRuleCache( - builder, - thirdPartyRuleInfo.get, - ruleCache, - getDomainFromString(apiUrl), - apiCall, - apiUrl - ) + tagThirdPartyAPIWithDomainAndUpdateRuleCache(builder, cpg, ruleCache, getDomainFromString(apiUrl), apiCall, apiUrl) storeForTag(builder, apiCall, ruleCache)(InternalTag.API_SINK_MARKED.toString) storeForTag(builder, apiCall, ruleCache)(InternalTag.API_URL_MARKED.toString) } diff --git a/src/main/scala/ai/privado/tagger/utility/APITaggerUtility.scala b/src/main/scala/ai/privado/tagger/utility/APITaggerUtility.scala index aa43706a3..53091694b 100644 --- a/src/main/scala/ai/privado/tagger/utility/APITaggerUtility.scala +++ b/src/main/scala/ai/privado/tagger/utility/APITaggerUtility.scala @@ -28,7 +28,7 @@ import ai.privado.dataflow.DuplicateFlowProcessor import ai.privado.entrypoint.{PrivadoInput, ScanProcessor} import ai.privado.languageEngine.java.language.NodeToProperty import ai.privado.languageEngine.java.semantic.JavaSemanticGenerator -import ai.privado.model.{Constants, RuleInfo} +import ai.privado.model.{Constants, FilterProperty, RuleInfo} import ai.privado.utility.Utilities.{ addRuleTags, getDomainFromString, @@ -41,6 +41,9 @@ import io.joern.dataflowengineoss.queryengine.{EngineConfig, EngineContext} import io.shiftleft.codepropertygraph.generated.nodes.{AstNode, CfgNode, JavaProperty, Member} import overflowdb.BatchedUpdate import overflowdb.BatchedUpdate.DiffGraphBuilder +import io.shiftleft.codepropertygraph.generated.Cpg +import ai.privado.languageEngine.java.language._ +import io.shiftleft.semanticcpg.language._ object APITaggerUtility { @@ -58,6 +61,7 @@ object APITaggerUtility { } def sinkTagger( + cpg: Cpg, apiInternalSinkPattern: List[AstNode], apis: List[CfgNode], builder: BatchedUpdate.DiffGraphBuilder, @@ -80,33 +84,27 @@ object APITaggerUtility { val sourceNode = flow.elements.head val apiNode = flow.elements.last // Tag API's when we find a dataflow to them - var newRuleIdToUse = ruleInfo.id - if (ruleInfo.id.equals(Constants.internalAPIRuleId)) addRuleTags(builder, apiNode, ruleInfo, ruleCache) - else { - val domain = resolveDomainFromSource(sourceNode) - newRuleIdToUse = ruleInfo.id + "." + domain - ruleCache.setRuleInfo( - ruleInfo.copy(id = newRuleIdToUse, name = ruleInfo.name + " " + domain, isGenerated = true) + if ruleInfo.id.equals(Constants.internalAPIRuleId) then + tagAPINode(builder, ruleCache, ruleInfo, apiNode, getLiteralCode(sourceNode)) + else + tagThirdPartyAPIWithDomainAndUpdateRuleCache( + builder, + cpg, + ruleCache, + resolveDomainFromSource(sourceNode), + apiNode, + sourceNode ) - addRuleTags(builder, apiNode, ruleInfo, ruleCache, Some(newRuleIdToUse)) - } - storeForTag(builder, apiNode, ruleCache)(Constants.apiUrl + newRuleIdToUse, getLiteralCode(sourceNode)) }) // Add url as 'API' for non Internal api nodes, so that at-least we show API without domains if (showAPI && !ruleInfo.id.equals(Constants.internalAPIRuleId)) { val literalPathApiNodes = apiFlows.map(_.elements.last).toSet val apiNodesWithoutLiteralPath = apis.toSet.diff(literalPathApiNodes) - apiNodesWithoutLiteralPath.foreach(apiNode => { - addRuleTags(builder, apiNode, ruleInfo, ruleCache) - storeForTag(builder, apiNode, ruleCache)(Constants.apiUrl + ruleInfo.id, Constants.API) - }) + apiNodesWithoutLiteralPath.foreach(tagAPINode(builder, ruleCache, ruleInfo, _, Constants.API)) } } // Add url as 'API' for non Internal api nodes, for cases where there is no http literal present in source code else if (showAPI && filteredSourceNode.isEmpty && !ruleInfo.id.equals(Constants.internalAPIRuleId)) { - apis.foreach(apiNode => { - addRuleTags(builder, apiNode, ruleInfo, ruleCache) - storeForTag(builder, apiNode, ruleCache)(Constants.apiUrl + ruleInfo.id, Constants.API) - }) + apis.foreach(tagAPINode(builder, ruleCache, ruleInfo, _, Constants.API)) } } @@ -123,33 +121,90 @@ object APITaggerUtility { } } - def tagAPIWithDomainAndUpdateRuleCache( + /** Tag API or internal api nodes, Don't use this for tagging Third Party API's + * @param builder + * @param ruleCache + * @param apiNode + * @param apiUrl + * @return + */ + private def tagAPINode( builder: DiffGraphBuilder, + ruleCache: RuleCache, ruleInfo: RuleInfo, + apiNode: AstNode, + apiUrl: String + ) = { + addRuleTags(builder, apiNode, ruleInfo, ruleCache) + storeForTag(builder, apiNode, ruleCache)(Constants.apiUrl + ruleInfo.id, apiUrl) + } + + def tagThirdPartyAPIWithDomainAndUpdateRuleCache( + builder: DiffGraphBuilder, + cpg: Cpg, ruleCache: RuleCache, domain: String, apiNode: AstNode, - apiUrlNode: AstNode - ): DiffGraphBuilder = { - val newRuleIdToUse = ruleInfo.id + "." + domain - ruleCache.setRuleInfo(ruleInfo.copy(id = newRuleIdToUse, name = ruleInfo.name + " " + domain, isGenerated = true)) - addRuleTags(builder, apiNode, ruleInfo, ruleCache, Some(newRuleIdToUse)) - storeForTag(builder, apiNode, ruleCache)(Constants.apiUrl + newRuleIdToUse, getLiteralCode(apiUrlNode)) - + apiUrlNode: Any + ): Unit = { + // TODO Optimise this by adding a cache mechanism + val flatMapList = ruleCache.getRule.inferences + .filter(_.catLevelTwo.equals(Constants.apiEndpoint)) + .filter(_.domains.nonEmpty) + .flatMap { ruleInfo => + ruleInfo.filterProperty match { + case FilterProperty.ENDPOINT_DOMAIN_WITH_LITERAL if domain.matches(ruleInfo.combinedRulePattern) => + Some(ruleInfo.domains.head) + case FilterProperty.ENDPOINT_DOMAIN_WITH_PROPERTY_NAME + if domain.matches(ruleInfo.combinedRulePattern) && cpg.property.name(ruleInfo.domains.head).nonEmpty => + Some(cpg.property.name(ruleInfo.domains.head).value.head) + case _ => None + } + } + ruleCache.getRule.inferences + .filter(_.catLevelTwo.equals(Constants.apiEndpoint)) + .filter(_.domains.nonEmpty) + .flatMap { ruleInfo => + ruleInfo.filterProperty match { + case FilterProperty.ENDPOINT_DOMAIN_WITH_LITERAL if domain.matches(ruleInfo.combinedRulePattern) => + Some(ruleInfo.domains.head) + case FilterProperty.ENDPOINT_DOMAIN_WITH_PROPERTY_NAME + if domain.matches(ruleInfo.combinedRulePattern) && cpg.property.name(ruleInfo.domains.head).nonEmpty => + Some(cpg.property.name(ruleInfo.domains.head).value.head) + case _ => None + } + } + .headOption match + case Some(inferredDomain) => + addThirdPartyRuleAndTagAPI(builder, ruleCache, inferredDomain, apiNode, inferredDomain) + case None => + apiUrlNode match { + case x: String => addThirdPartyRuleAndTagAPI(builder, ruleCache, domain, apiNode, x) + case x: AstNode => addThirdPartyRuleAndTagAPI(builder, ruleCache, domain, apiNode, getLiteralCode(x)) + case _ => + } } - def tagAPIWithDomainAndUpdateRuleCache( + /** Generates a new third party rule, updates ruleCache, and tag the apiSink with this generated rule + * @param builder + * @param ruleCache + * @param domain + * @param apiNode + * @param apiUrl + * @return + */ + private def addThirdPartyRuleAndTagAPI( builder: DiffGraphBuilder, - ruleInfo: RuleInfo, ruleCache: RuleCache, domain: String, apiNode: AstNode, apiUrl: String - ): DiffGraphBuilder = { - val newRuleIdToUse = ruleInfo.id + "." + domain - ruleCache.setRuleInfo(ruleInfo.copy(id = newRuleIdToUse, name = ruleInfo.name + " " + domain, isGenerated = true)) - addRuleTags(builder, apiNode, ruleInfo, ruleCache, Some(newRuleIdToUse)) - storeForTag(builder, apiNode, ruleCache)(Constants.apiUrl + newRuleIdToUse, apiUrl) - + ) = { + ruleCache.getRuleInfo(Constants.thirdPartiesAPIRuleId) match + case Some(thirdPartyAPIRuleInfo) => + val newRuleId = ruleCache.addThirdPartyRuleInfo(thirdPartyAPIRuleInfo, domain) + addRuleTags(builder, apiNode, thirdPartyAPIRuleInfo, ruleCache, Some(newRuleId)) + storeForTag(builder, apiNode, ruleCache)(Constants.apiUrl + newRuleId, apiUrl) + case None => // Third party rule doesn't exist, which is ideally not possible } } diff --git a/src/test/scala/ai/privado/tagger/sink/api/InferenceAPIEndpointTaggerTest.scala b/src/test/scala/ai/privado/tagger/sink/api/InferenceAPIEndpointTaggerTest.scala index ecd37e8e1..3934ac52b 100644 --- a/src/test/scala/ai/privado/tagger/sink/api/InferenceAPIEndpointTaggerTest.scala +++ b/src/test/scala/ai/privado/tagger/sink/api/InferenceAPIEndpointTaggerTest.scala @@ -1,10 +1,10 @@ package ai.privado.tagger.sink.api import ai.privado.cache.RuleCache -import ai.privado.model.{CatLevelOne, Constants, FilterProperty, RuleInfo} +import ai.privado.model.{CatLevelOne, Constants, FilterProperty, RuleInfo, SystemConfig} import ai.privado.rule.RuleInfoTestData import ai.privado.testfixtures.JavaFrontendTestSuite -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class InferenceAPIEndpointTaggerTest extends JavaFrontendTestSuite with APIValidator { @@ -98,4 +98,100 @@ class InferenceAPIEndpointTaggerTest extends JavaFrontendTestSuite with APIValid assertAPIEndpointURL(getUsersCall, "http://user-service.com") } } + + "Inference tagger when used against filterProperty endpoint_domain_with_literal" should { + + val inferences = List( + RuleInfo( + id = "Inferences.API.UserServiceUrl.ByLiteral", + name = "User service url", + category = "", + filterProperty = FilterProperty.ENDPOINT_DOMAIN_WITH_LITERAL, + domains = Array("http://user-service.com"), + patterns = List("userServiceUrl"), + catLevelOne = CatLevelOne.INFERENCES, + catLevelTwo = Constants.apiEndpoint + ) + ) + + val systemConfig = List(SystemConfig(Constants.apiIdentifier, ".*Url")) + val ruleCache = + RuleCache().setRule(RuleInfoTestData.rule.copy(inferences = inferences, systemConfig = systemConfig)) + + val cpg = code(""" + |import org.apache.http.client.fluent.Request; + |import java.io.IOException; + |import io.privado.urls.userServiceUrl; + | + |public class Main { + | public static void main(String[] args) throws IOException { + | + | String responseBody = Request.Get(userServiceUrl).execute().returnContent().asString(); + | System.out.println("Response: " + responseBody); + | } + |} + | + |""".stripMargin).withRuleCache(ruleCache) + + "tag the api sink matching fullName" in { + val List(getUsersCall) = cpg.call("Get").l + assertAPISinkCall(getUsersCall) + } + + "tag the api sink with passed literal" in { + val List(getUsersCall) = cpg.call("Get").l + assertAPIEndpointURL(getUsersCall, "http://user-service.com") + } + } + + "Inference tagger when used against filterProperty endpoint_domain_with_property_name" should { + + val inferences = List( + RuleInfo( + id = "Inferences.API.UserServiceUrl.ByPropertyName", + name = "User service url", + category = "", + filterProperty = FilterProperty.ENDPOINT_DOMAIN_WITH_PROPERTY_NAME, + domains = Array(".*user.*url"), + patterns = List(".*userServiceUrl"), + catLevelOne = CatLevelOne.INFERENCES, + catLevelTwo = Constants.apiEndpoint + ) + ) + val systemConfig = List(SystemConfig(Constants.apiIdentifier, ".*Url")) + val ruleCache = + RuleCache().setRule(RuleInfoTestData.rule.copy(inferences = inferences, systemConfig = systemConfig)) + + val cpg = code(""" + |import org.apache.http.client.fluent.Request; + |import java.io.IOException; + |import io.privado.urls.userServiceUrl; + | + |public class Main { + | public static void main(String[] args) throws IOException { + | + | String responseBody = Request.Get(userServiceUrl).execute().returnContent().asString(); + | System.out.println("Response: " + responseBody); + | } + |} + |""".stripMargin) + .moreCode( + """ + |user.service.url = http://user-service.com + |""".stripMargin, + "application.properties" + ) + .withRuleCache(ruleCache) + + "tag the api sink matching fullName" in { + val List(getUsersCall) = cpg.call("Get").l + assertAPISinkCall(getUsersCall) + } + + "tag the api sink with passed literal" in { + val List(getUsersCall) = cpg.call("Get").l + assertAPIEndpointURL(getUsersCall, "http://user-service.com") + } + } + }