Skip to content

Commit

Permalink
#1872 Added quota for alerts, case templates, custom fields & organis…
Browse files Browse the repository at this point in the history
…ations
  • Loading branch information
rriclet committed Mar 23, 2021
1 parent a8611b5 commit a7ba8bd
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class OrganisationCtrl @Inject() (
val inputOrganisation: InputOrganisation = request.body("organisation")
for {
user <- userSrv.current.getOrFail("User")
organisation <- organisationSrv.create(inputOrganisation.toOrganisation, user)
organisation <- organisationSrv.createWithUserAsOrgadmin(inputOrganisation.toOrganisation, user)
} yield Results.Created(organisation.toJson)
}

Expand Down
17 changes: 17 additions & 0 deletions thehive/app/org/thp/thehive/services/AlertSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import org.thp.thehive.services.CaseOps._
import org.thp.thehive.services.CaseTemplateOps._
import org.thp.thehive.services.CustomFieldOps._
import org.thp.thehive.services.ObservableOps._
import org.thp.thehive.services.OrganisationOps._
import play.api.Configuration
import play.api.libs.json.{JsObject, JsValue, Json}

import java.lang.{Long => JLong}
Expand All @@ -28,6 +30,7 @@ import scala.util.{Failure, Success, Try}

@Singleton
class AlertSrv @Inject() (
configuration: Configuration,
caseSrv: CaseSrv,
tagSrv: TagSrv,
organisationSrv: OrganisationSrv,
Expand Down Expand Up @@ -79,6 +82,7 @@ class AlertSrv @Inject() (
Failure(CreateError(s"Alert ${alert.`type`}:${alert.source}:${alert.sourceRef} already exist in organisation ${organisation.name}"))
else
for {
_ <- checkAlertQuota(organisation)
createdAlert <- createEntity(alert.copy(organisationId = organisation._id))
_ <- alertOrganisationSrv.create(AlertOrganisation(), createdAlert, organisation)
_ <- caseTemplate.map(ct => alertCaseTemplateSrv.create(AlertCaseTemplate(), createdAlert, ct)).flip
Expand All @@ -89,6 +93,19 @@ class AlertSrv @Inject() (
} yield richAlert
}

private def checkAlertQuota(organisation: Organisation with Entity)(implicit
graph: Graph,
authContext: AuthContext
): Try[Unit] = {
val alertQuota = configuration.getOptional[Long]("quota.organisation.alert.count")
val alertCount = organisationSrv.get(organisation).alerts.getCount

alertQuota.fold[Try[Unit]](Success(()))(quota =>
if (alertCount < quota) Success(())
else Failure(BadRequestError(s"Alert quota is reached, this organisation cannot have more alerts"))
)
}

override def update(
traversal: Traversal.V[Alert],
propertyUpdaters: Seq[PropertyUpdater]
Expand Down
18 changes: 17 additions & 1 deletion thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,23 @@ import org.thp.scalligraph.query.PropertyUpdater
import org.thp.scalligraph.services._
import org.thp.scalligraph.traversal.TraversalOps._
import org.thp.scalligraph.traversal.{Converter, Graph, StepLabel, Traversal}
import org.thp.scalligraph.{CreateError, EntityIdOrName, EntityName, RichSeq}
import org.thp.scalligraph.{BadRequestError, CreateError, EntityIdOrName, EntityName, RichSeq}
import org.thp.thehive.controllers.v1.Conversion._
import org.thp.thehive.models._
import org.thp.thehive.services.CaseTemplateOps._
import org.thp.thehive.services.CustomFieldOps._
import org.thp.thehive.services.OrganisationOps._
import org.thp.thehive.services.TaskOps._
import org.thp.thehive.services.UserOps._
import play.api.Configuration
import play.api.libs.json.{JsObject, Json}

import java.util.{Map => JMap}
import javax.inject.{Inject, Named}
import scala.util.{Failure, Success, Try}

class CaseTemplateSrv @Inject() (
configuration: Configuration,
customFieldSrv: CustomFieldSrv,
organisationSrv: OrganisationSrv,
tagSrv: TagSrv,
Expand Down Expand Up @@ -57,6 +59,7 @@ class CaseTemplateSrv @Inject() (
Failure(CreateError(s"""The case template "${caseTemplate.name}" already exists"""))
else
for {
_ <- checkCaseTemplateQuota(organisation)
createdCaseTemplate <- createEntity(caseTemplate)
_ <- caseTemplateOrganisationSrv.create(CaseTemplateOrganisation(), createdCaseTemplate, organisation)
createdTasks <- tasks.toTry(createTask(createdCaseTemplate, _))
Expand All @@ -66,6 +69,19 @@ class CaseTemplateSrv @Inject() (
_ <- auditSrv.caseTemplate.create(createdCaseTemplate, richCaseTemplate.toJson)
} yield richCaseTemplate

private def checkCaseTemplateQuota(organisation: Organisation with Entity)(implicit
graph: Graph,
authContext: AuthContext
): Try[Unit] = {
val caseTemplateQuota = configuration.getOptional[Long]("quota.organisation.caseTemplate.count")
val caseTemplateCount = organisationSrv.get(organisation).caseTemplates.getCount

caseTemplateQuota.fold[Try[Unit]](Success(()))(quota =>
if (caseTemplateCount < quota) Success(())
else Failure(BadRequestError(s"Case template quota is reached, this organisation cannot have more case templates"))
)
}

def createTask(caseTemplate: CaseTemplate with Entity, task: Task)(implicit graph: Graph, authContext: AuthContext): Try[RichTask] =
for {
assignee <- task.assignee.map(u => organisationSrv.current.users(Permissions.manageTask).getByName(u).getOrFail("User")).flip
Expand Down
20 changes: 18 additions & 2 deletions thehive/app/org/thp/thehive/services/CustomFieldSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,21 @@ import org.thp.scalligraph.query.PropertyUpdater
import org.thp.scalligraph.services.{IntegrityCheckOps, VertexSrv}
import org.thp.scalligraph.traversal.TraversalOps._
import org.thp.scalligraph.traversal._
import org.thp.scalligraph.{EntityIdOrName, RichSeq}
import org.thp.scalligraph.{BadRequestError, EntityIdOrName, RichSeq}
import org.thp.thehive.controllers.v1.Conversion._
import org.thp.thehive.models._
import org.thp.thehive.services.CustomFieldOps._
import play.api.Configuration
import play.api.cache.SyncCacheApi
import play.api.libs.json.{JsObject, JsValue}

import java.util.{Map => JMap}
import javax.inject.{Inject, Named, Singleton}
import scala.util.{Success, Try}
import scala.util.{Failure, Success, Try}

@Singleton
class CustomFieldSrv @Inject() (
configuration: Configuration,
auditSrv: AuditSrv,
organisationSrv: OrganisationSrv,
@Named("integrity-check-actor") integrityCheckActor: ActorRef,
Expand All @@ -36,10 +38,24 @@ class CustomFieldSrv @Inject() (

def create(e: CustomField)(implicit graph: Graph, authContext: AuthContext): Try[CustomField with Entity] =
for {
_ <- checkCustomFieldQuota
created <- createEntity(e)
_ <- auditSrv.customField.create(created, created.toJson)
} yield created

private def checkCustomFieldQuota(implicit
graph: Graph,
authContext: AuthContext
): Try[Unit] = {
val customFieldQuota = configuration.getOptional[Long]("quota.customField.count")
val customFieldCount = startTraversal.getCount

customFieldQuota.fold[Try[Unit]](Success(()))(quota =>
if (customFieldCount < quota) Success(())
else Failure(BadRequestError(s"Custom field quota is reached, no more custom fields can be created"))
)
}

override def exists(e: CustomField)(implicit graph: Graph): Boolean = startTraversal.getByName(e.name).exists

def delete(c: CustomField with Entity, force: Boolean)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = {
Expand Down
21 changes: 20 additions & 1 deletion thehive/app/org/thp/thehive/services/OrganisationSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import org.thp.thehive.models._
import org.thp.thehive.services.OrganisationOps._
import org.thp.thehive.services.RoleOps._
import org.thp.thehive.services.UserOps._
import play.api.Configuration
import play.api.cache.SyncCacheApi
import play.api.libs.json.JsObject

Expand All @@ -22,6 +23,7 @@ import scala.util.{Failure, Success, Try}

@Singleton
class OrganisationSrv @Inject() (
configuration: Configuration,
taxonomySrvProvider: Provider[TaxonomySrv],
roleSrv: RoleSrv,
profileSrv: ProfileSrv,
Expand All @@ -42,7 +44,10 @@ class OrganisationSrv @Inject() (

override def getByName(name: String)(implicit graph: Graph): Traversal.V[Organisation] = startTraversal.getByName(name)

def create(organisation: Organisation, user: User with Entity)(implicit graph: Graph, authContext: AuthContext): Try[Organisation with Entity] =
def createWithUserAsOrgadmin(organisation: Organisation, user: User with Entity)(implicit
graph: Graph,
authContext: AuthContext
): Try[Organisation with Entity] =
for {
createdOrganisation <- create(organisation)
_ <- roleSrv.create(user, createdOrganisation, profileSrv.orgAdmin)
Expand All @@ -51,13 +56,27 @@ class OrganisationSrv @Inject() (
def create(e: Organisation)(implicit graph: Graph, authContext: AuthContext): Try[Organisation with Entity] = {
val activeTaxos = getByName("admin").taxonomies.toSeq
for {
_ <- checkOrganisationQuota
newOrga <- createEntity(e)
_ <- taxonomySrv.createFreetagTaxonomy(newOrga)
_ <- activeTaxos.toTry(t => organisationTaxonomySrv.create(OrganisationTaxonomy(), newOrga, t))
_ <- auditSrv.organisation.create(newOrga, newOrga.toJson)
} yield newOrga
}

private def checkOrganisationQuota(implicit
graph: Graph,
authContext: AuthContext
): Try[Unit] = {
val organisationQuota = configuration.getOptional[Long]("quota.organisation.count")
val organisationCount = startTraversal.getCount

organisationQuota.fold[Try[Unit]](Success(()))(quota =>
if (organisationCount < quota) Success(())
else Failure(BadRequestError(s"Organisation quota is reached, no more organisations can be created"))
)
}

def current(implicit graph: Graph, authContext: AuthContext): Traversal.V[Organisation] = get(authContext.organisation)

def currentId(implicit graph: Graph, authContext: AuthContext): EntityId =
Expand Down
2 changes: 1 addition & 1 deletion thehive/app/org/thp/thehive/services/UserSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class UserSrv @Inject() (
else Failure(BadRequestError(s"User login is invalid, it must be an email address (found: ${user.login})"))
}

def checkUserQuota(organisation: Organisation with Entity)(implicit
private def checkUserQuota(organisation: Organisation with Entity)(implicit
graph: Graph,
authContext: AuthContext
): Try[Unit] = {
Expand Down

0 comments on commit a7ba8bd

Please sign in to comment.