Skip to content

Commit

Permalink
Merge pull request #98 from hanjinliu/more-interface
Browse files Browse the repository at this point in the history
Implement interfaces for cmap/formatter/validator
  • Loading branch information
hanjinliu authored Jan 7, 2023
2 parents 0b2afe5 + 95511b1 commit 7801129
Show file tree
Hide file tree
Showing 16 changed files with 974 additions and 418 deletions.
141 changes: 66 additions & 75 deletions tabulous/commands/_colormap.py → tabulous/_colormap.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from __future__ import annotations
from typing import Callable, Hashable, Sequence, TYPE_CHECKING, TypeVar
from typing import Iterable, TYPE_CHECKING, Union
import numpy as np
import pandas as pd
from tabulous.color import ColorTuple, normalize_color
from tabulous.types import ColorType
from tabulous._dtype import isna, get_converter

if TYPE_CHECKING:
from pandas.core.dtypes.dtypes import CategoricalDtype
from magicgui.widgets import Widget
import pandas as pd

_TimeLike = Union[pd.Timestamp, pd.Timedelta]


_ColorType = tuple[int, int, int, int]
_DEFAULT_MIN = "#697FD1"
_DEFAULT_MAX = "#FF696B"


def exec_colormap_dialog(ds: pd.Series, parent=None) -> Callable | None:
def exec_colormap_dialog(ds: pd.Series, parent=None):
"""Open a dialog to define a colormap for a series."""
from tabulous._qt._color_edit import ColorEdit
from magicgui.widgets import Dialog, LineEdit, Container
Expand All @@ -28,11 +31,7 @@ def exec_colormap_dialog(ds: pd.Series, parent=None) -> Callable | None:
dlg = Dialog(widgets=widgets)
dlg.native.setParent(parent, dlg.native.windowFlags())
if dlg.exec():
return _define_categorical_colormap(
dtype.categories,
[w.value for w in widgets],
dtype.kind,
)
return dict(zip(dtype.categories, (w.value for w in widgets)))

