From a0c246b8c1300d5302b1797ecfae707640751344 Mon Sep 17 00:00:00 2001 From: qianheng Date: Tue, 12 Nov 2024 03:51:12 +0800 Subject: [PATCH 1/4] Add sanity test script (#878) * Add sanity test script Signed-off-by: Heng Qian * Add header Signed-off-by: Heng Qian * Minor fix Signed-off-by: Heng Qian * Minor fix Signed-off-by: Heng Qian * Minor fix Signed-off-by: Heng Qian * Support check expected_status if have that column in input file. Signed-off-by: Heng Qian * Add README.md Signed-off-by: Heng Qian * Minor fix Signed-off-by: Heng Qian * Support set log-level Signed-off-by: Heng Qian --------- Signed-off-by: Heng Qian --- integ-test/script/README.md | 158 +++++++++ integ-test/script/SanityTest.py | 291 ++++++++++++++++ integ-test/script/test_cases.csv | 567 +++++++++++++++++++++++++++++++ 3 files changed, 1016 insertions(+) create mode 100644 integ-test/script/README.md create mode 100644 integ-test/script/SanityTest.py create mode 100644 integ-test/script/test_cases.csv diff --git a/integ-test/script/README.md b/integ-test/script/README.md new file mode 100644 index 000000000..79b188158 --- /dev/null +++ b/integ-test/script/README.md @@ -0,0 +1,158 @@ +# Sanity Test Script + +### Description +This Python script executes test queries from a CSV file using an asynchronous query API and generates comprehensive test reports. + +The script produces two report types: +1. An Excel report with detailed test information for each query +2. A JSON report containing both test result overview and query-specific details + +Apart from the basic feature, it also has some advanced functionality includes: +1. Concurrent query execution (note: the async query service has session limits, so use thread workers moderately despite it already supports session ID reuse) +2. Configurable query timeout with periodic status checks and automatic cancellation if timeout occurs. +3. Flexible row selection from the input CSV file, by specifying start row and end row of the input CSV file. +4. Expected status validation when expected_status is present in the CSV +5. Ability to generate partial reports if testing is interrupted + +### Usage +To use this script, you need to have Python **3.6** or higher installed. It also requires the following Python libraries: +```shell +pip install requests pandas +``` + +After getting the requisite libraries, you can run the script with the following command line parameters in your shell: +```shell +python SanityTest.py --base-url ${URL_ADDRESS} --username *** --password *** --datasource ${DATASOURCE_NAME} --input-csv test_cases.csv --output-file test_report --max-workers 2 --check-interval 10 --timeout 600 +``` +You need to replace the placeholders with your actual values of URL_ADDRESS, DATASOURCE_NAME and USERNAME, PASSWORD for authentication to your endpoint. + +For more details of the command line parameters, you can see the help manual via command: +```shell +python SanityTest.py --help + +usage: SanityTest.py [-h] --base-url BASE_URL --username USERNAME --password PASSWORD --datasource DATASOURCE --input-csv INPUT_CSV + --output-file OUTPUT_FILE [--max-workers MAX_WORKERS] [--check-interval CHECK_INTERVAL] [--timeout TIMEOUT] + [--start-row START_ROW] [--end-row END_ROW] + +Run tests from a CSV file and generate a report. + +options: + -h, --help show this help message and exit + --base-url BASE_URL Base URL of the service + --username USERNAME Username for authentication + --password PASSWORD Password for authentication + --datasource DATASOURCE + Datasource name + --input-csv INPUT_CSV + Path to the CSV file containing test queries + --output-file OUTPUT_FILE + Path to the output report file + --max-workers MAX_WORKERS + optional, Maximum number of worker threads (default: 2) + --check-interval CHECK_INTERVAL + optional, Check interval in seconds (default: 10) + --timeout TIMEOUT optional, Timeout in seconds (default: 600) + --start-row START_ROW + optional, The start row of the query to run, start from 1 + --end-row END_ROW optional, The end row of the query to run, not included + --log-level LOG_LEVEL + optional, Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL, default: INFO) +``` + +### Input CSV File +As claimed in the description, the input CSV file should at least have the column of `query` to run the tests. It also supports an optional column of `expected_status`, the script will check the actual status against the expected status and generate a new column of `check_status` for the check result -- TRUE means the status check passed; FALSE means the status check failed. + +We also provide a sample input CSV file `test_cases.csv` for reference. It includes all sanity test cases we have currently in the Flint. + +**TODO**: the prerequisite data of the test cases and ingesting process + +### Report Explanation +The generated report contains two files: + +#### Excel Report +The Excel report provides the test result details of each query, including the query name(i.e. sequence number in the input csv file currently), query itself, expected status, actual status, and whether the status satisfy the expected status or not. + +It provides an error message if the query execution failed, otherwise it provides the query execution result with empty error. + +It also provides the query_id, session_id and start/end time for each query, which can be used to debug the query execution in the Flint. + +An example of Excel report: + +| query_name | query | expected_status | status | check_status | error | result | Duration (s) | query_id | session_id | Start Time | End Time | +|------------|------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------|---------|--------------|------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------|-------------------------------|------------------------------|----------------------|---------------------| +| 1 | describe myglue_test.default.http_logs | SUCCESS | SUCCESS | TRUE | | {'status': 'SUCCESS', 'schema': [{...}, ...], 'datarows': [[...], ...], 'total': 31, 'size': 31} | 37.51 | SHFEVWxDNnZjem15Z2x1ZV90ZXN0 | RkgzZm0xNlA5MG15Z2x1ZV90ZXN0 | 2024-11-07 13:34:10 | 2024-11-07 13:34:47 | +| 2 | source = myglue_test.default.http_logs \| dedup status CONSECUTIVE=true | SUCCESS | FAILED | FALSE | {"Message":"Fail to run query. Cause: Consecutive deduplication is not supported"} | | 39.53 | dVNlaVVxOFZrZW15Z2x1ZV90ZXN0 | ZGU2MllVYmI4dG15Z2x1ZV90ZXN0 | 2024-11-07 13:34:10 | 2024-11-07 13:34:49 | +| 3 | source = myglue_test.default.http_logs \| eval res = json_keys(json('{"account_number":1,"balance":39225,"age":32,"gender":"M"}')) \| head 1 \| fields res | SUCCESS | SUCCESS | TRUE | | {'status': 'SUCCESS', 'schema': [{'name': 'res', 'type': 'array'}], 'datarows': [[['account_number', 'balance', 'age', 'gender']]], 'total': 1, 'size': 1} | 12.77 | WHQxaXlVSGtGUm15Z2x1ZV90ZXN0 | RkgzZm0xNlA5MG15Z2x1ZV90ZXN0 | 2024-11-07 13:34:47 | 2024-11-07 13:38:45 | +| ... | ... | ... | ... | ... | | | ... | ... | ... | ... | ... | + + +#### JSON Report +The JSON report provides the same information as the Excel report, but in JSON format.Additionally, it includes a statistical summary of the test results at the beginning of the report. + +An example of JSON report: +```json +{ + "summary": { + "total_queries": 115, + "successful_queries": 110, + "failed_queries": 3, + "submit_failed_queries": 0, + "timeout_queries": 2, + "execution_time": 16793.223807 + }, + "detailed_results": [ + { + "query_name": 1, + "query": "source = myglue_test.default.http_logs | stats avg(size)", + "query_id": "eFZmTlpTa3EyTW15Z2x1ZV90ZXN0", + "session_id": "bFJDMWxzb2NVUm15Z2x1ZV90ZXN0", + "status": "SUCCESS", + "error": "", + "result": { + "status": "SUCCESS", + "schema": [ + { + "name": "avg(size)", + "type": "double" + } + ], + "datarows": [ + [ + 4654.305710913499 + ] + ], + "total": 1, + "size": 1 + }, + "duration": 170.621145, + "start_time": "2024-11-07 14:56:13.869226", + "end_time": "2024-11-07 14:59:04.490371" + }, + { + "query_name": 2, + "query": "source = myglue_test.default.http_logs | eval res = json_keys(json(\u2018{\"teacher\":\"Alice\",\"student\":[{\"name\":\"Bob\",\"rank\":1},{\"name\":\"Charlie\",\"rank\":2}]}')) | head 1 | fields res", + "query_id": "bjF4Y1VnbXdFYm15Z2x1ZV90ZXN0", + "session_id": "c3pvU1V6OW8xM215Z2x1ZV90ZXN0", + "status": "FAILED", + "error": "{\"Message\":\"Syntax error: \\n[PARSE_SYNTAX_ERROR] Syntax error at or near 'source'.(line 1, pos 0)\\n\\n== SQL ==\\nsource = myglue_test.default.http_logs | eval res = json_keys(json(\u2018{\\\"teacher\\\":\\\"Alice\\\",\\\"student\\\":[{\\\"name\\\":\\\"Bob\\\",\\\"rank\\\":1},{\\\"name\\\":\\\"Charlie\\\",\\\"rank\\\":2}]}')) | head 1 | fields res\\n^^^\\n\"}", + "result": null, + "duration": 14.051738, + "start_time": "2024-11-07 14:59:18.699335", + "end_time": "2024-11-07 14:59:32.751073" + }, + { + "query_name": 2, + "query": "source = myglue_test.default.http_logs | eval col1 = size, col2 = clientip | stats avg(col1) by col2", + "query_id": "azVyMFFORnBFRW15Z2x1ZV90ZXN0", + "session_id": "VWF0SEtrNWM3bm15Z2x1ZV90ZXN0", + "status": "TIMEOUT", + "error": "Query execution exceeded 600 seconds with last status: running", + "result": null, + "duration": 673.710946, + "start_time": "2024-11-07 14:45:00.157875", + "end_time": "2024-11-07 14:56:13.868821" + }, + ... + ] +} +``` diff --git a/integ-test/script/SanityTest.py b/integ-test/script/SanityTest.py new file mode 100644 index 000000000..1c51d4d20 --- /dev/null +++ b/integ-test/script/SanityTest.py @@ -0,0 +1,291 @@ +""" +Copyright OpenSearch Contributors +SPDX-License-Identifier: Apache-2.0 +""" + +import signal +import sys +import requests +import json +import csv +import time +import logging +from datetime import datetime +import pandas as pd +import argparse +from requests.auth import HTTPBasicAuth +from concurrent.futures import ThreadPoolExecutor, as_completed +import threading + +""" +Environment: python3 + +Example to use this script: + +python SanityTest.py --base-url ${URL_ADDRESS} --username *** --password *** --datasource ${DATASOURCE_NAME} --input-csv test_queries.csv --output-file test_report --max-workers 2 --check-interval 10 --timeout 600 + +The input file test_queries.csv should contain column: `query` + +For more details, please use command: + +python SanityTest.py --help + +""" + +class FlintTester: + def __init__(self, base_url, username, password, datasource, max_workers, check_interval, timeout, output_file, start_row, end_row, log_level): + self.base_url = base_url + self.auth = HTTPBasicAuth(username, password) + self.datasource = datasource + self.headers = { 'Content-Type': 'application/json' } + self.max_workers = max_workers + self.check_interval = check_interval + self.timeout = timeout + self.output_file = output_file + self.start = start_row - 1 if start_row else None + self.end = end_row - 1 if end_row else None + self.log_level = log_level + self.max_attempts = (int)(timeout / check_interval) + self.logger = self._setup_logger() + self.executor = ThreadPoolExecutor(max_workers=self.max_workers) + self.thread_local = threading.local() + self.test_results = [] + + def _setup_logger(self): + logger = logging.getLogger('FlintTester') + logger.setLevel(self.log_level) + + fh = logging.FileHandler('flint_test.log') + fh.setLevel(self.log_level) + + ch = logging.StreamHandler() + ch.setLevel(self.log_level) + + formatter = logging.Formatter( + '%(asctime)s - %(threadName)s - %(levelname)s - %(message)s' + ) + fh.setFormatter(formatter) + ch.setFormatter(formatter) + + logger.addHandler(fh) + logger.addHandler(ch) + + return logger + + + def get_session_id(self): + if not hasattr(self.thread_local, 'session_id'): + self.thread_local.session_id = "empty_session_id" + self.logger.debug(f"get session id {self.thread_local.session_id}") + return self.thread_local.session_id + + def set_session_id(self, session_id): + """Reuse the session id for the same thread""" + self.logger.debug(f"set session id {session_id}") + self.thread_local.session_id = session_id + + # Call submit API to submit the query + def submit_query(self, query, session_id="Empty"): + url = f"{self.base_url}/_plugins/_async_query" + payload = { + "datasource": self.datasource, + "lang": "ppl", + "query": query, + "sessionId": session_id + } + self.logger.debug(f"Submit query with payload: {payload}") + response_json = None + try: + response = requests.post(url, auth=self.auth, json=payload, headers=self.headers) + response_json = response.json() + response.raise_for_status() + return response_json + except Exception as e: + return {"error": str(e), "response": response_json} + + # Call get API to check the query status + def get_query_result(self, query_id): + url = f"{self.base_url}/_plugins/_async_query/{query_id}" + response_json = None + try: + response = requests.get(url, auth=self.auth) + response_json = response.json() + response.raise_for_status() + return response_json + except Exception as e: + return {"status": "FAILED", "error": str(e), "response": response_json} + + # Call delete API to cancel the query + def cancel_query(self, query_id): + url = f"{self.base_url}/_plugins/_async_query/{query_id}" + response_json = None + try: + response = requests.delete(url, auth=self.auth) + response_json = response.json() + response.raise_for_status() + self.logger.info(f"Cancelled query [{query_id}] with info {response.json()}") + return response_json + except Exception as e: + self.logger.warning(f"Cancel query [{query_id}] error: {str(e)}, got response {response_json}") + + # Run the test and return the result + def run_test(self, query, seq_id, expected_status): + self.logger.info(f"Starting test: {seq_id}, {query}") + start_time = datetime.now() + pre_session_id = self.get_session_id() + submit_result = self.submit_query(query, pre_session_id) + if "error" in submit_result: + self.logger.warning(f"Submit error: {submit_result}") + return { + "query_name": seq_id, + "query": query, + "expected_status": expected_status, + "status": "SUBMIT_FAILED", + "check_status": "SUBMIT_FAILED" == expected_status if expected_status else None, + "error": submit_result["error"], + "duration": 0, + "start_time": start_time, + "end_time": datetime.now() + } + + query_id = submit_result["queryId"] + session_id = submit_result["sessionId"] + self.logger.info(f"Submit return: {submit_result}") + if (session_id != pre_session_id): + self.logger.info(f"Update session id from {pre_session_id} to {session_id}") + self.set_session_id(session_id) + + test_result = self.check_query_status(query_id) + end_time = datetime.now() + duration = (end_time - start_time).total_seconds() + + return { + "query_name": seq_id, + "query": query, + "query_id": query_id, + "session_id": session_id, + "expected_status": expected_status, + "status": test_result["status"], + "check_status": test_result["status"] == expected_status if expected_status else None, + "error": test_result.get("error", ""), + "result": test_result if test_result["status"] == "SUCCESS" else None, + "duration": duration, + "start_time": start_time, + "end_time": end_time + } + + # Check the status of the query periodically until it is completed or failed or exceeded the timeout + def check_query_status(self, query_id): + query_id = query_id + + for attempt in range(self.max_attempts): + time.sleep(self.check_interval) + result = self.get_query_result(query_id) + + if result["status"] == "FAILED" or result["status"] == "SUCCESS": + return result + + # Cancel the query if it exceeds the timeout + self.cancel_query(query_id) + return { + "status": "TIMEOUT", + "error": "Query execution exceeded " + str(self.timeout) + " seconds with last status: " + result["status"], + } + + def run_tests_from_csv(self, csv_file): + with open(csv_file, 'r') as f: + reader = csv.DictReader(f) + queries = [(row['query'], i, row.get('expected_status', None)) for i, row in enumerate(reader, start=1) if row['query'].strip()] + + # Filtering queries based on start and end + queries = queries[self.start:self.end] + + # Parallel execution + futures = [self.executor.submit(self.run_test, query, seq_id, expected_status) for query, seq_id, expected_status in queries] + for future in as_completed(futures): + result = future.result() + self.test_results.append(result) + + def generate_report(self): + self.logger.info("Generating report...") + total_queries = len(self.test_results) + successful_queries = sum(1 for r in self.test_results if r['status'] == 'SUCCESS') + failed_queries = sum(1 for r in self.test_results if r['status'] == 'FAILED') + submit_failed_queries = sum(1 for r in self.test_results if r['status'] == 'SUBMIT_FAILED') + timeout_queries = sum(1 for r in self.test_results if r['status'] == 'TIMEOUT') + + # Create report + report = { + "summary": { + "total_queries": total_queries, + "successful_queries": successful_queries, + "failed_queries": failed_queries, + "submit_failed_queries": submit_failed_queries, + "timeout_queries": timeout_queries, + "execution_time": sum(r['duration'] for r in self.test_results) + }, + "detailed_results": self.test_results + } + + # Save report to JSON file + with open(f"{self.output_file}.json", 'w') as f: + json.dump(report, f, indent=2, default=str) + + # Save reults to Excel file + df = pd.DataFrame(self.test_results) + df.to_excel(f"{self.output_file}.xlsx", index=False) + + self.logger.info(f"Generated report in {self.output_file}.xlsx and {self.output_file}.json") + +def signal_handler(sig, frame, tester): + print(f"Signal {sig} received, generating report...") + try: + tester.executor.shutdown(wait=False, cancel_futures=True) + tester.generate_report() + finally: + sys.exit(0) + +def main(): + # Parse command line arguments + parser = argparse.ArgumentParser(description="Run tests from a CSV file and generate a report.") + parser.add_argument("--base-url", required=True, help="Base URL of the service") + parser.add_argument("--username", required=True, help="Username for authentication") + parser.add_argument("--password", required=True, help="Password for authentication") + parser.add_argument("--datasource", required=True, help="Datasource name") + parser.add_argument("--input-csv", required=True, help="Path to the CSV file containing test queries") + parser.add_argument("--output-file", required=True, help="Path to the output report file") + parser.add_argument("--max-workers", type=int, default=2, help="optional, Maximum number of worker threads (default: 2)") + parser.add_argument("--check-interval", type=int, default=5, help="optional, Check interval in seconds (default: 5)") + parser.add_argument("--timeout", type=int, default=600, help="optional, Timeout in seconds (default: 600)") + parser.add_argument("--start-row", type=int, default=None, help="optional, The start row of the query to run, start from 1") + parser.add_argument("--end-row", type=int, default=None, help="optional, The end row of the query to run, not included") + parser.add_argument("--log-level", default="INFO", help="optional, Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL, default: INFO)") + + args = parser.parse_args() + + tester = FlintTester( + base_url=args.base_url, + username=args.username, + password=args.password, + datasource=args.datasource, + max_workers=args.max_workers, + check_interval=args.check_interval, + timeout=args.timeout, + output_file=args.output_file, + start_row=args.start_row, + end_row=args.end_row, + log_level=args.log_level, + ) + + # Register signal handlers to generate report on interrupt + signal.signal(signal.SIGINT, lambda sig, frame: signal_handler(sig, frame, tester)) + signal.signal(signal.SIGTERM, lambda sig, frame: signal_handler(sig, frame, tester)) + + # Running tests + tester.run_tests_from_csv(args.input_csv) + + # Gnerate report + tester.generate_report() + +if __name__ == "__main__": + main() diff --git a/integ-test/script/test_cases.csv b/integ-test/script/test_cases.csv new file mode 100644 index 000000000..7df05f5a3 --- /dev/null +++ b/integ-test/script/test_cases.csv @@ -0,0 +1,567 @@ +query,expected_status +describe myglue_test.default.http_logs,FAILED +describe `myglue_test`.`default`.`http_logs`,FAILED +"source = myglue_test.default.http_logs | dedup 1 status | fields @timestamp, clientip, status, size | head 10",SUCCESS +"source = myglue_test.default.http_logs | dedup status, size | head 10",SUCCESS +source = myglue_test.default.http_logs | dedup 1 status keepempty=true | head 10,SUCCESS +"source = myglue_test.default.http_logs | dedup status, size keepempty=true | head 10",SUCCESS +source = myglue_test.default.http_logs | dedup 2 status | head 10,SUCCESS +"source = myglue_test.default.http_logs | dedup 2 status, size | head 10",SUCCESS +"source = myglue_test.default.http_logs | dedup 2 status, size keepempty=true | head 10",SUCCESS +source = myglue_test.default.http_logs | dedup status CONSECUTIVE=true | fields status,FAILED +"source = myglue_test.default.http_logs | dedup 2 status, size CONSECUTIVE=true | fields status",FAILED +"source = myglue_test.default.http_logs | sort stat | fields @timestamp, clientip, status | head 10",SUCCESS +"source = myglue_test.default.http_logs | fields @timestamp, notexisted | head 10",FAILED +"source = myglue_test.default.nested | fields int_col, struct_col.field1, struct_col2.field1 | head 10",FAILED +"source = myglue_test.default.nested | where struct_col2.field1.subfield > 'valueA' | sort int_col | fields int_col, struct_col.field1.subfield, struct_col2.field1.subfield",FAILED +"source = myglue_test.default.http_logs | fields - @timestamp, clientip, status | head 10",SUCCESS +"source = myglue_test.default.http_logs | eval new_time = @timestamp, new_clientip = clientip | fields - new_time, new_clientip, status | head 10",SUCCESS +source = myglue_test.default.http_logs | eval new_clientip = lower(clientip) | fields - new_clientip | head 10,SUCCESS +"source = myglue_test.default.http_logs | fields + @timestamp, clientip, status | fields - clientip, status | head 10",SUCCESS +"source = myglue_test.default.http_logs | fields - clientip, status | fields + @timestamp, clientip, status| head 10",SUCCESS +source = myglue_test.default.http_logs | where status = 200 | head 10,SUCCESS +source = myglue_test.default.http_logs | where status != 200 | head 10,SUCCESS +source = myglue_test.default.http_logs | where size > 0 | head 10,SUCCESS +source = myglue_test.default.http_logs | where size <= 0 | head 10,SUCCESS +source = myglue_test.default.http_logs | where clientip = '236.14.2.0' | head 10,SUCCESS +source = myglue_test.default.http_logs | where size > 0 AND status = 200 OR clientip = '236.14.2.0' | head 100,SUCCESS +"source = myglue_test.default.http_logs | where size <= 0 AND like(request, 'GET%') | head 10",SUCCESS +source = myglue_test.default.http_logs status = 200 | head 10,SUCCESS +source = myglue_test.default.http_logs size > 0 AND status = 200 OR clientip = '236.14.2.0' | head 100,SUCCESS +"source = myglue_test.default.http_logs size <= 0 AND like(request, 'GET%') | head 10",SUCCESS +"source = myglue_test.default.http_logs substring(clientip, 5, 2) = ""12"" | head 10",SUCCESS +source = myglue_test.default.http_logs | where isempty(size),FAILED +source = myglue_test.default.http_logs | where ispresent(size),FAILED +source = myglue_test.default.http_logs | where isnull(size) | head 10,SUCCESS +source = myglue_test.default.http_logs | where isnotnull(size) | head 10,SUCCESS +"source = myglue_test.default.http_logs | where isnotnull(coalesce(size, status)) | head 10",FAILED +"source = myglue_test.default.http_logs | where like(request, 'GET%') | head 10",SUCCESS +"source = myglue_test.default.http_logs | where like(request, '%bordeaux%') | head 10",SUCCESS +"source = myglue_test.default.http_logs | where substring(clientip, 5, 2) = ""12"" | head 10",SUCCESS +"source = myglue_test.default.http_logs | where lower(request) = ""get /images/backnews.gif http/1.0"" | head 10",SUCCESS +source = myglue_test.default.http_logs | where length(request) = 38 | head 10,SUCCESS +"source = myglue_test.default.http_logs | where case(status = 200, 'success' else 'failed') = 'success' | head 10",FAILED +"source = myglue_test.default.http_logs | eval h = ""Hello"", w = ""World"" | head 10",SUCCESS +"source = myglue_test.default.http_logs | eval @h = ""Hello"" | eval @w = ""World"" | fields @timestamp, @h, @w",SUCCESS +source = myglue_test.default.http_logs | eval newF = clientip | head 10,SUCCESS +"source = myglue_test.default.http_logs | eval newF = clientip | fields clientip, newF | head 10",SUCCESS +"source = myglue_test.default.http_logs | eval f = size | where f > 1 | sort f | fields size, clientip, status | head 10",SUCCESS +"source = myglue_test.default.http_logs | eval f = status * 2 | eval h = f * 2 | fields status, f, h | head 10",SUCCESS +"source = myglue_test.default.http_logs | eval f = size * 2, h = status | stats sum(f) by h",SUCCESS +"source = myglue_test.default.http_logs | eval f = UPPER(request) | eval h = 40 | fields f, h | head 10",SUCCESS +"source = myglue_test.default.http_logs | eval request = ""test"" | fields request | head 10",FAILED +source = myglue_test.default.http_logs | eval size = abs(size) | where size < 500,FAILED +"source = myglue_test.default.http_logs | eval status_string = case(status = 200, 'success' else 'failed') | head 10",FAILED +"source = myglue_test.default.http_logs | eval n = now() | eval t = unix_timestamp(@timestamp) | fields n, t | head 10",SUCCESS +source = myglue_test.default.http_logs | eval e = isempty(size) | eval p = ispresent(size) | head 10,FAILED +"source = myglue_test.default.http_logs | eval c = coalesce(size, status) | head 10",FAILED +source = myglue_test.default.http_logs | eval c = coalesce(request) | head 10,FAILED +source = myglue_test.default.http_logs | eval col1 = ln(size) | eval col2 = unix_timestamp(@timestamp) | sort - col1 | head 10,SUCCESS +"source = myglue_test.default.http_logs | eval col1 = 1 | sort col1 | head 4 | eval col2 = 2 | sort - col2 | sort - size | head 2 | fields @timestamp, clientip, col2",SUCCESS +"source = myglue_test.default.mini_http_logs | eval stat = status | where stat > 300 | sort stat | fields @timestamp,clientip,status | head 5",SUCCESS +"source = myglue_test.default.http_logs | eval col1 = size, col2 = clientip | stats avg(col1) by col2",SUCCESS +source = myglue_test.default.http_logs | stats avg(size) by clientip,SUCCESS +"source = myglue_test.default.http_logs | eval new_request = upper(request) | eval compound_field = concat('Hello ', if(like(new_request, '%bordeaux%'), 'World', clientip)) | fields new_request, compound_field | head 10",SUCCESS +source = myglue_test.default.http_logs | stats avg(size),SUCCESS +source = myglue_test.default.nested | stats max(int_col) by struct_col.field2,SUCCESS +source = myglue_test.default.nested | stats distinct_count(int_col),SUCCESS +source = myglue_test.default.nested | stats stddev_samp(int_col),SUCCESS +source = myglue_test.default.nested | stats stddev_pop(int_col),SUCCESS +source = myglue_test.default.nested | stats percentile(int_col),SUCCESS +source = myglue_test.default.nested | stats percentile_approx(int_col),SUCCESS +source = myglue_test.default.mini_http_logs | stats stddev_samp(status),SUCCESS +"source = myglue_test.default.mini_http_logs | where stats > 200 | stats percentile_approx(status, 99)",SUCCESS +"source = myglue_test.default.nested | stats count(int_col) by span(struct_col.field2, 10) as a_span",SUCCESS +"source = myglue_test.default.nested | stats avg(int_col) by span(struct_col.field2, 10) as a_span, struct_col2.field2",SUCCESS +"source = myglue_test.default.http_logs | stats sum(size) by span(@timestamp, 1d) as age_size_per_day | sort - age_size_per_day | head 10",SUCCESS +"source = myglue_test.default.http_logs | stats distinct_count(clientip) by span(@timestamp, 1d) as age_size_per_day | sort - age_size_per_day | head 10",SUCCESS +"source = myglue_test.default.http_logs | stats avg(size) as avg_size by status, year | stats avg(avg_size) as avg_avg_size by year",SUCCESS +"source = myglue_test.default.http_logs | stats avg(size) as avg_size by status, year, month | stats avg(avg_size) as avg_avg_size by year, month | stats avg(avg_avg_size) as avg_avg_avg_size by year",SUCCESS +"source = myglue_test.default.nested | stats avg(int_col) as avg_int by struct_col.field2, struct_col2.field2 | stats avg(avg_int) as avg_avg_int by struct_col2.field2",FAILED +"source = myglue_test.default.nested | stats avg(int_col) as avg_int by struct_col.field2, struct_col2.field2 | eval new_col = avg_int | stats avg(avg_int) as avg_avg_int by new_col",SUCCESS +source = myglue_test.default.nested | rare int_col,SUCCESS +source = myglue_test.default.nested | rare int_col by struct_col.field2,SUCCESS +source = myglue_test.default.http_logs | rare request,SUCCESS +source = myglue_test.default.http_logs | where status > 300 | rare request by status,SUCCESS +source = myglue_test.default.http_logs | rare clientip,SUCCESS +source = myglue_test.default.http_logs | where status > 300 | rare clientip,SUCCESS +source = myglue_test.default.http_logs | where status > 300 | rare clientip by day,SUCCESS +source = myglue_test.default.nested | top int_col by struct_col.field2,SUCCESS +source = myglue_test.default.nested | top 1 int_col by struct_col.field2,SUCCESS +source = myglue_test.default.nested | top 2 int_col by struct_col.field2,SUCCESS +source = myglue_test.default.nested | top int_col,SUCCESS +source = myglue_test.default.http_logs | inner join left=l right=r on l.status = r.int_col myglue_test.default.nested | head 10,FAILED +"source = myglue_test.default.http_logs | parse request 'GET /(?[a-zA-Z]+)/.*' | fields request, domain | head 10",SUCCESS +source = myglue_test.default.http_logs | parse request 'GET /(?[a-zA-Z]+)/.*' | top 1 domain,SUCCESS +source = myglue_test.default.http_logs | parse request 'GET /(?[a-zA-Z]+)/.*' | stats count() by domain,SUCCESS +"source = myglue_test.default.http_logs | parse request 'GET /(?[a-zA-Z]+)/.*' | eval a = 1 | fields a, domain | head 10",SUCCESS +"source = myglue_test.default.http_logs | parse request 'GET /(?[a-zA-Z]+)/.*' | where size > 0 | sort - size | fields size, domain | head 10",SUCCESS +"source = myglue_test.default.http_logs | parse request 'GET /(?[a-zA-Z]+)/(?[a-zA-Z]+)/.*' | where domain = 'english' | sort - picName | fields domain, picName | head 10",SUCCESS +source = myglue_test.default.http_logs | patterns request | fields patterns_field | head 10,SUCCESS +source = myglue_test.default.http_logs | patterns request | where size > 0 | fields patterns_field | head 10,SUCCESS +"source = myglue_test.default.http_logs | patterns new_field='no_letter' pattern='[a-zA-Z]' request | fields request, no_letter | head 10",SUCCESS +source = myglue_test.default.http_logs | patterns new_field='no_letter' pattern='[a-zA-Z]' request | stats count() by no_letter,SUCCESS +"source = myglue_test.default.http_logs | patterns new_field='status' pattern='[a-zA-Z]' request | fields request, status | head 10",FAILED +source = myglue_test.default.http_logs | rename @timestamp as timestamp | head 10,FAILED +source = myglue_test.default.http_logs | sort size | head 10,SUCCESS +source = myglue_test.default.http_logs | sort + size | head 10,SUCCESS +source = myglue_test.default.http_logs | sort - size | head 10,SUCCESS +"source = myglue_test.default.http_logs | sort + size, + @timestamp | head 10",SUCCESS +"source = myglue_test.default.http_logs | sort - size, - @timestamp | head 10",SUCCESS +"source = myglue_test.default.http_logs | sort - size, @timestamp | head 10",SUCCESS +"source = myglue_test.default.http_logs | eval c1 = upper(request) | eval c2 = concat('Hello ', if(like(c1, '%bordeaux%'), 'World', clientip)) | eval c3 = length(request) | eval c4 = ltrim(request) | eval c5 = rtrim(request) | eval c6 = substring(clientip, 5, 2) | eval c7 = trim(request) | eval c8 = upper(request) | eval c9 = position('bordeaux' IN request) | eval c10 = replace(request, 'GET', 'GGG') | fields c1, c2, c3, c4, c5, c6, c7, c8, c9, c10 | head 10",SUCCESS +"source = myglue_test.default.http_logs | eval c1 = unix_timestamp(@timestamp) | eval c2 = now() | eval c3 = +DAY_OF_WEEK(@timestamp) | eval c4 = +DAY_OF_MONTH(@timestamp) | eval c5 = +DAY_OF_YEAR(@timestamp) | eval c6 = +WEEK_OF_YEAR(@timestamp) | eval c7 = +WEEK(@timestamp) | eval c8 = +MONTH_OF_YEAR(@timestamp) | eval c9 = +HOUR_OF_DAY(@timestamp) | eval c10 = +MINUTE_OF_HOUR(@timestamp) | eval c11 = +SECOND_OF_MINUTE(@timestamp) | eval c12 = +LOCALTIME() | fields c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12 | head 10",SUCCESS +"source=myglue_test.default.people | eval c1 = adddate(@timestamp, 1) | fields c1 | head 10",SUCCESS +"source=myglue_test.default.people | eval c2 = subdate(@timestamp, 1) | fields c2 | head 10",SUCCESS +source=myglue_test.default.people | eval c1 = date_add(@timestamp INTERVAL 1 DAY) | fields c1 | head 10,SUCCESS +source=myglue_test.default.people | eval c1 = date_sub(@timestamp INTERVAL 1 DAY) | fields c1 | head 10,SUCCESS +source=myglue_test.default.people | eval `CURDATE()` = CURDATE() | fields `CURDATE()`,SUCCESS +source=myglue_test.default.people | eval `CURRENT_DATE()` = CURRENT_DATE() | fields `CURRENT_DATE()`,SUCCESS +source=myglue_test.default.people | eval `CURRENT_TIMESTAMP()` = CURRENT_TIMESTAMP() | fields `CURRENT_TIMESTAMP()`,SUCCESS +source=myglue_test.default.people | eval `DATE('2020-08-26')` = DATE('2020-08-26') | fields `DATE('2020-08-26')`,SUCCESS +source=myglue_test.default.people | eval `DATE(TIMESTAMP('2020-08-26 13:49:00'))` = DATE(TIMESTAMP('2020-08-26 13:49:00')) | fields `DATE(TIMESTAMP('2020-08-26 13:49:00'))`,SUCCESS +source=myglue_test.default.people | eval `DATE('2020-08-26 13:49')` = DATE('2020-08-26 13:49') | fields `DATE('2020-08-26 13:49')`,SUCCESS +"source=myglue_test.default.people | eval `DATE_FORMAT('1998-01-31 13:14:15.012345', 'HH:mm:ss.SSSSSS')` = DATE_FORMAT('1998-01-31 13:14:15.012345', 'HH:mm:ss.SSSSSS'), `DATE_FORMAT(TIMESTAMP('1998-01-31 13:14:15.012345'), 'yyyy-MMM-dd hh:mm:ss a')` = DATE_FORMAT(TIMESTAMP('1998-01-31 13:14:15.012345'), 'yyyy-MMM-dd hh:mm:ss a') | fields `DATE_FORMAT('1998-01-31 13:14:15.012345', 'HH:mm:ss.SSSSSS')`, `DATE_FORMAT(TIMESTAMP('1998-01-31 13:14:15.012345'), 'yyyy-MMM-dd hh:mm:ss a')`",SUCCESS +"source=myglue_test.default.people | eval `'2000-01-02' - '2000-01-01'` = DATEDIFF(TIMESTAMP('2000-01-02 00:00:00'), TIMESTAMP('2000-01-01 23:59:59')), `'2001-02-01' - '2004-01-01'` = DATEDIFF(DATE('2001-02-01'), TIMESTAMP('2004-01-01 00:00:00')) | fields `'2000-01-02' - '2000-01-01'`, `'2001-02-01' - '2004-01-01'`", +source=myglue_test.default.people | eval `DAY(DATE('2020-08-26'))` = DAY(DATE('2020-08-26')) | fields `DAY(DATE('2020-08-26'))`, +source=myglue_test.default.people | eval `DAYNAME(DATE('2020-08-26'))` = DAYNAME(DATE('2020-08-26')) | fields `DAYNAME(DATE('2020-08-26'))`,FAILED +source=myglue_test.default.people | eval `CURRENT_TIMEZONE()` = CURRENT_TIMEZONE() | fields `CURRENT_TIMEZONE()`,SUCCESS +source=myglue_test.default.people | eval `UTC_TIMESTAMP()` = UTC_TIMESTAMP() | fields `UTC_TIMESTAMP()`,SUCCESS +"source=myglue_test.default.people | eval `TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00')` = TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00') | eval `TIMESTAMPDIFF(SECOND, timestamp('1997-01-01 00:00:23'), timestamp('1997-01-01 00:00:00'))` = TIMESTAMPDIFF(SECOND, timestamp('1997-01-01 00:00:23'), timestamp('1997-01-01 00:00:00')) | fields `TIMESTAMPDIFF(YEAR, '1997-01-01 00:00:00', '2001-03-06 00:00:00')`, `TIMESTAMPDIFF(SECOND, timestamp('1997-01-01 00:00:23'), timestamp('1997-01-01 00:00:00'))`",SUCCESS +"source=myglue_test.default.people | eval `TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00')` = TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00') | eval `TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00')` = TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00') | fields `TIMESTAMPADD(DAY, 17, '2000-01-01 00:00:00')`, `TIMESTAMPADD(QUARTER, -1, '2000-01-01 00:00:00')`",SUCCESS + source = myglue_test.default.http_logs | stats count(),SUCCESS +"source = myglue_test.default.http_logs | stats avg(size) as c1, max(size) as c2, min(size) as c3, sum(size) as c4, percentile(size, 50) as c5, stddev_pop(size) as c6, stddev_samp(size) as c7, distinct_count(size) as c8",SUCCESS +"source = myglue_test.default.http_logs | eval c1 = abs(size) | eval c2 = ceil(size) | eval c3 = floor(size) | eval c4 = sqrt(size) | eval c5 = ln(size) | eval c6 = pow(size, 2) | eval c7 = mod(size, 2) | fields c1, c2, c3, c4, c5, c6, c7 | head 10",SUCCESS +"source = myglue_test.default.http_logs | eval c1 = isnull(request) | eval c2 = isnotnull(request) | eval c3 = ifnull(request, +""Unknown"") | eval c4 = nullif(request, +""Unknown"") | eval c5 = isnull(size) | eval c6 = if(like(request, '%bordeaux%'), 'hello', 'world') | fields c1, c2, c3, c4, c5, c6 | head 10",SUCCESS +/* this is block comment */ source = myglue_test.tpch_csv.orders | head 1 // this is line comment,SUCCESS +"/* test in tpch q16, q18, q20 */ source = myglue_test.tpch_csv.orders | head 1 // add source=xx to avoid failure in automation",SUCCESS +"/* test in tpch q4, q21, q22 */ source = myglue_test.tpch_csv.orders | head 1",SUCCESS +"/* test in tpch q2, q11, q15, q17, q20, q22 */ source = myglue_test.tpch_csv.orders | head 1",SUCCESS +"/* test in tpch q7, q8, q9, q13, q15, q22 */ source = myglue_test.tpch_csv.orders | head 1",SUCCESS +/* lots of inner join tests in tpch */ source = myglue_test.tpch_csv.orders | head 1,SUCCESS +/* left join test in tpch q13 */ source = myglue_test.tpch_csv.orders | head 1,SUCCESS +"source = myglue_test.tpch_csv.orders + | right outer join ON c_custkey = o_custkey AND not like(o_comment, '%special%requests%') + myglue_test.tpch_csv.customer +| stats count(o_orderkey) as c_count by c_custkey +| sort - c_count",SUCCESS +"source = myglue_test.tpch_csv.orders + | full outer join ON c_custkey = o_custkey AND not like(o_comment, '%special%requests%') + myglue_test.tpch_csv.customer +| stats count(o_orderkey) as c_count by c_custkey +| sort - c_count",SUCCESS +"source = myglue_test.tpch_csv.customer +| semi join ON c_custkey = o_custkey myglue_test.tpch_csv.orders +| where c_mktsegment = 'BUILDING' + | sort - c_custkey +| head 10",SUCCESS +"source = myglue_test.tpch_csv.customer +| anti join ON c_custkey = o_custkey myglue_test.tpch_csv.orders +| where c_mktsegment = 'BUILDING' + | sort - c_custkey +| head 10",SUCCESS +"source = myglue_test.tpch_csv.supplier +| where like(s_comment, '%Customer%Complaints%') +| join ON s_nationkey > n_nationkey [ source = myglue_test.tpch_csv.nation | where n_name = 'SAUDI ARABIA' ] +| sort - s_name +| head 10",SUCCESS +"source = myglue_test.tpch_csv.supplier +| where like(s_comment, '%Customer%Complaints%') +| join [ source = myglue_test.tpch_csv.nation | where n_name = 'SAUDI ARABIA' ] +| sort - s_name +| head 10",SUCCESS +source=myglue_test.default.people | LOOKUP myglue_test.default.work_info uid AS id REPLACE department | stats distinct_count(department),SUCCESS +source = myglue_test.default.people| LOOKUP myglue_test.default.work_info uid AS id APPEND department | stats distinct_count(department),SUCCESS +source = myglue_test.default.people| LOOKUP myglue_test.default.work_info uid AS id REPLACE department AS country | stats distinct_count(country),SUCCESS +source = myglue_test.default.people| LOOKUP myglue_test.default.work_info uid AS id APPEND department AS country | stats distinct_count(country),SUCCESS +"source = myglue_test.default.people| LOOKUP myglue_test.default.work_info uID AS id, name REPLACE department | stats distinct_count(department)",SUCCESS +"source = myglue_test.default.people| LOOKUP myglue_test.default.work_info uid AS ID, name APPEND department | stats distinct_count(department)",SUCCESS +"source = myglue_test.default.people| LOOKUP myglue_test.default.work_info uID AS id, name | head 10",SUCCESS +"source = myglue_test.default.people | eval major = occupation | fields id, name, major, country, salary | LOOKUP myglue_test.default.work_info name REPLACE occupation AS major | stats distinct_count(major)",SUCCESS +"source = myglue_test.default.people | eval major = occupation | fields id, name, major, country, salary | LOOKUP myglue_test.default.work_info name APPEND occupation AS major | stats distinct_count(major)",SUCCESS +"source = myglue_test.default.http_logs | eval res = json('{""account_number"":1,""balance"":39225,""age"":32,""gender"":""M""}') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json('{""f1"":""abc"",""f2"":{""f3"":""a"",""f4"":""b""}}') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json('[1,2,3,{""f1"":1,""f2"":[5,6]},4]') | head 1 | fields res",SUCCESS +source = myglue_test.default.http_logs | eval res = json('[]') | head 1 | fields res,SUCCESS +"source = myglue_test.default.http_logs | eval res = json(‘{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json('{""invalid"": ""json""') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json('[1,2,3]') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json(‘[1,2') | head 1 | fields res",SUCCESS +source = myglue_test.default.http_logs | eval res = json('[invalid json]') | head 1 | fields res,SUCCESS +source = myglue_test.default.http_logs | eval res = json('invalid json') | head 1 | fields res,SUCCESS +source = myglue_test.default.http_logs | eval res = json(null) | head 1 | fields res,SUCCESS +"source = myglue_test.default.http_logs | eval res = json_array('this', 'is', 'a', 'string', 'array') | head 1 | fields res",SUCCESS +source = myglue_test.default.http_logs | eval res = json_array() | head 1 | fields res,SUCCESS +"source = myglue_test.default.http_logs | eval res = json_array(1, 2, 0, -1, 1.1, -0.11) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_array('this', 'is', 1.1, -0.11, true, false) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = to_json_string(json_array(1,2,0,-1,1.1,-0.11)) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = array_length(json_array(1,2,0,-1,1.1,-0.11)) | head 1 | fields res",SUCCESS +source = myglue_test.default.http_logs | eval res = array_length(json_array()) | head 1 | fields res,SUCCESS +source = myglue_test.default.http_logs | eval res = json_array_length('[]') | head 1 | fields res,SUCCESS +"source = myglue_test.default.http_logs | eval res = json_array_length('[1,2,3,{""f1"":1,""f2"":[5,6]},4]') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_array_length('{\""key\"": 1}') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_array_length('[1,2') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = to_json_string(json_object('key', 'string_value')) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = to_json_string(json_object('key', 123.45)) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = to_json_string(json_object('key', true)) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = to_json_string(json_object(""a"", 1, ""b"", 2, ""c"", 3)) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = to_json_string(json_object('key', array())) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = to_json_string(json_object('key', array(1, 2, 3))) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = to_json_string(json_object('outer', json_object('inner', 123.45))) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = to_json_string(json_object(""array"", json_array(1,2,0,-1,1.1,-0.11))) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | where json_valid(('{""account_number"":1,""balance"":39225,""age"":32,""gender"":""M""}') | head 1",SUCCESS +"source = myglue_test.default.http_logs | where not json_valid(('{""account_number"":1,""balance"":39225,""age"":32,""gender"":""M""}') | head 1",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_keys(json('{""account_number"":1,""balance"":39225,""age"":32,""gender"":""M""}')) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_keys(json('{""f1"":""abc"",""f2"":{""f3"":""a"",""f4"":""b""}}')) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_keys(json('[1,2,3,{""f1"":1,""f2"":[5,6]},4]')) | head 1 | fields res",SUCCESS +source = myglue_test.default.http_logs | eval res = json_keys(json('[]')) | head 1 | fields res,SUCCESS +"source = myglue_test.default.http_logs | eval res = json_keys(json(‘{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}')) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_keys(json('{""invalid"": ""json""')) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_keys(json('[1,2,3]')) | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_keys(json('[1,2')) | head 1 | fields res",SUCCESS +source = myglue_test.default.http_logs | eval res = json_keys(json('[invalid json]')) | head 1 | fields res,SUCCESS +source = myglue_test.default.http_logs | eval res = json_keys(json('invalid json')) | head 1 | fields res,SUCCESS +source = myglue_test.default.http_logs | eval res = json_keys(json(null)) | head 1 | fields res,SUCCESS +"source = myglue_test.default.http_logs | eval res = json_extract('{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}', '$') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_extract('{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}', '$.teacher') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_extract('{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}', '$.student') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_extract('{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}', '$.student[*]') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_extract('{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}', '$.student[0]') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_extract('{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}', '$.student[*].name') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_extract('{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}', '$.student[1].name') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_extract('{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}', '$.student[0].not_exist_key') | head 1 | fields res",SUCCESS +"source = myglue_test.default.http_logs | eval res = json_extract('{""teacher"":""Alice"",""student"":[{""name"":""Bob"",""rank"":1},{""name"":""Charlie"",""rank"":2}]}', '$.student[10]') | head 1 | fields res",SUCCESS +"source = myglue_test.default.people | eval array = json_array(1,2,0,-1,1.1,-0.11), result = forall(array, x -> x > 0) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(1,2,0,-1,1.1,-0.11), result = forall(array, x -> x > -10) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(json_object(""a"",1,""b"",-1),json_object(""a"",-1,""b"",-1)), result = forall(array, x -> x.a > 0) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(json_object(""a"",1,""b"",-1),json_object(""a"",-1,""b"",-1)), result = exists(array, x -> x.b < 0) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(1,2,0,-1,1.1,-0.11), result = exists(array, x -> x > 0) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(1,2,0,-1,1.1,-0.11), result = exists(array, x -> x > 10) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(1,2,0,-1,1.1,-0.11), result = filter(array, x -> x > 0) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(1,2,0,-1,1.1,-0.11), result = filter(array, x -> x > 10) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(1,2,3), result = transform(array, x -> x + 1) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(1,2,3), result = transform(array, (x, y) -> x + y) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(1,2,3), result = reduce(array, 0, (acc, x) -> acc + x) | head 1 | fields result",SUCCESS +"source = myglue_test.default.people | eval array = json_array(1,2,3), result = reduce(array, 0, (acc, x) -> acc + x, acc -> acc * 10) | head 1 | fields result",SUCCESS +source=myglue_test.default.people | eval age = salary | eventstats avg(age) | sort id | head 10,SUCCESS +"source=myglue_test.default.people | eval age = salary | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count | sort id | head 10",SUCCESS +source=myglue_test.default.people | eventstats avg(salary) by country | sort id | head 10,SUCCESS +"source=myglue_test.default.people | eval age = salary | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by country | sort id | head 10",SUCCESS +"source=myglue_test.default.people | eval age = salary | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count +by span(age, 10) | sort id | head 10",SUCCESS +"source=myglue_test.default.people | eval age = salary | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by span(age, 10) as age_span, country | sort id | head 10",SUCCESS +"source=myglue_test.default.people | where country != 'USA' | eventstats stddev_samp(salary), stddev_pop(salary), percentile_approx(salary, 60) by span(salary, 1000) as salary_span | sort id | head 10",SUCCESS +"source=myglue_test.default.people | eval age = salary | eventstats avg(age) as avg_age by occupation, country | eventstats avg(avg_age) as avg_state_age by country | sort id | head 10",SUCCESS +"source=myglue_test.default.people | eventstats distinct_count(salary) by span(salary, 1000) as age_span",FAILED +"source = myglue_test.tpch_csv.lineitem +| where l_shipdate <= subdate(date('1998-12-01'), 90) +| stats sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, + avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count() as count_order + by l_returnflag, l_linestatus +| sort l_returnflag, l_linestatus",SUCCESS +"source = myglue_test.tpch_csv.part +| join ON p_partkey = ps_partkey myglue_test.tpch_csv.partsupp +| join ON s_suppkey = ps_suppkey myglue_test.tpch_csv.supplier +| join ON s_nationkey = n_nationkey myglue_test.tpch_csv.nation +| join ON n_regionkey = r_regionkey myglue_test.tpch_csv.region +| where p_size = 15 AND like(p_type, '%BRASS') AND r_name = 'EUROPE' AND ps_supplycost = [ + source = myglue_test.tpch_csv.partsupp + | join ON s_suppkey = ps_suppkey myglue_test.tpch_csv.supplier + | join ON s_nationkey = n_nationkey myglue_test.tpch_csv.nation + | join ON n_regionkey = r_regionkey myglue_test.tpch_csv.region + | where r_name = 'EUROPE' + | stats MIN(ps_supplycost) + ] +| sort - s_acctbal, n_name, s_name, p_partkey +| head 100",SUCCESS +"source = myglue_test.tpch_csv.customer +| join ON c_custkey = o_custkey myglue_test.tpch_csv.orders +| join ON l_orderkey = o_orderkey myglue_test.tpch_csv.lineitem +| where c_mktsegment = 'BUILDING' AND o_orderdate < date('1995-03-15') AND l_shipdate > date('1995-03-15') +| stats sum(l_extendedprice * (1 - l_discount)) as revenue by l_orderkey, o_orderdate, o_shippriority + | sort - revenue, o_orderdate +| head 10",SUCCESS +"source = myglue_test.tpch_csv.orders +| where o_orderdate >= date('1993-07-01') + and o_orderdate < date_add(date('1993-07-01'), interval 3 month) + and exists [ + source = myglue_test.tpch_csv.lineitem + | where l_orderkey = o_orderkey and l_commitdate < l_receiptdate + ] +| stats count() as order_count by o_orderpriority +| sort o_orderpriority",SUCCESS +"source = myglue_test.tpch_csv.customer +| join ON c_custkey = o_custkey myglue_test.tpch_csv.orders +| join ON l_orderkey = o_orderkey myglue_test.tpch_csv.lineitem +| join ON l_suppkey = s_suppkey AND c_nationkey = s_nationkey myglue_test.tpch_csv.supplier +| join ON s_nationkey = n_nationkey myglue_test.tpch_csv.nation +| join ON n_regionkey = r_regionkey myglue_test.tpch_csv.region +| where r_name = 'ASIA' AND o_orderdate >= date('1994-01-01') AND o_orderdate < date_add(date('1994-01-01'), interval 1 year) +| stats sum(l_extendedprice * (1 - l_discount)) as revenue by n_name +| sort - revenue",SUCCESS +"source = myglue_test.tpch_csv.lineitem +| where l_shipdate >= date('1994-01-01') + and l_shipdate < adddate(date('1994-01-01'), 365) + and l_discount between .06 - 0.01 and .06 + 0.01 + and l_quantity < 24 +| stats sum(l_extendedprice * l_discount) as revenue",SUCCESS +"source = [ + source = myglue_test.tpch_csv.supplier + | join ON s_suppkey = l_suppkey myglue_test.tpch_csv.lineitem + | join ON o_orderkey = l_orderkey myglue_test.tpch_csv.orders + | join ON c_custkey = o_custkey myglue_test.tpch_csv.customer + | join ON s_nationkey = n1.n_nationkey myglue_test.tpch_csv.nation as n1 + | join ON c_nationkey = n2.n_nationkey myglue_test.tpch_csv.nation as n2 + | where l_shipdate between date('1995-01-01') and date('1996-12-31') + and n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY' or n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE' + | eval supp_nation = n1.n_name, cust_nation = n2.n_name, l_year = year(l_shipdate), volume = l_extendedprice * (1 - l_discount) + | fields supp_nation, cust_nation, l_year, volume + ] as shipping +| stats sum(volume) as revenue by supp_nation, cust_nation, l_year +| sort supp_nation, cust_nation, l_year",SUCCESS +"source = [ + source = myglue_test.tpch_csv.part + | join ON p_partkey = l_partkey myglue_test.tpch_csv.lineitem + | join ON s_suppkey = l_suppkey myglue_test.tpch_csv.supplier + | join ON l_orderkey = o_orderkey myglue_test.tpch_csv.orders + | join ON o_custkey = c_custkey myglue_test.tpch_csv.customer + | join ON c_nationkey = n1.n_nationkey myglue_test.tpch_csv.nation as n1 + | join ON s_nationkey = n2.n_nationkey myglue_test.tpch_csv.nation as n2 + | join ON n1.n_regionkey = r_regionkey myglue_test.tpch_csv.region + | where r_name = 'AMERICA' AND p_type = 'ECONOMY ANODIZED STEEL' + and o_orderdate between date('1995-01-01') and date('1996-12-31') + | eval o_year = year(o_orderdate) + | eval volume = l_extendedprice * (1 - l_discount) + | eval nation = n2.n_name + | fields o_year, volume, nation + ] as all_nations +| stats sum(case(nation = 'BRAZIL', volume else 0)) as sum_case, sum(volume) as sum_volume by o_year +| eval mkt_share = sum_case / sum_volume +| fields mkt_share, o_year +| sort o_year",SUCCESS +"source = [ + source = myglue_test.tpch_csv.part + | join ON p_partkey = l_partkey myglue_test.tpch_csv.lineitem + | join ON s_suppkey = l_suppkey myglue_test.tpch_csv.supplier + | join ON ps_partkey = l_partkey and ps_suppkey = l_suppkey myglue_test.tpch_csv.partsupp + | join ON o_orderkey = l_orderkey myglue_test.tpch_csv.orders + | join ON s_nationkey = n_nationkey myglue_test.tpch_csv.nation + | where like(p_name, '%green%') + | eval nation = n_name + | eval o_year = year(o_orderdate) + | eval amount = l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity + | fields nation, o_year, amount + ] as profit +| stats sum(amount) as sum_profit by nation, o_year +| sort nation, - o_year",SUCCESS +"source = myglue_test.tpch_csv.customer +| join ON c_custkey = o_custkey myglue_test.tpch_csv.orders +| join ON l_orderkey = o_orderkey myglue_test.tpch_csv.lineitem +| join ON c_nationkey = n_nationkey myglue_test.tpch_csv.nation +| where o_orderdate >= date('1993-10-01') + AND o_orderdate < date_add(date('1993-10-01'), interval 3 month) + AND l_returnflag = 'R' +| stats sum(l_extendedprice * (1 - l_discount)) as revenue by c_custkey, c_name, c_acctbal, c_phone, n_name, c_address, c_comment +| sort - revenue +| head 20",SUCCESS +"source = myglue_test.tpch_csv.partsupp +| join ON ps_suppkey = s_suppkey myglue_test.tpch_csv.supplier +| join ON s_nationkey = n_nationkey myglue_test.tpch_csv.nation +| where n_name = 'GERMANY' +| stats sum(ps_supplycost * ps_availqty) as value by ps_partkey +| where value > [ + source = myglue_test.tpch_csv.partsupp + | join ON ps_suppkey = s_suppkey myglue_test.tpch_csv.supplier + | join ON s_nationkey = n_nationkey myglue_test.tpch_csv.nation + | where n_name = 'GERMANY' + | stats sum(ps_supplycost * ps_availqty) as check + | eval threshold = check * 0.0001000000 + | fields threshold + ] +| sort - value",SUCCESS +"source = myglue_test.tpch_csv.orders +| join ON o_orderkey = l_orderkey myglue_test.tpch_csv.lineitem +| where l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_shipmode in ('MAIL', 'SHIP') + and l_receiptdate >= date('1994-01-01') + and l_receiptdate < date_add(date('1994-01-01'), interval 1 year) +| stats sum(case(o_orderpriority = '1-URGENT' or o_orderpriority = '2-HIGH', 1 else 0)) as high_line_count, + sum(case(o_orderpriority != '1-URGENT' and o_orderpriority != '2-HIGH', 1 else 0)) as low_line_countby + by l_shipmode +| sort l_shipmode",SUCCESS +"source = [ + source = myglue_test.tpch_csv.customer + | left outer join ON c_custkey = o_custkey AND not like(o_comment, '%special%requests%') + myglue_test.tpch_csv.orders + | stats count(o_orderkey) as c_count by c_custkey + ] as c_orders +| stats count() as custdist by c_count +| sort - custdist, - c_count",SUCCESS +"source = myglue_test.tpch_csv.lineitem +| join ON l_partkey = p_partkey + AND l_shipdate >= date('1995-09-01') + AND l_shipdate < date_add(date('1995-09-01'), interval 1 month) + myglue_test.tpch_csv.part +| stats sum(case(like(p_type, 'PROMO%'), l_extendedprice * (1 - l_discount) else 0)) as sum1, + sum(l_extendedprice * (1 - l_discount)) as sum2 +| eval promo_revenue = 100.00 * sum1 / sum2 // Stats and Eval commands can combine when issues/819 resolved +| fields promo_revenue",SUCCESS +"source = myglue_test.tpch_csv.supplier +| join right = revenue0 ON s_suppkey = supplier_no [ + source = myglue_test.tpch_csv.lineitem + | where l_shipdate >= date('1996-01-01') AND l_shipdate < date_add(date('1996-01-01'), interval 3 month) + | eval supplier_no = l_suppkey + | stats sum(l_extendedprice * (1 - l_discount)) as total_revenue by supplier_no + ] +| where total_revenue = [ + source = [ + source = myglue_test.tpch_csv.lineitem + | where l_shipdate >= date('1996-01-01') AND l_shipdate < date_add(date('1996-01-01'), interval 3 month) + | eval supplier_no = l_suppkey + | stats sum(l_extendedprice * (1 - l_discount)) as total_revenue by supplier_no + ] + | stats max(total_revenue) + ] +| sort s_suppkey +| fields s_suppkey, s_name, s_address, s_phone, total_revenue",SUCCESS +"source = myglue_test.tpch_csv.partsupp +| join ON p_partkey = ps_partkey myglue_test.tpch_csv.part +| where p_brand != 'Brand#45' + and not like(p_type, 'MEDIUM POLISHED%') + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in [ + source = myglue_test.tpch_csv.supplier + | where like(s_comment, '%Customer%Complaints%') + | fields s_suppkey + ] +| stats distinct_count(ps_suppkey) as supplier_cnt by p_brand, p_type, p_size +| sort - supplier_cnt, p_brand, p_type, p_size",SUCCESS +"source = myglue_test.tpch_csv.lineitem +| join ON p_partkey = l_partkey myglue_test.tpch_csv.part +| where p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < [ + source = myglue_test.tpch_csv.lineitem + | where l_partkey = p_partkey + | stats avg(l_quantity) as avg + | eval `0.2 * avg` = 0.2 * avg + | fields `0.2 * avg` + ] +| stats sum(l_extendedprice) as sum +| eval avg_yearly = sum / 7.0 +| fields avg_yearly",SUCCESS +"source = myglue_test.tpch_csv.customer +| join ON c_custkey = o_custkey myglue_test.tpch_csv.orders +| join ON o_orderkey = l_orderkey myglue_test.tpch_csv.lineitem +| where o_orderkey in [ + source = myglue_test.tpch_csv.lineitem + | stats sum(l_quantity) as sum by l_orderkey + | where sum > 300 + | fields l_orderkey + ] +| stats sum(l_quantity) by c_name, c_custkey, o_orderkey, o_orderdate, o_totalprice +| sort - o_totalprice, o_orderdate +| head 100",SUCCESS +"source = myglue_test.tpch_csv.lineitem +| join ON p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 and l_quantity <= 1 + 10 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + OR p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 and l_quantity <= 10 + 10 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + OR p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 and l_quantity <= 20 + 10 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + myglue_test.tpch_csv.part",SUCCESS +"source = myglue_test.tpch_csv.supplier +| join ON s_nationkey = n_nationkey myglue_test.tpch_csv.nation +| where n_name = 'CANADA' + and s_suppkey in [ + source = myglue_test.tpch_csv.partsupp + | where ps_partkey in [ + source = myglue_test.tpch_csv.part + | where like(p_name, 'forest%') + | fields p_partkey + ] + and ps_availqty > [ + source = myglue_test.tpch_csv.lineitem + | where l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date('1994-01-01') + and l_shipdate < date_add(date('1994-01-01'), interval 1 year) + | stats sum(l_quantity) as sum_l_quantity + | eval half_sum_l_quantity = 0.5 * sum_l_quantity + | fields half_sum_l_quantity + ] + | fields ps_suppkey + ]",SUCCESS +"source = myglue_test.tpch_csv.supplier +| join ON s_suppkey = l1.l_suppkey myglue_test.tpch_csv.lineitem as l1 +| join ON o_orderkey = l1.l_orderkey myglue_test.tpch_csv.orders +| join ON s_nationkey = n_nationkey myglue_test.tpch_csv.nation +| where o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists [ + source = myglue_test.tpch_csv.lineitem as l2 + | where l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey != l1.l_suppkey + ] + and not exists [ + source = myglue_test.tpch_csv.lineitem as l3 + | where l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey != l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ] + and n_name = 'SAUDI ARABIA' +| stats count() as numwait by s_name +| sort - numwait, s_name +| head 100",SUCCESS +"source = [ + source = myglue_test.tpch_csv.customer + | where substring(c_phone, 1, 2) in ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > [ + source = myglue_test.tpch_csv.customer + | where c_acctbal > 0.00 + and substring(c_phone, 1, 2) in ('13', '31', '23', '29', '30', '18', '17') + | stats avg(c_acctbal) + ] + and not exists [ + source = myglue_test.tpch_csv.orders + | where o_custkey = c_custkey + ] + | eval cntrycode = substring(c_phone, 1, 2) + | fields cntrycode, c_acctbal + ] as custsale +| stats count() as numcust, sum(c_acctbal) as totacctbal by cntrycode +| sort cntrycode",SUCCESS From b53a6993ed028671d33a6debe36574751b05d9de Mon Sep 17 00:00:00 2001 From: YANGDB Date: Mon, 11 Nov 2024 15:51:24 -0800 Subject: [PATCH 2/4] Ppl count approximate support (#884) * add functional approximation support for: - distinct count - top - rare Signed-off-by: YANGDB * update license and scalafmt Signed-off-by: YANGDB * update additional tests using APPROX_COUNT_DISTINCT Signed-off-by: YANGDB * add visitFirstChild(node, context) method for the PlanVisitor for simplify node inner child access visibility Signed-off-by: YANGDB * update inline documentation Signed-off-by: YANGDB * update according to PR comments - DISTINCT_COUNT_APPROX should be added to keywordsCanBeId Signed-off-by: YANGDB --------- Signed-off-by: YANGDB --- docs/ppl-lang/PPL-Example-Commands.md | 5 + docs/ppl-lang/ppl-rare-command.md | 10 +- docs/ppl-lang/ppl-top-command.md | 7 +- ...ntSparkPPLAggregationWithSpanITSuite.scala | 39 +++ .../FlintSparkPPLAggregationsITSuite.scala | 124 ++++++++ .../ppl/FlintSparkPPLTopAndRareITSuite.scala | 270 ++++++++++++++++++ .../src/main/antlr4/OpenSearchPPLLexer.g4 | 3 + .../src/main/antlr4/OpenSearchPPLParser.g4 | 9 +- .../sql/ast/tree/CountedAggregation.java | 16 ++ .../sql/ast/tree/RareAggregation.java | 10 +- .../sql/ast/tree/TopAggregation.java | 2 +- .../function/BuiltinFunctionName.java | 2 + .../sql/ppl/CatalystPlanContext.java | 3 +- .../sql/ppl/CatalystQueryPlanVisitor.java | 68 +++-- .../opensearch/sql/ppl/parser/AstBuilder.java | 20 +- .../sql/ppl/parser/AstExpressionBuilder.java | 3 +- .../sql/ppl/utils/AggregatorTransformer.java | 2 + .../ppl/utils/BuiltinFunctionTransformer.java | 3 + ...ggregationQueriesTranslatorTestSuite.scala | 92 ++++++ ...TopAndRareQueriesTranslatorTestSuite.scala | 36 +++ 20 files changed, 668 insertions(+), 56 deletions(-) create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/CountedAggregation.java diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 4ea564111..cb50431f6 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -177,6 +177,7 @@ source = table | where ispresent(a) | - `source = table | stats max(c) by b` - `source = table | stats count(c) by b | head 5` - `source = table | stats distinct_count(c)` +- `source = table | stats distinct_count_approx(c)` - `source = table | stats stddev_samp(c)` - `source = table | stats stddev_pop(c)` - `source = table | stats percentile(c, 90)` @@ -202,6 +203,7 @@ source = table | where ispresent(a) | - `source = table | where a < 50 | eventstats avg(c) ` - `source = table | eventstats max(c) by b` - `source = table | eventstats count(c) by b | head 5` +- `source = table | eventstats count(c) by b | head 5` - `source = table | eventstats stddev_samp(c)` - `source = table | eventstats stddev_pop(c)` - `source = table | eventstats percentile(c, 90)` @@ -246,12 +248,15 @@ source = table | where ispresent(a) | - `source=accounts | rare gender` - `source=accounts | rare age by gender` +- `source=accounts | rare 5 age by gender` +- `source=accounts | rare_approx age by gender` #### **Top** [See additional command details](ppl-top-command.md) - `source=accounts | top gender` - `source=accounts | top 1 gender` +- `source=accounts | top_approx 5 gender` - `source=accounts | top 1 age by gender` #### **Parse** diff --git a/docs/ppl-lang/ppl-rare-command.md b/docs/ppl-lang/ppl-rare-command.md index 5645382f8..e3ad21f4e 100644 --- a/docs/ppl-lang/ppl-rare-command.md +++ b/docs/ppl-lang/ppl-rare-command.md @@ -6,10 +6,13 @@ Using ``rare`` command to find the least common tuple of values of all fields in **Note**: A maximum of 10 results is returned for each distinct tuple of values of the group-by fields. **Syntax** -`rare [by-clause]` +`rare [N] [by-clause]` +`rare_approx [N] [by-clause]` +* N: number of results to return. **Default**: 10 * field-list: mandatory. comma-delimited list of field names. * by-clause: optional. one or more fields to group the results by. +* rare_approx: approximate count of the rare (n) fields by using estimated [cardinality by HyperLogLog++ algorithm](https://spark.apache.org/docs/3.5.2/sql-ref-functions-builtin.html). ### Example 1: Find the least common values in a field @@ -19,6 +22,8 @@ The example finds least common gender of all the accounts. PPL query: os> source=accounts | rare gender; + os> source=accounts | rare_approx 10 gender; + os> source=accounts | rare_approx gender; fetched rows / total rows = 2/2 +----------+ | gender | @@ -34,7 +39,8 @@ The example finds least common age of all the accounts group by gender. PPL query: - os> source=accounts | rare age by gender; + os> source=accounts | rare 5 age by gender; + os> source=accounts | rare_approx 5 age by gender; fetched rows / total rows = 4/4 +----------+-------+ | gender | age | diff --git a/docs/ppl-lang/ppl-top-command.md b/docs/ppl-lang/ppl-top-command.md index 4ba56f692..93d3a7148 100644 --- a/docs/ppl-lang/ppl-top-command.md +++ b/docs/ppl-lang/ppl-top-command.md @@ -6,11 +6,12 @@ Using ``top`` command to find the most common tuple of values of all fields in t ### Syntax `top [N] [by-clause]` +`top_approx [N] [by-clause]` * N: number of results to return. **Default**: 10 * field-list: mandatory. comma-delimited list of field names. * by-clause: optional. one or more fields to group the results by. - +* top_approx: approximate count of the (n) top fields by using estimated [cardinality by HyperLogLog++ algorithm](https://spark.apache.org/docs/3.5.2/sql-ref-functions-builtin.html). ### Example 1: Find the most common values in a field @@ -19,6 +20,7 @@ The example finds most common gender of all the accounts. PPL query: os> source=accounts | top gender; + os> source=accounts | top_approx gender; fetched rows / total rows = 2/2 +----------+ | gender | @@ -33,7 +35,7 @@ The example finds most common gender of all the accounts. PPL query: - os> source=accounts | top 1 gender; + os> source=accounts | top_approx 1 gender; fetched rows / total rows = 1/1 +----------+ | gender | @@ -48,6 +50,7 @@ The example finds most common age of all the accounts group by gender. PPL query: os> source=accounts | top 1 age by gender; + os> source=accounts | top_approx 1 age by gender; fetched rows / total rows = 2/2 +----------+-------+ | gender | age | diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala index 0bebca9b0..aa96d0991 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala @@ -494,4 +494,43 @@ class FlintSparkPPLAggregationWithSpanITSuite // Compare the two plans comparePlans(expectedPlan, logicalPlan, false) } + + test( + "create ppl simple distinct count age by span of interval of 10 years query with state filter test using approximation") { + val frame = sql(s""" + | source = $testTable | where state != 'Quebec' | stats distinct_count_approx(age) by span(age, 10) as age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(1, 70L), Row(1, 30L), Row(1, 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val stateField = UnresolvedAttribute("state") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(ageField), isDistinct = true), + "distinct_count_approx(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val filterExpr = Not(EqualTo(stateField, Literal("Quebec"))) + val filterPlan = Filter(filterExpr, table) + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala index bcfe22764..2275c775c 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala @@ -835,6 +835,43 @@ class FlintSparkPPLAggregationsITSuite comparePlans(expectedPlan, logicalPlan, false) } + test("create ppl simple country distinct_count using approximation ") { + val frame = sql(s""" + | source = $testTable| stats distinct_count_approx(country) + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + + // Define the expected results + val expectedResults: Array[Row] = Array(Row(2L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(countryField), isDistinct = true), + "distinct_count_approx(country)")() + + val aggregatePlan = + Aggregate(Seq.empty, Seq(aggregateExpressions), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + comparePlans(expectedPlan, logicalPlan, false) + } + test("create ppl simple age distinct_count group by country query test with sort") { val frame = sql(s""" | source = $testTable | stats distinct_count(age) by country | sort country @@ -881,6 +918,53 @@ class FlintSparkPPLAggregationsITSuite s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") } + test( + "create ppl simple age distinct_count group by country query test with sort using approximation") { + val frame = sql(s""" + | source = $testTable | stats distinct_count_approx(age) by country | sort country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(2L, "Canada"), Row(2L, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(ageField), isDistinct = true), + "distinct_count_approx(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + // Compare the two plans + assert( + compareByString(expectedPlan) === compareByString(logicalPlan), + s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") + } + test("create ppl simple age distinct_count group by country with state filter query test") { val frame = sql(s""" | source = $testTable | where state != 'Ontario' | stats distinct_count(age) by country @@ -920,6 +1004,46 @@ class FlintSparkPPLAggregationsITSuite assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + test( + "create ppl simple age distinct_count group by country with state filter query test using approximation") { + val frame = sql(s""" + | source = $testTable | where state != 'Ontario' | stats distinct_count_approx(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(1L, "Canada"), Row(2L, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val stateField = UnresolvedAttribute("state") + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val filterExpr = Not(EqualTo(stateField, Literal("Ontario"))) + val filterPlan = Filter(filterExpr, table) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(ageField), isDistinct = true), + "distinct_count_approx(age)")() + val productAlias = Alias(countryField, "country")() + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + test("two-level stats") { val frame = sql(s""" | source = $testTable| stats avg(age) as avg_age by state, country | stats avg(avg_age) as avg_state_age by country diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala index f10b6e2f5..4a1633035 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala @@ -84,6 +84,48 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + test("create ppl rare address field query test with approximation") { + val frame = sql(s""" + | source = $testTable| rare_approx address + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val addressField = UnresolvedAttribute("address") + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(addressField), isDistinct = false), + "count_address")(), + addressField) + val aggregatePlan = + Aggregate( + Seq(addressField), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(addressField), + isDistinct = false), + "count_address")(), + Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + test("create ppl rare address by age field query test") { val frame = sql(s""" | source = $testTable| rare address by age @@ -132,6 +174,104 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, false) } + test("create ppl rare 3 address by age field query test") { + val frame = sql(s""" + | source = $testTable| rare 3 address by age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + val expectedRow = Row(1, "Vancouver", 60) + assert( + results.head == expectedRow, + s"Expected least frequent result to be $expectedRow, but got ${results.head}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val addressField = UnresolvedAttribute("address") + val ageField = UnresolvedAttribute("age") + val ageAlias = Alias(ageField, "age")() + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val countExpr = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")() + + val aggregateExpressions = Seq(countExpr, addressField, ageAlias) + val aggregatePlan = + Aggregate( + Seq(addressField, ageAlias), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Ascending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, false) + } + + test("create ppl rare 3 address by age field query test with approximation") { + val frame = sql(s""" + | source = $testTable| rare_approx 3 address by age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val addressField = UnresolvedAttribute("address") + val ageField = UnresolvedAttribute("age") + val ageAlias = Alias(ageField, "age")() + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val countExpr = Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(addressField), isDistinct = false), + "count_address")() + + val aggregateExpressions = Seq(countExpr, addressField, ageAlias) + val aggregatePlan = + Aggregate( + Seq(addressField, ageAlias), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(addressField), + isDistinct = false), + "count_address")(), + Ascending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, false) + } + test("create ppl top address field query test") { val frame = sql(s""" | source = $testTable| top address @@ -179,6 +319,48 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + test("create ppl top address field query test with approximation") { + val frame = sql(s""" + | source = $testTable| top_approx address + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val addressField = UnresolvedAttribute("address") + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(addressField), isDistinct = false), + "count_address")(), + addressField) + val aggregatePlan = + Aggregate( + Seq(addressField), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(addressField), + isDistinct = false), + "count_address")(), + Descending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + test("create ppl top 3 countries query test") { val frame = sql(s""" | source = $newTestTable| top 3 country @@ -226,6 +408,48 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + test("create ppl top 3 countries query test with approximation") { + val frame = sql(s""" + | source = $newTestTable| top_approx 3 country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val countryField = UnresolvedAttribute("country") + val countExpr = Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(countryField), isDistinct = false), + "count_country")() + val aggregateExpressions = Seq(countExpr, countryField) + val aggregatePlan = + Aggregate( + Seq(countryField), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(countryField), + isDistinct = false), + "count_country")(), + Descending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + test("create ppl top 2 countries by occupation field query test") { val frame = sql(s""" | source = $newTestTable| top 3 country by occupation @@ -277,4 +501,50 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + + test("create ppl top 2 countries by occupation field query test with approximation") { + val frame = sql(s""" + | source = $newTestTable| top_approx 3 country by occupation + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val countryField = UnresolvedAttribute("country") + val occupationField = UnresolvedAttribute("occupation") + val occupationFieldAlias = Alias(occupationField, "occupation")() + + val countExpr = Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(countryField), isDistinct = false), + "count_country")() + val aggregateExpressions = Seq(countExpr, countryField, occupationFieldAlias) + val aggregatePlan = + Aggregate( + Seq(countryField, occupationFieldAlias), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(countryField), + isDistinct = false), + "count_country")(), + Descending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + + } } diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 2c3344b3c..10b2e01b8 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -23,7 +23,9 @@ DEDUP: 'DEDUP'; SORT: 'SORT'; EVAL: 'EVAL'; HEAD: 'HEAD'; +TOP_APPROX: 'TOP_APPROX'; TOP: 'TOP'; +RARE_APPROX: 'RARE_APPROX'; RARE: 'RARE'; PARSE: 'PARSE'; METHOD: 'METHOD'; @@ -216,6 +218,7 @@ BIT_XOR_OP: '^'; AVG: 'AVG'; COUNT: 'COUNT'; DISTINCT_COUNT: 'DISTINCT_COUNT'; +DISTINCT_COUNT_APPROX: 'DISTINCT_COUNT_APPROX'; ESTDC: 'ESTDC'; ESTDC_ERROR: 'ESTDC_ERROR'; MAX: 'MAX'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 1cfd172f7..63efd8c6c 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -76,7 +76,9 @@ commandName | SORT | HEAD | TOP + | TOP_APPROX | RARE + | RARE_APPROX | EVAL | GROK | PARSE @@ -180,11 +182,11 @@ headCommand ; topCommand - : TOP (number = integerLiteral)? fieldList (byClause)? + : (TOP | TOP_APPROX) (number = integerLiteral)? fieldList (byClause)? ; rareCommand - : RARE fieldList (byClause)? + : (RARE | RARE_APPROX) (number = integerLiteral)? fieldList (byClause)? ; grokCommand @@ -400,7 +402,7 @@ statsAggTerm statsFunction : statsFunctionName LT_PRTHS valueExpression RT_PRTHS # statsFunctionCall | COUNT LT_PRTHS RT_PRTHS # countAllFunctionCall - | (DISTINCT_COUNT | DC) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall + | (DISTINCT_COUNT | DC | DISTINCT_COUNT_APPROX) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall | percentileFunctionName = (PERCENTILE | PERCENTILE_APPROX) LT_PRTHS valueExpression COMMA percent = integerLiteral RT_PRTHS # percentileFunctionCall ; @@ -1122,6 +1124,7 @@ keywordsCanBeId // AGGREGATIONS | statsFunctionName | DISTINCT_COUNT + | DISTINCT_COUNT_APPROX | PERCENTILE | PERCENTILE_APPROX | ESTDC diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/CountedAggregation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/CountedAggregation.java new file mode 100644 index 000000000..9a4aa5d7d --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/CountedAggregation.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.sql.ast.tree; + +import org.opensearch.sql.ast.expression.Literal; + +import java.util.Optional; + +/** + * marker interface for numeric based count aggregation (specific number of returned results) + */ +public interface CountedAggregation { + Optional getResults(); +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java index d5a637f3d..8e454685a 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java @@ -6,21 +6,29 @@ package org.opensearch.sql.ast.tree; import lombok.EqualsAndHashCode; +import lombok.Getter; import lombok.ToString; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.UnresolvedExpression; import java.util.Collections; import java.util.List; +import java.util.Optional; /** Logical plan node of Rare (Aggregation) command, the interface for building aggregation actions in queries. */ @ToString +@Getter @EqualsAndHashCode(callSuper = true) -public class RareAggregation extends Aggregation { +public class RareAggregation extends Aggregation implements CountedAggregation{ + private final Optional results; + /** Aggregation Constructor without span and argument. */ public RareAggregation( + Optional results, List aggExprList, List sortExprList, List groupExprList) { super(aggExprList, sortExprList, groupExprList, null, Collections.emptyList()); + this.results = results; } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java index e87a3b0b0..90aac5838 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java @@ -20,7 +20,7 @@ @ToString @Getter @EqualsAndHashCode(callSuper = true) -public class TopAggregation extends Aggregation { +public class TopAggregation extends Aggregation implements CountedAggregation { private final Optional results; /** Aggregation Constructor without span and argument. */ diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 1959d0f6d..f039bf47f 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -185,6 +185,7 @@ public enum BuiltinFunctionName { NESTED(FunctionName.of("nested")), PERCENTILE(FunctionName.of("percentile")), PERCENTILE_APPROX(FunctionName.of("percentile_approx")), + APPROX_COUNT_DISTINCT(FunctionName.of("approx_count_distinct")), /** Text Functions. */ ASCII(FunctionName.of("ascii")), @@ -332,6 +333,7 @@ public FunctionName getName() { .put("take", BuiltinFunctionName.TAKE) .put("percentile", BuiltinFunctionName.PERCENTILE) .put("percentile_approx", BuiltinFunctionName.PERCENTILE_APPROX) + .put("approx_count_distinct", BuiltinFunctionName.APPROX_COUNT_DISTINCT) .build(); public static Optional of(String str) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index 53dc17576..1621e65d5 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -26,6 +26,7 @@ import java.util.Stack; import java.util.function.BiFunction; import java.util.function.Function; +import java.util.function.UnaryOperator; import java.util.stream.Collectors; import static java.util.Collections.emptyList; @@ -187,7 +188,7 @@ public LogicalPlan reduce(BiFunction tran return result; }).orElse(getPlan())); } - + /** * apply for each plan with the given function * diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index d2ee46ae6..00a7905f0 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -14,13 +14,6 @@ import org.apache.spark.sql.catalyst.expressions.Explode; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.GeneratorOuter; -import org.apache.spark.sql.catalyst.expressions.In$; -import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; -import org.apache.spark.sql.catalyst.expressions.InSubquery$; -import org.apache.spark.sql.catalyst.expressions.LessThan; -import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; -import org.apache.spark.sql.catalyst.expressions.ListQuery$; -import org.apache.spark.sql.catalyst.expressions.MakeInterval$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; @@ -38,6 +31,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.opensearch.flint.spark.FlattenGenerator; import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Field; @@ -53,6 +47,7 @@ import org.opensearch.sql.ast.statement.Statement; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Correlation; +import org.opensearch.sql.ast.tree.CountedAggregation; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.DescribeRelation; import org.opensearch.sql.ast.tree.Eval; @@ -72,7 +67,6 @@ import org.opensearch.sql.ast.tree.Rename; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.SubqueryAlias; -import org.opensearch.sql.ast.tree.TopAggregation; import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.Window; import org.opensearch.sql.common.antlr.SyntaxCheckException; @@ -90,6 +84,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.function.BiConsumer; import java.util.stream.Collectors; import static java.util.Collections.emptyList; @@ -132,6 +127,10 @@ public LogicalPlan visitQuery(Query node, CatalystPlanContext context) { return node.getPlan().accept(this, context); } + public LogicalPlan visitFirstChild(Node node, CatalystPlanContext context) { + return node.getChild().get(0).accept(this, context); + } + @Override public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) { node.getStatement().accept(this, context); @@ -140,6 +139,7 @@ public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) { @Override public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) { + //relations doesnt have a visitFirstChild call since its the leaf of the AST tree if (node instanceof DescribeRelation) { TableIdentifier identifier = getTableIdentifier(node.getTableQualifiedName()); return context.with( @@ -159,7 +159,7 @@ public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) { @Override public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); return context.apply(p -> { Expression conditionExpression = visitExpression(node.getCondition(), context); Optional innerConditionExpression = context.popNamedParseExpressions(); @@ -173,8 +173,7 @@ public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) { */ @Override public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); - + visitFirstChild(node, context); return context.apply( searchSide -> { LogicalPlan lookupTable = node.getLookupRelation().accept(this, context); Expression lookupCondition = buildLookupMappingCondition(node, expressionAnalyzer, context); @@ -230,8 +229,7 @@ public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { @Override public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); - + visitFirstChild(node, context); node.getSortByField() .ifPresent(sortField -> { Expression sortFieldExpression = visitExpression(sortField, context); @@ -254,7 +252,7 @@ public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { @Override public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); context.reduce((left, right) -> { visitFieldList(node.getFieldsList().stream().map(Field::new).collect(Collectors.toList()), context); Seq fields = context.retainAllNamedParseExpressions(e -> e); @@ -272,7 +270,7 @@ public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext contex @Override public LogicalPlan visitJoin(Join node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); return context.apply(left -> { LogicalPlan right = node.getRight().accept(this, context); Optional joinCondition = node.getJoinCondition() @@ -285,7 +283,7 @@ public LogicalPlan visitJoin(Join node, CatalystPlanContext context) { @Override public LogicalPlan visitSubqueryAlias(SubqueryAlias node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); return context.apply(p -> { var alias = org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias$.MODULE$.apply(node.getAlias(), p); context.withSubqueryAlias(alias); @@ -296,7 +294,7 @@ public LogicalPlan visitSubqueryAlias(SubqueryAlias node, CatalystPlanContext co @Override public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); List aggsExpList = visitExpressionList(node.getAggExprList(), context); List groupExpList = visitExpressionList(node.getGroupExprList(), context); if (!groupExpList.isEmpty()) { @@ -327,9 +325,9 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, logicalPlan)); } //visit TopAggregation results limit - if ((node instanceof TopAggregation) && ((TopAggregation) node).getResults().isPresent()) { + if ((node instanceof CountedAggregation) && ((CountedAggregation) node).getResults().isPresent()) { context.apply(p -> (LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal( - ((TopAggregation) node).getResults().get().getValue(), org.apache.spark.sql.types.DataTypes.IntegerType), p)); + ((CountedAggregation) node).getResults().get().getValue(), org.apache.spark.sql.types.DataTypes.IntegerType), p)); } return logicalPlan; } @@ -342,7 +340,7 @@ private static LogicalPlan extractedAggregation(CatalystPlanContext context) { @Override public LogicalPlan visitWindow(Window node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); List windowFunctionExpList = visitExpressionList(node.getWindowFunctionList(), context); Seq windowFunctionExpressions = context.retainAllNamedParseExpressions(p -> p); List partitionExpList = visitExpressionList(node.getPartExprList(), context); @@ -372,10 +370,11 @@ public LogicalPlan visitAlias(Alias node, CatalystPlanContext context) { @Override public LogicalPlan visitProject(Project node, CatalystPlanContext context) { + //update plan's context prior to visiting node children if (node.isExcluded()) { List intersect = context.getProjectedFields().stream() - .filter(node.getProjectList()::contains) - .collect(Collectors.toList()); + .filter(node.getProjectList()::contains) + .collect(Collectors.toList()); if (!intersect.isEmpty()) { // Fields in parent projection, but they have be excluded in child. For example, // source=t | fields - A, B | fields A, B, C will throw "[Field A, Field B] can't be resolved" @@ -384,7 +383,7 @@ public LogicalPlan visitProject(Project node, CatalystPlanContext context) { } else { context.withProjectedFields(node.getProjectList()); } - LogicalPlan child = node.getChild().get(0).accept(this, context); + LogicalPlan child = visitFirstChild(node, context); visitExpressionList(node.getProjectList(), context); // Create a projection list from the existing expressions @@ -405,7 +404,7 @@ public LogicalPlan visitProject(Project node, CatalystPlanContext context) { @Override public LogicalPlan visitSort(Sort node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); visitFieldList(node.getSortList(), context); Seq sortElements = context.retainAllNamedParseExpressions(exp -> SortUtils.getSortDirection(node, (NamedExpression) exp)); return context.apply(p -> (LogicalPlan) new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, p)); @@ -413,20 +412,20 @@ public LogicalPlan visitSort(Sort node, CatalystPlanContext context) { @Override public LogicalPlan visitHead(Head node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); return context.apply(p -> (LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal( node.getSize(), DataTypes.IntegerType), p)); } @Override public LogicalPlan visitFieldSummary(FieldSummary fieldSummary, CatalystPlanContext context) { - fieldSummary.getChild().get(0).accept(this, context); + visitFirstChild(fieldSummary, context); return FieldSummaryTransformer.translate(fieldSummary, context); } @Override public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) { - fillNull.getChild().get(0).accept(this, context); + visitFirstChild(fillNull, context); List aliases = new ArrayList<>(); for(FillNull.NullableFieldFill nullableFieldFill : fillNull.getNullableFieldFills()) { Field field = nullableFieldFill.getNullableFieldReference(); @@ -457,7 +456,7 @@ public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) @Override public LogicalPlan visitFlatten(Flatten flatten, CatalystPlanContext context) { - flatten.getChild().get(0).accept(this, context); + visitFirstChild(flatten, context); if (context.getNamedParseExpressions().isEmpty()) { // Create an UnresolvedStar for all-fields projection context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); @@ -471,7 +470,7 @@ public LogicalPlan visitFlatten(Flatten flatten, CatalystPlanContext context) { @Override public LogicalPlan visitExpand(org.opensearch.sql.ast.tree.Expand node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); if (context.getNamedParseExpressions().isEmpty()) { // Create an UnresolvedStar for all-fields projection context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); @@ -507,7 +506,7 @@ private Expression visitExpression(UnresolvedExpression expression, CatalystPlan @Override public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); Expression sourceField = visitExpression(node.getSourceField(), context); ParseMethod parseMethod = node.getParseMethod(); java.util.Map arguments = node.getArguments(); @@ -517,7 +516,7 @@ public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { @Override public LogicalPlan visitRename(Rename node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); if (context.getNamedParseExpressions().isEmpty()) { // Create an UnresolvedStar for all-fields projection context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.empty())); @@ -534,7 +533,7 @@ public LogicalPlan visitRename(Rename node, CatalystPlanContext context) { @Override public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); List aliases = new ArrayList<>(); List letExpressions = node.getExpressionList(); for (Let let : letExpressions) { @@ -548,8 +547,7 @@ public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { List expressionList = visitExpressionList(aliases, context); Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); // build the plan with the projection step - child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); - return child; + return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); } @Override @@ -574,7 +572,7 @@ public LogicalPlan visitWindowFunction(WindowFunction node, CatalystPlanContext @Override public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); + visitFirstChild(node, context); List options = node.getOptions(); Integer allowedDuplication = (Integer) options.get(0).getValue().getValue(); Boolean keepEmpty = (Boolean) options.get(1).getValue().getValue(); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index f6581016f..7d1cc072b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -432,8 +432,9 @@ private Trendline.TrendlineComputation toTrendlineComputation(OpenSearchPPLParse public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) { ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); ImmutableList.Builder groupListBuilder = new ImmutableList.Builder<>(); + String funcName = ctx.TOP_APPROX() != null ? "approx_count_distinct" : "count"; ctx.fieldList().fieldExpression().forEach(field -> { - UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field), + AggregateFunction aggExpression = new AggregateFunction(funcName,internalVisitExpression(field), Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); String name = field.qualifiedName().getText(); Alias alias = new Alias("count_"+name, aggExpression); @@ -458,14 +459,12 @@ public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) .collect(Collectors.toList())) .orElse(emptyList()) ); - UnresolvedExpression unresolvedPlan = (ctx.number != null ? internalVisitExpression(ctx.number) : null); - TopAggregation aggregation = - new TopAggregation( - Optional.ofNullable((Literal) unresolvedPlan), + UnresolvedExpression expectedResults = (ctx.number != null ? internalVisitExpression(ctx.number) : null); + return new TopAggregation( + Optional.ofNullable((Literal) expectedResults), aggListBuilder.build(), aggListBuilder.build(), groupListBuilder.build()); - return aggregation; } /** Fieldsummary command. */ @@ -479,8 +478,9 @@ public UnresolvedPlan visitFieldsummaryCommand(OpenSearchPPLParser.FieldsummaryC public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) { ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); ImmutableList.Builder groupListBuilder = new ImmutableList.Builder<>(); + String funcName = ctx.RARE_APPROX() != null ? "approx_count_distinct" : "count"; ctx.fieldList().fieldExpression().forEach(field -> { - UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field), + AggregateFunction aggExpression = new AggregateFunction(funcName,internalVisitExpression(field), Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); String name = field.qualifiedName().getText(); Alias alias = new Alias("count_"+name, aggExpression); @@ -505,12 +505,12 @@ public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ct .collect(Collectors.toList())) .orElse(emptyList()) ); - RareAggregation aggregation = - new RareAggregation( + UnresolvedExpression expectedResults = (ctx.number != null ? internalVisitExpression(ctx.number) : null); + return new RareAggregation( + Optional.ofNullable((Literal) expectedResults), aggListBuilder.build(), aggListBuilder.build(), groupListBuilder.build()); - return aggregation; } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 4b7c8a1c1..36d9f9577 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -211,7 +211,8 @@ public UnresolvedExpression visitCountAllFunctionCall(OpenSearchPPLParser.CountA @Override public UnresolvedExpression visitDistinctCountFunctionCall(OpenSearchPPLParser.DistinctCountFunctionCallContext ctx) { - return new AggregateFunction("count", visit(ctx.valueExpression()), true); + String funcName = ctx.DISTINCT_COUNT_APPROX()!=null ? "approx_count_distinct" :"count"; + return new AggregateFunction(funcName, visit(ctx.valueExpression()), true); } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java index 9788ac1bc..c06f37aa3 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTransformer.java @@ -57,6 +57,8 @@ static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction return new UnresolvedFunction(seq("PERCENTILE"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), distinct, empty(),false); case PERCENTILE_APPROX: return new UnresolvedFunction(seq("PERCENTILE_APPROX"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), distinct, empty(),false); + case APPROX_COUNT_DISTINCT: + return new UnresolvedFunction(seq("APPROX_COUNT_DISTINCT"), seq(arg), distinct, empty(),false); } throw new IllegalStateException("Not Supported value: " + aggregateFunction.getFuncName()); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java index 0b0fb8314..0a4f19b53 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTransformer.java @@ -26,8 +26,10 @@ import java.util.Map; import java.util.function.Function; +import static org.opensearch.flint.spark.ppl.OpenSearchPPLLexer.DISTINCT_COUNT_APPROX; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDDATE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.APPROX_COUNT_DISTINCT; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ARRAY_LENGTH; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATEDIFF; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATE_ADD; @@ -109,6 +111,7 @@ public interface BuiltinFunctionTransformer { .put(TO_JSON_STRING, "to_json") .put(JSON_KEYS, "json_object_keys") .put(JSON_EXTRACT, "get_json_object") + .put(APPROX_COUNT_DISTINCT, "approx_count_distinct") .build(); /** diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index 9946bff6a..42cc7ed10 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -754,6 +754,34 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("test approx distinct count product group by brand sorted") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats distinct_count_approx(product) by brand | sort brand"), + context) + val star = Seq(UnresolvedStar(None)) + val brandField = UnresolvedAttribute("brand") + val productField = UnresolvedAttribute("product") + val tableRelation = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(brandField, "brand")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(productField), isDistinct = true), + "distinct_count_approx(product)")() + val brandAlias = Alias(brandField, "brand")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, brandAlias), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(brandField, Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + test("test distinct count product with alias and filter") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( @@ -803,6 +831,34 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test( + "test distinct count age by span of interval of 10 years query with sort using approximation ") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats distinct_count_approx(age) by span(age, 10) as age_span | sort age"), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(ageField), isDistinct = true), + "distinct_count_approx(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + test("test distinct count status by week window and group by status with limit") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( @@ -838,6 +894,42 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test( + "test distinct count status by week window and group by status with limit using approximation") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats distinct_count_approx(status) by span(@timestamp, 1w) as status_count_by_week, status | head 100"), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val status = Alias(UnresolvedAttribute("status"), "status")() + val statusCount = UnresolvedAttribute("status") + val table = UnresolvedRelation(Seq("table")) + + val windowExpression = Alias( + TimeWindow( + UnresolvedAttribute("`@timestamp`"), + TimeWindow.parseExpression(Literal("1 week")), + TimeWindow.parseExpression(Literal("1 week")), + 0), + "status_count_by_week")() + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(statusCount), isDistinct = true), + "distinct_count_approx(status)")() + val aggregatePlan = Aggregate( + Seq(status, windowExpression), + Seq(aggregateExpressions, status, windowExpression), + table) + val planWithLimit = GlobalLimit(Literal(100), LocalLimit(Literal(100), aggregatePlan)) + val expectedPlan = Project(star, planWithLimit) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + test("multiple stats - test average price and average age") { val context = new CatalystPlanContext val logPlan = diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala index 792a2dee6..106cba93a 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala @@ -59,6 +59,42 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, checkAnalysis = false) } + test("test simple rare command with a single field approximation") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=accounts | rare_approx address"), context) + val addressField = UnresolvedAttribute("address") + val tableRelation = UnresolvedRelation(Seq("accounts")) + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("APPROX_COUNT_DISTINCT"), Seq(addressField), isDistinct = false), + "count_address")(), + addressField) + + val aggregatePlan = + Aggregate(Seq(addressField), aggregateExpressions, tableRelation) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction( + Seq("APPROX_COUNT_DISTINCT"), + Seq(addressField), + isDistinct = false), + "count_address")(), + Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + test("test simple rare command with a by field test") { // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext From 98e1c0330af4eafe79558cebea3541545cb6a377 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Tue, 12 Nov 2024 09:53:15 -0800 Subject: [PATCH 3/4] Apply shaded rules (#885) Signed-off-by: Louis Chu --- build.sbt | 31 +++++++++++++++++++ .../apache/spark/sql/FlintJobExecutor.scala | 9 +++--- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/build.sbt b/build.sbt index 8752d3bf9..781b4f51f 100644 --- a/build.sbt +++ b/build.sbt @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ import Dependencies._ +import sbtassembly.AssemblyPlugin.autoImport.ShadeRule lazy val scala212 = "2.12.14" lazy val sparkVersion = "3.5.1" @@ -43,7 +44,35 @@ lazy val compileScalastyle = taskKey[Unit]("compileScalastyle") // Run as part of test task. lazy val testScalastyle = taskKey[Unit]("testScalastyle") +// Explanation: +// - ThisBuild / assemblyShadeRules sets the shading rules for the entire build +// - ShadeRule.rename(...) creates a rule to rename multiple package patterns +// - "shaded.@0" means prepend "shaded." to the original package name +// - .inAll applies the rule to all dependencies, not just direct dependencies +val packagesToShade = Seq( + "com.amazonaws.cloudwatch.**", + "com.fasterxml.jackson.core.**", + "com.fasterxml.jackson.dataformat.**", + "com.fasterxml.jackson.databind.**", + "com.sun.jna.**", + "com.thoughtworks.paranamer.**", + "javax.annotation.**", + "org.apache.commons.codec.**", + "org.apache.commons.logging.**", + "org.apache.hc.**", + "org.apache.http.**", + "org.glassfish.json.**", + "org.joda.time.**", + "org.reactivestreams.**", + "org.yaml.**", + "software.amazon.**" +) +ThisBuild / assemblyShadeRules := Seq( + ShadeRule.rename( + packagesToShade.map(_ -> "shaded.flint.@0"): _* + ).inAll +) lazy val commonSettings = Seq( javacOptions ++= Seq("-source", "11"), @@ -89,6 +118,8 @@ lazy val flintCore = (project in file("flint-core")) "com.amazonaws" % "aws-java-sdk-cloudwatch" % "1.12.593" exclude("com.fasterxml.jackson.core", "jackson-databind"), "software.amazon.awssdk" % "auth-crt" % "2.28.10", + "com.fasterxml.jackson.core" % "jackson-core" % jacksonVersion, + "com.fasterxml.jackson.core" % "jackson-databind" % jacksonVersion, "org.projectlombok" % "lombok" % "1.18.30" % "provided", "org.scalactic" %% "scalactic" % "3.2.15" % "test", "org.scalatest" %% "scalatest" % "3.2.15" % "test", diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index c076f9974..8e037a53e 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -10,7 +10,6 @@ import java.util.Locale import com.amazonaws.services.glue.model.{AccessDeniedException, AWSGlueException} import com.amazonaws.services.s3.model.AmazonS3Exception import com.fasterxml.jackson.databind.ObjectMapper -import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.apache.commons.text.StringEscapeUtils.unescapeJava import org.opensearch.common.Strings import org.opensearch.flint.core.IRestHighLevelClient @@ -45,7 +44,6 @@ trait FlintJobExecutor { this: Logging => val mapper = new ObjectMapper() - mapper.registerModule(DefaultScalaModule) var currentTimeProvider: TimeProvider = new RealTimeProvider() var threadPoolFactory: ThreadPoolFactory = new DefaultThreadPoolFactory() @@ -442,9 +440,10 @@ trait FlintJobExecutor { errorSource: Option[String] = None, statusCode: Option[Int] = None): String = { val errorMessage = s"$messagePrefix: ${e.getMessage}" - val errorDetails = Map("Message" -> errorMessage) ++ - errorSource.map("ErrorSource" -> _) ++ - statusCode.map(code => "StatusCode" -> code.toString) + val errorDetails = new java.util.LinkedHashMap[String, String]() + errorDetails.put("Message", errorMessage) + errorSource.foreach(es => errorDetails.put("ErrorSource", es)) + statusCode.foreach(code => errorDetails.put("StatusCode", code.toString)) val errorJson = mapper.writeValueAsString(errorDetails) From dd9c0cfbab0ca327ba655c8aa9f9ee44404d4fce Mon Sep 17 00:00:00 2001 From: Sean Kao Date: Tue, 12 Nov 2024 10:30:46 -0800 Subject: [PATCH 4/4] Fix bug for not able to get sourceTables from metadata (#883) * add logs Signed-off-by: Sean Kao * match Array when reading sourceTables Signed-off-by: Sean Kao * add test cases Signed-off-by: Sean Kao * use ArrayList only Signed-off-by: Sean Kao --------- Signed-off-by: Sean Kao --- .../flint/spark/FlintSparkIndexFactory.scala | 14 +--- .../metadatacache/FlintMetadataCache.scala | 9 +-- .../FlintOpenSearchMetadataCacheWriter.scala | 6 +- .../spark/mv/FlintSparkMaterializedView.scala | 44 ++++++++++-- .../mv/FlintSparkMaterializedViewSuite.scala | 4 +- .../FlintSparkMaterializedViewITSuite.scala | 68 ++++++++++++++++++- ...OpenSearchMetadataCacheWriterITSuite.scala | 29 ++++++-- 7 files changed, 139 insertions(+), 35 deletions(-) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala index ca659550d..3a12b63fe 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala @@ -14,7 +14,7 @@ import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE import org.opensearch.flint.spark.mv.FlintSparkMaterializedView -import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{getSourceTablesFromMetadata, MV_INDEX_TYPE} import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.SKIPPING_INDEX_TYPE import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind @@ -141,9 +141,9 @@ object FlintSparkIndexFactory extends Logging { } private def getMvSourceTables(spark: SparkSession, metadata: FlintMetadata): Array[String] = { - val sourceTables = getArrayString(metadata.properties, "sourceTables") + val sourceTables = getSourceTablesFromMetadata(metadata) if (sourceTables.isEmpty) { - FlintSparkMaterializedView.extractSourceTableNames(spark, metadata.source) + FlintSparkMaterializedView.extractSourceTablesFromQuery(spark, metadata.source) } else { sourceTables } @@ -161,12 +161,4 @@ object FlintSparkIndexFactory extends Logging { Some(value.asInstanceOf[String]) } } - - private def getArrayString(map: java.util.Map[String, AnyRef], key: String): Array[String] = { - map.get(key) match { - case list: java.util.ArrayList[_] => - list.toArray.map(_.toString) - case _ => Array.empty[String] - } - } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala index e1c0f318c..86267c881 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala @@ -10,7 +10,7 @@ import scala.collection.JavaConverters.mapAsScalaMapConverter import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.spark.FlintSparkIndexOptions -import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{getSourceTablesFromMetadata, MV_INDEX_TYPE} import org.opensearch.flint.spark.scheduler.util.IntervalSchedulerParser /** @@ -61,12 +61,7 @@ object FlintMetadataCache { None } val sourceTables = metadata.kind match { - case MV_INDEX_TYPE => - metadata.properties.get("sourceTables") match { - case list: java.util.ArrayList[_] => - list.toArray.map(_.toString) - case _ => Array.empty[String] - } + case MV_INDEX_TYPE => getSourceTablesFromMetadata(metadata) case _ => Array(metadata.source) } val lastRefreshTime: Option[Long] = metadata.latestLogEntry.flatMap { entry => diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriter.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriter.scala index 2bc373792..f6fc0ba6f 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriter.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriter.scala @@ -38,13 +38,15 @@ class FlintOpenSearchMetadataCacheWriter(options: FlintOptions) .isInstanceOf[FlintOpenSearchIndexMetadataService] override def updateMetadataCache(indexName: String, metadata: FlintMetadata): Unit = { - logInfo(s"Updating metadata cache for $indexName"); + logInfo(s"Updating metadata cache for $indexName with $metadata"); val osIndexName = OpenSearchClientUtils.sanitizeIndexName(indexName) var client: IRestHighLevelClient = null try { client = OpenSearchClientUtils.createClient(options) val request = new PutMappingRequest(osIndexName) - request.source(serialize(metadata), XContentType.JSON) + val serialized = serialize(metadata) + logInfo(s"Serialized: $serialized") + request.source(serialized, XContentType.JSON) client.updateIndexMapping(request, RequestOptions.DEFAULT) } catch { case e: Exception => diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala index e2a64d183..d5c450e7e 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.mv import java.util.Locale -import scala.collection.JavaConverters.mapAsJavaMapConverter +import scala.collection.JavaConverters._ import scala.collection.convert.ImplicitConversions.`map AsScala` import org.opensearch.flint.common.metadata.FlintMetadata @@ -18,6 +18,7 @@ import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.function.TumbleFunction import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{getFlintIndexName, MV_INDEX_TYPE} +import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedRelation} @@ -64,10 +65,14 @@ case class FlintSparkMaterializedView( }.toArray val schema = generateSchema(outputSchema).asJava + // Convert Scala Array to Java ArrayList for consistency with OpenSearch JSON parsing. + // OpenSearch uses Jackson, which deserializes JSON arrays into ArrayLists. + val sourceTablesProperty = new java.util.ArrayList[String](sourceTables.toSeq.asJava) + metadataBuilder(this) .name(mvName) .source(query) - .addProperty("sourceTables", sourceTables) + .addProperty("sourceTables", sourceTablesProperty) .indexedColumns(indexColumnMaps) .schema(schema) .build() @@ -153,7 +158,7 @@ case class FlintSparkMaterializedView( } } -object FlintSparkMaterializedView { +object FlintSparkMaterializedView extends Logging { /** MV index type name */ val MV_INDEX_TYPE = "mv" @@ -185,13 +190,40 @@ object FlintSparkMaterializedView { * @return * source table names */ - def extractSourceTableNames(spark: SparkSession, query: String): Array[String] = { - spark.sessionState.sqlParser + def extractSourceTablesFromQuery(spark: SparkSession, query: String): Array[String] = { + logInfo(s"Extracting source tables from query $query") + val sourceTables = spark.sessionState.sqlParser .parsePlan(query) .collect { case relation: UnresolvedRelation => qualifyTableName(spark, relation.tableName) } .toArray + logInfo(s"Extracted tables: [${sourceTables.mkString(", ")}]") + sourceTables + } + + /** + * Get source tables from Flint metadata properties field. + * + * @param metadata + * Flint metadata + * @return + * source table names + */ + def getSourceTablesFromMetadata(metadata: FlintMetadata): Array[String] = { + logInfo(s"Getting source tables from metadata $metadata") + val sourceTables = metadata.properties.get("sourceTables") + sourceTables match { + case list: java.util.ArrayList[_] => + logInfo(s"sourceTables is [${list.asScala.mkString(", ")}]") + list.toArray.map(_.toString) + case null => + logInfo("sourceTables property does not exist") + Array.empty[String] + case _ => + logInfo(s"sourceTables has unexpected type: ${sourceTables.getClass.getName}") + Array.empty[String] + } } /** Builder class for MV build */ @@ -223,7 +255,7 @@ object FlintSparkMaterializedView { */ def query(query: String): Builder = { this.query = query - this.sourceTables = extractSourceTableNames(flint.spark, query) + this.sourceTables = extractSourceTablesFromQuery(flint.spark, query) this } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala index 1c9a9e83c..78d2eb09e 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala @@ -64,7 +64,9 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { metadata.kind shouldBe MV_INDEX_TYPE metadata.source shouldBe "SELECT 1" metadata.properties should contain key "sourceTables" - metadata.properties.get("sourceTables").asInstanceOf[Array[String]] should have size 0 + metadata.properties + .get("sourceTables") + .asInstanceOf[java.util.ArrayList[String]] should have size 0 metadata.indexedColumns shouldBe Array( Map("columnName" -> "test_col", "columnType" -> "integer").asJava) metadata.schema shouldBe Map("test_col" -> Map("type" -> "integer").asJava).asJava diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala index fc77faaea..cf0347820 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala @@ -18,7 +18,7 @@ import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.storage.{FlintOpenSearchIndexMetadataService, OpenSearchClientUtils} import org.opensearch.flint.spark.FlintSparkIndex.quotedTableName import org.opensearch.flint.spark.mv.FlintSparkMaterializedView -import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{extractSourceTableNames, getFlintIndexName} +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{extractSourceTablesFromQuery, getFlintIndexName, getSourceTablesFromMetadata, MV_INDEX_TYPE} import org.opensearch.flint.spark.scheduler.OpenSearchAsyncQueryScheduler import org.scalatest.matchers.must.Matchers._ import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper @@ -65,14 +65,76 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { | FROM spark_catalog.default.`table/3` | INNER JOIN spark_catalog.default.`table.4` |""".stripMargin - extractSourceTableNames(flint.spark, testComplexQuery) should contain theSameElementsAs + extractSourceTablesFromQuery(flint.spark, testComplexQuery) should contain theSameElementsAs Array( "spark_catalog.default.table1", "spark_catalog.default.table2", "spark_catalog.default.`table/3`", "spark_catalog.default.`table.4`") - extractSourceTableNames(flint.spark, "SELECT 1") should have size 0 + extractSourceTablesFromQuery(flint.spark, "SELECT 1") should have size 0 + } + + test("get source table names from index metadata successfully") { + val mv = FlintSparkMaterializedView( + "spark_catalog.default.mv", + s"SELECT 1 FROM $testTable", + Array(testTable), + Map("1" -> "integer")) + val metadata = mv.metadata() + getSourceTablesFromMetadata(metadata) should contain theSameElementsAs Array(testTable) + } + + test("get source table names from deserialized metadata successfully") { + val metadata = FlintOpenSearchIndexMetadataService.deserialize(s""" { + | "_meta": { + | "kind": "$MV_INDEX_TYPE", + | "properties": { + | "sourceTables": [ + | "$testTable" + | ] + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin) + getSourceTablesFromMetadata(metadata) should contain theSameElementsAs Array(testTable) + } + + test("get empty source tables from invalid field in metadata") { + val metadataWrongType = FlintOpenSearchIndexMetadataService.deserialize(s""" { + | "_meta": { + | "kind": "$MV_INDEX_TYPE", + | "properties": { + | "sourceTables": "$testTable" + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin) + val metadataMissingField = FlintOpenSearchIndexMetadataService.deserialize(s""" { + | "_meta": { + | "kind": "$MV_INDEX_TYPE", + | "properties": { } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin) + + getSourceTablesFromMetadata(metadataWrongType) shouldBe empty + getSourceTablesFromMetadata(metadataMissingField) shouldBe empty } test("create materialized view with metadata successfully") { diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala index c0d253fd3..5b4dd0208 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala @@ -18,6 +18,7 @@ import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintOpenSearchIndexMetadataService} import org.opensearch.flint.spark.{FlintSparkIndexOptions, FlintSparkSuite} import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, SKIPPING_INDEX_TYPE} import org.scalatest.Entry @@ -161,12 +162,29 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties properties .get("sourceTables") - .asInstanceOf[List[String]] - .toArray should contain theSameElementsAs Array(testTable) + .asInstanceOf[java.util.ArrayList[String]] should contain theSameElementsAs Array( + testTable) } } - test(s"write metadata cache to materialized view index mappings with source tables") { + test("write metadata cache with source tables from index metadata") { + val mv = FlintSparkMaterializedView( + "spark_catalog.default.mv", + s"SELECT 1 FROM $testTable", + Array(testTable), + Map("1" -> "integer")) + val metadata = mv.metadata().copy(latestLogEntry = Some(flintMetadataLogEntry)) + + flintClient.createIndex(testFlintIndex, metadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, metadata) + + val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties + properties + .get("sourceTables") + .asInstanceOf[java.util.ArrayList[String]] should contain theSameElementsAs Array(testTable) + } + + test("write metadata cache with source tables from deserialized metadata") { val testTable2 = "spark_catalog.default.metadatacache_test2" val content = s""" { @@ -194,8 +212,9 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties properties .get("sourceTables") - .asInstanceOf[List[String]] - .toArray should contain theSameElementsAs Array(testTable, testTable2) + .asInstanceOf[java.util.ArrayList[String]] should contain theSameElementsAs Array( + testTable, + testTable2) } test("write metadata cache to index mappings with refresh interval") {