Skip to content

Commit

Permalink
ENH: added auto merge for cartesian_chunk (#699)
Browse files Browse the repository at this point in the history
  • Loading branch information
qinxuye authored Sep 18, 2023
1 parent dd999f0 commit 14d64ae
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 122 deletions.
109 changes: 54 additions & 55 deletions python/xorbits/_mars/dataframe/base/cartesian_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import numpy as np
import pandas as pd

Expand All @@ -22,75 +24,45 @@
from ...serialization.serializables import (
DictField,
FunctionField,
Int32Field,
KeyField,
StringField,
TupleField,
)
from ...utils import enter_current_session, has_unknown_shape, quiet_stdio
from ..operands import DataFrameOperand, DataFrameOperandMixin, OutputType
from ..operands import DataFrameOperand, OutputType
from ..utils import (
build_df,
build_empty_df,
build_series,
parse_index,
validate_output_types,
)
from .core import DataFrameAutoMergeMixin

logger = logging.getLogger(__name__)


class DataFrameCartesianChunk(DataFrameOperand, DataFrameOperandMixin):
class DataFrameCartesianChunk(DataFrameOperand, DataFrameAutoMergeMixin):
_op_type_ = opcodes.CARTESIAN_CHUNK

_left = KeyField("left")
_right = KeyField("right")
_func = FunctionField("func")
_args = TupleField("args")
_kwargs = DictField("kwargs")
left = KeyField("left")
right = KeyField("right")
func = FunctionField("func")
args = TupleField("args")
kwargs = DictField("kwargs")
auto_merge = StringField("auto_merge")
auto_merge_threshold = Int32Field("auto_merge_threshold")

def __init__(
self,
left=None,
right=None,
func=None,
args=None,
kwargs=None,
output_types=None,
**kw
):
super().__init__(
_left=left,
_right=right,
_func=func,
_args=args,
_kwargs=kwargs,
_output_types=output_types,
**kw
)
def __init__(self, output_types=None, **kw):
super().__init__(_output_types=output_types, **kw)
if self.memory_scale is None:
self.memory_scale = 2.0

@property
def left(self):
return self._left

@property
def right(self):
return self._right

@property
def func(self):
return self._func

@property
def args(self):
return self._args

@property
def kwargs(self):
return self._kwargs

def _set_inputs(self, inputs):
super()._set_inputs(inputs)
self._left = self._inputs[0]
self._right = self._inputs[1]
self.left = self.inputs[0]
self.right = self.inputs[1]

@staticmethod
def _build_test_obj(obj):
Expand All @@ -103,15 +75,15 @@ def _build_test_obj(obj):
def __call__(self, left, right, index=None, dtypes=None):
test_left = self._build_test_obj(left)
test_right = self._build_test_obj(right)
output_type = self._output_types[0] if self._output_types else None
output_type = self.output_types[0] if self.output_types else None

if output_type == OutputType.df_or_series:
return self.new_df_or_series([left, right])

# try run to infer meta
try:
with np.errstate(all="ignore"), quiet_stdio():
obj = self._func(test_left, test_right, *self._args, **self._kwargs)
obj = self.func(test_left, test_right, *self.args, **self.kwargs)
except: # noqa: E722 # nosec # pylint: disable=bare-except
if output_type == OutputType.series:
obj = pd.Series([], dtype=np.dtype(object))
Expand All @@ -126,11 +98,11 @@ def __call__(self, left, right, index=None, dtypes=None):
)

