Skip to content

Commit

Permalink
improvement: ibis support in mo.ui.dataframe (#2188)
Browse files Browse the repository at this point in the history
* improvement: ibis support in mo.ui.dataframe

* test dep

* fixes

* cr comments

* widen pyarrow

* > 3.8

* > 3.9

* cleanup pyarrow stubs

* types

* fix tests

* fixes
  • Loading branch information
mscolnick authored Sep 3, 2024
1 parent 2a85406 commit 4d6c9b5
Show file tree
Hide file tree
Showing 13 changed files with 531 additions and 82 deletions.
29 changes: 17 additions & 12 deletions frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import { TooltipProvider } from "@/components/ui/tooltip";
import { Banner, ErrorBanner } from "../common/error-banner";
import type { DataType } from "../vega/vega-loader";
import type { FieldTypesWithExternalType } from "@/components/data-table/types";
import { Spinner } from "@/components/icons/spinner";

type CsvURL = string;
type TableData<T> = T[] | CsvURL;
Expand Down Expand Up @@ -170,7 +171,7 @@ export const DataFrameComponent = memo(
get_column_values,
search,
}: DataTableProps): JSX.Element => {
const { data, error } = useAsyncData(
const { data, error, loading } = useAsyncData(
() => get_dataframe({}),
[value?.transforms],
);
Expand Down Expand Up @@ -206,18 +207,22 @@ export const DataFrameComponent = memo(
return (
<div>
<Tabs defaultValue="transform">
<TabsList className="h-8">
<TabsTrigger value="transform" className="text-xs py-1">
<FunctionSquareIcon className="w-3 h-3 mr-2" />
Transform
</TabsTrigger>
{supports_code_sample && (
<TabsTrigger value="code" className="text-xs py-1">
<Code2Icon className="w-3 h-3 mr-2" />
Code
<div className="flex items-center gap-2">
<TabsList className="h-8">
<TabsTrigger value="transform" className="text-xs py-1">
<FunctionSquareIcon className="w-3 h-3 mr-2" />
Transform
</TabsTrigger>
)}
</TabsList>
{supports_code_sample && (
<TabsTrigger value="code" className="text-xs py-1">
<Code2Icon className="w-3 h-3 mr-2" />
Code
</TabsTrigger>
)}
<div className="flex-grow" />
</TabsList>
{loading && <Spinner size="small" />}
</div>
<TabsContent
value="transform"
className="mt-1 border rounded-t overflow-hidden"
Expand Down
1 change: 1 addition & 0 deletions marimo/_dependencies/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class DependencyManager:

pandas = Dependency("pandas")
polars = Dependency("polars")
ibis = Dependency("ibis")
numpy = Dependency("numpy")
altair = Dependency("altair", min_version="5.3.0", max_version="6.0.0")
duckdb = Dependency("duckdb")
Expand Down
2 changes: 1 addition & 1 deletion marimo/_plugins/ui/_impl/dataframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def get_dataframe(self, _args: EmptyArgs) -> GetDataFrameResponse:
if self._error is not None:
raise GetDataFrameError(self._error)

manager = get_table_manager(self._data)
manager = get_table_manager(self._value)
response = self.search(SearchTableArgs(page_size=10, page_number=0))
return GetDataFrameResponse(
url=str(response.data),
Expand Down
17 changes: 12 additions & 5 deletions marimo/_plugins/ui/_impl/dataframes/transforms/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from marimo._dependencies.dependencies import DependencyManager
from marimo._plugins.ui._impl.dataframes.transforms.handlers import (
IbisTransformHandler,
PandasTransformHandler,
PolarsTransformHandler,
)
Expand All @@ -19,7 +20,7 @@
T = TypeVar("T")


def handle(df: T, handler: TransformHandler[T], transform: Transform) -> T:
def _handle(df: T, handler: TransformHandler[T], transform: Transform) -> T:
if transform.type is TransformType.COLUMN_CONVERSION:
return handler.handle_column_conversion(df, transform)
elif transform.type is TransformType.RENAME_COLUMN:
Expand All @@ -46,13 +47,13 @@ def handle(df: T, handler: TransformHandler[T], transform: Transform) -> T:
assert_never(transform.type)


def apply_transforms(
def _apply_transforms(
df: T, handler: TransformHandler[T], transforms: Transformations
) -> T:
if not transforms.transforms:
return df
for transform in transforms.transforms:
df = handle(df, handler, transform)
df = _handle(df, handler, transform)
return df


Expand All @@ -75,6 +76,12 @@ def get_handler_for_dataframe(
if isinstance(df, pl.DataFrame):
return PolarsTransformHandler()

if DependencyManager.ibis.has():
import ibis # type: ignore

if isinstance(df, ibis.Table):
return IbisTransformHandler()

raise ValueError(
"Unsupported dataframe type. Must be Pandas or Polars."
f" Got: {type(df)}"
Expand Down Expand Up @@ -102,7 +109,7 @@ def apply(self, transform: Transformations) -> T:
# then we can just apply the new ones to the snapshot dataframe.
if self._is_superset(transform):
transforms_to_apply = self._get_next_transformations(transform)
self._snapshot_df = apply_transforms(
self._snapshot_df = _apply_transforms(
self._snapshot_df, self._handler, transforms_to_apply
)
self._transforms = transform.transforms
Expand All @@ -111,7 +118,7 @@ def apply(self, transform: Transformations) -> T:
# If the new transformations are not a superset of the existing ones,
# then we need to start from the original dataframe.
else:
self._snapshot_df = apply_transforms(
self._snapshot_df = _apply_transforms(
self._original_df, self._handler, transform
)
self._transforms = transform.transforms
Expand Down
196 changes: 195 additions & 1 deletion marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2024 Marimo. All rights reserved.
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional, Sequence, cast
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, cast

from marimo._plugins.ui._impl.dataframes.transforms.types import (
AggregateTransform,
Expand All @@ -20,6 +20,8 @@
from marimo._utils.assert_never import assert_never

if TYPE_CHECKING:
import ibis # type: ignore
import ibis.expr.types as ir # type: ignore
import pandas as pd
import polars as pl

Expand Down Expand Up @@ -409,6 +411,198 @@ def handle_expand_dict(
return df.hstack(pl.DataFrame(column.to_list()))


class IbisTransformHandler(TransformHandler["ibis.Table"]):
@staticmethod
def supports_code_sample() -> bool:
return False

@staticmethod
def handle_column_conversion(
df: "ibis.Table", transform: ColumnConversionTransform
) -> "ibis.Table":
import ibis

if transform.errors == "ignore":
try:
# Use coalesce to handle conversion errors
return df.mutate(
**{
transform.column_id: ibis.coalesce(
df[transform.column_id].cast(
ibis.dtype(transform.data_type)
),
df[transform.column_id],
)
}
)
except ibis.common.exceptions.IbisTypeError:
return df
else:
# Default behavior (raise errors)
return df.mutate(
**{
transform.column_id: df[transform.column_id].cast(
ibis.dtype(transform.data_type)
)
}
)

@staticmethod
def handle_rename_column(
df: "ibis.Table", transform: RenameColumnTransform
) -> "ibis.Table":
return df.rename({transform.new_column_id: transform.column_id})

@staticmethod
def handle_sort_column(
df: "ibis.Table", transform: SortColumnTransform
) -> "ibis.Table":
return df.order_by(
[
df[transform.column_id].asc()
if transform.ascending
else df[transform.column_id].desc()
]
)

@staticmethod
def handle_filter_rows(
df: "ibis.Table", transform: FilterRowsTransform
) -> "ibis.Table":
import ibis

filter_conditions: list[ir.BooleanValue] = []
for condition in transform.where:
column = df[str(condition.column_id)]
value = condition.value
if condition.operator == "==":
filter_conditions.append(column == value)
elif condition.operator == "!=":
filter_conditions.append(column != value)
elif condition.operator == ">":
filter_conditions.append(column > value)
elif condition.operator == "<":
filter_conditions.append(column < value)
elif condition.operator == ">=":
filter_conditions.append(column >= value)
elif condition.operator == "<=":
filter_conditions.append(column <= value)
elif condition.operator == "is_true":
filter_conditions.append(column)
elif condition.operator == "is_false":
filter_conditions.append(~column)
elif condition.operator == "is_nan":
filter_conditions.append(column.isnull())
elif condition.operator == "is_not_nan":
filter_conditions.append(column.notnull())
elif condition.operator == "equals":
filter_conditions.append(column == value)
elif condition.operator == "does_not_equal":
filter_conditions.append(column != value)
elif condition.operator == "contains":
filter_conditions.append(column.contains(value))
elif condition.operator == "regex":
filter_conditions.append(column.re_search(value))
elif condition.operator == "starts_with":
filter_conditions.append(column.startswith(value))
elif condition.operator == "ends_with":
filter_conditions.append(column.endswith(value))
elif condition.operator == "in":
filter_conditions.append(column.isin(value))
else:
assert_never(condition.operator)

combined_condition = ibis.and_(*filter_conditions)

if transform.operation == "keep_rows":
return df.filter(combined_condition)
elif transform.operation == "remove_rows":
return df.filter(~combined_condition)
else:
raise ValueError(f"Unsupported operation: {transform.operation}")

@staticmethod
def handle_group_by(
df: "ibis.Table", transform: GroupByTransform
) -> "ibis.Table":
aggs: list[ir.Expr] = []

group_by_column_id_set = set(transform.column_ids)
agg_columns = [
column_id
for column_id in df.columns
if column_id not in group_by_column_id_set
]
for column_id in agg_columns:
agg_func = transform.aggregation
if agg_func == "count":
aggs.append(df[column_id].count().name(f"{column_id}_count"))
elif agg_func == "sum":
aggs.append(df[column_id].sum().name(f"{column_id}_sum"))
elif agg_func == "mean":
aggs.append(df[column_id].mean().name(f"{column_id}_mean"))
elif agg_func == "median":
aggs.append(df[column_id].median().name(f"{column_id}_median"))
elif agg_func == "min":
aggs.append(df[column_id].min().name(f"{column_id}_min"))
elif agg_func == "max":
aggs.append(df[column_id].max().name(f"{column_id}_max"))
else:
assert_never(agg_func)

return df.group_by(transform.column_ids).aggregate(aggs)

@staticmethod
def handle_aggregate(
df: "ibis.Table", transform: AggregateTransform
) -> "ibis.Table":
agg_dict: Dict[str, Any] = {}
for agg_func in transform.aggregations:
for column_id in transform.column_ids:
name = f"{column_id}_{agg_func}"
agg_dict[name] = getattr(df[column_id], agg_func)()
return df.aggregate(**agg_dict)

@staticmethod
def handle_select_columns(
df: "ibis.Table", transform: SelectColumnsTransform
) -> "ibis.Table":
return df.select(transform.column_ids)

@staticmethod
def handle_shuffle_rows(
df: "ibis.Table", transform: ShuffleRowsTransform
) -> "ibis.Table":
del transform
import ibis

return df.order_by(ibis.random())

@staticmethod
def handle_sample_rows(
df: "ibis.Table", transform: SampleRowsTransform
) -> "ibis.Table":
return df.sample(
transform.n / df.count().execute(),
method="row",
seed=transform.seed,
)

@staticmethod
def handle_explode_columns(
df: "ibis.Table", transform: ExplodeColumnsTransform
) -> "ibis.Table":
for column_id in transform.column_ids:
df = df.unnest(column_id)
return df

@staticmethod
def handle_expand_dict(
df: "ibis.Table", transform: ExpandDictTransform
) -> "ibis.Table":
return df.unpack(transform.column_id)


def _coerce_value(dtype: Any, value: Any) -> Any:
import numpy as np

Expand Down
6 changes: 4 additions & 2 deletions marimo/_plugins/ui/_impl/tables/default_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,12 @@ def _as_table_manager(self) -> TableManager[Any]:

if isinstance(self.data, dict):
return PyArrowTableManagerFactory.create()(
pa.Table.from_pydict(self.data)
pa.Table.from_pydict(cast(Any, self.data))
)
return PyArrowTableManagerFactory.create()(
pa.Table.from_pylist(self._normalize_data(self.data))
pa.Table.from_pylist(
cast(Any, self._normalize_data(self.data))
)
)

raise ValueError("No supported table libraries found.")
Expand Down
4 changes: 2 additions & 2 deletions marimo/_plugins/ui/_impl/tables/df_protocol_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _get_field_type(column: Column) -> Tuple[FieldType, ExternalDataType]:
# https://github.com/vega/altair/blob/18a2c3c237014591d172284560546a2f0ac1a883/altair/utils/data.py#L343
def arrow_table_from_dataframe_protocol(
dfi_df: DataFrameLike,
) -> "pa.lib.Table":
) -> "pa.Table":
"""
Convert a DataFrame Interchange Protocol compatible object
to an Arrow Table
Expand All @@ -176,4 +176,4 @@ def arrow_table_from_dataframe_protocol(
if isinstance(result, pa.Table):
return result

return pi.from_dataframe(dfi_df) # type: ignore[no-any-return]
return pi.from_dataframe(dfi_df) # type: ignore
Loading

0 comments on commit 4d6c9b5

Please sign in to comment.