diff --git a/.license_ignore b/.license_ignore index 16bf6665..035b5130 100644 --- a/.license_ignore +++ b/.license_ignore @@ -2,6 +2,7 @@ *.g *.ipynb *.md +*.ndjson *.png *.tar.gz *.toml @@ -28,6 +29,7 @@ scripts/test-qaf-stack.sh tests/config.yaml tests/deployment.json +tests/data/.gitignore tests/data/config_elastic-package.yaml tests/data/config_geneve-test-env.yaml tests/data/test-package-1.2.3/* diff --git a/geneve/utils/__init__.py b/geneve/utils/__init__.py index d4a112fc..aa6301dd 100644 --- a/geneve/utils/__init__.py +++ b/geneve/utils/__init__.py @@ -94,17 +94,21 @@ def resource(uri, basedir=None, cachedir=None, cachefile=None, validate=None): kwargs = {} if sys.version_info >= (3, 12) and ".tar" in local_file.suffixes: kwargs = {"filter": "data"} - shutil.unpack_archive(local_file, tmpdir, **kwargs) - if local_file.parent == tmpdir: - local_file.unlink() - inner_entries = tmpdir.glob("*") - new_tmpdir = next(inner_entries) try: - # check if there are other directories or files - _ = next(inner_entries) - except StopIteration: - # lone entry, probably a directory, let's use it as base - tmpdir = new_tmpdir + shutil.unpack_archive(local_file, tmpdir, **kwargs) + except shutil.ReadError: + tmpdir = local_file + else: + if local_file.parent == tmpdir: + local_file.unlink() + inner_entries = tmpdir.glob("*") + new_tmpdir = next(inner_entries) + try: + # check if there are other directories or files + _ = next(inner_entries) + except StopIteration: + # lone entry, probably a directory, let's use it as base + tmpdir = new_tmpdir yield tmpdir diff --git a/scripts/test-stacks.sh b/scripts/test-stacks.sh index 6f043ad1..44874ec4 100755 --- a/scripts/test-stacks.sh +++ b/scripts/test-stacks.sh @@ -255,6 +255,7 @@ while [ $ITERATIONS -lt 0 ] || [ $ITERATION -lt $ITERATIONS ]; do echo TEST_STACK_VERSION: $TEST_STACK_VERSION echo TEST_SCHEMA_URI: $TEST_SCHEMA_URI echo TEST_DETECTION_RULES_URI: $TEST_DETECTION_RULES_URI + echo TEST_CORPUS_URI: $TEST_CORPUS_URI if [ "$ONLINE_TESTS" = "1" ]; then TEST_SIGNALS_QUERIES=1 diff --git a/tests/data/.gitignore b/tests/data/.gitignore new file mode 100644 index 00000000..13df6a73 --- /dev/null +++ b/tests/data/.gitignore @@ -0,0 +1 @@ +*.ndjson diff --git a/tests/test_utils.py b/tests/test_utils.py index 4f1fea8e..f97ef17b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -24,7 +24,7 @@ from geneve.utils import deep_merge, resource, tempdir from geneve.utils.hdict import hdict -from .utils import data_dir, http_server, tempenv +from .utils import data_dir, flat_walk, http_server, tempenv class TestDictUtils(unittest.TestCase): @@ -253,3 +253,30 @@ def test_groups(self): del d["ecs.version"] self.assertEqual([], list(d.groups())) + + +class TestFlatWalk(unittest.TestCase): + def test_flat_walk(self): + doc = { + "0": { + "a": { + "I": None, + }, + "b": None, + }, + "1.a": { + "I": None, + "II": None, + }, + "2.a.I": None, + } + + fields = [ + "0.a.I", + "0.b", + "1.a.I", + "1.a.II", + "2.a.I", + ] + + self.assertEqual(fields, list(flat_walk(doc))) diff --git a/tests/utils.py b/tests/utils.py index 9a23a40b..85be2825 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -33,7 +33,7 @@ from pathlib import Path from geneve.events_emitter import SourceEvents -from geneve.utils import batched, load_schema, random +from geneve.utils import batched, dirs, load_schema, random, resource from . import jupyter @@ -164,6 +164,14 @@ def diff_files(first, second): return out.decode("utf-8") +def flat_walk(doc, path=[]): + for k, v in doc.items(): + if isinstance(v, dict): + yield from flat_walk(v, path + [k]) + else: + yield ".".join(path + [k]) + + def assertIdenticalFiles(tc, first, second): # noqa: N802 with open(first) as f: first_hash = hashlib.sha256(f.read().encode("utf-8")).hexdigest() @@ -318,9 +326,54 @@ class SignalsTestCase: multiplying_factor = int(os.getenv("TEST_SIGNALS_MULTI") or 0) or 1 test_tags = ["Geneve"] + def load_corpus(self): + corpus = None + corpus_fields = set() + + corpus_uri = os.getenv("TEST_CORPUS_URI") + if corpus_uri: + if verbose: + sys.stderr.write("\n Loading corpus: ") + sys.stderr.flush() + + with resource(corpus_uri, cachedir=dirs.cache) as corpus_file: + import mmap + + def reader(*, wrap_around): + with open(corpus_file, "r") as f: + with mmap.mmap(f.fileno(), 0, prot=mmap.PROT_READ) as mm: + while True: + mm.seek(0) + while line := mm.readline(): + yield json.loads(line) + if not wrap_around: + break + + count = 0 + fields = set() + for doc in reader(wrap_around=False): + fields |= set(flat_walk(doc)) + if count % 100000 == 0: + sys.stderr.write(f"{count} ") + sys.stderr.flush() + count += 1 + + corpus = reader(wrap_around=True) + corpus_fields = fields + + if verbose: + if count % 100000: + sys.stderr.write(f"{count} ") + sys.stderr.write(f"docs, {len(fields)} fields") + sys.stderr.flush() + + return corpus, corpus_fields + def generate_docs_and_mappings(self, rules, asts): + corpus, corpus_fields = self.load_corpus() + schema = load_test_schema() - se = SourceEvents(schema) + se = SourceEvents(schema, corpus=corpus) se.stack_version = self.get_version() if verbose and verbose <= 2: @@ -362,7 +415,7 @@ def generate_docs_and_mappings(self, rules, asts): sys.stderr.write(f"{ok_rules}/{len(bulk)} ") sys.stderr.flush() - return (bulk, se.mappings()) + return (bulk, se.mappings(extra_fields=corpus_fields)) def load_rules_and_docs(self, rules, asts, *, docs_chunk_size=200, rules_chunk_size=50): docs, mappings = self.generate_docs_and_mappings(rules, asts)