Skip to content

Commit

Permalink
Merge pull request #1078 from Privado-Inc/dev
Browse files Browse the repository at this point in the history
Release PR
  • Loading branch information
khemrajrathore authored Apr 22, 2024
2 parents 2abffe5 + 309893e commit 3628e36
Show file tree
Hide file tree
Showing 15 changed files with 238 additions and 122 deletions.
9 changes: 9 additions & 0 deletions src/main/scala/ai/privado/cache/RuleCache.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ class RuleCache {
}
}

def addThirdPartyRuleInfo(thirdPartyAPIRuleInfo: RuleInfo, domain: String): String = {
val newRuleIdToUse = s"${Constants.thirdPartiesAPIRuleId}.$domain"
this.setRuleInfo(
thirdPartyAPIRuleInfo
.copy(id = newRuleIdToUse, name = s"${thirdPartyAPIRuleInfo.name} $domain", isGenerated = true)
)
newRuleIdToUse
}

def addStorageRuleInfo(ruleInfo: RuleInfo): Unit = storageRuleInfo.addOne(ruleInfo)

def getStorageRuleInfo(): List[RuleInfo] = storageRuleInfo.toList
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ class JavaAPIRetrofitTagger(cpg: Cpg, ruleCache: RuleCache) extends PrivadoParal
private val apiMatchingRegex =
ruleCache.getAllRuleInfo.filter(_.nodeType == NodeType.API).map(_.combinedRulePattern).mkString("(", "|", ")")

private val thirdPartyRuleInfo = ruleCache.getRuleInfo(Constants.thirdPartiesAPIRuleId)

