diff --git a/detection_rules/misc.py b/detection_rules/misc.py index e940f920316..82d8e4893af 100644 --- a/detection_rules/misc.py +++ b/detection_rules/misc.py @@ -11,7 +11,7 @@ from pathlib import Path from functools import wraps -from typing import NoReturn +from typing import NoReturn, Optional import click import requests @@ -270,12 +270,16 @@ def load_current_package_version() -> str: return load_etc_dump('packages.yml')['package']['name'] +def get_default_config() -> Optional[Path]: + return next(Path(get_path()).glob('.detection-rules-cfg.*'), None) + + @cached def parse_config(): """Parse a default config file.""" import eql - config_file = next(Path(get_path()).glob('.detection-rules-cfg.*'), None) + config_file = get_default_config() config = {} if config_file and config_file.exists(): diff --git a/detection_rules/remote_validation.py b/detection_rules/remote_validation.py new file mode 100644 index 00000000000..bab2646041b --- /dev/null +++ b/detection_rules/remote_validation.py @@ -0,0 +1,203 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0; you may not use this file except in compliance with the Elastic License +# 2.0. + +from dataclasses import dataclass +from datetime import datetime +from functools import cached_property +from multiprocessing.pool import ThreadPool +from typing import Dict, List, Optional + +import elasticsearch +from elasticsearch import Elasticsearch +from marshmallow import ValidationError +from requests import HTTPError + +from kibana import Kibana + +from .misc import ClientError, getdefault, get_elasticsearch_client, get_kibana_client, load_current_package_version +from .rule import TOMLRule, TOMLRuleContents +from .schemas import definitions + + +@dataclass +class RemoteValidationResult: + """Dataclass for remote validation results.""" + rule_id: definitions.UUIDString + rule_name: str + contents: dict + rule_version: int + stack_version: str + query_results: Optional[dict] + engine_results: Optional[dict] + + +class RemoteConnector: + """Base client class for remote validation and testing.""" + + MAX_RETRIES = 5 + + def __init__(self, parse_config: bool = False, **kwargs): + es_args = ['cloud_id', 'ignore_ssl_errors', 'elasticsearch_url', 'es_user', 'es_password', 'timeout'] + kibana_args = [ + 'cloud_id', 'ignore_ssl_errors', 'kibana_url', 'kibana_user', 'kibana_password', 'space', 'kibana_cookie', + 'provider_type', 'provider_name' + ] + + if parse_config: + es_kwargs = {arg: getdefault(arg)() for arg in es_args} + kibana_kwargs = {arg: getdefault(arg)() for arg in kibana_args} + + try: + if 'max_retries' not in es_kwargs: + es_kwargs['max_retries'] = self.MAX_RETRIES + self.es_client = get_elasticsearch_client(**es_kwargs, **kwargs) + except ClientError: + self.es_client = None + + try: + self.kibana_client = get_kibana_client(**kibana_kwargs, **kwargs) + except HTTPError: + self.kibana_client = None + + def auth_es(self, *, cloud_id: Optional[str] = None, ignore_ssl_errors: Optional[bool] = None, + elasticsearch_url: Optional[str] = None, es_user: Optional[str] = None, + es_password: Optional[str] = None, timeout: Optional[int] = None, **kwargs) -> Elasticsearch: + """Return an authenticated Elasticsearch client.""" + if 'max_retries' not in kwargs: + kwargs['max_retries'] = self.MAX_RETRIES + self.es_client = get_elasticsearch_client(cloud_id=cloud_id, ignore_ssl_errors=ignore_ssl_errors, + elasticsearch_url=elasticsearch_url, es_user=es_user, + es_password=es_password, timeout=timeout, **kwargs) + return self.es_client + + def auth_kibana(self, *, cloud_id: Optional[str] = None, ignore_ssl_errors: Optional[bool] = None, + kibana_url: Optional[str] = None, kibana_user: Optional[str] = None, + kibana_password: Optional[str] = None, space: Optional[str] = None, + kibana_cookie: Optional[str] = None, provider_type: Optional[str] = None, + provider_name: Optional[str] = None, **kwargs) -> Kibana: + """Return an authenticated Kibana client.""" + self.kibana_client = get_kibana_client(cloud_id=cloud_id, ignore_ssl_errors=ignore_ssl_errors, + kibana_url=kibana_url, kibana_user=kibana_user, + kibana_password=kibana_password, space=space, + kibana_cookie=kibana_cookie, provider_type=provider_type, + provider_name=provider_name, **kwargs) + return self.kibana_client + + +class RemoteValidator(RemoteConnector): + """Client class for remote validation.""" + + def __init__(self, parse_config: bool = False): + super(RemoteValidator, self).__init__(parse_config=parse_config) + + @cached_property + def get_validate_methods(self) -> List[str]: + """Return all validate methods.""" + exempt = ('validate_rule', 'validate_rules') + methods = [m for m in self.__dir__() if m.startswith('validate_') and m not in exempt] + return methods + + def get_validate_method(self, name: str) -> callable: + """Return validate method by name.""" + assert name in self.get_validate_methods, f'validate method {name} not found' + return getattr(self, name) + + @staticmethod + def prep_for_preview(contents: TOMLRuleContents) -> dict: + """Prepare rule for preview.""" + end_time = datetime.utcnow().isoformat() + dumped = contents.to_api_format().copy() + dumped.update(timeframeEnd=end_time, invocationCount=1) + return dumped + + def engine_preview(self, contents: TOMLRuleContents) -> dict: + """Get results from detection engine preview API.""" + dumped = self.prep_for_preview(contents) + return self.kibana_client.post('/api/detection_engine/rules/preview', json=dumped) + + def validate_rule(self, contents: TOMLRuleContents) -> RemoteValidationResult: + """Validate a single rule query.""" + method = self.get_validate_method(f'validate_{contents.data.type}') + query_results = method(contents) + engine_results = self.engine_preview(contents) + rule_version = contents.autobumped_version + stack_version = load_current_package_version() + return RemoteValidationResult(contents.data.rule_id, contents.data.name, contents.to_api_format(), + rule_version, stack_version, query_results, engine_results) + + def validate_rules(self, rules: List[TOMLRule], threads: int = 5) -> Dict[str, RemoteValidationResult]: + """Validate a collection of rules via threads.""" + responses = {} + + def request(c: TOMLRuleContents): + try: + responses[c.data.rule_id] = self.validate_rule(c) + except ValidationError as e: + responses[c.data.rule_id] = e.messages + + pool = ThreadPool(processes=threads) + pool.map(request, [r.contents for r in rules]) + pool.close() + pool.join() + + return responses + + def validate_esql(self, contents: TOMLRuleContents) -> dict: + query = contents.data.query + rule_id = contents.data.rule_id + headers = {"accept": "application/json", "content-type": "application/json"} + body = {'query': f'{query} | LIMIT 0'} + try: + response = self.es_client.perform_request('POST', '/_query', headers=headers, params={'pretty': True}, + body=body) + except Exception as exc: + if isinstance(exc, elasticsearch.BadRequestError): + raise ValidationError(f'ES|QL query failed: {exc} for rule: {rule_id}, query: \n{query}') + else: + raise Exception(f'ES|QL query failed for rule: {rule_id}, query: \n{query}') from exc + + return response.body + + def validate_eql(self, contents: TOMLRuleContents) -> dict: + """Validate query for "eql" rule types.""" + query = contents.data.query + rule_id = contents.data.rule_id + index = contents.data.index + time_range = {"range": {"@timestamp": {"gt": 'now-1h/h', "lte": 'now', "format": "strict_date_optional_time"}}} + body = {'query': query} + try: + response = self.es_client.eql.search(index=index, body=body, ignore_unavailable=True, filter=time_range) + except Exception as exc: + if isinstance(exc, elasticsearch.BadRequestError): + raise ValidationError(f'EQL query failed: {exc} for rule: {rule_id}, query: \n{query}') + else: + raise Exception(f'EQL query failed for rule: {rule_id}, query: \n{query}') from exc + + return response.body + + @staticmethod + def validate_query(self, contents: TOMLRuleContents) -> dict: + """Validate query for "query" rule types.""" + return {'results': 'Unable to remote validate query rules'} + + @staticmethod + def validate_threshold(self, contents: TOMLRuleContents) -> dict: + """Validate query for "threshold" rule types.""" + return {'results': 'Unable to remote validate threshold rules'} + + @staticmethod + def validate_new_terms(self, contents: TOMLRuleContents) -> dict: + """Validate query for "new_terms" rule types.""" + return {'results': 'Unable to remote validate new_terms rules'} + + @staticmethod + def validate_threat_match(self, contents: TOMLRuleContents) -> dict: + """Validate query for "threat_match" rule types.""" + return {'results': 'Unable to remote validate threat_match rules'} + + @staticmethod + def validate_machine_learning(self, contents: TOMLRuleContents) -> dict: + """Validate query for "machine_learning" rule types.""" + return {'results': 'Unable to remote validate machine_learning rules'} diff --git a/detection_rules/rule.py b/detection_rules/rule.py index d63619a2e4d..b001c0bab08 100644 --- a/detection_rules/rule.py +++ b/detection_rules/rule.py @@ -594,26 +594,11 @@ def get_required_fields(self, index: str) -> List[dict]: @validates_schema def validates_query_data(self, data, **kwargs): """Custom validation for query rule type and subclasses.""" - # alert suppression is only valid for query rule type and not any of its subclasses if data.get('alert_suppression') and data['type'] != 'query': raise ValidationError("Alert suppression is only valid for query rule type.") -@dataclass(frozen=True) -class ESQLRuleData(QueryRuleData): - """ESQL rules are a special case of query rules.""" - type: Literal["esql"] - language: Literal["esql"] - query: str - - @validates_schema - def validate_esql_data(self, data, **kwargs): - """Custom validation for esql rule type.""" - if data.get('index'): - raise ValidationError("Index is not valid for esql rule type.") - - @dataclass(frozen=True) class MachineLearningRuleData(BaseRuleData): type: Literal["machine_learning"] @@ -726,6 +711,20 @@ def interval_ratio(self) -> Optional[float]: return interval / self.max_span +@dataclass(frozen=True) +class ESQLRuleData(QueryRuleData): + """ESQL rules are a special case of query rules.""" + type: Literal["esql"] + language: Literal["esql"] + query: str + + @validates_schema + def validates_esql_data(self, data, **kwargs): + """Custom validation for query rule type and subclasses.""" + if data.get('index'): + raise ValidationError("Index is not a valid field for ES|QL rule type.") + + @dataclass(frozen=True) class ThreatMatchRuleData(QueryRuleData): """Specific fields for indicator (threat) match rule.""" @@ -1096,12 +1095,11 @@ def get_packaged_integrations(cls, data: QueryRuleData, meta: RuleMeta, packaged_integrations = [] datasets = set() - if data.type != "esql": - for node in data.get('ast', []): - if isinstance(node, eql.ast.Comparison) and str(node.left) == 'event.dataset': - datasets.update(set(n.value for n in node if isinstance(n, eql.ast.Literal))) - elif isinstance(node, FieldComparison) and str(node.field) == 'event.dataset': - datasets.update(set(str(n) for n in node if isinstance(n, kql.ast.Value))) + for node in data.get('ast') or []: + if isinstance(node, eql.ast.Comparison) and str(node.left) == 'event.dataset': + datasets.update(set(n.value for n in node if isinstance(n, eql.ast.Literal))) + elif isinstance(node, FieldComparison) and str(node.field) == 'event.dataset': + datasets.update(set(str(n) for n in node if isinstance(n, kql.ast.Value))) # integration is None to remove duplicate references upstream in Kibana # chronologically, event.dataset is checked for package:integration, then rule tags @@ -1139,6 +1137,10 @@ def post_conversion_validation(self, value: dict, **kwargs): data.data_validator.validate_bbr(metadata.get('bypass_bbr_timing')) data.validate(metadata) if hasattr(data, 'validate') else False + @staticmethod + def validate_remote(remote_validator: 'RemoteValidator', contents: 'TOMLRuleContents'): + remote_validator.validate_rule(contents) + def to_dict(self, strip_none_values=True) -> dict: # Load schemas directly from the data and metadata classes to avoid schema ambiguity which can # result from union fields which contain classes and related subclasses (AnyRuleData). See issue #1141 @@ -1347,3 +1349,4 @@ def get_unique_query_fields(rule: TOMLRule) -> List[str]: # avoid a circular import from .rule_validators import EQLValidator, ESQLValidator, KQLValidator # noqa: E402 +from .remote_validation import RemoteValidator # noqa: E402 diff --git a/detection_rules/rule_validators.py b/detection_rules/rule_validators.py index ea9185b072f..9dcfdb468f2 100644 --- a/detection_rules/rule_validators.py +++ b/detection_rules/rule_validators.py @@ -5,18 +5,21 @@ """Validation logic for rules containing queries.""" from functools import cached_property -from typing import List, Optional, Union, Tuple -from semver import Version +from typing import List, Optional, Tuple, Union import eql +from marshmallow import ValidationError +from semver import Version import kql from . import ecs, endgame -from .integrations import get_integration_schema_data, load_integrations_manifests +from .integrations import (get_integration_schema_data, + load_integrations_manifests) from .misc import load_current_package_version +from .rule import (EQLRuleData, QueryRuleData, QueryValidator, RuleMeta, + TOMLRuleContents, set_eql_config) from .schemas import get_stack_schemas -from .rule import QueryRuleData, QueryValidator, RuleMeta, TOMLRuleContents, EQLRuleData, set_eql_config EQL_ERROR_TYPES = Union[eql.EqlCompileError, eql.EqlError, @@ -351,7 +354,6 @@ class ESQLValidator(QueryValidator): @cached_property def ast(self): - """Return an AST.""" return None @cached_property @@ -365,6 +367,11 @@ def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None: """Validate an ESQL query while checking TOMLRule.""" # temporarily override to NOP until ES|QL query parsing is supported + def validate_integration(self, data: QueryRuleData, meta: RuleMeta, package_integrations: List[dict]) -> Union[ + ValidationError, None, ValueError]: + # return self.validate(data, meta) + pass + def extract_error_field(exc: Union[eql.EqlParseError, kql.KqlParseError]) -> Optional[str]: """Extract the field name from an EQL or KQL parse error.""" diff --git a/detection_rules/schemas/definitions.py b/detection_rules/schemas/definitions.py index e1b1c5273a3..38e4d28222b 100644 --- a/detection_rules/schemas/definitions.py +++ b/detection_rules/schemas/definitions.py @@ -138,7 +138,7 @@ CodeString = NewType("CodeString", str) ConditionSemVer = NewType('ConditionSemVer', str, validate=validate.Regexp(CONDITION_VERSION_PATTERN)) Date = NewType('Date', str, validate=validate.Regexp(DATE_PATTERN)) -FilterLanguages = Literal["kuery", "lucene", "eql", "esql"] +FilterLanguages = Literal["eql", "esql", "kuery", "lucene"] Interval = NewType('Interval', str, validate=validate.Regexp(INTERVAL_PATTERN)) InvestigateProviderQueryType = Literal["phrase", "range"] InvestigateProviderValueType = Literal["string", "boolean"] diff --git a/tests/data/__init__.py b/tests/data/__init__.py new file mode 100644 index 00000000000..72ea1f6e244 --- /dev/null +++ b/tests/data/__init__.py @@ -0,0 +1,4 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0; you may not use this file except in compliance with the Elastic License +# 2.0. diff --git a/tests/data/command_control_dummy_production_rule.toml b/tests/data/command_control_dummy_production_rule.toml new file mode 100644 index 00000000000..caa798ff22a --- /dev/null +++ b/tests/data/command_control_dummy_production_rule.toml @@ -0,0 +1,37 @@ +[metadata] +creation_date = "2023/11/20" +integration = ["endpoint"] +maturity = "production" +min_stack_comments = "ES|QL Rule" +min_stack_version = "8.11.0" +updated_date = "2023/11/20" + +[rule] +author = ["Elastic"] +description = """ +Sample ES|QL rule for unit tests. +""" +from = "now-9m" +language = "esql" +license = "Elastic License v2" +name = "Sample ES|QL rule for unit tests" +risk_score = 47 +rule_id = "24220495-cffe-45a1-996c-37b599ba0e43" +severity = "medium" +tags = ["Data Source: Elastic Endpoint", "Domain: Endpoint", "OS: Windows", "Use Case: Threat Detection", "Tactic: Command and Control", "Data Source: Elastic Defend"] +timestamp_override = "event.ingested" +type = "esql" +query = ''' +from .ds-logs-endpoint.events.process-default-* + | where event.action == "start" and process.code_signature.subject_name like "Microsoft*" and process.parent.name in ("winword.exe", "WINWORD.EXE", "EXCEL.EXE", "excel.exe") + | eval process_path = replace(process.executable, """[cC]:\\[uU][sS][eE][rR][sS]\\[a-zA-Z0-9\.\-\_\$]+\\""", "C:\\\\users\\\\user\\\\") + | stats cc = count(*) by process_path, process.parent.name | where cc <= 5 +''' + +[[rule.threat]] +framework = "MITRE ATT&CK" + +[rule.threat.tactic] +id = "TA0011" +name = "Command and Control" +reference = "https://attack.mitre.org/tactics/TA0011/" diff --git a/tests/test_all_rules.py b/tests/test_all_rules.py index 50d1d1af82d..5fba8ea4a94 100644 --- a/tests/test_all_rules.py +++ b/tests/test_all_rules.py @@ -13,6 +13,7 @@ from pathlib import Path import eql.ast + from marshmallow import ValidationError from semver import Version @@ -29,8 +30,7 @@ from detection_rules.rule_loader import FILE_PATTERN from detection_rules.rule_validators import EQLValidator, KQLValidator from detection_rules.schemas import definitions, get_stack_schemas -from detection_rules.utils import (INTEGRATION_RULE_DIR, PatchedTemplate, - get_path, load_etc_dump) +from detection_rules.utils import INTEGRATION_RULE_DIR, PatchedTemplate, get_path, load_etc_dump from detection_rules.version_lock import default_version_lock from rta import get_available_tests @@ -666,7 +666,7 @@ def test_integration_tag(self): "f3e22c8b-ea47-45d1-b502-b57b6de950b3" ] if any([re.search("|".join(non_dataset_packages), i, re.IGNORECASE) - for i in rule.contents.data.index]): + for i in rule.contents.data.get('index') or []]): if not rule.contents.metadata.integration and rule.id not in ignore_ids and \ rule.contents.data.type not in definitions.MACHINE_LEARNING: err_msg = f'substrings {non_dataset_packages} found in '\ @@ -1182,35 +1182,6 @@ def test_rule_risk_score_severity_mismatch(self): self.fail(err_msg) -class TestEndpointQuery(BaseRuleTest): - """Test endpoint-specific rules.""" - - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.3.0"), - "Test only applicable to 8.3+ stacks since query updates are min_stacked at 8.3.0") - def test_os_and_platform_in_query(self): - """Test that all endpoint rules have an os defined and linux includes platform.""" - for rule in self.production_rules: - if not rule.contents.data.get('language') in ('eql', 'kuery'): - continue - if rule.path.parent.name not in ('windows', 'macos', 'linux'): - # skip cross-platform for now - continue - - ast = rule.contents.data.ast - fields = [str(f) for f in ast if isinstance(f, (kql.ast.Field, eql.ast.Field))] - - err_msg = f'{self.rule_str(rule)} missing required field for endpoint rule' - if 'host.os.type' not in fields: - # Exception for Forwarded Events which contain Windows-only fields. - if rule.path.parent.name == 'windows' and not any(field.startswith('winlog.') for field in fields): - self.assertIn('host.os.type', fields, err_msg) - - # going to bypass this for now - # if rule.path.parent.name == 'linux': - # err_msg = f'{self.rule_str(rule)} missing required field for linux endpoint rule' - # self.assertIn('host.os.platform', fields, err_msg) - - class TestNoteMarkdownPlugins(BaseRuleTest): """Test if a guide containing Osquery Plugin syntax contains the version note.""" @@ -1334,101 +1305,3 @@ def test_group_field_in_schemas(self): if fld not in schema.keys(): self.fail(f"{self.rule_str(rule)} alert suppression field {fld} not \ found in ECS, Beats, or non-ecs schemas") - - -class TestNewTerms(BaseRuleTest): - """Test new term rules.""" - - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"), - "Test only applicable to 8.4+ stacks for new terms feature.") - def test_history_window_start(self): - """Test new terms history window start field.""" - - for rule in self.production_rules: - if rule.contents.data.type == "new_terms": - - # validate history window start field exists and is correct - assert rule.contents.data.new_terms.history_window_start, \ - "new terms field found with no history_window_start field defined" - assert rule.contents.data.new_terms.history_window_start[0].field == "history_window_start", \ - f"{rule.contents.data.new_terms.history_window_start} should be 'history_window_start'" - - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"), - "Test only applicable to 8.4+ stacks for new terms feature.") - def test_new_terms_field_exists(self): - # validate new terms and history window start fields are correct - for rule in self.production_rules: - if rule.contents.data.type == "new_terms": - assert rule.contents.data.new_terms.field == "new_terms_fields", \ - f"{rule.contents.data.new_terms.field} should be 'new_terms_fields' for new_terms rule type" - - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"), - "Test only applicable to 8.4+ stacks for new terms feature.") - def test_new_terms_fields(self): - """Test new terms fields are schema validated.""" - # ecs validation - for rule in self.production_rules: - if rule.contents.data.type == "new_terms": - meta = rule.contents.metadata - feature_min_stack = Version.parse('8.4.0') - current_package_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) - min_stack_version = Version.parse(meta.get("min_stack_version")) if \ - meta.get("min_stack_version") else None - min_stack_version = current_package_version if min_stack_version is None or min_stack_version < \ - current_package_version else min_stack_version - - assert min_stack_version >= feature_min_stack, \ - f"New Terms rule types only compatible with {feature_min_stack}+" - ecs_version = get_stack_schemas()[str(min_stack_version)]['ecs'] - beats_version = get_stack_schemas()[str(min_stack_version)]['beats'] - - # checks if new terms field(s) are in ecs, beats non-ecs or integration schemas - queryvalidator = QueryValidator(rule.contents.data.query) - _, _, schema = queryvalidator.get_beats_schema([], beats_version, ecs_version) - integration_manifests = load_integrations_manifests() - integration_schemas = load_integrations_schemas() - integration_tags = meta.get("integration") - if integration_tags: - for tag in integration_tags: - latest_tag_compat_ver, _ = find_latest_compatible_version( - package=tag, - integration="", - rule_stack_version=min_stack_version, - packages_manifest=integration_manifests) - if latest_tag_compat_ver: - integration_schema = integration_schemas[tag][latest_tag_compat_ver] - for policy_template in integration_schema.keys(): - schema.update(**integration_schemas[tag][latest_tag_compat_ver][policy_template]) - for new_terms_field in rule.contents.data.new_terms.value: - assert new_terms_field in schema.keys(), \ - f"{new_terms_field} not found in ECS, Beats, or non-ecs schemas" - - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"), - "Test only applicable to 8.4+ stacks for new terms feature.") - def test_new_terms_max_limit(self): - """Test new terms max limit.""" - # validates length of new_terms to stack version - https://github.com/elastic/kibana/issues/142862 - for rule in self.production_rules: - if rule.contents.data.type == "new_terms": - meta = rule.contents.metadata - feature_min_stack = Version.parse('8.4.0') - feature_min_stack_extended_fields = Version.parse('8.6.0') - current_package_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) - min_stack_version = Version.parse(meta.get("min_stack_version")) if \ - meta.get("min_stack_version") else None - min_stack_version = current_package_version if min_stack_version is None or min_stack_version < \ - current_package_version else min_stack_version - if min_stack_version >= feature_min_stack and \ - min_stack_version < feature_min_stack_extended_fields: - assert len(rule.contents.data.new_terms.value) == 1, \ - f"new terms have a max limit of 1 for stack versions below {feature_min_stack_extended_fields}" - - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.6.0"), - "Test only applicable to 8.4+ stacks for new terms feature.") - def test_new_terms_fields_unique(self): - """Test new terms fields are unique.""" - # validate fields are unique - for rule in self.production_rules: - if rule.contents.data.type == "new_terms": - assert len(set(rule.contents.data.new_terms.value)) == len(rule.contents.data.new_terms.value), \ - f"new terms fields values are not unique - {rule.contents.data.new_terms.value}" diff --git a/tests/test_rules_remote.py b/tests/test_rules_remote.py new file mode 100644 index 00000000000..e422239ce62 --- /dev/null +++ b/tests/test_rules_remote.py @@ -0,0 +1,21 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0; you may not use this file except in compliance with the Elastic License +# 2.0. + +import unittest + +from .base import BaseRuleTest +from detection_rules.misc import get_default_config +# from detection_rules.remote_validation import RemoteValidator + + +@unittest.skipIf(get_default_config() is None, 'Skipping remote validation due to missing config') +class TestRemoteRules(BaseRuleTest): + """Test rules against a remote Elastic stack instance.""" + + # def test_esql_rules(self): + # """Temporarily explicitly test all ES|QL rules remotely pending parsing lib.""" + # esql_rules = [r for r in self.all_rules if r.contents.data.type == 'esql'] + # rv = RemoteValidator(parse_config=True) + # rv.validate_rules(esql_rules) diff --git a/tests/test_specific_rules.py b/tests/test_specific_rules.py new file mode 100644 index 00000000000..f844f89f4aa --- /dev/null +++ b/tests/test_specific_rules.py @@ -0,0 +1,186 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0; you may not use this file except in compliance with the Elastic License +# 2.0. + +import unittest +from copy import deepcopy +from pathlib import Path + +import eql.ast + +from semver import Version + +import kql +from detection_rules.integrations import ( + find_latest_compatible_version, load_integrations_manifests, load_integrations_schemas +) +from detection_rules.misc import load_current_package_version +from detection_rules.packaging import current_stack_version +from detection_rules.rule import QueryValidator +from detection_rules.rule_loader import RuleCollection +from detection_rules.schemas import get_stack_schemas +from detection_rules.utils import get_path, load_rule_contents + +from .base import BaseRuleTest +PACKAGE_STACK_VERSION = Version.parse(current_stack_version(), optional_minor_and_patch=True) + + +class TestEndpointQuery(BaseRuleTest): + """Test endpoint-specific rules.""" + + @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.3.0"), + "Test only applicable to 8.3+ stacks since query updates are min_stacked at 8.3.0") + def test_os_and_platform_in_query(self): + """Test that all endpoint rules have an os defined and linux includes platform.""" + for rule in self.production_rules: + if not rule.contents.data.get('language') in ('eql', 'kuery'): + continue + if rule.path.parent.name not in ('windows', 'macos', 'linux'): + # skip cross-platform for now + continue + + ast = rule.contents.data.ast + fields = [str(f) for f in ast if isinstance(f, (kql.ast.Field, eql.ast.Field))] + + err_msg = f'{self.rule_str(rule)} missing required field for endpoint rule' + if 'host.os.type' not in fields: + # Exception for Forwarded Events which contain Windows-only fields. + if rule.path.parent.name == 'windows' and not any(field.startswith('winlog.') for field in fields): + self.assertIn('host.os.type', fields, err_msg) + + # going to bypass this for now + # if rule.path.parent.name == 'linux': + # err_msg = f'{self.rule_str(rule)} missing required field for linux endpoint rule' + # self.assertIn('host.os.platform', fields, err_msg) + + +class TestNewTerms(BaseRuleTest): + """Test new term rules.""" + + @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"), + "Test only applicable to 8.4+ stacks for new terms feature.") + def test_history_window_start(self): + """Test new terms history window start field.""" + + for rule in self.production_rules: + if rule.contents.data.type == "new_terms": + + # validate history window start field exists and is correct + assert rule.contents.data.new_terms.history_window_start, \ + "new terms field found with no history_window_start field defined" + assert rule.contents.data.new_terms.history_window_start[0].field == "history_window_start", \ + f"{rule.contents.data.new_terms.history_window_start} should be 'history_window_start'" + + @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"), + "Test only applicable to 8.4+ stacks for new terms feature.") + def test_new_terms_field_exists(self): + # validate new terms and history window start fields are correct + for rule in self.production_rules: + if rule.contents.data.type == "new_terms": + assert rule.contents.data.new_terms.field == "new_terms_fields", \ + f"{rule.contents.data.new_terms.field} should be 'new_terms_fields' for new_terms rule type" + + @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"), + "Test only applicable to 8.4+ stacks for new terms feature.") + def test_new_terms_fields(self): + """Test new terms fields are schema validated.""" + # ecs validation + for rule in self.production_rules: + if rule.contents.data.type == "new_terms": + meta = rule.contents.metadata + feature_min_stack = Version.parse('8.4.0') + current_package_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) + min_stack_version = Version.parse(meta.get("min_stack_version")) if \ + meta.get("min_stack_version") else None + min_stack_version = current_package_version if min_stack_version is None or min_stack_version < \ + current_package_version else min_stack_version + + assert min_stack_version >= feature_min_stack, \ + f"New Terms rule types only compatible with {feature_min_stack}+" + ecs_version = get_stack_schemas()[str(min_stack_version)]['ecs'] + beats_version = get_stack_schemas()[str(min_stack_version)]['beats'] + + # checks if new terms field(s) are in ecs, beats non-ecs or integration schemas + queryvalidator = QueryValidator(rule.contents.data.query) + _, _, schema = queryvalidator.get_beats_schema([], beats_version, ecs_version) + integration_manifests = load_integrations_manifests() + integration_schemas = load_integrations_schemas() + integration_tags = meta.get("integration") + if integration_tags: + for tag in integration_tags: + latest_tag_compat_ver, _ = find_latest_compatible_version( + package=tag, + integration="", + rule_stack_version=min_stack_version, + packages_manifest=integration_manifests) + if latest_tag_compat_ver: + integration_schema = integration_schemas[tag][latest_tag_compat_ver] + for policy_template in integration_schema.keys(): + schema.update(**integration_schemas[tag][latest_tag_compat_ver][policy_template]) + for new_terms_field in rule.contents.data.new_terms.value: + assert new_terms_field in schema.keys(), \ + f"{new_terms_field} not found in ECS, Beats, or non-ecs schemas" + + @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"), + "Test only applicable to 8.4+ stacks for new terms feature.") + def test_new_terms_max_limit(self): + """Test new terms max limit.""" + # validates length of new_terms to stack version - https://github.com/elastic/kibana/issues/142862 + for rule in self.production_rules: + if rule.contents.data.type == "new_terms": + meta = rule.contents.metadata + feature_min_stack = Version.parse('8.4.0') + feature_min_stack_extended_fields = Version.parse('8.6.0') + current_package_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) + min_stack_version = Version.parse(meta.get("min_stack_version")) if \ + meta.get("min_stack_version") else None + min_stack_version = current_package_version if min_stack_version is None or min_stack_version < \ + current_package_version else min_stack_version + if feature_min_stack <= min_stack_version < feature_min_stack_extended_fields: + assert len(rule.contents.data.new_terms.value) == 1, \ + f"new terms have a max limit of 1 for stack versions below {feature_min_stack_extended_fields}" + + @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.6.0"), + "Test only applicable to 8.4+ stacks for new terms feature.") + def test_new_terms_fields_unique(self): + """Test new terms fields are unique.""" + # validate fields are unique + for rule in self.production_rules: + if rule.contents.data.type == "new_terms": + assert len(set(rule.contents.data.new_terms.value)) == len(rule.contents.data.new_terms.value), \ + f"new terms fields values are not unique - {rule.contents.data.new_terms.value}" + + +class TestESQLRules(BaseRuleTest): + """Test ESQL Rules.""" + + def run_esql_test(self, esql_query, expectation, message): + """Test that the query validation is working correctly.""" + rc = RuleCollection() + file_path = Path(get_path("tests", "data", "command_control_dummy_production_rule.toml")) + original_production_rule = load_rule_contents(file_path) + + # Test that a ValidationError is raised if the query doesn't match the schema + production_rule = deepcopy(original_production_rule)[0] + production_rule["rule"]["query"] = esql_query + + expectation.match_expr = message + with expectation: + rc.load_dict(production_rule) + + def test_esql_queries(self): + """Test ESQL queries.""" + # test_cases = [ + # # invalid queries + # ('from .ds-logs-endpoint.events.process-default-* | wheres process.name like "Microsoft*"', + # pytest.raises(marshmallow.exceptions.ValidationError), r"ESQL query failed"), + # ('from .ds-logs-endpoint.events.process-default-* | where process.names like "Microsoft*"', + # pytest.raises(marshmallow.exceptions.ValidationError), r"ESQL query failed"), + # + # # valid queries + # ('from .ds-logs-endpoint.events.process-default-* | where process.name like "Microsoft*"', + # does_not_raise(), None), + # ] + # for esql_query, expectation, message in test_cases: + # self.run_esql_test(esql_query, expectation, message)