diff --git a/ibis-server/app/model/validator.py b/ibis-server/app/model/validator.py index 5f899c089..b287e88c9 100644 --- a/ibis-server/app/model/validator.py +++ b/ibis-server/app/model/validator.py @@ -1,10 +1,13 @@ from __future__ import annotations +import base64 +import json + from app.mdl.rewriter import Rewriter from app.model import UnprocessableEntityError from app.model.connector import Connector -rules = ["column_is_valid"] +rules = ["column_is_valid", "relationship_is_valid"] class Validator: @@ -12,17 +15,17 @@ def __init__(self, connector: Connector, rewriter: Rewriter): self.connector = connector self.rewriter = rewriter - def validate(self, rule: str, parameters: dict[str, str]): + def validate(self, rule: str, parameters: dict[str, str], manifest_str: str): if rule not in rules: raise RuleNotFoundError(rule) try: - getattr(self, f"_validate_{rule}")(parameters) + getattr(self, f"_validate_{rule}")(parameters, manifest_str) except ValidationError as e: raise e except Exception as e: raise ValidationError(f"Unknown exception: {type(e)}, message: {e!s}") - def _validate_column_is_valid(self, parameters: dict[str, str]): + def _validate_column_is_valid(self, parameters: dict[str, str], manifest_str: str): model_name = parameters.get("modelName") column_name = parameters.get("columnName") if model_name is None: @@ -37,6 +40,116 @@ def _validate_column_is_valid(self, parameters: dict[str, str]): except Exception as e: raise ValidationError(f"Exception: {type(e)}, message: {e!s}") + def _validate_relationship_is_valid( + self, parameters: dict[str, str], manifest_str: str + ): + relationship_name = parameters.get("relationshipName") + if relationship_name is None: + raise MissingRequiredParameterError("relationship") + decoded_manifest = base64.b64decode(manifest_str).decode("utf-8") + manifest = json.loads(decoded_manifest) + + relationship = list( + filter(lambda r: r["name"] == relationship_name, manifest["relationships"]) + ) + + if len(relationship) == 0: + raise ValidationError( + f"Relationship {relationship_name} not found in manifest" + ) + + left_model = self._get_model(manifest, relationship[0]["models"][0]) + right_model = self._get_model(manifest, relationship[0]["models"][1]) + relationship_type = relationship[0]["joinType"].lower() + condition = relationship[0]["condition"] + columns = condition.split("=") + left_column = columns[0].strip().split(".")[1] + right_column = columns[1].strip().split(".")[1] + + def generate_column_is_unique_sql(model_name, column_name): + return f'SELECT count(*) = count(distinct {column_name}) AS result FROM "{model_name}"' + + def generate_is_exist_join_sql( + left_model, right_model, left_column, right_column + ): + return f'SELECT count(*) > 0 AS result FROM "{left_model}" JOIN "{right_model}" ON "{left_model}"."{left_column}" = "{right_model}"."{right_column}"' + + def generate_sql_from_type( + relationship_type, left_model, right_model, left_column, right_column + ): + if relationship_type == "one_to_one": + return f"""WITH + lefttable AS ({generate_column_is_unique_sql(left_model, left_column)}), + righttable AS ({generate_column_is_unique_sql(right_model, right_column)}), + joinexist AS ({generate_is_exist_join_sql(left_model, right_model, left_column, right_column)}) + SELECT lefttable.result AND righttable.result AND joinexist.result result, + lefttable.result left_table_unique, + righttable.result right_table_unique, + joinexist.result is_related + FROM lefttable, righttable, joinexist""" + elif relationship_type == "many_to_one": + return f"""WITH + righttable AS ({generate_column_is_unique_sql(right_model, right_column)}), + joinexist AS ({generate_is_exist_join_sql(left_model, right_model, left_column, right_column)}) + SELECT righttable.result AND joinexist.result result, + righttable.result right_table_unique, + joinexist.result is_related + FROM righttable, joinexist""" + elif relationship_type == "one_to_many": + return f"""WITH + lefttable AS ({generate_column_is_unique_sql(left_model, left_column)}), + joinexist AS ({generate_is_exist_join_sql(left_model, right_model, left_column, right_column)}) + SELECT lefttable.result AND joinexist.result result, + lefttable.result left_table_unique, + joinexist.result is_related + FROM lefttable, joinexist""" + elif relationship_type == "many_to_many": + return f"""WITH + joinexist AS ({generate_is_exist_join_sql(left_model, right_model, left_column, right_column)}) + SELECT joinexist.result result, + joinexist.result is_related + FROM joinexist""" + else: + raise ValidationError(f"Unknown relationship type: {relationship_type}") + + def format_result(result): + output = {} + output["result"] = str(result.get("result").get(0)) + output["is_related"] = str(result.get("is_related").get(0)) + if result.get("left_table_unique") is not None: + output["left_table_unique"] = str( + result.get("left_table_unique").get(0) + ) + if result.get("right_table_unique") is not None: + output["right_table_unique"] = str( + result.get("right_table_unique").get(0) + ) + return output + + sql = generate_sql_from_type( + relationship_type, + left_model["name"], + right_model["name"], + left_column, + right_column, + ) + try: + rewritten_sql = self.rewriter.rewrite(sql) + result = self.connector.query(rewritten_sql, limit=1) + if not result.get("result").get(0): + raise ValidationError( + f"Relationship {relationship_name} is not valid: {format_result(result)}" + ) + + except Exception as e: + raise ValidationError(f"Exception: {type(e)}, message: {e!s}") + + def _get_model(self, manifest, model_name): + models = list(filter(lambda m: m["name"] == model_name, manifest["models"])) + if len(models) == 0: + raise ValidationError(f"Model {model_name} not found in manifest") + return models[0] + class ValidationError(UnprocessableEntityError): pass diff --git a/ibis-server/app/routers/v2/connector.py b/ibis-server/app/routers/v2/connector.py index 263a46b75..51aea4eb5 100644 --- a/ibis-server/app/routers/v2/connector.py +++ b/ibis-server/app/routers/v2/connector.py @@ -41,7 +41,7 @@ def validate(data_source: DataSource, rule_name: str, dto: ValidateDTO) -> Respo Connector(data_source, dto.connection_info, dto.manifest_str), Rewriter(dto.manifest_str, data_source=data_source), ) - validator.validate(rule_name, dto.parameters) + validator.validate(rule_name, dto.parameters, dto.manifest_str) return Response(status_code=204) diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index 664ab6e9f..fc5e43727 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -53,5 +53,5 @@ def validate(data_source: DataSource, rule_name: str, dto: ValidateDTO) -> Respo Connector(data_source, dto.connection_info, dto.manifest_str), Rewriter(dto.manifest_str, data_source=data_source, experiment=True), ) - validator.validate(rule_name, dto.parameters) + validator.validate(rule_name, dto.parameters, dto.manifest_str) return Response(status_code=204) diff --git a/ibis-server/tests/routers/v2/connector/test_bigquery.py b/ibis-server/tests/routers/v2/connector/test_bigquery.py index f24ff56b0..227b70f1f 100644 --- a/ibis-server/tests/routers/v2/connector/test_bigquery.py +++ b/ibis-server/tests/routers/v2/connector/test_bigquery.py @@ -6,6 +6,7 @@ from fastapi.testclient import TestClient from app.main import app +from app.model.validator import rules pytestmark = pytest.mark.bigquery @@ -185,6 +186,20 @@ def test_query_with_dry_run_and_invalid_sql(): assert response.text is not None +def test_query_values(): + response = client.post( + url=f"{base_url}/query", + params={"dryRun": True}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM (VALUES (1, 2), (3, 4))", + }, + ) + + assert response.status_code == 204 + + def test_validate_with_unknown_rule(): response = client.post( url=f"{base_url}/validate/unknown_rule", @@ -194,10 +209,10 @@ def test_validate_with_unknown_rule(): "parameters": {"modelName": "Orders", "columnName": "orderkey"}, }, ) + assert response.status_code == 422 assert ( - response.text - == "The rule `unknown_rule` is not in the rules, rules: ['column_is_valid']" + response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" ) diff --git a/ibis-server/tests/routers/v2/connector/test_clickhouse.py b/ibis-server/tests/routers/v2/connector/test_clickhouse.py index 00b24bc0b..8d0022de1 100644 --- a/ibis-server/tests/routers/v2/connector/test_clickhouse.py +++ b/ibis-server/tests/routers/v2/connector/test_clickhouse.py @@ -8,6 +8,7 @@ from testcontainers.clickhouse import ClickHouseContainer from app.main import app +from app.model.validator import rules from tests.confest import file_path pytestmark = pytest.mark.clickhouse @@ -410,8 +411,7 @@ def test_validate_with_unknown_rule(clickhouse: ClickHouseContainer): ) assert response.status_code == 422 assert ( - response.text - == "The rule `unknown_rule` is not in the rules, rules: ['column_is_valid']" + response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" ) diff --git a/ibis-server/tests/routers/v2/connector/test_mssql.py b/ibis-server/tests/routers/v2/connector/test_mssql.py index f57d1ffa3..d858f0a82 100644 --- a/ibis-server/tests/routers/v2/connector/test_mssql.py +++ b/ibis-server/tests/routers/v2/connector/test_mssql.py @@ -9,6 +9,7 @@ from testcontainers.mssql import SqlServerContainer from app.main import app +from app.model.validator import rules from tests.confest import file_path pytestmark = pytest.mark.mssql @@ -251,8 +252,7 @@ def test_validate_with_unknown_rule(mssql: SqlServerContainer): ) assert response.status_code == 422 assert ( - response.text - == "The rule `unknown_rule` is not in the rules, rules: ['column_is_valid']" + response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" ) diff --git a/ibis-server/tests/routers/v2/connector/test_mysql.py b/ibis-server/tests/routers/v2/connector/test_mysql.py index 62014822c..37ae693bd 100644 --- a/ibis-server/tests/routers/v2/connector/test_mysql.py +++ b/ibis-server/tests/routers/v2/connector/test_mysql.py @@ -9,6 +9,7 @@ from testcontainers.mysql import MySqlContainer from app.main import app +from app.model.validator import rules from tests.confest import file_path pytestmark = pytest.mark.mysql @@ -257,8 +258,7 @@ def test_validate_with_unknown_rule(mysql: MySqlContainer): ) assert response.status_code == 422 assert ( - response.text - == "The rule `unknown_rule` is not in the rules, rules: ['column_is_valid']" + response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" ) diff --git a/ibis-server/tests/routers/v2/connector/test_postgres.py b/ibis-server/tests/routers/v2/connector/test_postgres.py index 2ab6a962e..523439448 100644 --- a/ibis-server/tests/routers/v2/connector/test_postgres.py +++ b/ibis-server/tests/routers/v2/connector/test_postgres.py @@ -11,6 +11,7 @@ from testcontainers.postgres import PostgresContainer from app.main import app +from app.model.validator import rules from tests.confest import file_path pytestmark = pytest.mark.postgres @@ -288,8 +289,7 @@ def test_validate_with_unknown_rule(postgres: PostgresContainer): ) assert response.status_code == 422 assert ( - response.text - == "The rule `unknown_rule` is not in the rules, rules: ['column_is_valid']" + response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" ) diff --git a/ibis-server/tests/routers/v2/connector/test_snowflake.py b/ibis-server/tests/routers/v2/connector/test_snowflake.py index d6eaa858f..e108f4fbb 100644 --- a/ibis-server/tests/routers/v2/connector/test_snowflake.py +++ b/ibis-server/tests/routers/v2/connector/test_snowflake.py @@ -6,6 +6,7 @@ from fastapi.testclient import TestClient from app.main import app +from app.model.validator import rules pytestmark = pytest.mark.snowflake @@ -192,8 +193,7 @@ def test_validate_with_unknown_rule(): ) assert response.status_code == 422 assert ( - response.text - == "The rule `unknown_rule` is not in the rules, rules: ['column_is_valid']" + response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" ) diff --git a/ibis-server/tests/routers/v2/connector/test_trino.py b/ibis-server/tests/routers/v2/connector/test_trino.py index 1f3689bd5..bb2f9deb1 100644 --- a/ibis-server/tests/routers/v2/connector/test_trino.py +++ b/ibis-server/tests/routers/v2/connector/test_trino.py @@ -8,6 +8,7 @@ from trino.dbapi import connect from app.main import app +from app.model.validator import rules pytestmark = pytest.mark.trino @@ -268,8 +269,7 @@ def test_validate_with_unknown_rule(trino: TrinoContainer): ) assert response.status_code == 422 assert ( - response.text - == "The rule `unknown_rule` is not in the rules, rules: ['column_is_valid']" + response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" ) diff --git a/ibis-server/tests/routers/v2/test_relationship_valid.py b/ibis-server/tests/routers/v2/test_relationship_valid.py new file mode 100644 index 000000000..2055ae8cd --- /dev/null +++ b/ibis-server/tests/routers/v2/test_relationship_valid.py @@ -0,0 +1,162 @@ +import base64 + +import orjson +import pytest +from fastapi.testclient import TestClient +from testcontainers.postgres import PostgresContainer + +from app.main import app + +client = TestClient(app) + +relationship_test_manifest = { + "catalog": "wrenai", + "schema": "public", + "models": [ + { + "name": "t1", + "refSql": "select * from (values (1, 2), (2, 3), (3, 3)) as t1(id, many_col)", + "columns": [ + {"name": "id", "type": "integer"}, + {"name": "many_col", "type": "integer"}, + ], + "primaryKey": "id", + }, + { + "name": "t2", + "refSql": "select * from (values (1, 2), (2, 3), (3, 3)) as t2(id, many_col)", + "columns": [ + {"name": "id", "type": "integer"}, + {"name": "many_col", "type": "integer"}, + ], + }, + ], + "relationships": [ + { + "name": "t1_id_t2_id", + "joinType": "ONE_TO_ONE", + "models": ["t1", "t2"], + "condition": "t1.id = t2.id", + }, + { + "name": "t1_id_t2_many", + "joinType": "ONE_TO_MANY", + "models": ["t1", "t2"], + "condition": "t1.id = t2.many_col", + }, + { + "name": "t1_many_t2_id", + "joinType": "MANY_TO_ONE", + "models": ["t1", "t2"], + "condition": "t1.many_col = t2.id", + }, + { + "name": "invalid_t1_many_t2_id", + "joinType": "ONE_TO_ONE", + "models": ["t1", "t2"], + "condition": "t1.many_col = t2.id", + }, + ], +} + +manifest_str = base64.b64encode(orjson.dumps(relationship_test_manifest)).decode( + "utf-8" +) + +base_url = "/v2/connector/postgres" + + +@pytest.fixture(scope="module") +def postgres(request) -> PostgresContainer: + pg = PostgresContainer("postgres:16-alpine").start() + request.addfinalizer(pg.stop) + return pg + + +def test_validation_relationship(postgres: PostgresContainer): + connection_info = _to_connection_info(postgres) + response = client.post( + url=f"{base_url}/validate/relationship_is_valid", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "parameters": {"relationshipName": "t1_id_t2_id"}, + }, + ) + assert response.status_code == 204 + + response = client.post( + url=f"{base_url}/validate/relationship_is_valid", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "parameters": {"relationshipName": "t1_id_t2_many"}, + }, + ) + assert response.status_code == 204 + + response = client.post( + url=f"{base_url}/validate/relationship_is_valid", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "parameters": {"relationshipName": "t1_many_t2_id"}, + }, + ) + assert response.status_code == 204 + + +def test_validation_relationship_not_found(postgres: PostgresContainer): + connection_info = _to_connection_info(postgres) + response = client.post( + url=f"{base_url}/validate/relationship_is_valid", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "parameters": {"relationshipName": "not_found"}, + }, + ) + + assert response.status_code == 422 + assert response.text == "Relationship not_found not found in manifest" + + connection_info = _to_connection_info(postgres) + response = client.post( + url=f"{base_url}/validate/relationship_is_valid", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "parameters": {}, + }, + ) + + assert response.status_code == 422 + assert response.text == "Missing required parameter: `relationship`" + + +def test_validation_faliure(postgres: PostgresContainer): + connection_info = _to_connection_info(postgres) + response = client.post( + url=f"{base_url}/validate/relationship_is_valid", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "parameters": {"relationshipName": "invalid_t1_many_t2_id"}, + }, + ) + + assert response.status_code == 422 + assert ( + response.content + == b"Exception: , message: Relationship invalid_t1_many_t2_id is not valid: {'result': 'False', 'is_related': 'True', 'left_table_unique': 'False', 'right_table_unique': 'True'}" + ) + + +def _to_connection_info(pg: PostgresContainer): + return { + "host": pg.get_container_host_ip(), + "port": pg.get_exposed_port(pg.port), + "user": pg.username, + "password": pg.password, + "database": pg.dbname, + } diff --git a/ibis-server/tests/routers/v3/connector/test_postgres.py b/ibis-server/tests/routers/v3/connector/test_postgres.py index 13960e115..8f7f2e8cc 100644 --- a/ibis-server/tests/routers/v3/connector/test_postgres.py +++ b/ibis-server/tests/routers/v3/connector/test_postgres.py @@ -8,6 +8,7 @@ from testcontainers.postgres import PostgresContainer from app.main import app +from app.model.validator import rules from tests.confest import file_path pytestmark = pytest.mark.beta @@ -273,8 +274,7 @@ def test_validate_with_unknown_rule(postgres: PostgresContainer): ) assert response.status_code == 422 assert ( - response.text - == "The rule `unknown_rule` is not in the rules, rules: ['column_is_valid']" + response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" ) diff --git a/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java b/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java index d9fb1ddc1..741ce3e5a 100644 --- a/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java +++ b/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java @@ -153,8 +153,7 @@ protected String visitNode(Node node, Void context) @Override protected String visitRow(Row node, Void context) { - String rowPrefix = (dialect == DEFAULT || dialect == BIGQUERY) ? "ROW" : ""; - return rowPrefix + " (" + Joiner.on(", ").join(node.getItems().stream() + return "(" + Joiner.on(", ").join(node.getItems().stream() .map(child -> process(child, context)) .collect(toList())) + ")"; } diff --git a/wren-tests/src/test/java/io/wren/testing/duckdb/TestDuckDBSqlConverter.java b/wren-tests/src/test/java/io/wren/testing/duckdb/TestDuckDBSqlConverter.java index b1e38622e..2785ab0cc 100644 --- a/wren-tests/src/test/java/io/wren/testing/duckdb/TestDuckDBSqlConverter.java +++ b/wren-tests/src/test/java/io/wren/testing/duckdb/TestDuckDBSqlConverter.java @@ -90,7 +90,7 @@ public void testValues() FROM ( VALUES\s - (1, 2, ARRAY[1,2,3]) + (1, 2, ARRAY[1,2,3]) )\s """); }