private val logger = getLogger(this.getClass)
override def generateParts(): Array[(Call, Call)] = {
cpg
Expand Down Expand Up @@ -46,19 +44,11 @@ class JavaAPIRetrofitTagger(cpg: Cpg, ruleCache: RuleCache) extends PrivadoParal
Try(baseUrlCall.argument.last).toOption match
case Some(lit: Literal) =>
// Mark the nodes as API sink
tagAPICallByItsUrlMethod(cpg, builder, lit, sinkCalls, apiMatchingRegex, thirdPartyRuleInfo, ruleCache)
tagAPICallByItsUrlMethod(cpg, builder, lit, sinkCalls, apiMatchingRegex, ruleCache)
case _ =>
Try {
val methodNode = createCall.method
tagAPICallByItsUrlMethod(
cpg,
builder,
methodNode,
sinkCalls,
apiMatchingRegex,
thirdPartyRuleInfo,
ruleCache
)
tagAPICallByItsUrlMethod(cpg, builder, methodNode, sinkCalls, apiMatchingRegex, ruleCache)
} match
case Failure(e) => logger.debug(s"Failed to get to a method node for retrofit create call")
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class JavaAPISinkByParameterMarkByAnnotationTagger(cpg: Cpg, ruleCache: RuleCach
private val apiMatchingRegex =
ruleCache.getAllRuleInfo.filter(_.nodeType == NodeType.API).map(_.combinedRulePattern).mkString("(", "|", ")")

private val thirdPartyRuleInfo = ruleCache.getRuleInfo(Constants.thirdPartiesAPIRuleId)
override def generateParts(): Array[(Method, List[String])] = {

val methodWithAnnotationCode = cpg.parameter
Expand Down Expand Up @@ -70,7 +69,7 @@ class JavaAPISinkByParameterMarkByAnnotationTagger(cpg: Cpg, ruleCache: RuleCach
.l

// Mark the nodes as API sink
tagAPICallByItsUrlMethod(cpg, builder, methodNode, sinkCalls, apiMatchingRegex, thirdPartyRuleInfo, ruleCache)
tagAPICallByItsUrlMethod(cpg, builder, methodNode, sinkCalls, apiMatchingRegex, ruleCache)

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class JavaAPISinkByParameterTagger(cpg: Cpg, ruleCache: RuleCache)
private val apiMatchingRegex =
ruleCache.getAllRuleInfo.filter(_.nodeType == NodeType.API).map(_.combinedRulePattern).mkString("(", "|", ")")

private val thirdPartyRuleInfo = ruleCache.getRuleInfo(Constants.thirdPartiesAPIRuleId)
override def generateParts(): Array[(Method, String)] = {

/* Below query looks for methods whose parameter names ends with `url|endpoint`,
Expand Down Expand Up @@ -96,7 +95,7 @@ class JavaAPISinkByParameterTagger(cpg: Cpg, ruleCache: RuleCache)
.l

// Mark the nodes as API sink
tagAPICallByItsUrlMethod(cpg, builder, methodNode, sinkCalls, apiMatchingRegex, thirdPartyRuleInfo, ruleCache)
tagAPICallByItsUrlMethod(cpg, builder, methodNode, sinkCalls, apiMatchingRegex, ruleCache)

}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import ai.privado.languageEngine.java.tagger.sink.api.Utility.tagAPICallByItsUrl
import ai.privado.tagger.utility.APITaggerUtility.{
getLiteralCode,
resolveDomainFromSource,
tagAPIWithDomainAndUpdateRuleCache
tagThirdPartyAPIWithDomainAndUpdateRuleCache
}
import ai.privado.utility.Utilities.{addRuleTags, getDomainFromString, storeForTag}
import io.shiftleft.codepropertygraph.generated.nodes.{AstNode, Call, Method}
Expand All @@ -23,7 +23,6 @@ class JavaAPISinkEndpointMapperByNonInitMethod(cpg: Cpg, ruleCache: RuleCache)
private val apiMatchingRegex =
ruleCache.getAllRuleInfo.filter(_.nodeType == NodeType.API).map(_.combinedRulePattern).mkString("(", "|", ")")

private val thirdPartyRuleInfo = ruleCache.getRuleInfo(Constants.thirdPartiesAPIRuleId)
override def generateParts(): Array[String] = {

/* General assumption - there is a function which creates a client, and the usage of the client and binding
Expand All @@ -32,20 +31,17 @@ class JavaAPISinkEndpointMapperByNonInitMethod(cpg: Cpg, ruleCache: RuleCache)
we can say the client uses the following endpoint
*/

if (thirdPartyRuleInfo.isDefined) {
cpg.call
.where(_.tag.nameExact(InternalTag.API_SINK_MARKED.toString))
.methodFullName
.map(_.split(methodFullNameSplitter).headOption.getOrElse(""))
.filter(_.nonEmpty)
.map { methodNamespace =>
val parts = methodNamespace.split("[.]")
if parts.nonEmpty then parts.dropRight(1).mkString(".") else ""
}
.dedup
.toArray
} else
Array[String]()
cpg.call
.where(_.tag.nameExact(InternalTag.API_SINK_MARKED.toString))
.methodFullName
.map(_.split(methodFullNameSplitter).headOption.getOrElse(""))
.filter(_.nonEmpty)
.map { methodNamespace =>
val parts = methodNamespace.split("[.]")
if parts.nonEmpty then parts.dropRight(1).mkString(".") else ""
}
.dedup
.toArray
}

override def runOnPart(builder: DiffGraphBuilder, typeFullName: String): Unit = {
Expand All @@ -57,15 +53,7 @@ class JavaAPISinkEndpointMapperByNonInitMethod(cpg: Cpg, ruleCache: RuleCache)
.where(_.tag.nameExact(InternalTag.API_SINK_MARKED.toString))
.l

tagAPICallByItsUrlMethod(
cpg,
builder,
clientReturningMethod,
impactedApiCalls,
apiMatchingRegex,
thirdPartyRuleInfo,
ruleCache
)
tagAPICallByItsUrlMethod(cpg, builder, clientReturningMethod, impactedApiCalls, apiMatchingRegex, ruleCache)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ class JavaAPITagger(cpg: Cpg, ruleCache: RuleCache, privadoInputConfig: PrivadoI
logger.debug("Using brute API Tagger to find API sinks")
println(s"${Calendar.getInstance().getTime} - --API TAGGER V1 invoked...")
sinkTagger(
cpg,
apiInternalSources ++ propertySources ++ identifierSource ++ serviceSource,
apis,
builder,
Expand All @@ -174,6 +175,7 @@ class JavaAPITagger(cpg: Cpg, ruleCache: RuleCache, privadoInputConfig: PrivadoI
privadoInputConfig.enableAPIDisplay
)
sinkTagger(
cpg,
apiInternalSources ++ propertySources ++ identifierSource ++ serviceSource,
feignAPISinks ++ grpcSinks ++ soapSinks ++ markedAPISinks,
builder,
Expand All @@ -185,6 +187,7 @@ class JavaAPITagger(cpg: Cpg, ruleCache: RuleCache, privadoInputConfig: PrivadoI
logger.debug("Using Enhanced API tagger to find API sinks")
println(s"${Calendar.getInstance().getTime} - --API TAGGER V2 invoked...")
sinkTagger(
cpg,
apiInternalSources ++ propertySources ++ identifierSource ++ serviceSource,
apis.methodFullName(commonHttpPackages).l ++ feignAPISinks ++ grpcSinks ++ soapSinks ++ markedAPISinks,
builder,
Expand All @@ -196,6 +199,7 @@ class JavaAPITagger(cpg: Cpg, ruleCache: RuleCache, privadoInputConfig: PrivadoI
logger.debug("Skipping API Tagger because valid match not found, only applying Feign client")
println(s"${Calendar.getInstance().getTime} - --API TAGGER SKIPPED, applying Feign client API...")
sinkTagger(
cpg,
apiInternalSources ++ propertySources ++ identifierSource ++ serviceSource,
feignAPISinks ++ grpcSinks ++ soapSinks ++ markedAPISinks,
builder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package ai.privado.languageEngine.java.tagger.sink.api

import ai.privado.cache.RuleCache
import ai.privado.model.{CatLevelOne, Constants, InternalTag, RuleInfo}
import ai.privado.tagger.utility.APITaggerUtility.{resolveDomainFromSource, tagAPIWithDomainAndUpdateRuleCache}
import ai.privado.tagger.utility.APITaggerUtility.{
resolveDomainFromSource,
tagThirdPartyAPIWithDomainAndUpdateRuleCache
}
import ai.privado.utility.Utilities.{getDomainFromString, storeForTag}
import io.shiftleft.codepropertygraph.generated.nodes.{AstNode, Call, Literal, Method}
import overflowdb.BatchedUpdate.DiffGraphBuilder
Expand All @@ -20,7 +23,6 @@ object Utility {
nodeToComputeUrl: AstNode,
apiCalls: List[Call],
apiMatchingRegex: String,
thirdPartyRuleInfo: Option[RuleInfo],
ruleCache: RuleCache
): Unit = {

Expand All @@ -33,17 +35,9 @@ object Utility {

nodeToComputeUrl match {
case lit: Literal =>
extractUrlFromLiteralAndTag(cpg, lit, builder, impactedApiCalls, thirdPartyRuleInfo, ruleCache)
extractUrlFromLiteralAndTag(cpg, lit, builder, impactedApiCalls, ruleCache)
case method: Method =>
extractUrlFromMethodAndTag(
cpg,
method,
builder,
impactedApiCalls,
apiMatchingRegex,
thirdPartyRuleInfo,
ruleCache
)
extractUrlFromMethodAndTag(cpg, method, builder, impactedApiCalls, apiMatchingRegex, ruleCache)
}

}
Expand All @@ -54,13 +48,12 @@ object Utility {
literalNode: Literal,
builder: DiffGraphBuilder,
impactedApiCalls: List[Call],
thirdPartyRuleInfo: Option[RuleInfo],
ruleCache: RuleCache
): Unit = {

val domain = resolveDomainFromSource(literalNode)
impactedApiCalls.foreach { apiCall =>
tagAPIWithDomainAndUpdateRuleCache(builder, thirdPartyRuleInfo.get, ruleCache, domain, apiCall, literalNode)
tagThirdPartyAPIWithDomainAndUpdateRuleCache(builder, cpg, ruleCache, domain, apiCall, literalNode)
storeForTag(builder, apiCall, ruleCache)(InternalTag.API_SINK_MARKED.toString)
storeForTag(builder, apiCall, ruleCache)(InternalTag.API_URL_MARKED.toString)
}
Expand All @@ -72,7 +65,6 @@ object Utility {
builder: DiffGraphBuilder,
impactedApiCalls: List[Call],
apiMatchingRegex: String,
thirdPartyRuleInfo: Option[RuleInfo],
ruleCache: RuleCache
): Unit = {
/*
Expand All @@ -83,7 +75,7 @@ object Utility {
matchingProperties.foreach { propertyNode =>
val domain = getDomainFromString(propertyNode.value)
impactedApiCalls.foreach { apiCall =>
tagAPIWithDomainAndUpdateRuleCache(builder, thirdPartyRuleInfo.get, ruleCache, domain, apiCall, propertyNode)
tagThirdPartyAPIWithDomainAndUpdateRuleCache(builder, cpg, ruleCache, domain, apiCall, propertyNode)
storeForTag(builder, apiCall, ruleCache)(InternalTag.API_SINK_MARKED.toString)
storeForTag(builder, apiCall, ruleCache)(InternalTag.API_URL_MARKED.toString)
}
Expand Down Expand Up @@ -144,14 +136,7 @@ object Utility {
if (endpointNode.isDefined) {
val domain = resolveDomainFromSource(endpointNode.get)
impactedApiCalls.foreach { apiCall =>
tagAPIWithDomainAndUpdateRuleCache(
builder,
thirdPartyRuleInfo.get,
ruleCache,
domain,
apiCall,
endpointNode.get
)
tagThirdPartyAPIWithDomainAndUpdateRuleCache(builder, cpg, ruleCache, domain, apiCall, endpointNode.get)
storeForTag(builder, apiCall, ruleCache)(InternalTag.API_SINK_MARKED.toString)
storeForTag(builder, apiCall, ruleCache)(InternalTag.API_URL_MARKED.toString)
}
Expand All @@ -165,7 +150,7 @@ object Utility {
matchingParameters.head // Pick only the first parameter as we don't want to tag same sink with multiple API's
val domain = resolveDomainFromSource(parameter)
impactedApiCalls.foreach { apiCall =>
tagAPIWithDomainAndUpdateRuleCache(builder, thirdPartyRuleInfo.get, ruleCache, domain, apiCall, parameter)
tagThirdPartyAPIWithDomainAndUpdateRuleCache(builder, cpg, ruleCache, domain, apiCall, parameter)
storeForTag(builder, apiCall, ruleCache)(InternalTag.API_SINK_MARKED.toString)
storeForTag(builder, apiCall, ruleCache)(InternalTag.API_URL_MARKED.toString)
}
Expand All @@ -177,14 +162,7 @@ object Utility {
matchingIdentifiers.head // Pick only the first identifier as we don't want to tag same sink with multiple API's
val domain = resolveDomainFromSource(identifier)
impactedApiCalls.foreach { apiCall =>
tagAPIWithDomainAndUpdateRuleCache(
builder,
thirdPartyRuleInfo.get,
ruleCache,
domain,
apiCall,
identifier
)
tagThirdPartyAPIWithDomainAndUpdateRuleCache(builder, cpg, ruleCache, domain, apiCall, identifier)
storeForTag(builder, apiCall, ruleCache)(InternalTag.API_SINK_MARKED.toString)
storeForTag(builder, apiCall, ruleCache)(InternalTag.API_URL_MARKED.toString)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class APITagger(cpg: Cpg, ruleCache: RuleCache, privadoInput: PrivadoInput, appC

logger.debug("Using Enhanced API tagger to find API sinks")
sinkTagger(
cpg,
apiInternalSources ++ propertySources ++ identifierSource,
httpApis.distinct,
builder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class PythonAPITagger(cpg: Cpg, ruleCache: RuleCache, privadoInput: PrivadoInput
logger.debug("Using Enhanced API tagger to find API sinks")
println(s"${Calendar.getInstance().getTime} - --API TAGGER Common HTTP Libraries Used...")
sinkTagger(
cpg,
apiInternalSources ++ propertySources ++ identifierSource,
apis.methodFullName(commonHttpPackages).l,
builder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class APITagger(cpg: Cpg, ruleCache: RuleCache, privadoInput: PrivadoInput, appC
logger.debug("Using Enhanced API tagger to find API sinks")
println(s"${Calendar.getInstance().getTime} - --API TAGGER Common HTTP Libraries Used...")
sinkTagger(
cpg,
apiInternalSources ++ propertySources ++ identifierSource,
(httpApis ++ clientLikeApis).distinct,
builder,
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/ai/privado/model/PrivadoTag.scala
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ object FilterProperty extends Enumeration {
// For Inference API Endpoint mapping
val METHOD_FULL_NAME_WITH_LITERAL: model.FilterProperty.Value = Value("method_full_name_with_literal")
val METHOD_FULL_NAME_WITH_PROPERTY_NAME: model.FilterProperty.Value = Value("method_full_name_with_property_name")
val ENDPOINT_DOMAIN_WITH_LITERAL: model.FilterProperty.Value = Value("endpoint_domain_with_literal")
val ENDPOINT_DOMAIN_WITH_PROPERTY_NAME: model.FilterProperty.Value = Value("endpoint_domain_with_property_name")

def withNameWithDefault(name: String): Value = {
try {
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/ai/privado/tagger/sink/APITagger.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class APITagger(cpg: Cpg, ruleCache: RuleCache, privadoInput: PrivadoInput, appC
List()
}
sinkTagger(
cpg,
apiInternalSources ++ propertySources ++ identifierSource ++ serviceSource,
apis,
builder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import ai.privado.languageEngine.java.language.*
import ai.privado.model.FilterProperty.*
import ai.privado.model.{Constants, InternalTag, RuleInfo}
import ai.privado.tagger.PrivadoParallelCpgPass
import ai.privado.tagger.utility.APITaggerUtility.tagAPIWithDomainAndUpdateRuleCache
import ai.privado.tagger.utility.APITaggerUtility.tagThirdPartyAPIWithDomainAndUpdateRuleCache
import ai.privado.utility.Utilities.{getDomainFromString, storeForTag}
import io.shiftleft.codepropertygraph.generated.Cpg
import io.shiftleft.codepropertygraph.generated.nodes.AstNode
Expand All @@ -18,8 +18,7 @@ import org.slf4j.LoggerFactory
*/
class InferenceAPIEndpointTagger(cpg: Cpg, ruleCache: RuleCache) extends PrivadoParallelCpgPass[RuleInfo](cpg) {

private val logger = LoggerFactory.getLogger(getClass)
private val thirdPartyRuleInfo = ruleCache.getRuleInfo(Constants.thirdPartiesAPIRuleId)
private val logger = LoggerFactory.getLogger(getClass)

override def generateParts(): Array[RuleInfo] =
ruleCache.getRule.inferences.filter(_.catLevelTwo.equals(Constants.apiEndpoint)).toArray
Expand Down Expand Up @@ -58,14 +57,7 @@ class InferenceAPIEndpointTagger(cpg: Cpg, ruleCache: RuleCache) extends Privado
}

private def tagNode(builder: DiffGraphBuilder, apiCall: AstNode, apiUrl: String) = {
tagAPIWithDomainAndUpdateRuleCache(
builder,
thirdPartyRuleInfo.get,
ruleCache,
getDomainFromString(apiUrl),
apiCall,
apiUrl
)
tagThirdPartyAPIWithDomainAndUpdateRuleCache(builder, cpg, ruleCache, getDomainFromString(apiUrl), apiCall, apiUrl)
storeForTag(builder, apiCall, ruleCache)(InternalTag.API_SINK_MARKED.toString)
storeForTag(builder, apiCall, ruleCache)(InternalTag.API_URL_MARKED.toString)
}
Expand Down
Loading

0 comments on commit 3628e36

Please sign in to comment.