diff --git a/ibis-server/app/model/__init__.py b/ibis-server/app/model/__init__.py index 0bd588d2a..8edb978e7 100644 --- a/ibis-server/app/model/__init__.py +++ b/ibis-server/app/model/__init__.py @@ -43,6 +43,10 @@ class QueryPostgresDTO(QueryDTO): connection_info: ConnectionUrl | PostgresConnectionInfo = connection_info_field +class QueryPySparkDTO(QueryDTO): + connection_info: ConnectionUrl | PySparkConnectionInfo = connection_info_field + + class QuerySnowflakeDTO(QueryDTO): connection_info: SnowflakeConnectionInfo = connection_info_field @@ -109,6 +113,17 @@ class PostgresConnectionInfo(BaseModel): password: SecretStr +class PySparkConnectionInfo(BaseModel): + app_name: SecretStr = Field(examples=["wrenai"]) + master: SecretStr = Field( + default="local[*]", + description="Spark master URL (e.g., 'local[*]', 'spark://master:7077')", + ) + configs: dict[str, str] | None = Field( + default=None, description="Additional Spark configurations" + ) + + class SnowflakeConnectionInfo(BaseModel): user: SecretStr password: SecretStr @@ -137,6 +152,7 @@ class TrinoConnectionInfo(BaseModel): | MSSqlConnectionInfo | MySqlConnectionInfo | PostgresConnectionInfo + | PySparkConnectionInfo | SnowflakeConnectionInfo | TrinoConnectionInfo ) diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index ea271f6aa..ef95ff4ba 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -7,6 +7,7 @@ import ibis from google.oauth2 import service_account from ibis import BaseBackend +from pyspark.sql import SparkSession from app.model import ( BigQueryConnectionInfo, @@ -16,6 +17,7 @@ MSSqlConnectionInfo, MySqlConnectionInfo, PostgresConnectionInfo, + PySparkConnectionInfo, QueryBigQueryDTO, QueryCannerDTO, QueryClickHouseDTO, @@ -23,6 +25,7 @@ QueryMSSqlDTO, QueryMySqlDTO, QueryPostgresDTO, + QueryPySparkDTO, QuerySnowflakeDTO, QueryTrinoDTO, SnowflakeConnectionInfo, @@ -37,6 +40,7 @@ class DataSource(StrEnum): mssql = auto() mysql = auto() postgres = auto() + pyspark = auto() snowflake = auto() trino = auto() @@ -60,6 +64,7 @@ class DataSourceExtension(Enum): mssql = QueryMSSqlDTO mysql = QueryMySqlDTO postgres = QueryPostgresDTO + pyspark = QueryPySparkDTO snowflake = QuerySnowflakeDTO trino = QueryTrinoDTO @@ -143,6 +148,20 @@ def get_postgres_connection(info: PostgresConnectionInfo) -> BaseBackend: password=info.password.get_secret_value(), ) + @staticmethod + def get_pyspark_connection(info: PySparkConnectionInfo) -> BaseBackend: + builder = SparkSession.builder.appName(info.app_name.get_secret_value()).master( + info.master.get_secret_value() + ) + + if info.configs: + for key, value in info.configs.items(): + builder = builder.config(key, value) + + # Create or get existing Spark session + spark_session = builder.getOrCreate() + return ibis.pyspark.connect(session=spark_session) + @staticmethod def get_snowflake_connection(info: SnowflakeConnectionInfo) -> BaseBackend: return ibis.snowflake.connect( diff --git a/ibis-server/pyproject.toml b/ibis-server/pyproject.toml index 594d5a77a..66df718ad 100644 --- a/ibis-server/pyproject.toml +++ b/ibis-server/pyproject.toml @@ -16,6 +16,7 @@ ibis-framework = { version = "9.5.0", extras = [ "mssql", "mysql", "postgres", + "pyspark", "snowflake", "trino", ] } @@ -42,6 +43,7 @@ sqlalchemy = "2.0.36" pre-commit = "4.0.1" ruff = "0.8.0" trino = ">=0.321,<1" +pyspark = "3.5.1" psycopg2 = ">=2.8.4,<3" clickhouse-connect = "0.8.7" @@ -54,6 +56,7 @@ markers = [ "mssql: mark a test as a mssql test", "mysql: mark a test as a mysql test", "postgres: mark a test as a postgres test", + "pyspark: mark a test as a pyspark test", "snowflake: mark a test as a snowflake test", "trino: mark a test as a trino test", "beta: mark a test as a test for beta versions of the engine", diff --git a/ibis-server/tests/routers/v2/connector/test_pyspark.py b/ibis-server/tests/routers/v2/connector/test_pyspark.py new file mode 100644 index 000000000..6e74479c7 --- /dev/null +++ b/ibis-server/tests/routers/v2/connector/test_pyspark.py @@ -0,0 +1,191 @@ +import base64 + +# import os +import orjson +import pytest +from fastapi.testclient import TestClient + +from app.main import app +from app.model.validator import rules + +pytestmark = pytest.mark.pyspark + +base_url = "/v2/connector/pyspark" + +connection_info = { + "app_name": "MyApp", + "master": "local", +} + +manifest = { + "catalog": "my_catalog", + "schema": "my_schema", + "models": [ + { + "name": "Orders", + "properties": {}, + "refSql": "select * from tpch.orders", + "columns": [ + {"name": "orderkey", "expression": "O_ORDERKEY", "type": "integer"}, + {"name": "custkey", "expression": "O_CUSTKEY", "type": "integer"}, + { + "name": "orderstatus", + "expression": "O_ORDERSTATUS", + "type": "varchar", + }, + { + "name": "totalprice", + "expression": "O_TOTALPRICE", + "type": "float", + }, + {"name": "orderdate", "expression": "O_ORDERDATE", "type": "date"}, + { + "name": "order_cust_key", + "expression": "concat(O_ORDERKEY, '_', O_CUSTKEY)", + "type": "varchar", + }, + { + "name": "timestamp", + "expression": "cast('2024-01-01T23:59:59' as timestamp)", + "type": "timestamp", + }, + { + "name": "timestamptz", + "expression": "cast('2024-01-01T23:59:59' as timestamp with time zone)", + "type": "timestamp", + }, + { + "name": "test_null_time", + "expression": "cast(NULL as timestamp)", + "type": "timestamp", + }, + ], + "primaryKey": "orderkey", + }, + ], +} + + +@pytest.fixture +def manifest_str(): + return base64.b64encode(orjson.dumps(manifest)).decode("utf-8") + + +with TestClient(app) as client: + # def test_query(manifest_str): + # response = client.post( + # url=f"{base_url}/query", + # json={ + # "connectionInfo": connection_info, + # "manifestStr": manifest_str, + # "sql": 'SELECT * FROM "Orders" ORDER BY "orderkey" LIMIT 1', + # }, + # ) + # assert response.status_code == 200 + # result = response.json() + # assert len(result["columns"]) == len(manifest["models"][0]["columns"]) + # assert len(result["data"]) == 1 + # assert result["data"][0] == [ + # 1, + # 36901, + # "O", + # "173665.47", + # "1996-01-02", + # "1_36901", + # "2024-01-01 23:59:59.000000", + # "2024-01-01 23:59:59.000000 UTC", + # None, + # ] + # assert result["dtypes"] == { + # "orderkey": "int64", + # "custkey": "int64", + # "orderstatus": "object", + # "totalprice": "object", + # "orderdate": "object", + # "order_cust_key": "object", + # "timestamp": "object", + # "timestamptz": "object", + # "test_null_time": "datetime64[ns]", + # } + + def test_query_without_manifest(): + response = client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, + ) + assert response.status_code == 422 + result = response.json() + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "manifestStr"] + assert result["detail"][0]["msg"] == "Field required" + + def test_query_without_sql(manifest_str): + response = client.post( + url=f"{base_url}/query", + json={"connectionInfo": connection_info, "manifestStr": manifest_str}, + ) + assert response.status_code == 422 + result = response.json() + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "sql"] + assert result["detail"][0]["msg"] == "Field required" + + def test_query_without_connection_info(manifest_str): + response = client.post( + url=f"{base_url}/query", + json={ + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, + ) + assert response.status_code == 422 + result = response.json() + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "connectionInfo"] + assert result["detail"][0]["msg"] == "Field required" + + # def test_query_with_dry_run(manifest_str): + # response = client.post( + # url=f"{base_url}/query", + # params={"dryRun": True}, + # json={ + # "connectionInfo": connection_info, + # "manifestStr": manifest_str, + # "sql": 'SELECT * FROM "Orders" LIMIT 1', + # }, + # ) + # assert response.status_code == 204 + + def test_query_with_dry_run_and_invalid_sql(manifest_str): + response = client.post( + url=f"{base_url}/query", + params={"dryRun": True}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM X", + }, + ) + assert response.status_code == 422 + assert response.text is not None + + def test_validate_with_unknown_rule(manifest_str): + response = client.post( + url=f"{base_url}/validate/unknown_rule", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "parameters": {"modelName": "Orders", "columnName": "orderkey"}, + }, + ) + assert response.status_code == 404 + assert ( + response.text + == f"The rule `unknown_rule` is not in the rules, rules: {rules}" + )