if getattr(obj, "ndim", 0) == 1 or output_type == OutputType.series:
shape = self._kwargs.pop("shape", (np.nan,))
shape = self.kwargs.pop("shape", (np.nan,))
if index is None:
index = obj.index
index_value = parse_index(
index, left, right, self._func, self._args, self._kwargs
index, left, right, self.func, self.args, self.kwargs
)
return self.new_series(
[left, right],
Expand All @@ -147,7 +119,7 @@ def __call__(self, left, right, index=None, dtypes=None):
if index is None:
index = obj.index
index_value = parse_index(
index, left, right, self._func, self._args, self._kwargs
index, left, right, self.func, self.args, self.kwargs
)
return self.new_dataframe(
[left, right],
Expand All @@ -164,6 +136,13 @@ def tile(cls, op: "DataFrameCartesianChunk"):
out = op.outputs[0]
out_type = op.output_types[0]

auto_merge_threshold = op.auto_merge_threshold
auto_merge_before, auto_merge_after = cls._get_auto_merge_options(op.auto_merge)

yield from cls._merge_before(
op, auto_merge_before, auto_merge_threshold, left, right, logger
)

if left.ndim == 2 and left.chunk_shape[1] > 1:
if has_unknown_shape(left):
yield
Expand Down Expand Up @@ -240,7 +219,12 @@ def tile(cls, op: "DataFrameCartesianChunk"):
params["nsplits"] = tuple(tuple(ns) for ns in nsplits) if nsplits else nsplits
params["chunks"] = out_chunks
new_op = op.copy()
return new_op.new_tileables(op.inputs, kws=[params])
ret = new_op.new_tileables(op.inputs, kws=[params])

ret = yield from cls._merge_after(
op, auto_merge_after, auto_merge_threshold, ret, logger
)
return ret

@classmethod
@redirect_custom_log
Expand All @@ -250,7 +234,16 @@ def execute(cls, ctx, op: "DataFrameCartesianChunk"):
ctx[op.outputs[0].key] = op.func(left, right, *op.args, **(op.kwargs or dict()))


def cartesian_chunk(left, right, func, skip_infer=False, args=(), **kwargs):
def cartesian_chunk(
left,
right,
func,
skip_infer=False,
args=(),
auto_merge: str = "both",
auto_merge_threshold: int = 8,
**kwargs,
):
output_type = kwargs.pop("output_type", None)
output_types = kwargs.pop("output_types", None)
object_type = kwargs.pop("object_type", None)
Expand All @@ -265,6 +258,10 @@ def cartesian_chunk(left, right, func, skip_infer=False, args=(), **kwargs):
index = kwargs.pop("index", None)
dtypes = kwargs.pop("dtypes", None)
memory_scale = kwargs.pop("memory_scale", None)
if auto_merge not in ["both", "none", "before", "after"]: # pragma: no cover
raise ValueError(
f"auto_merge can only be `both`, `none`, `before` or `after`, got {auto_merge}"
)

op = DataFrameCartesianChunk(
left=left,
Expand All @@ -274,5 +271,7 @@ def cartesian_chunk(left, right, func, skip_infer=False, args=(), **kwargs):
kwargs=kwargs,
output_types=output_types,
memory_scale=memory_scale,
auto_merge=auto_merge,
auto_merge_threshold=auto_merge_threshold,
)
return op(left, right, index=index, dtypes=dtypes)
98 changes: 98 additions & 0 deletions python/xorbits/_mars/dataframe/base/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import logging

from ...core import TileStatus
from ...core.context import get_context
from ...serialization.serializables import KeyField
from ...typing import OperandType, TileableType
from ..core import DATAFRAME_TYPE, SERIES_TYPE
from ..operands import DataFrameOperand, DataFrameOperandMixin
from ..utils import auto_merge_chunks


class DataFrameDeviceConversionBase(DataFrameOperand, DataFrameOperandMixin):
Expand Down Expand Up @@ -63,3 +71,93 @@ def tile(cls, op):
return new_op.new_tileables(
op.inputs, chunks=out_chunks, nsplits=op.inputs[0].nsplits, **out.params
)


class DataFrameAutoMergeMixin(DataFrameOperandMixin):
@classmethod
def _get_auto_merge_options(cls, auto_merge: str) -> tuple[bool, bool]:
if auto_merge == "both":
return True, True
elif auto_merge == "none":
return False, False
elif auto_merge == "before":
return True, False
else:
assert auto_merge == "after"
return False, True

@classmethod
def _merge_before(
cls,
op: OperandType,
auto_merge_before: bool,
auto_merge_threshold: int,
left: TileableType,
right: TileableType,
logger: logging.Logger,
):
ctx = get_context()

if (
auto_merge_before
and len(left.chunks) + len(right.chunks) > auto_merge_threshold
):
yield TileStatus([left, right] + left.chunks + right.chunks, progress=0.2)
left_chunk_size = len(left.chunks)
right_chunk_size = len(right.chunks)
left = auto_merge_chunks(ctx, left)
right = auto_merge_chunks(ctx, right)
logger.info(
"Auto merge before %s, left data shape: %s, chunk count: %s -> %s, "
"right data shape: %s, chunk count: %s -> %s.",
op,
left.shape,
left_chunk_size,
len(left.chunks),
right.shape,
right_chunk_size,
len(right.chunks),
)
else:
logger.info(
"Skip auto merge before %s, left data shape: %s, chunk count: %d, "
"right data shape: %s, chunk count: %d.",
op,
left.shape,
len(left.chunks),
right.shape,
len(right.chunks),
)

@classmethod
def _merge_after(
cls,
op: OperandType,
auto_merge_after: bool,
auto_merge_threshold: int,
ret: TileableType,
logger: logging.Logger,
):
if auto_merge_after and len(ret[0].chunks) > auto_merge_threshold:
# if how=="inner", output data size will reduce greatly with high probability,
# use auto_merge_chunks to combine small chunks.
yield TileStatus(
ret[0].chunks, progress=0.8
) # trigger execution for chunks
merged = auto_merge_chunks(get_context(), ret[0])
logger.info(
"Auto merge after %s, data shape: %s, chunk count: %s -> %s.",
op,
merged.shape,
len(ret[0].chunks),
len(merged.chunks),
)
return [merged]
else:
logger.info(
"Skip auto merge after %s, data shape: %s, chunk count: %d.",
op,
ret[0].shape,
len(ret[0].chunks),
)
return ret
Loading

0 comments on commit 14d64ae

Please sign in to comment.