elif dtype.kind in "uif": # unsigned int, int, float
lmin = LineEdit(value=str(ds.min()))
Expand All @@ -50,8 +49,8 @@ def exec_colormap_dialog(ds: pd.Series, parent=None) -> Callable | None:
dlg = Dialog(widgets=[min_, max_])
dlg.native.setParent(parent, dlg.native.windowFlags())
if dlg.exec():
return _define_continuous_colormap(
float(lmin.value), float(lmax.value), cmin.value, cmax.value
return segment_by_float(
[(float(lmin.value), cmin.value), (float(lmax.value), cmax.value)]
)

elif dtype.kind == "b": # boolean
Expand All @@ -60,96 +59,88 @@ def exec_colormap_dialog(ds: pd.Series, parent=None) -> Callable | None:
dlg = Dialog(widgets=[false_, true_])
dlg.native.setParent(parent, dlg.native.windowFlags())
if dlg.exec():
return _define_categorical_colormap(
[False, True], [false_.value, true_.value], dtype.kind
)
converter = get_converter("b")
_dict = {False: false_.value, True: true_.value}
return lambda val: _dict.get(converter(val), None)

elif dtype.kind in "mM": # time stamp or time delta
min_ = ColorEdit(value=_DEFAULT_MIN, label="Min")
max_ = ColorEdit(value=_DEFAULT_MAX, label="Max")
dlg = Dialog(widgets=[min_, max_])
dlg.native.setParent(parent, dlg.native.windowFlags())
if dlg.exec():
return _define_time_colormap(
ds.min(), ds.max(), min_.value, max_.value, dtype.kind
)
return segment_by_time([(ds.min(), min_.value), (ds.max(), max_.value)])

else:
raise NotImplementedError(
f"Dtype {dtype!r} not supported. Please set colormap programmatically."
)

return None

def _random_color() -> list[int]:
return list(np.random.randint(256, size=3)) + [255]

def _define_continuous_colormap(
min: float, max: float, min_color: _ColorType, max_color: _ColorType
):
converter = get_converter("f")

def _colormap(value: float) -> _ColorType:
nonlocal min_color, max_color
if isna(value):
return None
value = converter(value)
if value < min:
return min_color
elif value > max:
return max_color
else:
min_color = np.array(min_color, dtype=np.float64)
max_color = np.array(max_color, dtype=np.float64)
return (value - min) / (max - min) * (max_color - min_color) + min_color

return _colormap
def _where(x, border: Iterable[float]) -> int:
for i, v in enumerate(border):
if x < v:
return i - 1
return len(border) - 1


def _define_categorical_colormap(
values: Sequence[Hashable],
colors: Sequence[_ColorType],
kind: str,
):
map = dict(zip(values, colors))
def segment_by_float(maps: list[tuple[float, ColorType]], kind: str = "f"):
converter = get_converter(kind)
borders: list[float] = []
colors: list[ColorTuple] = []
for v, c in maps:
borders.append(v)
colors.append(normalize_color(c))
idx_max = len(borders) - 1

def _colormap(value: Hashable) -> _ColorType:
return map.get(converter(value), None)

return _colormap
# check is sorted
if not all(borders[i] <= borders[i + 1] for i in range(len(borders) - 1)):
raise ValueError("Borders must be sorted")

def _colormap(value: float) -> _ColorType:
if isna(value):
return None
value = converter(value)
idx = _where(value, borders)
if idx == -1 or idx == idx_max:
return colors[idx]
min_color = np.array(colors[idx], dtype=np.float64)
max_color = np.array(colors[idx + 1], dtype=np.float64)
min = borders[idx]
max = borders[idx + 1]
return (value - min) / (max - min) * (max_color - min_color) + min_color

_T = TypeVar("_T", pd.Timestamp, pd.Timedelta)
return _colormap


def _define_time_colormap(
min: _T,
max: _T,
min_color: _ColorType,
max_color: _ColorType,
kind: str,
):
min_t = min.value
max_t = max.value
def segment_by_time(maps: list[tuple[_TimeLike, ColorType]], kind: str):
converter = get_converter(kind)

def _colormap(value: _T) -> _ColorType:
nonlocal min_color, max_color
borders: list[_TimeLike] = []
colors: list[ColorTuple] = []
for v, c in maps:
borders.append(v)
colors.append(normalize_color(c))
idx_max = len(borders) - 1

# check is sorted
if not all(borders[i] <= borders[i + 1] for i in range(len(borders) - 1)):
raise ValueError("Borders must be sorted")

def _colormap(value: _TimeLike) -> _ColorType:
if isna(value):
return None
value = converter(value).value
if value < min_t:
return min_color
elif value > max_t:
return max_color
else:
min_color = np.array(min_color, dtype=np.float64)
max_color = np.array(max_color, dtype=np.float64)
return (value - min_t) / (max_t - min_t) * (
max_color - min_color
) + min_color
value = converter(value)
idx = _where(value, borders)
if idx == -1 or idx == idx_max:
return colors[idx]
min_color = np.array(colors[idx], dtype=np.float64)
max_color = np.array(colors[idx + 1], dtype=np.float64)
min = borders[idx].value
max = borders[idx + 1].value
return (value.value - min) / (max - min) * (max_color - min_color) + min_color

return _colormap


def _random_color() -> list[int]:
return list(np.random.randint(256, size=3)) + [255]
23 changes: 23 additions & 0 deletions tabulous/_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,29 @@ def get_converter(kind: str) -> Callable[[Any], Any]:
return _DTYPE_CONVERTER[kind]


def get_converter_from_type(tp: type | str) -> Callable[[Any], Any]:
if not isinstance(tp, str):
tp = tp.__name__

if tp == "int":
kind = "i"
elif tp == "float":
kind = "f"
elif tp == "str":
kind = "U"
elif tp == "bool":
kind = "b"
elif tp == "complex":
kind = "c"
elif tp == "datetime":
kind = "M"
elif tp == "timedelta":
kind = "m"
else:
kind = "O"
return get_converter(kind)


class DefaultValidator:
"""
The default validator function.
Expand Down
11 changes: 9 additions & 2 deletions tabulous/_qt/_table/_base/_table_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from functools import partial
from pathlib import Path
from typing import Any, Callable, TYPE_CHECKING, Iterable, Tuple, TypeVar
from typing import Any, Callable, TYPE_CHECKING, Iterable, Tuple, TypeVar, overload
import warnings
from qtpy import QtWidgets as QtW, QtGui, QtCore
from qtpy.QtCore import Signal, Qt
Expand Down Expand Up @@ -181,7 +181,14 @@ def createQTableView(self) -> None:
def getDataFrame(self) -> pd.DataFrame:
raise NotImplementedError()

def _get_sub_frame(self, columns: list[str]):
# fmt: off
@overload
def _get_sub_frame(self, columns: list[str]) -> pd.DataFrame: ...
@overload
def _get_sub_frame(self, columns: str) -> pd.Series: ...
# fmt: on

def _get_sub_frame(self, columns):
return self.getDataFrame()[columns]

def setDataFrame(self, df: pd.DataFrame) -> None:
Expand Down
Loading

0 comments on commit 7801129

Please sign in to comment.