From a30d46b47ea9dc059c2b75c1a4c4c64f4f9ff8cb Mon Sep 17 00:00:00 2001 From: Fabien Arcellier Date: Mon, 18 Nov 2024 21:44:19 +0100 Subject: [PATCH 1/2] chore: clean up core * chore: move all functions that start with writer_event_handler_* to EventHandlerExecutor class as static --- src/writer/core.py | 153 +++++++++++++++++++++++---------------------- 1 file changed, 78 insertions(+), 75 deletions(-) diff --git a/src/writer/core.py b/src/writer/core.py index d3e8c9e6f..8fbd4186a 100644 --- a/src/writer/core.py +++ b/src/writer/core.py @@ -464,7 +464,7 @@ def __setitem__(self, key: str, raw_value: Any) -> None: "new_value": raw_value } - writer_event_handler_invoke(local_mutation.handler, { + EventHandlerExecutor.invoke(local_mutation.handler, { "state": local_mutation.state, "context": context_data, "payload": payload, @@ -834,7 +834,7 @@ def subscribe_mutation(self, # existing states. To cause this, we trigger manually # the handler. if initial_triggered is True: - writer_event_handler_invoke(handler, { + EventHandlerExecutor.invoke(handler, { "state": self, "context": {"event": "init"}, "payload": {}, @@ -1083,7 +1083,7 @@ def __init__(self, middleware: Callable): @contextlib.contextmanager def execute(self, args: dict): - middleware_args = writer_event_handler_build_arguments(self.middleware, args) + middleware_args = EventHandlerExecutor.build_arguments(self.middleware, args) it = self.middleware(*middleware_args) try: yield from it @@ -1845,7 +1845,7 @@ def _call_handler_callable( with core_ui.use_component_tree(self.session.session_component_tree), \ contextlib.redirect_stdout(io.StringIO()) as f: middlewares_executors = current_app_process.middleware_registry.executors() - result = writer_event_handler_invoke_with_middlewares(middlewares_executors, handler_callable, writer_args) + result = EventHandlerExecutor.invoke_with_middlewares(middlewares_executors, handler_callable, writer_args) captured_stdout = f.getvalue() if captured_stdout: @@ -1886,6 +1886,80 @@ def handle(self, ev: WriterEvent) -> WriterEventResult: return {"ok": ok, "result": result} +class EventHandlerExecutor: + + @staticmethod + def build_arguments(func: Callable, writer_args: dict) -> List[Any]: + """ + Constructs the list of arguments based on the signature of the function + which can be a handler or middleware. + + >>> def my_event_handler(state, context): + >>> yield + + >>> args = EventHandlerExecutor.build_arguments(my_event_handler, {'state': {}, 'payload': {}, 'context': {"target": '11'}, 'session': None, 'ui': None}) + >>> [{}, {"target": '11'}] + + :param func: the function that will be called + :param writer_args: the possible arguments in writer (state, payload, ...) + """ + handler_args = inspect.getfullargspec(func).args + func_args = [] + for arg in handler_args: + if arg in writer_args: + func_args.append(writer_args[arg]) + + return func_args + + @staticmethod + def invoke(callable_handler: Callable, writer_args: dict) -> Any: + """ + Runs a handler based on its signature. + + If the handler is asynchronous, it is executed asynchronously. + If the handler only has certain parameters, only these are passed as arguments + + >>> def my_handler(state): + >>> state['a'] = 2 + >>> + >>> EventHandlerExecutor.invoke(my_handler, {'state': {'a': 1}, 'payload': None, 'context': None, 'session': None, 'ui': None}) + """ + is_async_handler = inspect.iscoroutinefunction(callable_handler) + if (not callable(callable_handler) and not is_async_handler): + raise ValueError("Invalid handler. The handler isn't a callable object.") + + handler_args = EventHandlerExecutor.build_arguments(callable_handler, writer_args) + + if is_async_handler: + async_wrapper = _async_wrapper_internal(callable_handler, handler_args) + result = asyncio.run(async_wrapper) + else: + result = callable_handler(*handler_args) + + return result + + @staticmethod + def invoke_with_middlewares(middlewares_executors: List[MiddlewareExecutor], callable_handler: Callable, writer_args: dict) -> Any: + """ + Runs the middlewares then the handler. This function allows you to manage exceptions that are triggered in middleware + + :param middlewares_executors: The list of middleware to run + :param callable_handler: The target handler + + >>> @wf.middleware() + >>> def my_middleware(state, payload, context, session, ui): + >>> yield + + >>> executor = MiddlewareExecutor(my_middleware, {'state': {}, 'payload': None, 'context': None, 'session': None, 'ui': None}) + >>> EventHandlerExecutor.invoke_with_middlewares([executor], my_handler, {'state': {}, 'payload': None, 'context': None, 'session': None, 'ui': None} + """ + if len(middlewares_executors) == 0: + return EventHandlerExecutor.invoke(callable_handler, writer_args) + else: + executor = middlewares_executors[0] + with executor.execute(writer_args): + return EventHandlerExecutor.invoke_with_middlewares(middlewares_executors[1:], callable_handler, writer_args) + class DictPropertyProxy: """ @@ -2523,77 +2597,6 @@ def parse_state_variable_expression(p: str): return parts -def writer_event_handler_build_arguments(func: Callable, writer_args: dict) -> List[Any]: - """ - Constructs the list of arguments based on the signature of the function - which can be a handler or middleware. - - >>> def my_event_handler(state, context): - >>> yield - - >>> args = writer_event_handler_build_arguments(my_event_handler, {'state': {}, 'payload': {}, 'context': {"target": '11'}, 'session': None, 'ui': None}) - >>> [{}, {"target": '11'}] - - :param func: the function that will be called - :param writer_args: the possible arguments in writer (state, payload, ...) - """ - handler_args = inspect.getfullargspec(func).args - func_args = [] - for arg in handler_args: - if arg in writer_args: - func_args.append(writer_args[arg]) - - return func_args - - -def writer_event_handler_invoke(callable_handler: Callable, writer_args: dict) -> Any: - """ - Runs a handler based on its signature. - - If the handler is asynchronous, it is executed asynchronously. - If the handler only has certain parameters, only these are passed as arguments - - >>> def my_handler(state): - >>> state['a'] = 2 - >>> - >>> writer_event_handler_invoke(my_handler, {'state': {'a': 1}, 'payload': None, 'context': None, 'session': None, 'ui': None}) - """ - is_async_handler = inspect.iscoroutinefunction(callable_handler) - if (not callable(callable_handler) and not is_async_handler): - raise ValueError("Invalid handler. The handler isn't a callable object.") - - handler_args = writer_event_handler_build_arguments(callable_handler, writer_args) - - if is_async_handler: - async_wrapper = _async_wrapper_internal(callable_handler, handler_args) - result = asyncio.run(async_wrapper) - else: - result = callable_handler(*handler_args) - - return result - -def writer_event_handler_invoke_with_middlewares(middlewares_executors: List[MiddlewareExecutor], callable_handler: Callable, writer_args: dict) -> Any: - """ - Runs the middlewares then the handler. This function allows you to manage exceptions that are triggered in middleware - - :param middlewares_executors: The list of middleware to run - :param callable_handler: The target handler - - >>> @wf.middleware() - >>> def my_middleware(state, payload, context, session, ui): - >>> yield - - >>> executor = MiddlewareExecutor(my_middleware, {'state': {}, 'payload': None, 'context': None, 'session': None, 'ui': None}) - >>> writer_event_handler_invoke_with_middlewares([executor], my_handler, {'state': {}, 'payload': None, 'context': None, 'session': None, 'ui': None} - """ - if len(middlewares_executors) == 0: - return writer_event_handler_invoke(callable_handler, writer_args) - else: - executor = middlewares_executors[0] - with executor.execute(writer_args): - return writer_event_handler_invoke_with_middlewares(middlewares_executors[1:], callable_handler, writer_args) - - async def _async_wrapper_internal(callable_handler: Callable, arg_values: List[Any]) -> Any: result = await callable_handler(*arg_values) return result From 0e379de40b3649655def3c2a207f48f97801b7c1 Mon Sep 17 00:00:00 2001 From: Fabien Arcellier Date: Mon, 18 Nov 2024 21:56:20 +0100 Subject: [PATCH 2/2] chore: clean up core * chore: move dataframe dedicated method into core_df --- src/writer/__init__.py | 2 +- src/writer/core.py | 471 +--------------------------------------- src/writer/core_df.py | 479 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 482 insertions(+), 470 deletions(-) create mode 100644 src/writer/core_df.py diff --git a/src/writer/__init__.py b/src/writer/__init__.py index 29fb3548e..b0ff35089 100644 --- a/src/writer/__init__.py +++ b/src/writer/__init__.py @@ -7,7 +7,6 @@ from writer.core import ( BytesWrapper, Config, - EditableDataframe, FileWrapper, Readable, State, @@ -22,6 +21,7 @@ from writer.core import ( writerproperty as property, ) +from writer.core_df import EditableDataframe try: from writer.ui import WriterUIManager diff --git a/src/writer/core.py b/src/writer/core.py index 8fbd4186a..827bda987 100644 --- a/src/writer/core.py +++ b/src/writer/core.py @@ -18,7 +18,6 @@ import time import traceback import urllib.request -from abc import ABCMeta from contextvars import ContextVar from multiprocessing.process import BaseProcess from types import ModuleType @@ -48,9 +47,6 @@ from writer import core_ui from writer.core_ui import Component from writer.ss_types import ( - DataframeRecordAdded, - DataframeRecordRemoved, - DataframeRecordUpdated, InstancePath, InstancePathItem, Readable, @@ -63,7 +59,6 @@ if TYPE_CHECKING: import pandas - import polars from writer.app_runner import AppProcess from writer.ss_types import AppProcessServerRequest @@ -264,6 +259,8 @@ class StateSerialiser: """ def serialise(self, v: Any) -> Union[Dict, List, str, bool, int, float, None]: from writer.ai import Conversation + from writer.core_df import EditableDataframe + if isinstance(v, State): return self._serialise_dict_recursively(v.to_dict()) if isinstance(v, Conversation): @@ -2007,412 +2004,6 @@ def __set__(self, instance, value): proxy[self.key] = value -class DataframeRecordRemove: - pass - - -class DataframeRecordProcessor(): - """ - This interface defines the signature of the methods to process the events of a - dataframe compatible with EditableDataframe. - - A Dataframe can be any structure composed of tabular data. - - This class defines the signature of the methods to be implemented. - """ - __metaclass__ = ABCMeta - - @staticmethod - def match(df: Any) -> bool: - """ - This method checks if the dataframe is compatible with the processor. - """ - raise NotImplementedError - - @staticmethod - def record(df: Any, record_index: int) -> dict: - """ - This method read a record at the given line and get it back as dictionary - - >>> edf = EditableDataframe(df) - >>> r = edf.record(1) - """ - raise NotImplementedError - - @staticmethod - def record_add(df: Any, payload: DataframeRecordAdded) -> Any: - """ - signature of the methods to be implemented to process wf-dataframe-add event - - >>> edf = EditableDataframe(df) - >>> edf.record_add({"record": {"a": 1, "b": 2}}) - """ - raise NotImplementedError - - @staticmethod - def record_update(df: Any, payload: DataframeRecordUpdated) -> Any: - """ - signature of the methods to be implemented to process wf-dataframe-update event - - >>> edf = EditableDataframe(df) - >>> edf.record_update({"record_index": 12, "record": {"a": 1, "b": 2}}) - """ - raise NotImplementedError - - @staticmethod - def record_remove(df: Any, payload: DataframeRecordRemoved) -> Any: - """ - signature of the methods to be implemented to process wf-dataframe-action event - - >>> edf = EditableDataframe(df) - >>> edf.record_remove({"record_index": 12}) - """ - raise NotImplementedError - - @staticmethod - def pyarrow_table(df: Any) -> pyarrow.Table: - """ - Serializes the dataframe into a pyarrow table - """ - raise NotImplementedError - - -class PandasRecordProcessor(DataframeRecordProcessor): - """ - PandasRecordProcessor processes records from a pandas dataframe saved into an EditableDataframe - - >>> df = pandas.DataFrame({"a": [1, 2], "b": [3, 4]}) - >>> edf = EditableDataframe(df) - >>> edf.record_add({"a": 1, "b": 2}) - """ - - @staticmethod - @import_failure(rvalue=False) - def match(df: Any) -> bool: - import pandas - return True if isinstance(df, pandas.DataFrame) else False - - @staticmethod - def record(df: 'pandas.DataFrame', record_index: int) -> dict: - """ - - >>> edf = EditableDataframe(df) - >>> r = edf.record(1) - """ - import pandas - - record = df.iloc[record_index] - if not isinstance(df.index, pandas.RangeIndex): - index_list = df.index.tolist() - record_index_content = index_list[record_index] - if isinstance(record_index_content, tuple): - for i, n in enumerate(df.index.names): - record[n] = record_index_content[i] - else: - record[df.index.names[0]] = record_index_content - - return dict(record) - - @staticmethod - def record_add(df: 'pandas.DataFrame', payload: DataframeRecordAdded) -> 'pandas.DataFrame': - """ - >>> edf = EditableDataframe(df) - >>> edf.record_add({"record": {"a": 1, "b": 2}}) - """ - import pandas - - _assert_record_match_pandas_df(df, payload['record']) - - record, index = _split_record_as_pandas_record_and_index(payload['record'], df.index.names) - - if isinstance(df.index, pandas.RangeIndex): - new_df = pandas.DataFrame([record]) - return pandas.concat([df, new_df], ignore_index=True) - else: - new_df = pandas.DataFrame([record], index=[index]) - return pandas.concat([df, new_df]) - - @staticmethod - def record_update(df: 'pandas.DataFrame', payload: DataframeRecordUpdated) -> 'pandas.DataFrame': - """ - >>> edf = EditableDataframe(df) - >>> edf.record_update({"record_index": 12, "record": {"a": 1, "b": 2}}) - """ - import pandas - - _assert_record_match_pandas_df(df, payload['record']) - - record: dict - record, index = _split_record_as_pandas_record_and_index(payload['record'], df.index.names) - - record_index = payload['record_index'] - - if isinstance(df.index, pandas.RangeIndex): - df.iloc[record_index] = record # type: ignore - else: - df.iloc[record_index] = record # type: ignore - index_list = df.index.tolist() - index_list[record_index] = index - df.index = index_list # type: ignore - - return df - - @staticmethod - def record_remove(df: 'pandas.DataFrame', payload: DataframeRecordRemoved) -> 'pandas.DataFrame': - """ - >>> edf = EditableDataframe(df) - >>> edf.record_remove({"record_index": 12}) - """ - record_index: int = payload['record_index'] - idx = df.index[record_index] - df = df.drop(idx) - - return df - - @staticmethod - def pyarrow_table(df: 'pandas.DataFrame') -> pyarrow.Table: - """ - Serializes the dataframe into a pyarrow table - """ - table = pyarrow.Table.from_pandas(df=df) - return table - - -class PolarRecordProcessor(DataframeRecordProcessor): - """ - PolarRecordProcessor processes records from a polar dataframe saved into an EditableDataframe - - >>> df = polars.DataFrame({"a": [1, 2], "b": [3, 4]}) - >>> edf = EditableDataframe(df) - >>> edf.record_add({"record": {"a": 1, "b": 2}}) - """ - - @staticmethod - @import_failure(rvalue=False) - def match(df: Any) -> bool: - import polars - return True if isinstance(df, polars.DataFrame) else False - - @staticmethod - def record(df: 'polars.DataFrame', record_index: int) -> dict: - """ - - >>> edf = EditableDataframe(df) - >>> r = edf.record(1) - """ - record = {} - r = df[record_index] - for c in r.columns: - record[c] = df[record_index, c] - - return record - - - @staticmethod - def record_add(df: 'polars.DataFrame', payload: DataframeRecordAdded) -> 'polars.DataFrame': - _assert_record_match_polar_df(df, payload['record']) - - import polars - new_df = polars.DataFrame([payload['record']]) - return polars.concat([df, new_df]) - - @staticmethod - def record_update(df: 'polars.DataFrame', payload: DataframeRecordUpdated) -> 'polars.DataFrame': - # This implementation works but is not optimal. - # I didn't find a better way to update a record in polars - # - # https://github.com/pola-rs/polars/issues/5973 - _assert_record_match_polar_df(df, payload['record']) - - record = payload['record'] - record_index = payload['record_index'] - for r in record: - df[record_index, r] = record[r] - - return df - - @staticmethod - def record_remove(df: 'polars.DataFrame', payload: DataframeRecordRemoved) -> 'polars.DataFrame': - import polars - - record_index: int = payload['record_index'] - df_filtered = polars.concat([df[:record_index], df[record_index + 1:]]) - return df_filtered - - @staticmethod - def pyarrow_table(df: 'polars.DataFrame') -> pyarrow.Table: - """ - Serializes the dataframe into a pyarrow table - """ - import pyarrow.interchange - table: pyarrow.Table = pyarrow.interchange.from_dataframe(df) - return table - -class RecordListRecordProcessor(DataframeRecordProcessor): - """ - RecordListRecordProcessor processes records from a list of record saved into an EditableDataframe - - >>> df = [{"a": 1, "b": 2}, {"a": 3, "b": 4}] - >>> edf = EditableDataframe(df) - >>> edf.record_add({"record": {"a": 1, "b": 2}}) - """ - - @staticmethod - def match(df: Any) -> bool: - return True if isinstance(df, list) else False - - - @staticmethod - def record(df: List[Dict[str, Any]], record_index: int) -> dict: - """ - - >>> edf = EditableDataframe(df) - >>> r = edf.record(1) - """ - r = df[record_index] - return copy.copy(r) - - @staticmethod - def record_add(df: List[Dict[str, Any]], payload: DataframeRecordAdded) -> List[Dict[str, Any]]: - _assert_record_match_list_of_records(df, payload['record']) - df.append(payload['record']) - return df - - @staticmethod - def record_update(df: List[Dict[str, Any]], payload: DataframeRecordUpdated) -> List[Dict[str, Any]]: - _assert_record_match_list_of_records(df, payload['record']) - - record_index = payload['record_index'] - record = payload['record'] - - df[record_index] = record - return df - - @staticmethod - def record_remove(df: List[Dict[str, Any]], payload: DataframeRecordRemoved) -> List[Dict[str, Any]]: - del(df[payload['record_index']]) - return df - - @staticmethod - def pyarrow_table(df: List[Dict[str, Any]]) -> pyarrow.Table: - """ - Serializes the dataframe into a pyarrow table - """ - column_names = list(df[0].keys()) - columns = {key: [record[key] for record in df] for key in column_names} - - pyarrow_columns = {key: pyarrow.array(values) for key, values in columns.items()} - schema = pyarrow.schema([(key, pyarrow_columns[key].type) for key in pyarrow_columns]) - table = pyarrow.Table.from_arrays( - [pyarrow_columns[key] for key in column_names], - schema=schema - ) - - return table - -class EditableDataframe(MutableValue): - """ - Editable Dataframe makes it easier to process events from components - that modify a dataframe like the dataframe editor. - - >>> initial_state = wf.init_state({ - >>> "df": wf.EditableDataframe(df) - >>> }) - - Editable Dataframe is compatible with a pandas, thrillers or record list dataframe - """ - processors = [PandasRecordProcessor, PolarRecordProcessor, RecordListRecordProcessor] - - def __init__(self, df: Union['pandas.DataFrame', 'polars.DataFrame', List[dict]]): - super().__init__() - self._df = df - self.processor: Type[DataframeRecordProcessor] - for processor in self.processors: - if processor.match(self.df): - self.processor = processor - break - - if self.processor is None: - raise ValueError("The dataframe must be a pandas, polar Dataframe or a list of record") - - @property - def df(self) -> Union['pandas.DataFrame', 'polars.DataFrame', List[dict]]: - return self._df - - @df.setter - def df(self, value: Union['pandas.DataFrame', 'polars.DataFrame', List[dict]]) -> None: - self._df = value - self.mutate() - - def record_add(self, payload: DataframeRecordAdded) -> None: - """ - Adds a record to the dataframe - - >>> df = pandas.DataFrame({"a": [1, 2], "b": [3, 4]}) - >>> edf = EditableDataframe(df) - >>> edf.record_add({"record": {"a": 1, "b": 2}}) - """ - assert self.processor is not None - - self._df = self.processor.record_add(self.df, payload) - self.mutate() - - def record_update(self, payload: DataframeRecordUpdated) -> None: - """ - Updates a record in the dataframe - - The record must be complete otherwise an error is raised (ValueError). - It must a value for each index / column. - - >>> df = pandas.DataFrame({"a": [1, 2], "b": [3, 4]}) - >>> edf = EditableDataframe(df) - >>> edf.record_update({"record_index": 0, "record": {"a": 2, "b": 2}}) - """ - assert self.processor is not None - - self._df = self.processor.record_update(self.df, payload) - self.mutate() - - def record_remove(self, payload: DataframeRecordRemoved) -> None: - """ - Removes a record from the dataframe - - >>> df = pandas.DataFrame({"a": [1, 2], "b": [3, 4]}) - >>> edf = EditableDataframe(df) - >>> edf.record_remove({"record_index": 0}) - """ - assert self.processor is not None - - self._df = self.processor.record_remove(self.df, payload) - self.mutate() - - def pyarrow_table(self) -> pyarrow.Table: - """ - Serializes the dataframe into a pyarrow table - - This mechanism is used for serializing data for transmission to the frontend. - - >>> df = pandas.DataFrame({"a": [1, 2], "b": [3, 4]}) - >>> edf = EditableDataframe(df) - >>> pa_table = edf.pyarrow_table() - """ - assert self.processor is not None - - pa_table = self.processor.pyarrow_table(self.df) - return pa_table - - def record(self, record_index: int): - """ - Retrieves a specific record in dictionary form. - - :param record_index: - :return: - """ - assert self.processor is not None - - record = self.processor.record(self.df, record_index) - return record - S = TypeVar("S", bound=WriterState) def new_initial_state(klass: Type[S], raw_state: dict) -> S: @@ -2601,44 +2192,6 @@ async def _async_wrapper_internal(callable_handler: Callable, arg_values: List[A result = await callable_handler(*arg_values) return result -def _assert_record_match_pandas_df(df: 'pandas.DataFrame', record: Dict[str, Any]) -> None: - """ - Asserts that the record matches the dataframe columns & index - - >>> _assert_record_match_pandas_df(pandas.DataFrame({"a": [1, 2], "b": [3, 4]}), {"a": 1, "b": 2}) - """ - import pandas - - columns = set(list(df.columns.values) + df.index.names) if isinstance(df.index, pandas.RangeIndex) is False else set(df.columns.values) - columns_record = set(record.keys()) - if columns != columns_record: - raise ValueError(f"Columns mismatch. Expected {columns}, got {columns_record}") - -def _assert_record_match_polar_df(df: 'polars.DataFrame', record: Dict[str, Any]) -> None: - """ - Asserts that the record matches the columns of polar dataframe - - >>> _assert_record_match_pandas_df(polars.DataFrame({"a": [1, 2], "b": [3, 4]}), {"a": 1, "b": 2}) - """ - columns = set(df.columns) - columns_record = set(record.keys()) - if columns != columns_record: - raise ValueError(f"Columns mismatch. Expected {columns}, got {columns_record}") - -def _assert_record_match_list_of_records(df: List[Dict[str, Any]], record: Dict[str, Any]) -> None: - """ - Asserts that the record matches the key in the record list (it use the first record to check) - - >>> _assert_record_match_list_of_records([{"a": 1, "b": 2}, {"a": 3, "b": 4}], {"a": 1, "b": 2}) - """ - if len(df) == 0: - return - - columns = set(df[0].keys()) - columns_record = set(record.keys()) - if columns != columns_record: - raise ValueError(f"Columns mismatch. Expected {columns}, got {columns_record}") - def _event_handler_session_info() -> Dict[str, Any]: """ Returns the session information for the current event handler. @@ -2664,26 +2217,6 @@ def _event_handler_ui_manager(): else: raise RuntimeError(_get_ui_runtime_error_message()) - -def _split_record_as_pandas_record_and_index(param: dict, index_columns: list) -> Tuple[dict, tuple]: - """ - Separates a record into the record part and the index part to be able to - create or update a row in a dataframe. - - >>> record, index = _split_record_as_pandas_record_and_index({"a": 1, "b": 2}, ["a"]) - >>> print(record) # {"b": 2} - >>> print(index) # (1,) - """ - final_record = {} - final_index = [] - for key, value in param.items(): - if key in index_columns: - final_index.append(value) - else: - final_record[key] = value - - return final_record, tuple(final_index) - def _deserialize_bigint_format(payload: Optional[Union[dict, list]]): """ Decodes the payload of a big int serialization diff --git a/src/writer/core_df.py b/src/writer/core_df.py new file mode 100644 index 000000000..f5e694798 --- /dev/null +++ b/src/writer/core_df.py @@ -0,0 +1,479 @@ +""" +`core_df` contains classes and functions that allow you to manipulate editable dataframes. +""" +import copy +from abc import ABCMeta +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, Union + +import pyarrow # type: ignore + +from .core import MutableValue, import_failure +from .ss_types import DataframeRecordAdded, DataframeRecordRemoved, DataframeRecordUpdated + +if TYPE_CHECKING: + import pandas + import polars + + +class DataframeRecordProcessor(): + """ + This interface defines the signature of the methods to process the events of a + dataframe compatible with EditableDataframe. + + A Dataframe can be any structure composed of tabular data. + + This class defines the signature of the methods to be implemented. + """ + __metaclass__ = ABCMeta + + @staticmethod + def match(df: Any) -> bool: + """ + This method checks if the dataframe is compatible with the processor. + """ + raise NotImplementedError + + @staticmethod + def record(df: Any, record_index: int) -> dict: + """ + This method read a record at the given line and get it back as dictionary + + >>> edf = EditableDataframe(df) + >>> r = edf.record(1) + """ + raise NotImplementedError + + @staticmethod + def record_add(df: Any, payload: DataframeRecordAdded) -> Any: + """ + signature of the methods to be implemented to process wf-dataframe-add event + + >>> edf = EditableDataframe(df) + >>> edf.record_add({"record": {"a": 1, "b": 2}}) + """ + raise NotImplementedError + + @staticmethod + def record_update(df: Any, payload: DataframeRecordUpdated) -> Any: + """ + signature of the methods to be implemented to process wf-dataframe-update event + + >>> edf = EditableDataframe(df) + >>> edf.record_update({"record_index": 12, "record": {"a": 1, "b": 2}}) + """ + raise NotImplementedError + + @staticmethod + def record_remove(df: Any, payload: DataframeRecordRemoved) -> Any: + """ + signature of the methods to be implemented to process wf-dataframe-action event + + >>> edf = EditableDataframe(df) + >>> edf.record_remove({"record_index": 12}) + """ + raise NotImplementedError + + @staticmethod + def pyarrow_table(df: Any) -> pyarrow.Table: + """ + Serializes the dataframe into a pyarrow table + """ + raise NotImplementedError + + +class PandasRecordProcessor(DataframeRecordProcessor): + """ + PandasRecordProcessor processes records from a pandas dataframe saved into an EditableDataframe + + >>> df = pandas.DataFrame({"a": [1, 2], "b": [3, 4]}) + >>> edf = EditableDataframe(df) + >>> edf.record_add({"a": 1, "b": 2}) + """ + + @staticmethod + @import_failure(rvalue=False) + def match(df: Any) -> bool: + import pandas + return True if isinstance(df, pandas.DataFrame) else False + + @staticmethod + def record(df: 'pandas.DataFrame', record_index: int) -> dict: + """ + + >>> edf = EditableDataframe(df) + >>> r = edf.record(1) + """ + import pandas + + record = df.iloc[record_index] + if not isinstance(df.index, pandas.RangeIndex): + index_list = df.index.tolist() + record_index_content = index_list[record_index] + if isinstance(record_index_content, tuple): + for i, n in enumerate(df.index.names): + record[n] = record_index_content[i] + else: + record[df.index.names[0]] = record_index_content + + return dict(record) + + @staticmethod + def record_add(df: 'pandas.DataFrame', payload: DataframeRecordAdded) -> 'pandas.DataFrame': + """ + >>> edf = EditableDataframe(df) + >>> edf.record_add({"record": {"a": 1, "b": 2}}) + """ + import pandas + + _assert_record_match_pandas_df(df, payload['record']) + + record, index = _split_record_as_pandas_record_and_index(payload['record'], df.index.names) + + if isinstance(df.index, pandas.RangeIndex): + new_df = pandas.DataFrame([record]) + return pandas.concat([df, new_df], ignore_index=True) + else: + new_df = pandas.DataFrame([record], index=[index]) + return pandas.concat([df, new_df]) + + @staticmethod + def record_update(df: 'pandas.DataFrame', payload: DataframeRecordUpdated) -> 'pandas.DataFrame': + """ + >>> edf = EditableDataframe(df) + >>> edf.record_update({"record_index": 12, "record": {"a": 1, "b": 2}}) + """ + import pandas + + _assert_record_match_pandas_df(df, payload['record']) + + record: dict + record, index = _split_record_as_pandas_record_and_index(payload['record'], df.index.names) + + record_index = payload['record_index'] + + if isinstance(df.index, pandas.RangeIndex): + df.iloc[record_index] = record # type: ignore + else: + df.iloc[record_index] = record # type: ignore + index_list = df.index.tolist() + index_list[record_index] = index + df.index = index_list # type: ignore + + return df + + @staticmethod + def record_remove(df: 'pandas.DataFrame', payload: DataframeRecordRemoved) -> 'pandas.DataFrame': + """ + >>> edf = EditableDataframe(df) + >>> edf.record_remove({"record_index": 12}) + """ + record_index: int = payload['record_index'] + idx = df.index[record_index] + df = df.drop(idx) + + return df + + @staticmethod + def pyarrow_table(df: 'pandas.DataFrame') -> pyarrow.Table: + """ + Serializes the dataframe into a pyarrow table + """ + table = pyarrow.Table.from_pandas(df=df) + return table + + +class PolarRecordProcessor(DataframeRecordProcessor): + """ + PolarRecordProcessor processes records from a polar dataframe saved into an EditableDataframe + + >>> df = polars.DataFrame({"a": [1, 2], "b": [3, 4]}) + >>> edf = EditableDataframe(df) + >>> edf.record_add({"record": {"a": 1, "b": 2}}) + """ + + @staticmethod + @import_failure(rvalue=False) + def match(df: Any) -> bool: + import polars + return True if isinstance(df, polars.DataFrame) else False + + @staticmethod + def record(df: 'polars.DataFrame', record_index: int) -> dict: + """ + + >>> edf = EditableDataframe(df) + >>> r = edf.record(1) + """ + record = {} + r = df[record_index] + for c in r.columns: + record[c] = df[record_index, c] + + return record + + + @staticmethod + def record_add(df: 'polars.DataFrame', payload: DataframeRecordAdded) -> 'polars.DataFrame': + _assert_record_match_polar_df(df, payload['record']) + + import polars + new_df = polars.DataFrame([payload['record']]) + return polars.concat([df, new_df]) + + @staticmethod + def record_update(df: 'polars.DataFrame', payload: DataframeRecordUpdated) -> 'polars.DataFrame': + # This implementation works but is not optimal. + # I didn't find a better way to update a record in polars + # + # https://github.com/pola-rs/polars/issues/5973 + _assert_record_match_polar_df(df, payload['record']) + + record = payload['record'] + record_index = payload['record_index'] + for r in record: + df[record_index, r] = record[r] + + return df + + @staticmethod + def record_remove(df: 'polars.DataFrame', payload: DataframeRecordRemoved) -> 'polars.DataFrame': + import polars + + record_index: int = payload['record_index'] + df_filtered = polars.concat([df[:record_index], df[record_index + 1:]]) + return df_filtered + + @staticmethod + def pyarrow_table(df: 'polars.DataFrame') -> pyarrow.Table: + """ + Serializes the dataframe into a pyarrow table + """ + import pyarrow.interchange # type: ignore + table: pyarrow.Table = pyarrow.interchange.from_dataframe(df) + return table + +class RecordListRecordProcessor(DataframeRecordProcessor): + """ + RecordListRecordProcessor processes records from a list of record saved into an EditableDataframe + + >>> df = [{"a": 1, "b": 2}, {"a": 3, "b": 4}] + >>> edf = EditableDataframe(df) + >>> edf.record_add({"record": {"a": 1, "b": 2}}) + """ + + @staticmethod + def match(df: Any) -> bool: + return True if isinstance(df, list) else False + + + @staticmethod + def record(df: List[Dict[str, Any]], record_index: int) -> dict: + """ + + >>> edf = EditableDataframe(df) + >>> r = edf.record(1) + """ + r = df[record_index] + return copy.copy(r) + + @staticmethod + def record_add(df: List[Dict[str, Any]], payload: DataframeRecordAdded) -> List[Dict[str, Any]]: + _assert_record_match_list_of_records(df, payload['record']) + df.append(payload['record']) + return df + + @staticmethod + def record_update(df: List[Dict[str, Any]], payload: DataframeRecordUpdated) -> List[Dict[str, Any]]: + _assert_record_match_list_of_records(df, payload['record']) + + record_index = payload['record_index'] + record = payload['record'] + + df[record_index] = record + return df + + @staticmethod + def record_remove(df: List[Dict[str, Any]], payload: DataframeRecordRemoved) -> List[Dict[str, Any]]: + del(df[payload['record_index']]) + return df + + @staticmethod + def pyarrow_table(df: List[Dict[str, Any]]) -> pyarrow.Table: + """ + Serializes the dataframe into a pyarrow table + """ + column_names = list(df[0].keys()) + columns = {key: [record[key] for record in df] for key in column_names} + + pyarrow_columns = {key: pyarrow.array(values) for key, values in columns.items()} + schema = pyarrow.schema([(key, pyarrow_columns[key].type) for key in pyarrow_columns]) + table = pyarrow.Table.from_arrays( + [pyarrow_columns[key] for key in column_names], + schema=schema + ) + + return table + +class EditableDataframe(MutableValue): + """ + Editable Dataframe makes it easier to process events from components + that modify a dataframe like the dataframe editor. + + >>> initial_state = wf.init_state({ + >>> "df": wf.EditableDataframe(df) + >>> }) + + Editable Dataframe is compatible with a pandas, thrillers or record list dataframe + """ + processors = [PandasRecordProcessor, PolarRecordProcessor, RecordListRecordProcessor] + + def __init__(self, df: Union['pandas.DataFrame', 'polars.DataFrame', List[dict]]): + super().__init__() + self._df = df + self.processor: Type[DataframeRecordProcessor] + for processor in self.processors: + if processor.match(self.df): + self.processor = processor + break + + if self.processor is None: + raise ValueError("The dataframe must be a pandas, polar Dataframe or a list of record") + + @property + def df(self) -> Union['pandas.DataFrame', 'polars.DataFrame', List[dict]]: + return self._df + + @df.setter + def df(self, value: Union['pandas.DataFrame', 'polars.DataFrame', List[dict]]) -> None: + self._df = value + self.mutate() + + def record_add(self, payload: DataframeRecordAdded) -> None: + """ + Adds a record to the dataframe + + >>> df = pandas.DataFrame({"a": [1, 2], "b": [3, 4]}) + >>> edf = EditableDataframe(df) + >>> edf.record_add({"record": {"a": 1, "b": 2}}) + """ + assert self.processor is not None + + self._df = self.processor.record_add(self.df, payload) + self.mutate() + + def record_update(self, payload: DataframeRecordUpdated) -> None: + """ + Updates a record in the dataframe + + The record must be complete otherwise an error is raised (ValueError). + It must a value for each index / column. + + >>> df = pandas.DataFrame({"a": [1, 2], "b": [3, 4]}) + >>> edf = EditableDataframe(df) + >>> edf.record_update({"record_index": 0, "record": {"a": 2, "b": 2}}) + """ + assert self.processor is not None + + self._df = self.processor.record_update(self.df, payload) + self.mutate() + + def record_remove(self, payload: DataframeRecordRemoved) -> None: + """ + Removes a record from the dataframe + + >>> df = pandas.DataFrame({"a": [1, 2], "b": [3, 4]}) + >>> edf = EditableDataframe(df) + >>> edf.record_remove({"record_index": 0}) + """ + assert self.processor is not None + + self._df = self.processor.record_remove(self.df, payload) + self.mutate() + + def pyarrow_table(self) -> pyarrow.Table: + """ + Serializes the dataframe into a pyarrow table + + This mechanism is used for serializing data for transmission to the frontend. + + >>> df = pandas.DataFrame({"a": [1, 2], "b": [3, 4]}) + >>> edf = EditableDataframe(df) + >>> pa_table = edf.pyarrow_table() + """ + assert self.processor is not None + + pa_table = self.processor.pyarrow_table(self.df) + return pa_table + + def record(self, record_index: int): + """ + Retrieves a specific record in dictionary form. + + :param record_index: + :return: + """ + assert self.processor is not None + + record = self.processor.record(self.df, record_index) + return record + + + +def _assert_record_match_pandas_df(df: 'pandas.DataFrame', record: Dict[str, Any]) -> None: + """ + Asserts that the record matches the dataframe columns & index + + >>> _assert_record_match_pandas_df(pandas.DataFrame({"a": [1, 2], "b": [3, 4]}), {"a": 1, "b": 2}) + """ + import pandas + + columns = set(list(df.columns.values) + df.index.names) if isinstance(df.index, pandas.RangeIndex) is False else set(df.columns.values) + columns_record = set(record.keys()) + if columns != columns_record: + raise ValueError(f"Columns mismatch. Expected {columns}, got {columns_record}") + +def _assert_record_match_polar_df(df: 'polars.DataFrame', record: Dict[str, Any]) -> None: + """ + Asserts that the record matches the columns of polar dataframe + + >>> _assert_record_match_pandas_df(polars.DataFrame({"a": [1, 2], "b": [3, 4]}), {"a": 1, "b": 2}) + """ + columns = set(df.columns) + columns_record = set(record.keys()) + if columns != columns_record: + raise ValueError(f"Columns mismatch. Expected {columns}, got {columns_record}") + +def _assert_record_match_list_of_records(df: List[Dict[str, Any]], record: Dict[str, Any]) -> None: + """ + Asserts that the record matches the key in the record list (it use the first record to check) + + >>> _assert_record_match_list_of_records([{"a": 1, "b": 2}, {"a": 3, "b": 4}], {"a": 1, "b": 2}) + """ + if len(df) == 0: + return + + columns = set(df[0].keys()) + columns_record = set(record.keys()) + if columns != columns_record: + raise ValueError(f"Columns mismatch. Expected {columns}, got {columns_record}") + + + +def _split_record_as_pandas_record_and_index(param: dict, index_columns: list) -> Tuple[dict, tuple]: + """ + Separates a record into the record part and the index part to be able to + create or update a row in a dataframe. + + >>> record, index = _split_record_as_pandas_record_and_index({"a": 1, "b": 2}, ["a"]) + >>> print(record) # {"b": 2} + >>> print(index) # (1,) + """ + final_record = {} + final_index = [] + for key, value in param.items(): + if key in index_columns: + final_index.append(value) + else: + final_record[key] = value + + return final_record, tuple(final_index)