Skip to content

Commit

Permalink
[FR] Add Support for ES|QL Rule Type and Remote Validation (#3281)
Browse files Browse the repository at this point in the history
* add suuport for esql type
* add unit tests
* set clients in RemoteConnector from auth methods
* thread remote rules; add engine test
* Add versions to remote validation results

---------

Co-authored-by: Terrance DeJesus <[email protected]>
Co-authored-by: brokensound77 <[email protected]>
Co-authored-by: Justin Ibarra <[email protected]>

(cherry picked from commit 7514c0a)
  • Loading branch information
Mikaayenson authored and github-actions[bot] committed Dec 8, 2023
1 parent c395b3d commit 2932d33
Show file tree
Hide file tree
Showing 10 changed files with 497 additions and 159 deletions.
8 changes: 6 additions & 2 deletions detection_rules/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pathlib import Path

from functools import wraps
from typing import NoReturn
from typing import NoReturn, Optional

import click
import requests
Expand Down Expand Up @@ -270,12 +270,16 @@ def load_current_package_version() -> str:
return load_etc_dump('packages.yml')['package']['name']


def get_default_config() -> Optional[Path]:
return next(Path(get_path()).glob('.detection-rules-cfg.*'), None)


@cached
def parse_config():
"""Parse a default config file."""
import eql

config_file = next(Path(get_path()).glob('.detection-rules-cfg.*'), None)
config_file = get_default_config()
config = {}

if config_file and config_file.exists():
Expand Down
203 changes: 203 additions & 0 deletions detection_rules/remote_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.

from dataclasses import dataclass
from datetime import datetime
from functools import cached_property
from multiprocessing.pool import ThreadPool
from typing import Dict, List, Optional

import elasticsearch
from elasticsearch import Elasticsearch
from marshmallow import ValidationError
from requests import HTTPError

from kibana import Kibana

from .misc import ClientError, getdefault, get_elasticsearch_client, get_kibana_client, load_current_package_version
from .rule import TOMLRule, TOMLRuleContents
from .schemas import definitions


@dataclass
class RemoteValidationResult:
"""Dataclass for remote validation results."""
rule_id: definitions.UUIDString
rule_name: str
contents: dict
rule_version: int
stack_version: str
query_results: Optional[dict]
engine_results: Optional[dict]


class RemoteConnector:
"""Base client class for remote validation and testing."""

MAX_RETRIES = 5

def __init__(self, parse_config: bool = False, **kwargs):
es_args = ['cloud_id', 'ignore_ssl_errors', 'elasticsearch_url', 'es_user', 'es_password', 'timeout']
kibana_args = [
'cloud_id', 'ignore_ssl_errors', 'kibana_url', 'kibana_user', 'kibana_password', 'space', 'kibana_cookie',
'provider_type', 'provider_name'
]

if parse_config:
es_kwargs = {arg: getdefault(arg)() for arg in es_args}
kibana_kwargs = {arg: getdefault(arg)() for arg in kibana_args}

try:
if 'max_retries' not in es_kwargs:
es_kwargs['max_retries'] = self.MAX_RETRIES
self.es_client = get_elasticsearch_client(**es_kwargs, **kwargs)
except ClientError:
self.es_client = None

try:
self.kibana_client = get_kibana_client(**kibana_kwargs, **kwargs)
except HTTPError:
self.kibana_client = None

def auth_es(self, *, cloud_id: Optional[str] = None, ignore_ssl_errors: Optional[bool] = None,
elasticsearch_url: Optional[str] = None, es_user: Optional[str] = None,
es_password: Optional[str] = None, timeout: Optional[int] = None, **kwargs) -> Elasticsearch:
"""Return an authenticated Elasticsearch client."""
if 'max_retries' not in kwargs:
kwargs['max_retries'] = self.MAX_RETRIES
self.es_client = get_elasticsearch_client(cloud_id=cloud_id, ignore_ssl_errors=ignore_ssl_errors,
elasticsearch_url=elasticsearch_url, es_user=es_user,
es_password=es_password, timeout=timeout, **kwargs)
return self.es_client

def auth_kibana(self, *, cloud_id: Optional[str] = None, ignore_ssl_errors: Optional[bool] = None,
kibana_url: Optional[str] = None, kibana_user: Optional[str] = None,
kibana_password: Optional[str] = None, space: Optional[str] = None,
kibana_cookie: Optional[str] = None, provider_type: Optional[str] = None,
provider_name: Optional[str] = None, **kwargs) -> Kibana:
"""Return an authenticated Kibana client."""
self.kibana_client = get_kibana_client(cloud_id=cloud_id, ignore_ssl_errors=ignore_ssl_errors,
kibana_url=kibana_url, kibana_user=kibana_user,
kibana_password=kibana_password, space=space,
kibana_cookie=kibana_cookie, provider_type=provider_type,
provider_name=provider_name, **kwargs)
return self.kibana_client


class RemoteValidator(RemoteConnector):
"""Client class for remote validation."""

def __init__(self, parse_config: bool = False):
super(RemoteValidator, self).__init__(parse_config=parse_config)

@cached_property
def get_validate_methods(self) -> List[str]:
"""Return all validate methods."""
exempt = ('validate_rule', 'validate_rules')
methods = [m for m in self.__dir__() if m.startswith('validate_') and m not in exempt]
return methods

def get_validate_method(self, name: str) -> callable:
"""Return validate method by name."""
assert name in self.get_validate_methods, f'validate method {name} not found'
return getattr(self, name)

@staticmethod
def prep_for_preview(contents: TOMLRuleContents) -> dict:
"""Prepare rule for preview."""
end_time = datetime.utcnow().isoformat()
dumped = contents.to_api_format().copy()
dumped.update(timeframeEnd=end_time, invocationCount=1)
return dumped

def engine_preview(self, contents: TOMLRuleContents) -> dict:
"""Get results from detection engine preview API."""
dumped = self.prep_for_preview(contents)
return self.kibana_client.post('/api/detection_engine/rules/preview', json=dumped)

def validate_rule(self, contents: TOMLRuleContents) -> RemoteValidationResult:
"""Validate a single rule query."""
method = self.get_validate_method(f'validate_{contents.data.type}')
query_results = method(contents)
engine_results = self.engine_preview(contents)
rule_version = contents.autobumped_version
stack_version = load_current_package_version()
return RemoteValidationResult(contents.data.rule_id, contents.data.name, contents.to_api_format(),
rule_version, stack_version, query_results, engine_results)

def validate_rules(self, rules: List[TOMLRule], threads: int = 5) -> Dict[str, RemoteValidationResult]:
"""Validate a collection of rules via threads."""
responses = {}

def request(c: TOMLRuleContents):
try:
responses[c.data.rule_id] = self.validate_rule(c)
except ValidationError as e:
responses[c.data.rule_id] = e.messages

pool = ThreadPool(processes=threads)
pool.map(request, [r.contents for r in rules])
pool.close()
pool.join()

return responses

def validate_esql(self, contents: TOMLRuleContents) -> dict:
query = contents.data.query
rule_id = contents.data.rule_id
headers = {"accept": "application/json", "content-type": "application/json"}
body = {'query': f'{query} | LIMIT 0'}
try:
response = self.es_client.perform_request('POST', '/_query', headers=headers, params={'pretty': True},
body=body)
except Exception as exc:
if isinstance(exc, elasticsearch.BadRequestError):
raise ValidationError(f'ES|QL query failed: {exc} for rule: {rule_id}, query: \n{query}')
else:
raise Exception(f'ES|QL query failed for rule: {rule_id}, query: \n{query}') from exc

return response.body

def validate_eql(self, contents: TOMLRuleContents) -> dict:
"""Validate query for "eql" rule types."""
query = contents.data.query
rule_id = contents.data.rule_id
index = contents.data.index
time_range = {"range": {"@timestamp": {"gt": 'now-1h/h', "lte": 'now', "format": "strict_date_optional_time"}}}
body = {'query': query}
try:
response = self.es_client.eql.search(index=index, body=body, ignore_unavailable=True, filter=time_range)
except Exception as exc:
if isinstance(exc, elasticsearch.BadRequestError):
raise ValidationError(f'EQL query failed: {exc} for rule: {rule_id}, query: \n{query}')
else:
raise Exception(f'EQL query failed for rule: {rule_id}, query: \n{query}') from exc

return response.body

@staticmethod
def validate_query(self, contents: TOMLRuleContents) -> dict:
"""Validate query for "query" rule types."""
return {'results': 'Unable to remote validate query rules'}

@staticmethod
def validate_threshold(self, contents: TOMLRuleContents) -> dict:
"""Validate query for "threshold" rule types."""
return {'results': 'Unable to remote validate threshold rules'}

@staticmethod
def validate_new_terms(self, contents: TOMLRuleContents) -> dict:
"""Validate query for "new_terms" rule types."""
return {'results': 'Unable to remote validate new_terms rules'}

@staticmethod
def validate_threat_match(self, contents: TOMLRuleContents) -> dict:
"""Validate query for "threat_match" rule types."""
return {'results': 'Unable to remote validate threat_match rules'}

@staticmethod
def validate_machine_learning(self, contents: TOMLRuleContents) -> dict:
"""Validate query for "machine_learning" rule types."""
return {'results': 'Unable to remote validate machine_learning rules'}
45 changes: 24 additions & 21 deletions detection_rules/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,26 +594,11 @@ def get_required_fields(self, index: str) -> List[dict]:
@validates_schema
def validates_query_data(self, data, **kwargs):
"""Custom validation for query rule type and subclasses."""

# alert suppression is only valid for query rule type and not any of its subclasses
if data.get('alert_suppression') and data['type'] != 'query':
raise ValidationError("Alert suppression is only valid for query rule type.")


@dataclass(frozen=True)
class ESQLRuleData(QueryRuleData):
"""ESQL rules are a special case of query rules."""
type: Literal["esql"]
language: Literal["esql"]
query: str

@validates_schema
def validate_esql_data(self, data, **kwargs):
"""Custom validation for esql rule type."""
if data.get('index'):
raise ValidationError("Index is not valid for esql rule type.")


@dataclass(frozen=True)
class MachineLearningRuleData(BaseRuleData):
type: Literal["machine_learning"]
Expand Down Expand Up @@ -726,6 +711,20 @@ def interval_ratio(self) -> Optional[float]:
return interval / self.max_span


@dataclass(frozen=True)
class ESQLRuleData(QueryRuleData):
"""ESQL rules are a special case of query rules."""
type: Literal["esql"]
language: Literal["esql"]
query: str

@validates_schema
def validates_esql_data(self, data, **kwargs):
"""Custom validation for query rule type and subclasses."""
if data.get('index'):
raise ValidationError("Index is not a valid field for ES|QL rule type.")


@dataclass(frozen=True)
class ThreatMatchRuleData(QueryRuleData):
"""Specific fields for indicator (threat) match rule."""
Expand Down Expand Up @@ -1096,12 +1095,11 @@ def get_packaged_integrations(cls, data: QueryRuleData, meta: RuleMeta,
packaged_integrations = []
datasets = set()

if data.type != "esql":
for node in data.get('ast', []):
if isinstance(node, eql.ast.Comparison) and str(node.left) == 'event.dataset':
datasets.update(set(n.value for n in node if isinstance(n, eql.ast.Literal)))
elif isinstance(node, FieldComparison) and str(node.field) == 'event.dataset':
datasets.update(set(str(n) for n in node if isinstance(n, kql.ast.Value)))
for node in data.get('ast') or []:
if isinstance(node, eql.ast.Comparison) and str(node.left) == 'event.dataset':
datasets.update(set(n.value for n in node if isinstance(n, eql.ast.Literal)))
elif isinstance(node, FieldComparison) and str(node.field) == 'event.dataset':
datasets.update(set(str(n) for n in node if isinstance(n, kql.ast.Value)))

# integration is None to remove duplicate references upstream in Kibana
# chronologically, event.dataset is checked for package:integration, then rule tags
Expand Down Expand Up @@ -1139,6 +1137,10 @@ def post_conversion_validation(self, value: dict, **kwargs):
data.data_validator.validate_bbr(metadata.get('bypass_bbr_timing'))
data.validate(metadata) if hasattr(data, 'validate') else False

@staticmethod
def validate_remote(remote_validator: 'RemoteValidator', contents: 'TOMLRuleContents'):
remote_validator.validate_rule(contents)

def to_dict(self, strip_none_values=True) -> dict:
# Load schemas directly from the data and metadata classes to avoid schema ambiguity which can
# result from union fields which contain classes and related subclasses (AnyRuleData). See issue #1141
Expand Down Expand Up @@ -1347,3 +1349,4 @@ def get_unique_query_fields(rule: TOMLRule) -> List[str]:

# avoid a circular import
from .rule_validators import EQLValidator, ESQLValidator, KQLValidator # noqa: E402
from .remote_validation import RemoteValidator # noqa: E402
17 changes: 12 additions & 5 deletions detection_rules/rule_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@

"""Validation logic for rules containing queries."""
from functools import cached_property
from typing import List, Optional, Union, Tuple
from semver import Version
from typing import List, Optional, Tuple, Union

import eql
from marshmallow import ValidationError
from semver import Version

import kql

from . import ecs, endgame
from .integrations import get_integration_schema_data, load_integrations_manifests
from .integrations import (get_integration_schema_data,
load_integrations_manifests)
from .misc import load_current_package_version
from .rule import (EQLRuleData, QueryRuleData, QueryValidator, RuleMeta,
TOMLRuleContents, set_eql_config)
from .schemas import get_stack_schemas
from .rule import QueryRuleData, QueryValidator, RuleMeta, TOMLRuleContents, EQLRuleData, set_eql_config

EQL_ERROR_TYPES = Union[eql.EqlCompileError,
eql.EqlError,
Expand Down Expand Up @@ -351,7 +354,6 @@ class ESQLValidator(QueryValidator):

@cached_property
def ast(self):
"""Return an AST."""
return None

@cached_property
Expand All @@ -365,6 +367,11 @@ def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None:
"""Validate an ESQL query while checking TOMLRule."""
# temporarily override to NOP until ES|QL query parsing is supported

def validate_integration(self, data: QueryRuleData, meta: RuleMeta, package_integrations: List[dict]) -> Union[
ValidationError, None, ValueError]:
# return self.validate(data, meta)
pass


def extract_error_field(exc: Union[eql.EqlParseError, kql.KqlParseError]) -> Optional[str]:
"""Extract the field name from an EQL or KQL parse error."""
Expand Down
2 changes: 1 addition & 1 deletion detection_rules/schemas/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@
CodeString = NewType("CodeString", str)
ConditionSemVer = NewType('ConditionSemVer', str, validate=validate.Regexp(CONDITION_VERSION_PATTERN))
Date = NewType('Date', str, validate=validate.Regexp(DATE_PATTERN))
FilterLanguages = Literal["kuery", "lucene", "eql", "esql"]
FilterLanguages = Literal["eql", "esql", "kuery", "lucene"]
Interval = NewType('Interval', str, validate=validate.Regexp(INTERVAL_PATTERN))
InvestigateProviderQueryType = Literal["phrase", "range"]
InvestigateProviderValueType = Literal["string", "boolean"]
Expand Down
4 changes: 4 additions & 0 deletions tests/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
Loading

0 comments on commit 2932d33

Please sign in to comment.