From 1ae82b6e5f4726e1ee5105d5625b08fc21c47540 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michal=20Bel=C3=A1k?= Date: Thu, 2 Nov 2023 10:36:21 +0100 Subject: [PATCH] feat: make contingency table cell size flexible by default (#189) Resolves #55 --- edvart/report_sections/bivariate_analysis.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/edvart/report_sections/bivariate_analysis.py b/edvart/report_sections/bivariate_analysis.py index eb33f0c..385fb15 100644 --- a/edvart/report_sections/bivariate_analysis.py +++ b/edvart/report_sections/bivariate_analysis.py @@ -1,6 +1,6 @@ import itertools from enum import IntEnum -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import matplotlib.pyplot as plt import nbformat.v4 as nbfv4 @@ -792,6 +792,7 @@ def required_imports(self) -> List[str]: "from edvart import utils", "import seaborn as sns", "import matplotlib.pyplot as plt", + "from typing import Literal", ] def add_cells(self, cells: List[Dict[str, Any]], df: pd.DataFrame) -> None: @@ -927,7 +928,7 @@ def contingency_table( hide_zeros: bool = True, scaling_func: Callable[[np.ndarray], np.ndarray] = np.cbrt, colormap: Any = "Blues", - size_factor: float = 0.7, + size_factor: Union[float, Literal["auto"]] = "auto", fontsize: float = 15, ) -> None: """ @@ -950,8 +951,9 @@ def contingency_table( Cube root is used by default. colormap : Any (default = "Blues") Colormap compatible with matplotlib/seaborn. - size_factor : float (default = 0.7) + size_factor : float or "auto" Size of each cell in the table. + If "auto", the cell size is automatically adjusted so that the numbers fit in the cells. fontsize : float (default = 15) Size of the font for axis labels. """ @@ -980,6 +982,13 @@ def contingency_table( annot_kws={"fontsize": fontsize}, square=True, ) + if size_factor == "auto": + n_digits_max = 1 + np.floor(np.log10(table.max().max())) + # Constants chosen empirically to make the numbers fit in the cells + size_factor = max( + 0.72, + 0.18 * n_digits_max, + ) ax.figure.set_size_inches(size_factor * len(table.columns), size_factor * len(table)) # Set y axis ax.set_ylabel(ax.get_ylabel(), fontsize=fontsize)