From 4d6c9b5c047e94a60ef083b78af81ac3d611cac6 Mon Sep 17 00:00:00 2001 From: Myles Scolnick Date: Tue, 3 Sep 2024 19:16:29 +0200 Subject: [PATCH] improvement: ibis support in mo.ui.dataframe (#2188) * improvement: ibis support in mo.ui.dataframe * test dep * fixes * cr comments * widen pyarrow * > 3.8 * > 3.9 * cleanup pyarrow stubs * types * fix tests * fixes --- .../impl/data-frames/DataFramePlugin.tsx | 29 ++- marimo/_dependencies/dependencies.py | 1 + .../_plugins/ui/_impl/dataframes/dataframe.py | 2 +- .../ui/_impl/dataframes/transforms/apply.py | 17 +- .../_impl/dataframes/transforms/handlers.py | 196 +++++++++++++++- .../_plugins/ui/_impl/tables/default_table.py | 6 +- .../ui/_impl/tables/df_protocol_table.py | 4 +- .../_plugins/ui/_impl/tables/pyarrow_table.py | 30 ++- marimo/_server/api/endpoints/ai.py | 4 +- marimo/_smoke_tests/ibis_example.py | 85 +++++++ pyproject.toml | 8 +- .../ui/_impl/dataframes/test_handlers.py | 217 +++++++++++++++--- tests/_sql/test_sql.py | 14 +- 13 files changed, 531 insertions(+), 82 deletions(-) create mode 100644 marimo/_smoke_tests/ibis_example.py diff --git a/frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx b/frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx index c53a7e89c60..84b94e0ae7a 100644 --- a/frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx +++ b/frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx @@ -22,6 +22,7 @@ import { TooltipProvider } from "@/components/ui/tooltip"; import { Banner, ErrorBanner } from "../common/error-banner"; import type { DataType } from "../vega/vega-loader"; import type { FieldTypesWithExternalType } from "@/components/data-table/types"; +import { Spinner } from "@/components/icons/spinner"; type CsvURL = string; type TableData = T[] | CsvURL; @@ -170,7 +171,7 @@ export const DataFrameComponent = memo( get_column_values, search, }: DataTableProps): JSX.Element => { - const { data, error } = useAsyncData( + const { data, error, loading } = useAsyncData( () => get_dataframe({}), [value?.transforms], ); @@ -206,18 +207,22 @@ export const DataFrameComponent = memo( return (
- - - - Transform - - {supports_code_sample && ( - - - Code +
+ + + + Transform - )} - + {supports_code_sample && ( + + + Code + + )} +
+ + {loading && } +
GetDataFrameResponse: if self._error is not None: raise GetDataFrameError(self._error) - manager = get_table_manager(self._data) + manager = get_table_manager(self._value) response = self.search(SearchTableArgs(page_size=10, page_number=0)) return GetDataFrameResponse( url=str(response.data), diff --git a/marimo/_plugins/ui/_impl/dataframes/transforms/apply.py b/marimo/_plugins/ui/_impl/dataframes/transforms/apply.py index 2637bd14efb..c6d4b0f96c7 100644 --- a/marimo/_plugins/ui/_impl/dataframes/transforms/apply.py +++ b/marimo/_plugins/ui/_impl/dataframes/transforms/apply.py @@ -5,6 +5,7 @@ from marimo._dependencies.dependencies import DependencyManager from marimo._plugins.ui._impl.dataframes.transforms.handlers import ( + IbisTransformHandler, PandasTransformHandler, PolarsTransformHandler, ) @@ -19,7 +20,7 @@ T = TypeVar("T") -def handle(df: T, handler: TransformHandler[T], transform: Transform) -> T: +def _handle(df: T, handler: TransformHandler[T], transform: Transform) -> T: if transform.type is TransformType.COLUMN_CONVERSION: return handler.handle_column_conversion(df, transform) elif transform.type is TransformType.RENAME_COLUMN: @@ -46,13 +47,13 @@ def handle(df: T, handler: TransformHandler[T], transform: Transform) -> T: assert_never(transform.type) -def apply_transforms( +def _apply_transforms( df: T, handler: TransformHandler[T], transforms: Transformations ) -> T: if not transforms.transforms: return df for transform in transforms.transforms: - df = handle(df, handler, transform) + df = _handle(df, handler, transform) return df @@ -75,6 +76,12 @@ def get_handler_for_dataframe( if isinstance(df, pl.DataFrame): return PolarsTransformHandler() + if DependencyManager.ibis.has(): + import ibis # type: ignore + + if isinstance(df, ibis.Table): + return IbisTransformHandler() + raise ValueError( "Unsupported dataframe type. Must be Pandas or Polars." f" Got: {type(df)}" @@ -102,7 +109,7 @@ def apply(self, transform: Transformations) -> T: # then we can just apply the new ones to the snapshot dataframe. if self._is_superset(transform): transforms_to_apply = self._get_next_transformations(transform) - self._snapshot_df = apply_transforms( + self._snapshot_df = _apply_transforms( self._snapshot_df, self._handler, transforms_to_apply ) self._transforms = transform.transforms @@ -111,7 +118,7 @@ def apply(self, transform: Transformations) -> T: # If the new transformations are not a superset of the existing ones, # then we need to start from the original dataframe. else: - self._snapshot_df = apply_transforms( + self._snapshot_df = _apply_transforms( self._original_df, self._handler, transform ) self._transforms = transform.transforms diff --git a/marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py b/marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py index 89b76a1ab62..07fffc2a18f 100644 --- a/marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py +++ b/marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py @@ -1,7 +1,7 @@ # Copyright 2024 Marimo. All rights reserved. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional, Sequence, cast +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, cast from marimo._plugins.ui._impl.dataframes.transforms.types import ( AggregateTransform, @@ -20,6 +20,8 @@ from marimo._utils.assert_never import assert_never if TYPE_CHECKING: + import ibis # type: ignore + import ibis.expr.types as ir # type: ignore import pandas as pd import polars as pl @@ -409,6 +411,198 @@ def handle_expand_dict( return df.hstack(pl.DataFrame(column.to_list())) +class IbisTransformHandler(TransformHandler["ibis.Table"]): + @staticmethod + def supports_code_sample() -> bool: + return False + + @staticmethod + def handle_column_conversion( + df: "ibis.Table", transform: ColumnConversionTransform + ) -> "ibis.Table": + import ibis + + if transform.errors == "ignore": + try: + # Use coalesce to handle conversion errors + return df.mutate( + **{ + transform.column_id: ibis.coalesce( + df[transform.column_id].cast( + ibis.dtype(transform.data_type) + ), + df[transform.column_id], + ) + } + ) + except ibis.common.exceptions.IbisTypeError: + return df + else: + # Default behavior (raise errors) + return df.mutate( + **{ + transform.column_id: df[transform.column_id].cast( + ibis.dtype(transform.data_type) + ) + } + ) + + @staticmethod + def handle_rename_column( + df: "ibis.Table", transform: RenameColumnTransform + ) -> "ibis.Table": + return df.rename({transform.new_column_id: transform.column_id}) + + @staticmethod + def handle_sort_column( + df: "ibis.Table", transform: SortColumnTransform + ) -> "ibis.Table": + return df.order_by( + [ + df[transform.column_id].asc() + if transform.ascending + else df[transform.column_id].desc() + ] + ) + + @staticmethod + def handle_filter_rows( + df: "ibis.Table", transform: FilterRowsTransform + ) -> "ibis.Table": + import ibis + + filter_conditions: list[ir.BooleanValue] = [] + for condition in transform.where: + column = df[str(condition.column_id)] + value = condition.value + if condition.operator == "==": + filter_conditions.append(column == value) + elif condition.operator == "!=": + filter_conditions.append(column != value) + elif condition.operator == ">": + filter_conditions.append(column > value) + elif condition.operator == "<": + filter_conditions.append(column < value) + elif condition.operator == ">=": + filter_conditions.append(column >= value) + elif condition.operator == "<=": + filter_conditions.append(column <= value) + elif condition.operator == "is_true": + filter_conditions.append(column) + elif condition.operator == "is_false": + filter_conditions.append(~column) + elif condition.operator == "is_nan": + filter_conditions.append(column.isnull()) + elif condition.operator == "is_not_nan": + filter_conditions.append(column.notnull()) + elif condition.operator == "equals": + filter_conditions.append(column == value) + elif condition.operator == "does_not_equal": + filter_conditions.append(column != value) + elif condition.operator == "contains": + filter_conditions.append(column.contains(value)) + elif condition.operator == "regex": + filter_conditions.append(column.re_search(value)) + elif condition.operator == "starts_with": + filter_conditions.append(column.startswith(value)) + elif condition.operator == "ends_with": + filter_conditions.append(column.endswith(value)) + elif condition.operator == "in": + filter_conditions.append(column.isin(value)) + else: + assert_never(condition.operator) + + combined_condition = ibis.and_(*filter_conditions) + + if transform.operation == "keep_rows": + return df.filter(combined_condition) + elif transform.operation == "remove_rows": + return df.filter(~combined_condition) + else: + raise ValueError(f"Unsupported operation: {transform.operation}") + + @staticmethod + def handle_group_by( + df: "ibis.Table", transform: GroupByTransform + ) -> "ibis.Table": + aggs: list[ir.Expr] = [] + + group_by_column_id_set = set(transform.column_ids) + agg_columns = [ + column_id + for column_id in df.columns + if column_id not in group_by_column_id_set + ] + for column_id in agg_columns: + agg_func = transform.aggregation + if agg_func == "count": + aggs.append(df[column_id].count().name(f"{column_id}_count")) + elif agg_func == "sum": + aggs.append(df[column_id].sum().name(f"{column_id}_sum")) + elif agg_func == "mean": + aggs.append(df[column_id].mean().name(f"{column_id}_mean")) + elif agg_func == "median": + aggs.append(df[column_id].median().name(f"{column_id}_median")) + elif agg_func == "min": + aggs.append(df[column_id].min().name(f"{column_id}_min")) + elif agg_func == "max": + aggs.append(df[column_id].max().name(f"{column_id}_max")) + else: + assert_never(agg_func) + + return df.group_by(transform.column_ids).aggregate(aggs) + + @staticmethod + def handle_aggregate( + df: "ibis.Table", transform: AggregateTransform + ) -> "ibis.Table": + agg_dict: Dict[str, Any] = {} + for agg_func in transform.aggregations: + for column_id in transform.column_ids: + name = f"{column_id}_{agg_func}" + agg_dict[name] = getattr(df[column_id], agg_func)() + return df.aggregate(**agg_dict) + + @staticmethod + def handle_select_columns( + df: "ibis.Table", transform: SelectColumnsTransform + ) -> "ibis.Table": + return df.select(transform.column_ids) + + @staticmethod + def handle_shuffle_rows( + df: "ibis.Table", transform: ShuffleRowsTransform + ) -> "ibis.Table": + del transform + import ibis + + return df.order_by(ibis.random()) + + @staticmethod + def handle_sample_rows( + df: "ibis.Table", transform: SampleRowsTransform + ) -> "ibis.Table": + return df.sample( + transform.n / df.count().execute(), + method="row", + seed=transform.seed, + ) + + @staticmethod + def handle_explode_columns( + df: "ibis.Table", transform: ExplodeColumnsTransform + ) -> "ibis.Table": + for column_id in transform.column_ids: + df = df.unnest(column_id) + return df + + @staticmethod + def handle_expand_dict( + df: "ibis.Table", transform: ExpandDictTransform + ) -> "ibis.Table": + return df.unpack(transform.column_id) + + def _coerce_value(dtype: Any, value: Any) -> Any: import numpy as np diff --git a/marimo/_plugins/ui/_impl/tables/default_table.py b/marimo/_plugins/ui/_impl/tables/default_table.py index fd6428c9f49..18f028ead9c 100644 --- a/marimo/_plugins/ui/_impl/tables/default_table.py +++ b/marimo/_plugins/ui/_impl/tables/default_table.py @@ -208,10 +208,12 @@ def _as_table_manager(self) -> TableManager[Any]: if isinstance(self.data, dict): return PyArrowTableManagerFactory.create()( - pa.Table.from_pydict(self.data) + pa.Table.from_pydict(cast(Any, self.data)) ) return PyArrowTableManagerFactory.create()( - pa.Table.from_pylist(self._normalize_data(self.data)) + pa.Table.from_pylist( + cast(Any, self._normalize_data(self.data)) + ) ) raise ValueError("No supported table libraries found.") diff --git a/marimo/_plugins/ui/_impl/tables/df_protocol_table.py b/marimo/_plugins/ui/_impl/tables/df_protocol_table.py index b364d5ff191..63064c364c6 100644 --- a/marimo/_plugins/ui/_impl/tables/df_protocol_table.py +++ b/marimo/_plugins/ui/_impl/tables/df_protocol_table.py @@ -154,7 +154,7 @@ def _get_field_type(column: Column) -> Tuple[FieldType, ExternalDataType]: # https://github.com/vega/altair/blob/18a2c3c237014591d172284560546a2f0ac1a883/altair/utils/data.py#L343 def arrow_table_from_dataframe_protocol( dfi_df: DataFrameLike, -) -> "pa.lib.Table": +) -> "pa.Table": """ Convert a DataFrame Interchange Protocol compatible object to an Arrow Table @@ -176,4 +176,4 @@ def arrow_table_from_dataframe_protocol( if isinstance(result, pa.Table): return result - return pi.from_dataframe(dfi_df) # type: ignore[no-any-return] + return pi.from_dataframe(dfi_df) # type: ignore diff --git a/marimo/_plugins/ui/_impl/tables/pyarrow_table.py b/marimo/_plugins/ui/_impl/tables/pyarrow_table.py index db76aed9e1d..2b8e5af9f6f 100644 --- a/marimo/_plugins/ui/_impl/tables/pyarrow_table.py +++ b/marimo/_plugins/ui/_impl/tables/pyarrow_table.py @@ -62,18 +62,21 @@ def apply_formatting( else: # pa.RecordBatch column_names = _data.schema.names - transformed_columns: list[pa.Array[Any, Any]] = [] + transformed_columns: list[pa.Array[Any]] = [] for i, col in enumerate(column_names): + transformed_column: pa.Array[Any] if isinstance(_data, pa.Table): - transformed_column = _data.column(i).chunk(0) + transformed_column = _data.column(i).chunks[0] else: transformed_column = _data.column(i) if col in format_mapping: - transformed_values = [ + transformed_values: list[Any] = [ format_value(col, value.as_py(), format_mapping) for value in transformed_column ] - formatted_type = pa.array(transformed_values).type + formatted_type = cast( + Any, pa.array(transformed_values) + ).type transformed_column = pa.array( transformed_values, type=formatted_type ) # type: ignore @@ -88,17 +91,22 @@ def apply_formatting( transformed_columns.append(transformed_column) if isinstance(_data, pa.Table): - _data = pa.table(transformed_columns, names=column_names) + _data = pa.table( + cast(Any, transformed_columns), names=column_names + ) else: # pa.RecordBatch new_schema = pa.schema( [ - pa.field(col, transformed_columns[i].type) + pa.field( + col, + cast(Any, transformed_columns[i]).type, + ) for i, col in enumerate(column_names) ] ) - _data = pa.RecordBatch.from_arrays( + _data = pa.record_batch( transformed_columns, schema=new_schema - ) # type: ignore + ) return _data @@ -119,7 +127,7 @@ def select_columns( ) -> PyArrowTableManager: if isinstance(self.data, pa.RecordBatch): return PyArrowTableManager( - pa.RecordBatch.from_arrays( + pa.record_batch( [ self.data.column( self.data.schema.get_field_index(col) @@ -147,7 +155,7 @@ def is_type(value: Any) -> bool: def get_field_types(self) -> FieldTypes: return { column: PyArrowTableManager._get_field_type( - cast(Any, self.data)[idx] + cast(Any, self.data.column(idx)) ) for idx, column in enumerate(self.data.schema.names) } @@ -263,7 +271,7 @@ def sort_values( @staticmethod def _get_field_type( - column: pa.Array[Any, Any], + column: pa.ChunkedArray[Any], ) -> Tuple[FieldType, ExternalDataType]: dtype_string = str(column.type) if isinstance(column, pa.NullArray): diff --git a/marimo/_server/api/endpoints/ai.py b/marimo/_server/api/endpoints/ai.py index 89d963127ca..ef2e207d717 100644 --- a/marimo/_server/api/endpoints/ai.py +++ b/marimo/_server/api/endpoints/ai.py @@ -234,7 +234,7 @@ async def ai_completion( # If the model starts with claude, use anthropic if model.startswith("claude"): anthropic_client = get_anthropic_client(config) - response = anthropic_client.messages.create( + anthropic_response = anthropic_client.messages.create( model=model, max_tokens=1000, messages=[ @@ -249,7 +249,7 @@ async def ai_completion( ) return StreamingResponse( - content=make_stream_response(response), + content=make_stream_response(anthropic_response), media_type="application/json", ) diff --git a/marimo/_smoke_tests/ibis_example.py b/marimo/_smoke_tests/ibis_example.py new file mode 100644 index 00000000000..872ef3e3def --- /dev/null +++ b/marimo/_smoke_tests/ibis_example.py @@ -0,0 +1,85 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "marimo", +# ] +# /// + +import marimo + +__generated_with = "0.8.7" +app = marimo.App(width="medium") + + +@app.cell +def __(): + import marimo as mo + import ibis + return ibis, mo + + +@app.cell +def __(ibis): + df = ibis.read_csv( + "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/penguins.csv", + table_name="penguins", + ) + df + return df, + + +@app.cell +def __(df): + # Print Ibis data in a pretty table + df.to_polars() + return + + +@app.cell +def __(df): + # Transform using the python API + _res = df.group_by("species", "island").agg(count=df.count()).order_by("count") + df.to_polars() + return + + +@app.cell +def __(df): + # Transform using SQL + _res = df.sql( + "SELECT species, island, count(*) AS count FROM penguins GROUP BY 1, 2" + ) + _res.to_polars() + return + + +@app.cell +def __(df, mo): + # Transform using the ui.dataframe GUI + mo.ui.dataframe(df) + return + + +@app.cell +def __(ibis): + # Unnest + ibis.memtable( + { + "x": [[0, 1, 2], [], [], [3, 4]], + "y": [["a", "b", "c"], [], [], ["d", "e"]], + } + ).unnest("x").to_polars() + return + + +@app.cell +def __(ibis): + # Unpack + ibis.memtable({"A": [{"foo": 1, "bar": "hello"}], "B": [1]}).unpack( + "A" + ).to_polars() + return + + +if __name__ == "__main__": + app.run() diff --git a/pyproject.toml b/pyproject.toml index d40815fbc7f..40b4bf3dc9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,7 +89,7 @@ dev = [ # For testing mo.ui.chart, dataframes, tables "pandas>=1.5.3", "pandas-stubs>=1.5.3.230321", - "pyarrow>=15.0.2,<16", + "pyarrow>=15.0.2,<18", "pyarrow-stubs>=10.0.1.9", # For tracing debugging "opentelemetry-api~=1.26.0", @@ -135,14 +135,16 @@ testoptional = [ "pandas-stubs>=1.5.3.230321", "geopandas~=0.13.0", "matplotlib~=3.9.2; python_version > \"3.8\"", - "pyarrow>=15.0.2,<16", - "pyarrow-stubs>=10.0.1.9", + "pyarrow>=15.0.2,<18", + "pyarrow-stubs>=20240830", + "pyarrow-hotfix", "pillow~=10.4.0", "types-Pillow~=10.2.0.20240520", "polars~=1.5.0", "anywidget~=0.9.13", "ipython~=8.12.3", "openai~=1.41.1", + "ibis-framework[duckdb]~=9.3.0; python_version > \"3.9\"", "anthropic==0.34.1", # exporting as ipynb "nbformat >=5.10.4", diff --git a/tests/_plugins/ui/_impl/dataframes/test_handlers.py b/tests/_plugins/ui/_impl/dataframes/test_handlers.py index 4b210cd4fa8..5fbc62f99c7 100644 --- a/tests/_plugins/ui/_impl/dataframes/test_handlers.py +++ b/tests/_plugins/ui/_impl/dataframes/test_handlers.py @@ -8,7 +8,7 @@ from marimo._dependencies.dependencies import DependencyManager from marimo._plugins.ui._impl.dataframes.transforms.apply import ( TransformsContainer, - apply_transforms, + _apply_transforms, get_handler_for_dataframe, ) from marimo._plugins.ui._impl.dataframes.transforms.types import ( @@ -30,9 +30,14 @@ TransformType, ) -HAS_DEPS = DependencyManager.pandas.has() and DependencyManager.polars.has() +HAS_DEPS = ( + DependencyManager.pandas.has() + and DependencyManager.polars.has() + and DependencyManager.ibis.has() +) if HAS_DEPS: + import ibis import numpy as np import pandas as pd import polars as pl @@ -41,11 +46,12 @@ pd = Mock() pl = Mock() np = Mock() + ibis = Mock() def apply(df: DataFrameType, transform: Transform) -> DataFrameType: handler = get_handler_for_dataframe(df) - return apply_transforms( + return _apply_transforms( df, handler, Transformations(transforms=[transform]) ) @@ -60,9 +66,27 @@ def assert_frame_equal(df1: DataFrameType, df2: DataFrameType) -> None: if isinstance(df1, pl.DataFrame) and isinstance(df2, pl.DataFrame): pl_testing.assert_frame_equal(df1, df2) return + if isinstance(df1, ibis.Expr) and isinstance(df2, ibis.Expr): + pl_testing.assert_frame_equal(df1.to_polars(), df2.to_polars()) + return pytest.fail("DataFrames are not of the same type") +def assert_frame_not_equal(df1: DataFrameType, df2: DataFrameType) -> None: + with pytest.raises(AssertionError): + assert_frame_equal(df1, df2) + + +def df_size(df: DataFrameType) -> int: + if isinstance(df, pd.DataFrame): + return df.size + if isinstance(df, pl.DataFrame): + return df.shape[0] + if isinstance(df, ibis.Table): + return df.count().execute() + raise ValueError("Unsupported dataframe type") + + @pytest.mark.skipif(not HAS_DEPS, reason="optional dependencies not installed") class TestTransformHandler: @staticmethod @@ -77,6 +101,10 @@ class TestTransformHandler: pl.DataFrame({"A": ["1", "2", "3"]}), pl.DataFrame({"A": [1, 2, 3]}), ), + ( + ibis.memtable({"A": ["1", "2", "3"]}), + ibis.memtable({"A": [1, 2, 3]}), + ), ], ) def test_handle_column_conversion_string_to_int( @@ -103,6 +131,10 @@ def test_handle_column_conversion_string_to_int( pl.DataFrame({"A": [1.1, 2.2, 3.3]}), pl.DataFrame({"A": ["1.1", "2.2", "3.3"]}), ), + ( + ibis.memtable({"A": [1.1, 2.2, 3.3]}), + ibis.memtable({"A": ["1.1", "2.2", "3.3"]}), + ), ], ) def test_handle_column_conversion_float_to_string( @@ -129,6 +161,10 @@ def test_handle_column_conversion_float_to_string( pl.DataFrame({"A": ["1", "2", "3", "a"]}), pl.DataFrame({"A": [1, 2, 3, None]}), ), + ( + ibis.memtable({"A": ["1", "2", "3", "a"]}), + ibis.memtable({"A": ["1", "2", "3", "a"]}), + ), ], ) def test_handle_column_conversion_ignore_errors( @@ -155,6 +191,10 @@ def test_handle_column_conversion_ignore_errors( pl.DataFrame({"A": [1, 2, 3]}), pl.DataFrame({"B": [1, 2, 3]}), ), + ( + ibis.memtable({"A": [1, 2, 3]}), + ibis.memtable({"B": [1, 2, 3]}), + ), ], ) def test_handle_rename_column( @@ -180,6 +220,11 @@ def test_handle_rename_column( pl.DataFrame({"A": [1, 2, 3]}), pl.DataFrame({"A": [3, 2, 1]}), ), + ( + ibis.memtable({"A": [3, 1, 2]}), + ibis.memtable({"A": [1, 2, 3]}), + ibis.memtable({"A": [3, 2, 1]}), + ), ], ) def test_handle_sort_column( @@ -217,6 +262,10 @@ def test_handle_sort_column( pl.DataFrame({"A": [1, 2, 3]}), pl.DataFrame({"A": [2, 3]}), ), + ( + ibis.memtable({"A": [1, 2, 3]}), + ibis.memtable({"A": [2, 3]}), + ), ], ) def test_handle_filter_rows_1( @@ -242,6 +291,10 @@ def test_handle_filter_rows_1( pl.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}), pl.DataFrame({"A": [2], "B": [5]}), ), + ( + ibis.memtable({"A": [1, 2, 3], "B": [4, 5, 6]}), + ibis.memtable({"A": [2], "B": [5]}), + ), ], ) def test_handle_filter_rows_2( @@ -267,6 +320,10 @@ def test_handle_filter_rows_2( pl.DataFrame({"A": [1, 2, 3, 4, 5]}), pl.DataFrame({"A": [1, 2, 3]}), ), + ( + ibis.memtable({"A": [1, 2, 3, 4, 5]}), + ibis.memtable({"A": [1, 2, 3]}), + ), ], ) def test_handle_filter_rows_3( @@ -292,6 +349,10 @@ def test_handle_filter_rows_3( pl.DataFrame({"A": [1, 2, 3]}), pl.DataFrame({"A": [1, 3]}), ), + ( + ibis.memtable({"A": [1, 2, 3]}), + ibis.memtable({"A": [1, 3]}), + ), ], ) def test_handle_filter_rows_4( @@ -317,6 +378,10 @@ def test_handle_filter_rows_4( pl.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}), pl.DataFrame({"A": [2, 3], "B": [5, 6]}), ), + ( + ibis.memtable({"A": [1, 2, 3], "B": [4, 5, 6]}), + ibis.memtable({"A": [2, 3], "B": [5, 6]}), + ), ], ) def test_handle_filter_rows_5( @@ -342,6 +407,10 @@ def test_handle_filter_rows_5( pl.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}), pl.DataFrame({"A": [3], "B": [6]}), ), + ( + ibis.memtable({"A": [1, 2, 3], "B": [4, 5, 6]}), + ibis.memtable({"A": [3], "B": [6]}), + ), ], ) def test_handle_filter_rows_6( @@ -367,6 +436,10 @@ def test_handle_filter_rows_6( pl.DataFrame({"A": [1, 2, 3, 4, 5], "B": [5, 4, 3, 2, 1]}), pl.DataFrame({"A": [3, 4, 5], "B": [3, 2, 1]}), ), + ( + ibis.memtable({"A": [1, 2, 3, 4, 5], "B": [5, 4, 3, 2, 1]}), + ibis.memtable({"A": [3, 4, 5], "B": [3, 2, 1]}), + ), ], ) def test_handle_filter_rows_multiple_conditions_1( @@ -395,6 +468,10 @@ def test_handle_filter_rows_multiple_conditions_1( pl.DataFrame({"A": [1, 2, 3, 4, 5], "B": [5, 4, 3, 2, 1]}), pl.DataFrame({"A": [1, 3, 4, 5], "B": [5, 3, 2, 1]}), ), + ( + ibis.memtable({"A": [1, 2, 3, 4, 5], "B": [5, 4, 3, 2, 1]}), + ibis.memtable({"A": [1, 3, 4, 5], "B": [5, 3, 2, 1]}), + ), ], ) def test_handle_filter_rows_multiple_conditions_2( @@ -423,6 +500,10 @@ def test_handle_filter_rows_multiple_conditions_2( pl.DataFrame({"A": [True, False, True, False]}), pl.DataFrame({"A": [True, True]}), ), + ( + ibis.memtable({"A": [True, False, True, False]}), + ibis.memtable({"A": [True, True]}), + ), ], ) def test_handle_filter_rows_boolean( @@ -456,6 +537,10 @@ def test_handle_filter_rows_boolean( pl.DataFrame({"A": [1, 2, 3]}), pl.exceptions.ColumnNotFoundError, ), + ( + ibis.memtable({"A": [1, 2, 3]}), + ibis.common.exceptions.IbisTypeError, + ), ], ) def test_handle_filter_rows_unknown_column( @@ -481,6 +566,10 @@ def test_handle_filter_rows_unknown_column( pl.DataFrame({"1": [1, 2, 3], "2": [4, 5, 6]}), pl.DataFrame({"1": [2, 3], "2": [5, 6]}), ), + ( + ibis.memtable({"1": [1, 2, 3], "2": [4, 5, 6]}), + ibis.memtable({"1": [2, 3], "2": [5, 6]}), + ), ], ) def test_handle_filter_rows_number_columns( @@ -512,6 +601,22 @@ def test_handle_filter_rows_number_columns( pl.DataFrame({"A": ["foo", "foo", "bar"], "B": [1, 2, 4]}), pl.DataFrame({"A": ["foo", "bar"], "B_sum": [3, 4]}), ), + ( + pl.DataFrame( + {"A": ["foo", "foo", "bar", "bar"], "B": [1, 2, 3, 4]} + ), + pl.DataFrame({"A": ["foo", "bar"], "B_sum": [3, 7]}), + ), + ( + ibis.memtable({"A": ["foo", "foo", "bar"], "B": [1, 2, 4]}), + ibis.memtable({"A": ["foo", "bar"], "B_sum": [3, 4]}), + ), + ( + ibis.memtable( + {"A": ["foo", "foo", "bar", "bar"], "B": [1, 2, 3, 4]} + ), + ibis.memtable({"A": ["foo", "bar"], "B_sum": [3, 7]}), + ), ], ) def test_handle_group_by( @@ -524,6 +629,14 @@ def test_handle_group_by( aggregation="sum", ) result = apply(df, transform) + if not isinstance(result, pd.DataFrame): + order_by_a = SortColumnTransform( + type=TransformType.SORT_COLUMN, + column_id="A", + ascending=False, + na_position="last", + ) + result = apply(result, order_by_a) assert_frame_equal(result, expected) @staticmethod @@ -543,6 +656,10 @@ def test_handle_group_by( } ), ), + ( + ibis.memtable({"A": [1, 2, 3], "B": [4, 5, 6]}), + ibis.memtable({"A_sum": [6], "B_sum": [15]}), + ), ], ) def test_handle_aggregate_sum( @@ -575,6 +692,17 @@ def test_handle_aggregate_sum( } ), ), + ( + ibis.memtable({"A": [1, 2, 3], "B": [4, 5, 6]}), + ibis.memtable( + { + "A_min": [1], + "B_min": [4], + "A_max": [3], + "B_max": [6], + } + ), + ), ], ) def test_handle_aggregate_min_max( @@ -600,6 +728,10 @@ def test_handle_aggregate_min_max( pl.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}), pl.DataFrame({"A": [1, 2, 3]}), ), + ( + ibis.memtable({"A": [1, 2, 3], "B": [4, 5, 6]}), + ibis.memtable({"A": [1, 2, 3]}), + ), ], ) def test_handle_select_columns_single( @@ -623,6 +755,10 @@ def test_handle_select_columns_single( pl.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}), pl.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}), ), + ( + ibis.memtable({"A": [1, 2, 3], "B": [4, 5, 6]}), + ibis.memtable({"A": [1, 2, 3], "B": [4, 5, 6]}), + ), ], ) def test_handle_select_columns_multiple( @@ -646,6 +782,10 @@ def test_handle_select_columns_multiple( pl.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}), pl.DataFrame({"A": [2, 3, 1], "B": [5, 6, 4]}), ), + ( + ibis.memtable({"A": [1, 2, 3], "B": [4, 5, 6]}), + ibis.memtable({"A": [2, 3, 1], "B": [5, 6, 4]}), + ), ], ) def test_shuffle_rows(df: DataFrameType, expected: DataFrameType) -> None: @@ -653,9 +793,9 @@ def test_shuffle_rows(df: DataFrameType, expected: DataFrameType) -> None: type=TransformType.SHUFFLE_ROWS, seed=42 ) result = apply(df, transform) - assert len(result) == len(expected) - assert "A" in result.columns - assert "B" in result.columns + assert df_size(result) == df_size(expected) + assert "A" in result + assert "B" in result @staticmethod @pytest.mark.parametrize( @@ -669,6 +809,10 @@ def test_shuffle_rows(df: DataFrameType, expected: DataFrameType) -> None: pl.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}), pl.DataFrame({"A": [1, 3], "B": [4, 6]}), ), + ( + ibis.memtable({"A": [1, 2, 3], "B": [4, 5, 6]}), + ibis.memtable({"A": [1, 3], "B": [4, 6]}), + ), ], ) def test_sample_rows(df: DataFrameType, expected: DataFrameType) -> None: @@ -676,7 +820,7 @@ def test_sample_rows(df: DataFrameType, expected: DataFrameType) -> None: type=TransformType.SAMPLE_ROWS, n=2, seed=42, replace=False ) result = apply(df, transform) - assert len(result) == len(expected) + assert df_size(result) == df_size(expected) assert "A" in result.columns assert "B" in result.columns @@ -698,48 +842,42 @@ def test_sample_rows(df: DataFrameType, expected: DataFrameType) -> None: "C": [["a", "b", "c"], [np.nan], [], ["d", "e"]], }, ), + ibis.memtable( + { + "A": [[0, 1, 2], [], [], [3, 4]], + "B": [1, 1, 1, 1], + "C": [["a", "b", "c"], [np.nan], [], ["d", "e"]], + } + ), ], ) def test_explode_columns(df: DataFrameType) -> None: + import ibis + transform = ExplodeColumnsTransform( type=TransformType.EXPLODE_COLUMNS, column_ids=["A", "C"] ) result = apply(df, transform) - assert_frame_equal(result, df.explode(["A", "C"])) + if isinstance(result, ibis.Table): + assert_frame_equal(result, df.unnest("A").unnest("C")) + else: + assert_frame_equal(result, df.explode(["A", "C"])) @staticmethod @pytest.mark.parametrize( ("df", "expected"), [ ( - pd.DataFrame( - { - "A": [{"foo": 1, "bar": "hello"}], - "B": [1], - } - ), - pd.DataFrame( - { - "B": [1], - "foo": [1], - "bar": ["hello"], - } - ), + pd.DataFrame({"A": [{"foo": 1, "bar": "hello"}], "B": [1]}), + pd.DataFrame({"B": [1], "foo": [1], "bar": ["hello"]}), ), ( - pl.DataFrame( - { - "A": [{"foo": 1, "bar": "hello"}], - "B": [1], - } - ), - pl.DataFrame( - { - "B": [1], - "foo": [1], - "bar": ["hello"], - } - ), + pl.DataFrame({"A": [{"foo": 1, "bar": "hello"}], "B": [1]}), + pl.DataFrame({"B": [1], "foo": [1], "bar": ["hello"]}), + ), + ( + ibis.memtable({"A": [{"foo": 1, "bar": "hello"}], "B": [1]}), + ibis.memtable({"B": [1], "foo": [1], "bar": ["hello"]}), ), ], ) @@ -748,7 +886,11 @@ def test_expand_dict(df: DataFrameType, expected: DataFrameType) -> None: type=TransformType.EXPAND_DICT, column_id="A" ) result = apply(df, transform) - assert_frame_equal(result, expected) + assert_frame_equal( + # Sort the columns because the order is not guaranteed + expected[sorted(expected.columns)], + result[sorted(result.columns)], + ) @staticmethod @pytest.mark.parametrize( @@ -764,6 +906,11 @@ def test_expand_dict(df: DataFrameType, expected: DataFrameType) -> None: pl.DataFrame({"A": [3, 2], "B": [5, 6]}), pl.DataFrame({"A": [2], "B": [6]}), ), + ( + ibis.memtable({"A": [1, 2, 3], "B": [4, 6, 5]}), + ibis.memtable({"A": [3, 2], "B": [5, 6]}), + ibis.memtable({"A": [2], "B": [6]}), + ), ], ) def test_transforms_container( diff --git a/tests/_sql/test_sql.py b/tests/_sql/test_sql.py index 02d3bdda022..db78478aa1d 100644 --- a/tests/_sql/test_sql.py +++ b/tests/_sql/test_sql.py @@ -12,7 +12,7 @@ HAS_DEPS = DependencyManager.duckdb.has() and DependencyManager.polars.has() -@pytest.mark.skipif(not HAS_DEPS, reason="pandas or polars is required") +@pytest.mark.skipif(not HAS_DEPS, reason="polars and duckdb is required") def test_query_includes_limit(): assert _query_includes_limit("SELECT * FROM t LIMIT 10") is True assert _query_includes_limit("SELECT * FROM t LIMIT\n10") is True @@ -32,7 +32,7 @@ def test_query_includes_limit(): @patch("marimo._sql.sql.output.replace") -@pytest.mark.skipif(not HAS_DEPS, reason="pandas or polars is required") +@pytest.mark.skipif(not HAS_DEPS, reason="polars and duckdb is required") def test_applies_limit(mock_replace: MagicMock) -> None: import duckdb @@ -50,7 +50,7 @@ def test_applies_limit(mock_replace: MagicMock) -> None: assert table._component_args["total-rows"] == "too_many" assert table._component_args["pagination"] is True assert len(table._data) == 300 - assert table._filtered_manager.get_num_rows() == 300 + assert table._searched_manager.get_num_rows() == 300 # Limit 10 mock_replace.reset_mock() @@ -60,7 +60,7 @@ def test_applies_limit(mock_replace: MagicMock) -> None: assert table._component_args["total-rows"] == 10 assert table._component_args["pagination"] is True assert len(table._data) == 10 - assert table._filtered_manager.get_num_rows() == 10 + assert table._searched_manager.get_num_rows() == 10 # Limit 400 mock_replace.reset_mock() @@ -70,7 +70,7 @@ def test_applies_limit(mock_replace: MagicMock) -> None: assert table._component_args["total-rows"] == 400 assert table._component_args["pagination"] is True assert len(table._data) == 400 - assert table._filtered_manager.get_num_rows() == 400 + assert table._searched_manager.get_num_rows() == 400 # Limit above 20_0000 (which is the mo.ui.table cutoff) mock_replace.reset_mock() @@ -81,9 +81,7 @@ def test_applies_limit(mock_replace: MagicMock) -> None: assert table._component_args["total-rows"] == 25_000 assert table._component_args["pagination"] is True assert len(table._data) == 25_000 - assert ( - table._filtered_manager.get_num_rows() == 20_000 - ) # cutoff by mo.ui.table DEFAULT_ROW_LIMIT + assert table._searched_manager.get_num_rows() == 25_000 finally: del os.environ["MARIMO_SQL_DEFAULT_LIMIT"]