From bf592ab0e4e00f8b7707f14951f7b910e697fc0e Mon Sep 17 00:00:00 2001 From: KhemrajSingh Rathore Date: Fri, 19 Apr 2024 18:04:21 +0530 Subject: [PATCH] Jdk frontend (#1042) * trial * added a check * Update README.md * Update README.md * Update README.md * add - handle more case * skip already tagged api sinks * add - api by inference rule * fix - missing flow * use kotlin and java language as one in flows filtering * create new rule per ruleId only once * minor refactor * fix failing build * fix * added test case for retrofit * add test cases for inference api --- README.md | 2 +- .../rulevalidator/schema/inferences.json | 104 +++++++ .../scala/ai/privado/cache/RuleCache.scala | 21 +- .../scala/ai/privado/dataflow/Dataflow.scala | 5 +- .../dataflow/DuplicateFlowProcessor.scala | 3 +- .../ai/privado/entrypoint/ScanProcessor.scala | 60 ++++- .../go/tagger/PrivadoTagger.scala | 4 +- .../go/tagger/sink/GoAPISinkTagger.scala | 17 ++ .../java/tagger/PrivadoTagger.scala | 7 +- .../sink/api/JavaAPIRetrofitTagger.scala | 67 +++++ .../tagger/sink/api/JavaAPISinkTagger.scala | 10 +- .../tagger/sink/{ => api}/JavaAPITagger.scala | 27 +- .../java/tagger/sink/api/Utility.scala | 196 +++++++++----- .../kotlin/tagger/PrivadoTagger.scala | 5 +- src/main/scala/ai/privado/model/Config.scala | 31 ++- .../scala/ai/privado/model/Constants.scala | 3 + .../scala/ai/privado/model/PrivadoTag.scala | 10 +- .../rulevalidator/YamlFileValidator.scala | 3 + .../ai/privado/tagger/sink/APITagger.scala | 3 +- .../tagger/sink/api/APISinkTagger.scala | 6 +- .../sink/api/InferenceAPIEndpointTagger.scala | 72 +++++ .../tagger/utility/APITaggerUtility.scala | 17 +- .../config/JavaYamlLinkerPassTest.scala | 2 +- .../sink/api/JavaAPIRetrofitTaggerTest.scala | 255 ++++++++++++++++++ ...avaAPISinkByMethodFullNameTaggerTest.scala | 4 +- .../JavaAPISinkByParameterTaggerTest.scala | 3 +- .../tagger/sink/api/APIValidator.scala | 41 +++ .../api/InferenceAPIEndpointTaggerTest.scala | 101 +++++++ 28 files changed, 939 insertions(+), 140 deletions(-) create mode 100644 src/main/resources/ai/privado/rulevalidator/schema/inferences.json create mode 100644 src/main/scala/ai/privado/languageEngine/go/tagger/sink/GoAPISinkTagger.scala create mode 100644 src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPIRetrofitTagger.scala rename src/main/scala/ai/privado/languageEngine/java/tagger/sink/{ => api}/JavaAPITagger.scala (89%) create mode 100644 src/main/scala/ai/privado/tagger/sink/api/InferenceAPIEndpointTagger.scala create mode 100644 src/test/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPIRetrofitTaggerTest.scala create mode 100644 src/test/scala/ai/privado/tagger/sink/api/APIValidator.scala create mode 100644 src/test/scala/ai/privado/tagger/sink/api/InferenceAPIEndpointTaggerTest.scala diff --git a/README.md b/README.md index 556c079cb..8a10604e3 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ Privado Core ============================================= -Branch structure +Branch structure main - This branch will contain the released version of the code. diff --git a/src/main/resources/ai/privado/rulevalidator/schema/inferences.json b/src/main/resources/ai/privado/rulevalidator/schema/inferences.json new file mode 100644 index 000000000..dbff9f0a9 --- /dev/null +++ b/src/main/resources/ai/privado/rulevalidator/schema/inferences.json @@ -0,0 +1,104 @@ +{ + "definitions": {}, + "$schema": "http://json-schema.org/draft-07/schema#", + "$id": "https://github.com/Privado-Inc/privado-core/tree/main/src/main/resources/ai/privado/rulevalidator/schema/inferences.json", + "title": "Root", + "type": "object", + "required": [ + "inferences" + ], + "additionalProperties": false, + "properties": { + "inferences": { + "$id": "#root/inferences", + "title": "Inferences", + "type": "array", + "default": [], + "items":{ + "$id": "#root/inferences/items", + "title": "Items", + "type": "object", + "required": [ + "id", + "name", + "domains", + "patterns" + ], + "additionalProperties": false, + "properties": { + "id": { + "$id": "#root/inferences/items/id", + "title": "Id", + "type": "string", + "default": "", + "examples": [ + "Storages.AmazonS3.Read" + ], + "pattern": "^.*$" + }, + "name": { + "$id": "#root/inferences/items/name", + "title": "Name", + "type": "string", + "default": "", + "examples": [ + "Amazon S3(Read)" + ], + "pattern": "^.*$" + }, + "filterProperty": { + "$id": "#root/inferences/items/filterProperty", + "title": "FilterProperty", + "type": "string", + "default": "method_full_name", + "examples": [ + "code", + "method_full_name" + ], + "pattern": "^(code|method_full_name|method_full_name_with_literal|method_full_name_with_property_name)$" + }, + "domains": { + "$id": "#root/inferences/items/domains", + "title": "Domains", + "type": "array", + "default": [], + "items":{ + "$id": "#root/inferences/items/domains/items", + "title": "Items", + "type": "string", + "default": "", + "examples": [ + "aws.amazon.com" + ], + "pattern": "^.*$" + } + }, + "patterns": { + "$id": "#root/inferences/items/patterns", + "title": "Patterns", + "type": "array", + "default": [], + "items":{ + "$id": "#root/inferences/items/patterns/items", + "title": "Items", + "type": "string", + "format": "regex", + "default": "", + "examples": [ + ".*(AmazonS3).*" + ], + "pattern": "^.*$" + } + }, + "tags": { + "$id": "#root/inferences/items/tags", + "title": "Tags", + "type": ["object", "null"], + "default": null + } + } + } + + } + } +} diff --git a/src/main/scala/ai/privado/cache/RuleCache.scala b/src/main/scala/ai/privado/cache/RuleCache.scala index 54a994ebe..271e4337e 100644 --- a/src/main/scala/ai/privado/cache/RuleCache.scala +++ b/src/main/scala/ai/privado/cache/RuleCache.scala @@ -38,25 +38,30 @@ class RuleCache { val internalPolicies = mutable.Set[String]() private val storageRuleInfo = mutable.ListBuffer[RuleInfo]() - def setRule(rule: ConfigAndRules): Unit = { + // TODO, rename setRule to withRule as it return the ruleCache object and setters are Unit functions + def setRule(rule: ConfigAndRules): RuleCache = { this.rule = rule rule.sources.foreach(r => ruleInfoMap.addOne(r.id -> r)) rule.sinks.foreach(r => ruleInfoMap.addOne(r.id -> r)) rule.collections.foreach(r => ruleInfoMap.addOne(r.id -> r)) rule.policies.foreach(r => policyOrThreatMap.addOne(r.id -> r)) rule.threats.foreach(r => policyOrThreatMap.addOne(r.id -> r)) + this } def getRule: ConfigAndRules = rule def setRuleInfo(ruleInfo: RuleInfo): Unit = { - ruleInfoMap.addOne(ruleInfo.id -> ruleInfo) - rule = ruleInfo.catLevelOne match { - case ai.privado.model.CatLevelOne.SOURCES => rule.copy(sources = rule.sources.appended(ruleInfo)) - case ai.privado.model.CatLevelOne.SINKS => rule.copy(sinks = rule.sinks.appended(ruleInfo)) - case ai.privado.model.CatLevelOne.COLLECTIONS => rule.copy(collections = rule.collections.appended(ruleInfo)) - case _ => rule - } + ruleInfoMap.get(ruleInfo.id) match + case Some(_) => // Rule already exists, skip adding again + case None => + ruleInfoMap.addOne(ruleInfo.id -> ruleInfo) + rule = ruleInfo.catLevelOne match { + case ai.privado.model.CatLevelOne.SOURCES => rule.copy(sources = rule.sources.appended(ruleInfo)) + case ai.privado.model.CatLevelOne.SINKS => rule.copy(sinks = rule.sinks.appended(ruleInfo)) + case ai.privado.model.CatLevelOne.COLLECTIONS => rule.copy(collections = rule.collections.appended(ruleInfo)) + case _ => rule + } } def addStorageRuleInfo(ruleInfo: RuleInfo): Unit = storageRuleInfo.addOne(ruleInfo) diff --git a/src/main/scala/ai/privado/dataflow/Dataflow.scala b/src/main/scala/ai/privado/dataflow/Dataflow.scala index 9523ca670..7fdb27f60 100644 --- a/src/main/scala/ai/privado/dataflow/Dataflow.scala +++ b/src/main/scala/ai/privado/dataflow/Dataflow.scala @@ -139,10 +139,11 @@ class Dataflow(cpg: Cpg) { println(s"${Calendar.getInstance().getTime} - --Filtering flows 1 invoked...") appCache.totalFlowFromReachableBy = dataflowPathsUnfiltered.size - // Apply `this` filtering for JS & JAVA also + // Apply `this` filtering for JS, JAVA val dataflowPaths = { if ( - privadoScanConfig.disableThisFiltering || (appCache.repoLanguage != Language.JAVA && appCache.repoLanguage != Language.JAVASCRIPT) + privadoScanConfig.disableThisFiltering || (!List(Language.JAVA, Language.JAVASCRIPT) + .contains(appCache.repoLanguage)) ) dataflowPathsUnfiltered else diff --git a/src/main/scala/ai/privado/dataflow/DuplicateFlowProcessor.scala b/src/main/scala/ai/privado/dataflow/DuplicateFlowProcessor.scala index 67267558a..8176734f6 100644 --- a/src/main/scala/ai/privado/dataflow/DuplicateFlowProcessor.scala +++ b/src/main/scala/ai/privado/dataflow/DuplicateFlowProcessor.scala @@ -177,7 +177,8 @@ object DuplicateFlowProcessor { auditCache.addIntoBeforeSecondFiltering(SourcePathInfo(flow.pathSourceId, flow.sinkId, flow.sinkPathId)) } if ( - privadoScanConfig.disableFlowSeparationByDataElement || (appCache.repoLanguage != Language.JAVA && appCache.repoLanguage != Language.JAVASCRIPT) + privadoScanConfig.disableFlowSeparationByDataElement || (!List(Language.JAVA, Language.JAVASCRIPT) + .contains(appCache.repoLanguage)) ) { // Filter out flows where source is cookie and sink is cookie read if ( diff --git a/src/main/scala/ai/privado/entrypoint/ScanProcessor.scala b/src/main/scala/ai/privado/entrypoint/ScanProcessor.scala index 3a686e7e1..b05aa7afa 100644 --- a/src/main/scala/ai/privado/entrypoint/ScanProcessor.scala +++ b/src/main/scala/ai/privado/entrypoint/ScanProcessor.scala @@ -52,6 +52,7 @@ import io.joern.console.cpgcreation.guessLanguage import io.shiftleft.codepropertygraph.generated.Languages import org.slf4j.LoggerFactory import ai.privado.languageEngine.csharp.processor.CSharpProcessor +import io.joern.x2cpg.SourceFiles import java.util.Calendar import scala.collection.parallel.CollectionConverters.ImmutableIterableIsParallelizable @@ -199,6 +200,19 @@ object ScanProcessor extends CommandProcessor { language = Language.withNameWithDefault(pathTree.last) ) ) + .filter(filterByLang), + inferences = configAndRules.inferences + .filter(rule => isValidRule(rule.combinedRulePattern, rule.id, fullPath)) + .map(x => + x.copy( + file = fullPath, + catLevelOne = CatLevelOne.INFERENCES, + catLevelTwo = pathTree.apply(2), + categoryTree = pathTree, + language = Language.withNameWithDefault(pathTree.last), + nodeType = NodeType.withNameWithDefault(pathTree.apply(3)) + ) + ) .filter(filterByLang) ) case Left(error) => @@ -223,7 +237,8 @@ object ScanProcessor extends CommandProcessor { semantics = a.semantics ++ b.semantics, sinkSkipList = a.sinkSkipList ++ b.sinkSkipList, systemConfig = a.systemConfig ++ b.systemConfig, - auditConfig = a.auditConfig ++ b.auditConfig + auditConfig = a.auditConfig ++ b.auditConfig, + inferences = a.inferences ++ b.inferences ) ) catch { @@ -281,6 +296,7 @@ object ScanProcessor extends CommandProcessor { val sinkSkipList = externalConfigAndRules.sinkSkipList ++ internalConfigAndRules.sinkSkipList val systemConfig = externalConfigAndRules.systemConfig ++ internalConfigAndRules.systemConfig val auditConfig = externalConfigAndRules.auditConfig ++ internalConfigAndRules.auditConfig + val inferences = externalConfigAndRules.inferences ++ internalConfigAndRules.inferences val mergedRules = ConfigAndRules( sources = mergePatterns(sources), @@ -292,7 +308,8 @@ object ScanProcessor extends CommandProcessor { semantics = semantics.distinctBy(_.signature), sinkSkipList = sinkSkipList.distinctBy(_.id), systemConfig = systemConfig, - auditConfig = auditConfig.distinctBy(_.id) + auditConfig = auditConfig.distinctBy(_.id), + inferences = mergePatterns(inferences) ) logger.trace(mergedRules.toString) println(s"${Calendar.getInstance().getTime} - Configuration parsed...") @@ -306,7 +323,8 @@ object ScanProcessor extends CommandProcessor { mergedRules.collections.size + mergedRules.policies.size + mergedRules.exclusions.size + - mergedRules.auditConfig.size + mergedRules.auditConfig.size + + mergedRules.inferences.size ) } @@ -366,16 +384,34 @@ object ScanProcessor extends CommandProcessor { lang match { case language if language == Languages.JAVASRC || language == Languages.JAVA => println(s"${Calendar.getInstance().getTime} - Detected language 'Java'") - new JavaProcessor( - getProcessedRule(Set(Language.JAVA), appCache), - this.config, + val kotlinPlusJavaRules = getProcessedRule(Set(Language.KOTLIN, Language.JAVA), appCache) + val filesWithKtExtension = SourceFiles.determine( sourceRepoLocation, - dataFlowCache = getDataflowCache, - auditCache, - s3DatabaseDetailsCache, - appCache, - propertyFilterCache = propertyFilterCache - ).processCpg() + Set(".kt"), + ignoredFilesRegex = Option(kotlinPlusJavaRules.getExclusionRegex.r) + ) + if (filesWithKtExtension.isEmpty) + new JavaProcessor( + getProcessedRule(Set(Language.JAVA), appCache), + this.config, + sourceRepoLocation, + dataFlowCache = getDataflowCache, + auditCache, + s3DatabaseDetailsCache, + appCache, + propertyFilterCache = propertyFilterCache + ).processCpg() + else + new KotlinProcessor( + kotlinPlusJavaRules, + this.config, + sourceRepoLocation, + dataFlowCache = getDataflowCache, + auditCache, + s3DatabaseDetailsCache, + appCache, + propertyFilterCache = propertyFilterCache + ).processCpg() case language if language == Languages.JSSRC => println(s"${Calendar.getInstance().getTime} - Detected language 'JavaScript'") JavascriptProcessor.createJavaScriptCpg( diff --git a/src/main/scala/ai/privado/languageEngine/go/tagger/PrivadoTagger.scala b/src/main/scala/ai/privado/languageEngine/go/tagger/PrivadoTagger.scala index 021df89bc..686490d8c 100644 --- a/src/main/scala/ai/privado/languageEngine/go/tagger/PrivadoTagger.scala +++ b/src/main/scala/ai/privado/languageEngine/go/tagger/PrivadoTagger.scala @@ -12,7 +12,7 @@ import overflowdb.traversal.Traversal import io.shiftleft.semanticcpg.language.* import ai.privado.languageEngine.go.tagger.source.IdentifierTagger import ai.privado.languageEngine.go.tagger.config.GoDBConfigTagger -import ai.privado.languageEngine.go.tagger.sink.GoAPITagger +import ai.privado.languageEngine.go.tagger.sink.{GoAPISinkTagger, GoAPITagger} import ai.privado.tagger.sink.RegularSinkTagger import ai.privado.utility.Utilities.ingressUrls @@ -37,7 +37,7 @@ class PrivadoTagger(cpg: Cpg) extends PrivadoBaseTagger { new GoDBConfigTagger(cpg).createAndApply() - new GoAPITagger(cpg, ruleCache, privadoInput = privadoInputConfig, appCache = appCache).createAndApply() + GoAPISinkTagger.applyTagger(cpg, ruleCache, privadoInputConfig, appCache) new RegularSinkTagger(cpg, ruleCache).createAndApply() diff --git a/src/main/scala/ai/privado/languageEngine/go/tagger/sink/GoAPISinkTagger.scala b/src/main/scala/ai/privado/languageEngine/go/tagger/sink/GoAPISinkTagger.scala new file mode 100644 index 000000000..55a4ac620 --- /dev/null +++ b/src/main/scala/ai/privado/languageEngine/go/tagger/sink/GoAPISinkTagger.scala @@ -0,0 +1,17 @@ +package ai.privado.languageEngine.go.tagger.sink + +import ai.privado.cache.{AppCache, RuleCache} +import ai.privado.entrypoint.PrivadoInput +import ai.privado.tagger.sink.api.APISinkTagger +import io.shiftleft.codepropertygraph.generated.Cpg + +object GoAPISinkTagger extends APISinkTagger { + + override def applyTagger(cpg: Cpg, ruleCache: RuleCache, privadoInput: PrivadoInput, appCache: AppCache): Unit = { + + super.applyTagger(cpg, ruleCache, privadoInput, appCache) + + new GoAPITagger(cpg, ruleCache, privadoInput, appCache).createAndApply() + } + +} diff --git a/src/main/scala/ai/privado/languageEngine/java/tagger/PrivadoTagger.scala b/src/main/scala/ai/privado/languageEngine/java/tagger/PrivadoTagger.scala index 65d9dc858..f5d1cc911 100644 --- a/src/main/scala/ai/privado/languageEngine/java/tagger/PrivadoTagger.scala +++ b/src/main/scala/ai/privado/languageEngine/java/tagger/PrivadoTagger.scala @@ -40,8 +40,8 @@ import ai.privado.languageEngine.java.tagger.collection.{ SOAPCollectionTagger } import ai.privado.languageEngine.java.tagger.config.JavaDBConfigTagger -import ai.privado.languageEngine.java.tagger.sink.api.JavaAPISinkTagger -import ai.privado.languageEngine.java.tagger.sink.{InheritMethodTagger, JavaAPITagger, MessagingConsumerCustomTagger} +import ai.privado.languageEngine.java.tagger.sink.api.{JavaAPISinkTagger, JavaAPITagger} +import ai.privado.languageEngine.java.tagger.sink.{InheritMethodTagger, MessagingConsumerCustomTagger} import ai.privado.languageEngine.java.tagger.source.* import ai.privado.tagger.PrivadoBaseTagger import ai.privado.tagger.collection.{AndroidCollectionTagger, WebFormsCollectionTagger} @@ -83,9 +83,8 @@ class PrivadoTagger(cpg: Cpg) extends PrivadoBaseTagger { new JavaS3Tagger(cpg, s3DatabaseDetailsCache).createAndApply() - JavaAPISinkTagger.applyTagger(cpg, ruleCache, privadoInputConfig) + JavaAPISinkTagger.applyTagger(cpg, ruleCache, privadoInputConfig, appCache) - new JavaAPITagger(cpg, ruleCache, privadoInputConfig, appCache).createAndApply() // Custom Rule tagging if (!privadoInputConfig.ignoreInternalRules) { // Adding custom rule to cache diff --git a/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPIRetrofitTagger.scala b/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPIRetrofitTagger.scala new file mode 100644 index 000000000..a724030f1 --- /dev/null +++ b/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPIRetrofitTagger.scala @@ -0,0 +1,67 @@ +package ai.privado.languageEngine.java.tagger.sink.api + +import ai.privado.cache.RuleCache +import ai.privado.languageEngine.java.tagger.sink.api.Utility.tagAPICallByItsUrlMethod +import ai.privado.model.{Constants, NodeType} +import ai.privado.tagger.PrivadoParallelCpgPass +import io.shiftleft.codepropertygraph.generated.nodes.{Call, Literal, Method} +import io.shiftleft.codepropertygraph.generated.{Cpg, Operators} +import io.shiftleft.semanticcpg.language.* +import org.slf4j.LoggerFactory.getLogger + +import scala.util.{Failure, Success, Try} + +class JavaAPIRetrofitTagger(cpg: Cpg, ruleCache: RuleCache) extends PrivadoParallelCpgPass[(Call, Call)](cpg) { + + private val apiMatchingRegex = + ruleCache.getAllRuleInfo.filter(_.nodeType == NodeType.API).map(_.combinedRulePattern).mkString("(", "|", ")") + + private val thirdPartyRuleInfo = ruleCache.getRuleInfo(Constants.thirdPartiesAPIRuleId) + + private val logger = getLogger(this.getClass) + override def generateParts(): Array[(Call, Call)] = { + cpg + .call("create") + .code(".*class.*") + .map(c => (c, c.start.repeat(_.receiver.isCall)(_.until(_.name("baseUrl"))).l)) + .filter(_._2.nonEmpty) + .map(item => (item._1, item._2.head)) + .toArray + + } + + override def runOnPart(builder: DiffGraphBuilder, createCallWithBaseUrlCall: (Call, Call)): Unit = { + + val createCall = createCallWithBaseUrlCall._1 + val baseUrlCall = createCallWithBaseUrlCall._2 + + // Strip .class and ::.class.java as after stripping these we get the class name of client + val clientClassName = Try( + createCall.argument.isCall.name(Operators.fieldAccess).head.code.stripSuffix(".class").stripSuffix("::class.java") + ).toOption + + if (clientClassName.isDefined) { + val sinkCalls = cpg.call.methodFullName(s".*${clientClassName.get}[.].*").nameNot("getClass").l + + Try(baseUrlCall.argument.last).toOption match + case Some(lit: Literal) => + // Mark the nodes as API sink + tagAPICallByItsUrlMethod(cpg, builder, lit, sinkCalls, apiMatchingRegex, thirdPartyRuleInfo, ruleCache) + case _ => + Try { + val methodNode = createCall.method + tagAPICallByItsUrlMethod( + cpg, + builder, + methodNode, + sinkCalls, + apiMatchingRegex, + thirdPartyRuleInfo, + ruleCache + ) + } match + case Failure(e) => logger.debug(s"Failed to get to a method node for retrofit create call") + case _ => + } + } +} diff --git a/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkTagger.scala b/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkTagger.scala index aab68008c..e13a756f7 100644 --- a/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkTagger.scala +++ b/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkTagger.scala @@ -1,6 +1,6 @@ package ai.privado.languageEngine.java.tagger.sink.api -import ai.privado.cache.RuleCache +import ai.privado.cache.{AppCache, RuleCache} import ai.privado.entrypoint.PrivadoInput import ai.privado.tagger.sink.api.APISinkTagger import io.shiftleft.codepropertygraph.generated.Cpg @@ -11,7 +11,10 @@ object JavaAPISinkTagger extends APISinkTagger { * @param cpg * @param ruleCache */ - override def applyTagger(cpg: Cpg, ruleCache: RuleCache, privadoInput: PrivadoInput): Unit = { + override def applyTagger(cpg: Cpg, ruleCache: RuleCache, privadoInput: PrivadoInput, appCache: AppCache): Unit = { + + super.applyTagger(cpg, ruleCache, privadoInput, appCache) + new JavaAPIRetrofitTagger(cpg, ruleCache).createAndApply() if (privadoInput.enableAPIByParameter) { new JavaAPISinkByParameterMarkByAnnotationTagger(cpg, ruleCache).createAndApply() @@ -20,8 +23,7 @@ object JavaAPISinkTagger extends APISinkTagger { new JavaAPISinkByMethodFullNameTagger(cpg, ruleCache).createAndApply() - // Invoke API Endpoint mappers - // new JavaAPISinkEndpointMapperByNonInitMethod(cpg, ruleCache).createAndApply() + new JavaAPITagger(cpg, ruleCache, privadoInput, appCache).createAndApply() } } diff --git a/src/main/scala/ai/privado/languageEngine/java/tagger/sink/JavaAPITagger.scala b/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPITagger.scala similarity index 89% rename from src/main/scala/ai/privado/languageEngine/java/tagger/sink/JavaAPITagger.scala rename to src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPITagger.scala index 750eb5f1f..c759a7d4a 100644 --- a/src/main/scala/ai/privado/languageEngine/java/tagger/sink/JavaAPITagger.scala +++ b/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPITagger.scala @@ -20,15 +20,16 @@ * For more information, contact support@privado.ai * */ -package ai.privado.languageEngine.java.tagger.sink +package ai.privado.languageEngine.java.tagger.sink.api import ai.privado.cache.{AppCache, RuleCache} import ai.privado.entrypoint.{PrivadoInput, ScanProcessor} import ai.privado.languageEngine.java.language.* import ai.privado.languageEngine.java.semantic.JavaSemanticGenerator import ai.privado.languageEngine.java.tagger.Utility.{GRPCTaggerUtility, SOAPTaggerUtility} +import ai.privado.languageEngine.java.tagger.sink.FeignAPI import ai.privado.metric.MetricHandler -import ai.privado.model.{Constants, InternalTag, Language, NodeType, RuleInfo} +import ai.privado.model.* import ai.privado.tagger.PrivadoParallelCpgPass import ai.privado.tagger.utility.APITaggerUtility.{SERVICE_URL_REGEX_PATTERN, sinkTagger} import ai.privado.utility.{ImportUtility, Utilities} @@ -96,12 +97,23 @@ class JavaAPITagger(cpg: Cpg, ruleCache: RuleCache, privadoInputConfig: PrivadoI APITaggerVersionJava.V2Tagger } - apis = apis.whereNot(_.tag.nameExact(InternalTag.API_URL_MARKED.toString)).l + apis = apis + .whereNot(_.tag.nameExact(InternalTag.API_URL_MARKED.toString)) + .whereNot(_.tag.nameExact(Constants.nodeType).valueExact(NodeType.API.toString)) + .l val commonHttpPackages: String = ruleCache.getSystemConfigByKey(Constants.apiHttpLibraries) - val grpcSinks = GRPCTaggerUtility.getGrpcSinks(cpg).whereNot(_.tag.nameExact(InternalTag.API_URL_MARKED.toString)).l + val grpcSinks = GRPCTaggerUtility + .getGrpcSinks(cpg) + .whereNot(_.tag.nameExact(InternalTag.API_URL_MARKED.toString)) + .whereNot(_.tag.nameExact(Constants.nodeType).valueExact(NodeType.API.toString)) + .l val soapSinks = - SOAPTaggerUtility.getAPICallNodes(cpg).whereNot(_.tag.nameExact(InternalTag.API_URL_MARKED.toString)).l + SOAPTaggerUtility + .getAPICallNodes(cpg) + .whereNot(_.tag.nameExact(InternalTag.API_URL_MARKED.toString)) + .whereNot(_.tag.nameExact(Constants.nodeType).valueExact(NodeType.API.toString)) + .l override def generateParts(): Array[_ <: AnyRef] = { ruleCache.getAllRuleInfo @@ -138,11 +150,14 @@ class JavaAPITagger(cpg: Cpg, ruleCache: RuleCache, privadoInputConfig: PrivadoI ) else List() - }.whereNot(_.tag.nameExact(InternalTag.API_URL_MARKED.toString)).l + }.whereNot(_.tag.nameExact(InternalTag.API_URL_MARKED.toString)) + .whereNot(_.tag.nameExact(Constants.nodeType).valueExact(NodeType.API.toString)) + .l val markedAPISinks = cpg.call .where(_.tag.nameExact(InternalTag.API_SINK_MARKED.toString)) .whereNot(_.tag.nameExact(InternalTag.API_URL_MARKED.toString)) + .whereNot(_.tag.nameExact(Constants.nodeType).valueExact(NodeType.API.toString)) .l apiTaggerToUse match { diff --git a/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/Utility.scala b/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/Utility.scala index c1161158a..3bbdc2755 100644 --- a/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/Utility.scala +++ b/src/main/scala/ai/privado/languageEngine/java/tagger/sink/api/Utility.scala @@ -1,10 +1,10 @@ package ai.privado.languageEngine.java.tagger.sink.api import ai.privado.cache.RuleCache -import ai.privado.model.{Constants, InternalTag, RuleInfo} +import ai.privado.model.{CatLevelOne, Constants, InternalTag, RuleInfo} import ai.privado.tagger.utility.APITaggerUtility.{resolveDomainFromSource, tagAPIWithDomainAndUpdateRuleCache} import ai.privado.utility.Utilities.{getDomainFromString, storeForTag} -import io.shiftleft.codepropertygraph.generated.nodes.{Call, Method} +import io.shiftleft.codepropertygraph.generated.nodes.{AstNode, Call, Literal, Method} import overflowdb.BatchedUpdate.DiffGraphBuilder import io.shiftleft.semanticcpg.language.* import ai.privado.languageEngine.java.language.* @@ -17,48 +17,97 @@ object Utility { def tagAPICallByItsUrlMethod( cpg: Cpg, builder: DiffGraphBuilder, - methodNode: Method, + nodeToComputeUrl: AstNode, apiCalls: List[Call], apiMatchingRegex: String, thirdPartyRuleInfo: Option[RuleInfo], ruleCache: RuleCache ): Unit = { - val impactedApiCalls = apiCalls.whereNot(_.tag.nameExact(InternalTag.API_URL_MARKED.toString)).l + val impactedApiCalls = apiCalls + .whereNot(_.tag.nameExact(InternalTag.API_URL_MARKED.toString)) + .whereNot(_.tag.nameExact(Constants.catLevelOne).valueExact(CatLevelOne.SINKS.name)) + .l if (impactedApiCalls.nonEmpty) { - /* - Try if we can get to a node in the methodNode which points to a property node, and matches the api regex on the properties value - */ - val matchingProperties = methodNode.ast.originalProperty.value(apiMatchingRegex).dedup.l - if (matchingProperties.nonEmpty) { - matchingProperties.foreach { propertyNode => - val domain = getDomainFromString(propertyNode.value) - impactedApiCalls.foreach { apiCall => - tagAPIWithDomainAndUpdateRuleCache( - builder, - thirdPartyRuleInfo.get, - ruleCache, - domain, - apiCall, - propertyNode - ) - storeForTag(builder, apiCall, ruleCache)(InternalTag.API_SINK_MARKED.toString) - storeForTag(builder, apiCall, ruleCache)(InternalTag.API_URL_MARKED.toString) - } + + nodeToComputeUrl match { + case lit: Literal => + extractUrlFromLiteralAndTag(cpg, lit, builder, impactedApiCalls, thirdPartyRuleInfo, ruleCache) + case method: Method => + extractUrlFromMethodAndTag( + cpg, + method, + builder, + impactedApiCalls, + apiMatchingRegex, + thirdPartyRuleInfo, + ruleCache + ) + } + + } + } + + private def extractUrlFromLiteralAndTag( + cpg: Cpg, + literalNode: Literal, + builder: DiffGraphBuilder, + impactedApiCalls: List[Call], + thirdPartyRuleInfo: Option[RuleInfo], + ruleCache: RuleCache + ): Unit = { + + val domain = resolveDomainFromSource(literalNode) + impactedApiCalls.foreach { apiCall => + tagAPIWithDomainAndUpdateRuleCache(builder, thirdPartyRuleInfo.get, ruleCache, domain, apiCall, literalNode) + storeForTag(builder, apiCall, ruleCache)(InternalTag.API_SINK_MARKED.toString) + storeForTag(builder, apiCall, ruleCache)(InternalTag.API_URL_MARKED.toString) + } + } + + private def extractUrlFromMethodAndTag( + cpg: Cpg, + methodNode: Method, + builder: DiffGraphBuilder, + impactedApiCalls: List[Call], + apiMatchingRegex: String, + thirdPartyRuleInfo: Option[RuleInfo], + ruleCache: RuleCache + ): Unit = { + /* + Try if we can get to a node in the methodNode which points to a property node, and matches the api regex on the properties value + */ + val matchingProperties = methodNode.ast.originalProperty.value(apiMatchingRegex).dedup.l + if (matchingProperties.nonEmpty) { + matchingProperties.foreach { propertyNode => + val domain = getDomainFromString(propertyNode.value) + impactedApiCalls.foreach { apiCall => + tagAPIWithDomainAndUpdateRuleCache(builder, thirdPartyRuleInfo.get, ruleCache, domain, apiCall, propertyNode) + storeForTag(builder, apiCall, ruleCache)(InternalTag.API_SINK_MARKED.toString) + storeForTag(builder, apiCall, ruleCache)(InternalTag.API_URL_MARKED.toString) } - } else { - /* Try fetching the url from the injection happening via Named annotation - Looks for parameters marked with @Named annotation, try getting to the binding and resolve the api url, - If we are able to resolve the api url use that or else, return the matching parameterAssign node's code - */ - val apiVariableRegex = ruleCache.getSystemConfigByKey(Constants.apiIdentifier) - val endpointNode = Try { - val namedUrlNode = methodNode.parameter.annotation.name("Named").parameterAssign.code(apiVariableRegex).head + } + } else { + /* Try fetching the url from the injection happening via Named annotation + Looks for parameters marked with @Named annotation, try getting to the binding and resolve the api url, + If we are able to resolve the api url use that or else, return the matching parameterAssign node's code + */ + val apiVariableRegex = ruleCache.getSystemConfigByKey(Constants.apiIdentifier) + val namedUrlNode = Try( + methodNode.parameter.annotation + .name("Named") + .parameterAssign + .code(apiVariableRegex) + .headOption + .getOrElse(methodNode.parameter.annotation.name("Named").code(s"@Named($apiVariableRegex)").head) + ).toOption + val endpointNode = { + val endpointNodeOption = Try { val fieldAccessNode = cpg .call("named") .whereNot(_.file.name(".*Mock.*")) - .where(_.argument.code(namedUrlNode.code)) + .where(_.argument.code(namedUrlNode.get.code)) .inCall .inCall .argument @@ -85,62 +134,63 @@ object Utility { else if (lastArg.isLiteral) // Being literal point to a url lastArg else // Return the parameterAssign node, which is the value inside the @Named annotation, Ex- @Name(ConfigKeys.MY_SERVICE_ENDPOINT), returns ConfigKeys.MY_SERVICE_ENDPOINT - namedUrlNode + namedUrlNode.get } endpointNode }.toOption - if (endpointNode.isDefined) { - val domain = resolveDomainFromSource(endpointNode.get) + + if endpointNodeOption.isDefined then endpointNodeOption else namedUrlNode + } + if (endpointNode.isDefined) { + val domain = resolveDomainFromSource(endpointNode.get) + impactedApiCalls.foreach { apiCall => + tagAPIWithDomainAndUpdateRuleCache( + builder, + thirdPartyRuleInfo.get, + ruleCache, + domain, + apiCall, + endpointNode.get + ) + storeForTag(builder, apiCall, ruleCache)(InternalTag.API_SINK_MARKED.toString) + storeForTag(builder, apiCall, ruleCache)(InternalTag.API_URL_MARKED.toString) + } + + } else { // There is no property node available to be used, try matching against the parameter name + val variableRegex = ruleCache.getSystemConfigByKey(Constants.apiIdentifier) + val matchingParameters = methodNode.parameter.name(variableRegex).l + + if (matchingParameters.nonEmpty) { + val parameter = + matchingParameters.head // Pick only the first parameter as we don't want to tag same sink with multiple API's + val domain = resolveDomainFromSource(parameter) impactedApiCalls.foreach { apiCall => - tagAPIWithDomainAndUpdateRuleCache( - builder, - thirdPartyRuleInfo.get, - ruleCache, - domain, - apiCall, - endpointNode.get - ) + tagAPIWithDomainAndUpdateRuleCache(builder, thirdPartyRuleInfo.get, ruleCache, domain, apiCall, parameter) storeForTag(builder, apiCall, ruleCache)(InternalTag.API_SINK_MARKED.toString) storeForTag(builder, apiCall, ruleCache)(InternalTag.API_URL_MARKED.toString) } - } else { // There is no property node available to be used, try matching against the parameter name - val variableRegex = ruleCache.getSystemConfigByKey(Constants.apiIdentifier) - val matchingParameters = methodNode.parameter.name(variableRegex).l - - if (matchingParameters.nonEmpty) { - val parameter = - matchingParameters.head // Pick only the first parameter as we don't want to tag same sink with multiple API's - val domain = resolveDomainFromSource(parameter) + } else { // There is no matching parameter to be used, try matching against the identifier name + val matchingIdentifiers = methodNode.ast.isIdentifier.name(variableRegex).l + if (matchingIdentifiers.nonEmpty) { + val identifier = + matchingIdentifiers.head // Pick only the first identifier as we don't want to tag same sink with multiple API's + val domain = resolveDomainFromSource(identifier) impactedApiCalls.foreach { apiCall => - tagAPIWithDomainAndUpdateRuleCache(builder, thirdPartyRuleInfo.get, ruleCache, domain, apiCall, parameter) + tagAPIWithDomainAndUpdateRuleCache( + builder, + thirdPartyRuleInfo.get, + ruleCache, + domain, + apiCall, + identifier + ) storeForTag(builder, apiCall, ruleCache)(InternalTag.API_SINK_MARKED.toString) storeForTag(builder, apiCall, ruleCache)(InternalTag.API_URL_MARKED.toString) } - - } else { // There is no matching parameter to be used, try matching against the identifier name - val matchingIdentifiers = methodNode.ast.isIdentifier.name(variableRegex).l - if (matchingIdentifiers.nonEmpty) { - val identifier = - matchingIdentifiers.head // Pick only the first identifier as we don't want to tag same sink with multiple API's - val domain = resolveDomainFromSource(identifier) - impactedApiCalls.foreach { apiCall => - tagAPIWithDomainAndUpdateRuleCache( - builder, - thirdPartyRuleInfo.get, - ruleCache, - domain, - apiCall, - identifier - ) - storeForTag(builder, apiCall, ruleCache)(InternalTag.API_SINK_MARKED.toString) - storeForTag(builder, apiCall, ruleCache)(InternalTag.API_URL_MARKED.toString) - } - } } } } } } - } diff --git a/src/main/scala/ai/privado/languageEngine/kotlin/tagger/PrivadoTagger.scala b/src/main/scala/ai/privado/languageEngine/kotlin/tagger/PrivadoTagger.scala index 70e21e3e5..373e4ec40 100644 --- a/src/main/scala/ai/privado/languageEngine/kotlin/tagger/PrivadoTagger.scala +++ b/src/main/scala/ai/privado/languageEngine/kotlin/tagger/PrivadoTagger.scala @@ -6,7 +6,8 @@ import ai.privado.feeder.PermissionSourceRule import ai.privado.languageEngine.java.feeder.StorageInheritRule import ai.privado.languageEngine.java.tagger.collection.{CollectionTagger, MethodFullNameCollectionTagger} import ai.privado.languageEngine.java.tagger.config.JavaDBConfigTagger -import ai.privado.languageEngine.java.tagger.sink.{InheritMethodTagger, JavaAPITagger} +import ai.privado.languageEngine.java.tagger.sink.api.{JavaAPISinkTagger, JavaAPITagger} +import ai.privado.languageEngine.java.tagger.sink.InheritMethodTagger import ai.privado.languageEngine.java.tagger.source.* import ai.privado.languageEngine.kotlin.feeder.StorageAnnotationRule import ai.privado.languageEngine.kotlin.tagger.sink.StorageAnnotationTagger @@ -58,7 +59,7 @@ class PrivadoTagger(cpg: Cpg) extends PrivadoBaseTagger { new StorageAnnotationTagger(cpg, ruleCache).createAndApply() } - new APITagger(cpg, ruleCache, privadoInputConfig, appCache = appCache).createAndApply() + JavaAPISinkTagger.applyTagger(cpg, ruleCache, privadoInputConfig, appCache) new AndroidCollectionTagger( cpg, diff --git a/src/main/scala/ai/privado/model/Config.scala b/src/main/scala/ai/privado/model/Config.scala index 24c4c9564..73f402124 100644 --- a/src/main/scala/ai/privado/model/Config.scala +++ b/src/main/scala/ai/privado/model/Config.scala @@ -35,15 +35,15 @@ case class RuleInfo( filterProperty: FilterProperty.FilterProperty, domains: Array[String], patterns: List[String], - isSensitive: Boolean, - sensitivity: String, - tags: Map[String, String], - nodeType: NodeType.NodeType, - file: String, - catLevelOne: CatLevelOne.CatLevelOne, - catLevelTwo: String, - language: Language.Language, - categoryTree: Array[String], + isSensitive: Boolean = false, + sensitivity: String = "", + tags: Map[String, String] = Map(), + nodeType: NodeType.NodeType = NodeType.UNKNOWN, + file: String = "", + catLevelOne: CatLevelOne.CatLevelOne = CatLevelOne.UNKNOWN, + catLevelTwo: String = "", + language: Language.Language = Language.UNKNOWN, + categoryTree: Array[String] = Array(), isGenerated: Boolean = false // mark this true, if the rule is generated by privado-core ) { def combinedRulePattern: String = { @@ -60,7 +60,8 @@ case class ConfigAndRules( semantics: List[Semantic] = List(), sinkSkipList: List[RuleInfo] = List(), systemConfig: List[SystemConfig] = List(), - auditConfig: List[RuleInfo] = List() + auditConfig: List[RuleInfo] = List(), + inferences: List[RuleInfo] = List() ) case class AllowedSourceFilters(sources: List[String]) @@ -110,9 +111,9 @@ case class Semantic( case class SystemConfig( key: String, value: String, - language: Language.Language, - file: String, - categoryTree: Array[String] + language: Language.Language = Language.UNKNOWN, + file: String = "", + categoryTree: Array[String] = Array() ) object CirceEnDe { @@ -274,6 +275,7 @@ object CirceEnDe { val sinkSkipList = c.downField(Constants.sinkSkipList).as[List[RuleInfo]] val systemConfig = c.downField(Constants.systemConfig).as[List[SystemConfig]] val auditConfig = c.downField(Constants.auditConfig).as[List[RuleInfo]] + val inferences = c.downField(Constants.inferences).as[List[RuleInfo]] Right( ConfigAndRules( sources = sources.getOrElse(List[RuleInfo]()), @@ -285,7 +287,8 @@ object CirceEnDe { semantics = semantics.getOrElse(List[Semantic]()), sinkSkipList = sinkSkipList.getOrElse(List[RuleInfo]()), systemConfig = systemConfig.getOrElse(List[SystemConfig]()), - auditConfig = auditConfig.getOrElse(List[RuleInfo]()) + auditConfig = auditConfig.getOrElse(List[RuleInfo]()), + inferences = inferences.getOrElse(List[RuleInfo]()) ) ) } diff --git a/src/main/scala/ai/privado/model/Constants.scala b/src/main/scala/ai/privado/model/Constants.scala index 963a54211..e1c7ff18c 100644 --- a/src/main/scala/ai/privado/model/Constants.scala +++ b/src/main/scala/ai/privado/model/Constants.scala @@ -108,6 +108,7 @@ object Constants { val sinkType = "sinkType" val collectionFilters = "collectionFilters" val collectionType = "collectionType" + val inferences = "inferences" val localScanPath = "localScanPath" val processing = "processing" val sinkProcessing = "sinkProcessing" @@ -188,8 +189,10 @@ object Constants { val UnknownDomain = "unknown-domain" val Unknown = "unknown" + // catlevelTwo val annotations = "annotations" val default = "default" + val apiEndpoint = "apiEndpoint" val semanticDelimeter = "_A_" val thisConstant = "this" diff --git a/src/main/scala/ai/privado/model/PrivadoTag.scala b/src/main/scala/ai/privado/model/PrivadoTag.scala index 2beb744c2..147b17c8e 100644 --- a/src/main/scala/ai/privado/model/PrivadoTag.scala +++ b/src/main/scala/ai/privado/model/PrivadoTag.scala @@ -22,6 +22,7 @@ package ai.privado.model +import ai.privado.model import ai.privado.model.Language.{JAVA, Language} object InternalTag extends Enumeration { @@ -80,6 +81,7 @@ object CatLevelOne extends Enumeration { val COLLECTIONS = CatLevelOneIn("collections", "Collections") val POLICIES = CatLevelOneIn("policies", "Policies") val THREATS = CatLevelOneIn("threats", "Threats") + val INFERENCES = CatLevelOneIn("inferences", "Inferences") val UNKNOWN = CatLevelOneIn("unknown", "Unknown") // internal CatLevelOne @@ -209,8 +211,12 @@ object ConfigRuleType extends Enumeration { object FilterProperty extends Enumeration { type FilterProperty = Value - val METHOD_FULL_NAME = Value("method_full_name") - val CODE = Value("code") + val METHOD_FULL_NAME: model.FilterProperty.Value = Value("method_full_name") + val CODE: model.FilterProperty.Value = Value("code") + + // For Inference API Endpoint mapping + val METHOD_FULL_NAME_WITH_LITERAL: model.FilterProperty.Value = Value("method_full_name_with_literal") + val METHOD_FULL_NAME_WITH_PROPERTY_NAME: model.FilterProperty.Value = Value("method_full_name_with_property_name") def withNameWithDefault(name: String): Value = { try { diff --git a/src/main/scala/ai/privado/rulevalidator/YamlFileValidator.scala b/src/main/scala/ai/privado/rulevalidator/YamlFileValidator.scala index c47815ecc..06c25a610 100644 --- a/src/main/scala/ai/privado/rulevalidator/YamlFileValidator.scala +++ b/src/main/scala/ai/privado/rulevalidator/YamlFileValidator.scala @@ -34,6 +34,8 @@ object YamlFileValidator { private val THREATS = Source.fromInputStream(getClass.getResourceAsStream(s"${SCHEMA_DIR_PATH}threats.json")).mkString private val COLLECTIONS = Source.fromInputStream(getClass.getResourceAsStream(s"${SCHEMA_DIR_PATH}collections.json")).mkString + private val INFERENCES = + Source.fromInputStream(getClass.getResourceAsStream(s"${SCHEMA_DIR_PATH}inferences.json")).mkString private val EXCLUSIONS = Source.fromInputStream(getClass.getResourceAsStream(s"${SCHEMA_DIR_PATH}exclusions.json")).mkString private val SEMANTICS = @@ -178,6 +180,7 @@ object YamlFileValidator { case CatLevelOne.THREATS => Right(CatLevelOne.THREATS.name, THREATS) case CatLevelOne.COLLECTIONS => Right(CatLevelOne.COLLECTIONS.name, COLLECTIONS) case CatLevelOne.SINKS => Right(CatLevelOne.SINKS.name, SINKS) + case CatLevelOne.INFERENCES => Right(CatLevelOne.INFERENCES.name, INFERENCES) case _ => matchSchemaConfigFile(ruleFile, ruleJsonTree, callerCommand) } diff --git a/src/main/scala/ai/privado/tagger/sink/APITagger.scala b/src/main/scala/ai/privado/tagger/sink/APITagger.scala index c07070b80..259917a74 100644 --- a/src/main/scala/ai/privado/tagger/sink/APITagger.scala +++ b/src/main/scala/ai/privado/tagger/sink/APITagger.scala @@ -27,7 +27,7 @@ import ai.privado.cache.{AppCache, RuleCache} import ai.privado.entrypoint.{PrivadoInput, ScanProcessor} import ai.privado.languageEngine.java.language.{NodeStarters, StepsForProperty} import ai.privado.languageEngine.java.semantic.JavaSemanticGenerator -import ai.privado.model.{Constants, Language, NodeType, RuleInfo} +import ai.privado.model.{CatLevelOne, Constants, Language, NodeType, RuleInfo} import ai.privado.tagger.PrivadoParallelCpgPass import ai.privado.tagger.utility.APITaggerUtility.{SERVICE_URL_REGEX_PATTERN, sinkTagger} import ai.privado.utility.Utilities @@ -50,6 +50,7 @@ class APITagger(cpg: Cpg, ruleCache: RuleCache, privadoInput: PrivadoInput, appC .name(APISINKS_REGEX) .methodFullNameNot(COMMON_IGNORED_SINKS_REGEX) .methodFullName(commonHttpPackages) + .whereNot(_.tag.nameExact(Constants.nodeType).valueExact(NodeType.API.toString)) .l implicit val engineContext: EngineContext = diff --git a/src/main/scala/ai/privado/tagger/sink/api/APISinkTagger.scala b/src/main/scala/ai/privado/tagger/sink/api/APISinkTagger.scala index f99f41950..14985501e 100644 --- a/src/main/scala/ai/privado/tagger/sink/api/APISinkTagger.scala +++ b/src/main/scala/ai/privado/tagger/sink/api/APISinkTagger.scala @@ -1,11 +1,13 @@ package ai.privado.tagger.sink.api -import ai.privado.cache.RuleCache +import ai.privado.cache.{AppCache, RuleCache} import ai.privado.entrypoint.PrivadoInput import io.shiftleft.codepropertygraph.generated.Cpg trait APISinkTagger { - def applyTagger(cpg: Cpg, ruleCache: RuleCache, privadoInput: PrivadoInput): Unit = ??? + def applyTagger(cpg: Cpg, ruleCache: RuleCache, privadoInput: PrivadoInput, appCache: AppCache): Unit = { + new InferenceAPIEndpointTagger(cpg, ruleCache).createAndApply() + } } diff --git a/src/main/scala/ai/privado/tagger/sink/api/InferenceAPIEndpointTagger.scala b/src/main/scala/ai/privado/tagger/sink/api/InferenceAPIEndpointTagger.scala new file mode 100644 index 000000000..c886461a2 --- /dev/null +++ b/src/main/scala/ai/privado/tagger/sink/api/InferenceAPIEndpointTagger.scala @@ -0,0 +1,72 @@ +package ai.privado.tagger.sink.api + +import ai.privado.cache.RuleCache +import ai.privado.languageEngine.java.language.* +import ai.privado.model.FilterProperty.* +import ai.privado.model.{Constants, InternalTag, RuleInfo} +import ai.privado.tagger.PrivadoParallelCpgPass +import ai.privado.tagger.utility.APITaggerUtility.tagAPIWithDomainAndUpdateRuleCache +import ai.privado.utility.Utilities.{getDomainFromString, storeForTag} +import io.shiftleft.codepropertygraph.generated.Cpg +import io.shiftleft.codepropertygraph.generated.nodes.AstNode +import io.shiftleft.semanticcpg.language.* +import org.slf4j.LoggerFactory + +/** Read the inference rule defined in passed ruleCache, and tag the API Sink and the corresponding API Endpoint + * @param cpg + * @param ruleCache + */ +class InferenceAPIEndpointTagger(cpg: Cpg, ruleCache: RuleCache) extends PrivadoParallelCpgPass[RuleInfo](cpg) { + + private val logger = LoggerFactory.getLogger(getClass) + private val thirdPartyRuleInfo = ruleCache.getRuleInfo(Constants.thirdPartiesAPIRuleId) + + override def generateParts(): Array[RuleInfo] = + ruleCache.getRule.inferences.filter(_.catLevelTwo.equals(Constants.apiEndpoint)).toArray + + override def runOnPart(builder: DiffGraphBuilder, ruleInfo: RuleInfo): Unit = { + + val domain = ruleInfo.domains.headOption.getOrElse("") + + ruleInfo.filterProperty match { + case METHOD_FULL_NAME_WITH_LITERAL => + val apiSinks = cpg.call + .or( + _.methodFullName(ruleInfo.combinedRulePattern), + _.filter(_.dynamicTypeHintFullName.exists(_.matches(ruleInfo.combinedRulePattern))) + ) + .l + + apiSinks.foreach { apiCall => tagNode(builder, apiCall, domain) } + + case METHOD_FULL_NAME_WITH_PROPERTY_NAME => + val apiUrlFromProperty = cpg.property.name(domain).value.headOption + + if (apiUrlFromProperty.isDefined) { + val apiSinks = cpg.call + .or( + _.methodFullName(ruleInfo.combinedRulePattern), + _.filter(_.dynamicTypeHintFullName.exists(_.matches(ruleInfo.combinedRulePattern))) + ) + .l + + apiSinks.foreach { apiCall => tagNode(builder, apiCall, apiUrlFromProperty.get) } + } + + case _ => + } + } + + private def tagNode(builder: DiffGraphBuilder, apiCall: AstNode, apiUrl: String) = { + tagAPIWithDomainAndUpdateRuleCache( + builder, + thirdPartyRuleInfo.get, + ruleCache, + getDomainFromString(apiUrl), + apiCall, + apiUrl + ) + storeForTag(builder, apiCall, ruleCache)(InternalTag.API_SINK_MARKED.toString) + storeForTag(builder, apiCall, ruleCache)(InternalTag.API_URL_MARKED.toString) + } +} diff --git a/src/main/scala/ai/privado/tagger/utility/APITaggerUtility.scala b/src/main/scala/ai/privado/tagger/utility/APITaggerUtility.scala index 3019ddcf0..aa43706a3 100644 --- a/src/main/scala/ai/privado/tagger/utility/APITaggerUtility.scala +++ b/src/main/scala/ai/privado/tagger/utility/APITaggerUtility.scala @@ -130,11 +130,26 @@ object APITaggerUtility { domain: String, apiNode: AstNode, apiUrlNode: AstNode - ) = { + ): DiffGraphBuilder = { val newRuleIdToUse = ruleInfo.id + "." + domain ruleCache.setRuleInfo(ruleInfo.copy(id = newRuleIdToUse, name = ruleInfo.name + " " + domain, isGenerated = true)) addRuleTags(builder, apiNode, ruleInfo, ruleCache, Some(newRuleIdToUse)) storeForTag(builder, apiNode, ruleCache)(Constants.apiUrl + newRuleIdToUse, getLiteralCode(apiUrlNode)) } + + def tagAPIWithDomainAndUpdateRuleCache( + builder: DiffGraphBuilder, + ruleInfo: RuleInfo, + ruleCache: RuleCache, + domain: String, + apiNode: AstNode, + apiUrl: String + ): DiffGraphBuilder = { + val newRuleIdToUse = ruleInfo.id + "." + domain + ruleCache.setRuleInfo(ruleInfo.copy(id = newRuleIdToUse, name = ruleInfo.name + " " + domain, isGenerated = true)) + addRuleTags(builder, apiNode, ruleInfo, ruleCache, Some(newRuleIdToUse)) + storeForTag(builder, apiNode, ruleCache)(Constants.apiUrl + newRuleIdToUse, apiUrl) + + } } diff --git a/src/test/scala/ai/privado/languageEngine/java/passes/config/JavaYamlLinkerPassTest.scala b/src/test/scala/ai/privado/languageEngine/java/passes/config/JavaYamlLinkerPassTest.scala index 79c6980bd..97f7ffc5a 100644 --- a/src/test/scala/ai/privado/languageEngine/java/passes/config/JavaYamlLinkerPassTest.scala +++ b/src/test/scala/ai/privado/languageEngine/java/passes/config/JavaYamlLinkerPassTest.scala @@ -3,7 +3,7 @@ package ai.privado.languageEngine.java.passes.config import ai.privado.cache.{AppCache, RuleCache, TaggerCache} import ai.privado.entrypoint.PrivadoInput import ai.privado.languageEngine.java.language.* -import ai.privado.languageEngine.java.tagger.sink.JavaAPITagger +import ai.privado.languageEngine.java.tagger.sink.api.JavaAPITagger import ai.privado.languageEngine.java.tagger.source.* import ai.privado.model.* import ai.privado.utility.PropertyParserPass diff --git a/src/test/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPIRetrofitTaggerTest.scala b/src/test/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPIRetrofitTaggerTest.scala new file mode 100644 index 000000000..46998132e --- /dev/null +++ b/src/test/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPIRetrofitTaggerTest.scala @@ -0,0 +1,255 @@ +package ai.privado.languageEngine.java.tagger.sink.api + +import ai.privado.cache.RuleCache +import ai.privado.model.{CatLevelOne, Constants, Language, NodeType, SystemConfig} +import ai.privado.rule.RuleInfoTestData +import ai.privado.tagger.sink.api.APIValidator +import ai.privado.testfixtures.JavaFrontendTestSuite +import io.shiftleft.semanticcpg.language.* + +class JavaAPIRetrofitTaggerTest extends JavaFrontendTestSuite with APIValidator { + + "Java api retrofit tagger (case - finding url by literal)" should { + val cpg = code( + """ + |import retrofit2.Call; + |import retrofit2.http.GET; + |import java.util.List; + | + |public interface UserService { + | @GET("users") + | Call> getUsers(); + |} + | + |""".stripMargin, + "UserService.java" + ).moreCode( + """ + |import retrofit2.Retrofit; + |import retrofit2.converter.gson.GsonConverterFactory; + | + |public class RetrofitClient { + | + | public static UserService getUserService() { + | return new Retrofit.Builder() + | .baseUrl("https://api.example.com/") + | .addConverterFactory(GsonConverterFactory.create()) + | .build() + | .create(UserService.class); + | } + |} + |""".stripMargin, + "RetrofitClient.java" + ).moreCode( + """ + |import retrofit2.Call; + |import retrofit2.Callback; + |import retrofit2.Response; + | + |public class Main { + | public static void main(String[] args) { + | UserService userService = RetrofitClient.getUserService(); + | + | Call> call = userService.getUsers(); + | call.enqueue(new Callback>() { + | @Override + | public void onResponse(Call> call, Response> response) { + | if (response.isSuccessful()) { + | List users = response.body(); + | for (User user : users) { + | System.out.println(user.getName()); + | } + | } else { + | System.out.println("Failed to fetch users: " + response.message()); + | } + | } + | + | @Override + | public void onFailure(Call> call, Throwable t) { + | System.out.println("Failed to fetch users: " + t.getMessage()); + | } + | }); + | } + |} + | + |""".stripMargin, + "Main.java" + ) + + "tag retrofit sink as api sink" in { + val List(getUserCall) = cpg.call("getUsers").l + assertAPISinkCall(getUserCall) + } + + "tag retrofit sink with endpoint" in { + val List(getUserCall) = cpg.call("getUsers").l + assertAPIEndpointURL(getUserCall, "\"https://api.example.com/\"") + } + } + + "Java api retrofit tagger (case - finding url by method parameter regex match)" should { + + val systemConfig = List(SystemConfig(Constants.apiIdentifier, "(?i).*endpoint", Language.JAVA, "", Array())) + val ruleCache = RuleCache().setRule(RuleInfoTestData.rule.copy(systemConfig = systemConfig)) + + val cpg = code( + """ + |import retrofit2.Call; + |import retrofit2.http.GET; + |import java.util.List; + | + |public interface UserService { + | @GET("users") + | Call> getUsers(); + |} + | + |""".stripMargin, + "UserService.java" + ).moreCode( + """ + |import retrofit2.Retrofit; + |import retrofit2.converter.gson.GsonConverterFactory; + | + |public class RetrofitClient { + | + | public static UserService getUserService(String exampleAPIEndpoint) { + | return new Retrofit.Builder() + | .baseUrl(exampleAPIEndpoint) + | .addConverterFactory(GsonConverterFactory.create()) + | .build() + | .create(UserService.class); + | } + |} + |""".stripMargin, + "RetrofitClient.java" + ).moreCode( + """ + |import retrofit2.Call; + |import retrofit2.Callback; + |import retrofit2.Response; + | + |public class Main { + | public static void main(String[] args) { + | UserService userService = RetrofitClient.getUserService("someXYZEndpoint"); + | + | Call> call = userService.getUsers(); + | call.enqueue(new Callback>() { + | @Override + | public void onResponse(Call> call, Response> response) { + | if (response.isSuccessful()) { + | List users = response.body(); + | for (User user : users) { + | System.out.println(user.getName()); + | } + | } else { + | System.out.println("Failed to fetch users: " + response.message()); + | } + | } + | + | @Override + | public void onFailure(Call> call, Throwable t) { + | System.out.println("Failed to fetch users: " + t.getMessage()); + | } + | }); + | } + |} + | + |""".stripMargin, + "Main.java" + ).withRuleCache(ruleCache) + + "tag retrofit sink as api sink" in { + val List(getUserCall) = cpg.call("getUsers").l + assertAPISinkCall(getUserCall) + } + + "tag retrofit sink with endpoint" in { + val List(getUserCall) = cpg.call("getUsers").l + assertAPIEndpointURL(getUserCall, "exampleAPIEndpoint") + } + } + + "Java api retrofit tagger (case - finding url by matching identifier in method)" should { + + val systemConfig = List(SystemConfig(Constants.apiIdentifier, "(?i).*endpoint", Language.JAVA, "", Array())) + val ruleCache = RuleCache().setRule(RuleInfoTestData.rule.copy(systemConfig = systemConfig)) + + val cpg = code( + """ + |import retrofit2.Call; + |import retrofit2.http.GET; + |import java.util.List; + | + |public interface UserService { + | @GET("users") + | Call> getUsers(); + |} + | + |""".stripMargin, + "UserService.java" + ).moreCode( + """ + |import retrofit2.Retrofit; + |import retrofit2.converter.gson.GsonConverterFactory; + | + |public class RetrofitClient { + | + | public static UserService getUserService() { + | String exampleAPIEndpoint = "someXYZEndpoint"; + | + | return new Retrofit.Builder() + | .baseUrl(exampleAPIEndpoint) + | .addConverterFactory(GsonConverterFactory.create()) + | .build() + | .create(UserService.class); + | } + |} + |""".stripMargin, + "RetrofitClient.java" + ).moreCode( + """ + |import retrofit2.Call; + |import retrofit2.Callback; + |import retrofit2.Response; + | + |public class Main { + | public static void main(String[] args) { + | UserService userService = RetrofitClient.getUserService(); + | + | Call> call = userService.getUsers(); + | call.enqueue(new Callback>() { + | @Override + | public void onResponse(Call> call, Response> response) { + | if (response.isSuccessful()) { + | List users = response.body(); + | for (User user : users) { + | System.out.println(user.getName()); + | } + | } else { + | System.out.println("Failed to fetch users: " + response.message()); + | } + | } + | + | @Override + | public void onFailure(Call> call, Throwable t) { + | System.out.println("Failed to fetch users: " + t.getMessage()); + | } + | }); + | } + |} + | + |""".stripMargin, + "Main.java" + ).withRuleCache(ruleCache) + + "tag retrofit sink as api sink" in { + val List(getUserCall) = cpg.call("getUsers").l + assertAPISinkCall(getUserCall) + } + + "tag retrofit sink with endpoint" in { + val List(getUserCall) = cpg.call("getUsers").l + assertAPIEndpointURL(getUserCall, "exampleAPIEndpoint") + } + } +} diff --git a/src/test/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkByMethodFullNameTaggerTest.scala b/src/test/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkByMethodFullNameTaggerTest.scala index db10fce78..8e1af7e48 100644 --- a/src/test/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkByMethodFullNameTaggerTest.scala +++ b/src/test/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkByMethodFullNameTaggerTest.scala @@ -1,6 +1,6 @@ package ai.privado.languageEngine.java.tagger.sink.api -import ai.privado.cache.RuleCache +import ai.privado.cache.{AppCache, RuleCache} import ai.privado.entrypoint.PrivadoInput import org.scalatest.BeforeAndAfterAll import org.scalatest.matchers.should.Matchers @@ -63,7 +63,7 @@ class JavaAPISinkByMethodFullNameTaggerTest extends AnyWordSpec with Matchers wi ) ) ) - JavaAPISinkTagger.applyTagger(cpg, ruleCache = ruleCache, privadoInput = PrivadoInput()) + JavaAPISinkTagger.applyTagger(cpg, ruleCache = ruleCache, privadoInput = PrivadoInput(), appCache = AppCache()) val apiSinks = cpg.call("getResponseCode").l diff --git a/src/test/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkByParameterTaggerTest.scala b/src/test/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkByParameterTaggerTest.scala index 54e19a051..b423b8fe2 100644 --- a/src/test/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkByParameterTaggerTest.scala +++ b/src/test/scala/ai/privado/languageEngine/java/tagger/sink/api/JavaAPISinkByParameterTaggerTest.scala @@ -6,7 +6,6 @@ import org.scalatest.BeforeAndAfterAll import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec import ai.privado.languageEngine.java.JavaTestBase.* -import ai.privado.languageEngine.java.tagger.sink.JavaAPITagger import ai.privado.model.{CatLevelOne, Constants, InternalTag, Language, NodeType, SourceCodeModel, SystemConfig} import ai.privado.rule.RuleInfoTestData import io.shiftleft.semanticcpg.language.* @@ -82,7 +81,7 @@ class JavaAPISinkByParameterTaggerTest extends AnyWordSpec with Matchers with Be val systemConfig = List(SystemConfig(Constants.apiIdentifier, "(?i).*endpoint.*", Language.UNKNOWN, "", Array[String]())) ruleCache.setRule(RuleInfoTestData.rule.copy(systemConfig = systemConfig)) - JavaAPISinkTagger.applyTagger(cpg, ruleCache = ruleCache, privadoInput = privadoInput) + JavaAPISinkTagger.applyTagger(cpg, ruleCache = ruleCache, privadoInput = privadoInput, appCache = AppCache()) new JavaAPITagger(cpg, ruleCache, privadoInputConfig = privadoInput, appCache = new AppCache()).createAndApply() diff --git a/src/test/scala/ai/privado/tagger/sink/api/APIValidator.scala b/src/test/scala/ai/privado/tagger/sink/api/APIValidator.scala new file mode 100644 index 000000000..bd439b0da --- /dev/null +++ b/src/test/scala/ai/privado/tagger/sink/api/APIValidator.scala @@ -0,0 +1,41 @@ +package ai.privado.tagger.sink.api + +import ai.privado.model.{CatLevelOne, Constants} +import io.shiftleft.codepropertygraph.generated.nodes.Call +import io.shiftleft.semanticcpg.language.* +import org.scalatest.matchers.should.Matchers +import ai.privado.model.NodeType +import org.scalatest.Assertion + +trait APIValidator extends Matchers { + + /** Asserts if the given callNode is an API sink node + * @param callNode + * \- Expected API call node + * @return + */ + def assertAPISinkCall(callNode: Call): Assertion = { + callNode.tag.nameExact(Constants.catLevelOne).valueExact(CatLevelOne.SINKS.name).size shouldBe 1 + callNode.tag.nameExact(Constants.nodeType).valueExact(NodeType.API.toString).size shouldBe 1 + } + + /** Asserts if the given callNode have the tagged API endpoint as apiUrl + * @param callNode + * \- API call node + * @param apiUrl + * \- Expected API URL on the call node + * @return + */ + def assertAPIEndpointURL(callNode: Call, apiUrl: String): Assertion = { + val domain = callNode.tag + .nameExact("third_partiesapi") + .value + .headOption + .getOrElse("") + .stripPrefix(s"${Constants.thirdPartiesAPIRuleId}.") + callNode.tag.nameExact(s"${Constants.apiUrl}${Constants.thirdPartiesAPIRuleId}.$domain").value.l shouldBe List( + apiUrl + ) + } + +} diff --git a/src/test/scala/ai/privado/tagger/sink/api/InferenceAPIEndpointTaggerTest.scala b/src/test/scala/ai/privado/tagger/sink/api/InferenceAPIEndpointTaggerTest.scala new file mode 100644 index 000000000..ecd37e8e1 --- /dev/null +++ b/src/test/scala/ai/privado/tagger/sink/api/InferenceAPIEndpointTaggerTest.scala @@ -0,0 +1,101 @@ +package ai.privado.tagger.sink.api + +import ai.privado.cache.RuleCache +import ai.privado.model.{CatLevelOne, Constants, FilterProperty, RuleInfo} +import ai.privado.rule.RuleInfoTestData +import ai.privado.testfixtures.JavaFrontendTestSuite +import io.shiftleft.semanticcpg.language._ + +class InferenceAPIEndpointTaggerTest extends JavaFrontendTestSuite with APIValidator { + + "Inference tagger when used against filterProperty method_full_name_with_literal" should { + + val inferences = List( + RuleInfo( + id = "Inferences.API.UserServiceClient.ByLiteral", + name = "User service client", + category = "", + filterProperty = FilterProperty.METHOD_FULL_NAME_WITH_LITERAL, + domains = Array("http://user-service.com"), + patterns = List(".*UserServiceClient[.]\\w+[:].*"), + catLevelOne = CatLevelOne.INFERENCES, + catLevelTwo = Constants.apiEndpoint + ) + ) + val ruleCache = RuleCache().setRule(RuleInfoTestData.rule.copy(inferences = inferences)) + + val cpg = code(""" + |import com.privado.clients.UserServiceClient; + |public class Main { + | public static void main(String[] args) { + | try { + | UserServiceClient client = new UserServiceClient(); + | String usersResponse = client.getUsers(); + | System.out.println(usersResponse); + | } catch (Exception e) { + | e.printStackTrace(); + | } + | } + |} + |""".stripMargin).withRuleCache(ruleCache) + + "tag the api sink matching fullName" in { + val List(getUsersCall) = cpg.call("getUsers").l + assertAPISinkCall(getUsersCall) + } + + "tag the api sink with passed literal" in { + val List(getUsersCall) = cpg.call("getUsers").l + assertAPIEndpointURL(getUsersCall, "http://user-service.com") + } + } + + "Inference tagger when used against filterProperty method_full_name_with_property_name" should { + + val inferences = List( + RuleInfo( + id = "Inferences.API.UserServiceClient.ByPropertyName", + name = "User service client", + category = "", + filterProperty = FilterProperty.METHOD_FULL_NAME_WITH_PROPERTY_NAME, + domains = Array(".*user.*url"), + patterns = List(".*UserServiceClient[.]\\w+[:].*"), + catLevelOne = CatLevelOne.INFERENCES, + catLevelTwo = Constants.apiEndpoint + ) + ) + val ruleCache = RuleCache().setRule(RuleInfoTestData.rule.copy(inferences = inferences)) + + val cpg = code(""" + |import com.privado.clients.UserServiceClient; + |public class Main { + | public static void main(String[] args) { + | try { + | UserServiceClient client = new UserServiceClient(); + | String usersResponse = client.getUsers(); + | System.out.println(usersResponse); + | } catch (Exception e) { + | e.printStackTrace(); + | } + | } + |} + |""".stripMargin) + .moreCode( + """ + |user.service.url = http://user-service.com + |""".stripMargin, + "application.properties" + ) + .withRuleCache(ruleCache) + + "tag the api sink matching fullName" in { + val List(getUsersCall) = cpg.call("getUsers").l + assertAPISinkCall(getUsersCall) + } + + "tag the api sink with passed literal" in { + val List(getUsersCall) = cpg.call("getUsers").l + assertAPIEndpointURL(getUsersCall, "http://user-service.com") + } + } +}