Skip to content

Commit

Permalink
feat: make contingency table cell size flexible by default (#189)
Browse files Browse the repository at this point in the history
Resolves #55
  • Loading branch information
mbelak-dtml authored Nov 2, 2023
1 parent f6421a7 commit 1ae82b6
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions edvart/report_sections/bivariate_analysis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import itertools
from enum import IntEnum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import matplotlib.pyplot as plt
import nbformat.v4 as nbfv4
Expand Down Expand Up @@ -792,6 +792,7 @@ def required_imports(self) -> List[str]:
"from edvart import utils",
"import seaborn as sns",
"import matplotlib.pyplot as plt",
"from typing import Literal",
]

def add_cells(self, cells: List[Dict[str, Any]], df: pd.DataFrame) -> None:
Expand Down Expand Up @@ -927,7 +928,7 @@ def contingency_table(
hide_zeros: bool = True,
scaling_func: Callable[[np.ndarray], np.ndarray] = np.cbrt,
colormap: Any = "Blues",
size_factor: float = 0.7,
size_factor: Union[float, Literal["auto"]] = "auto",
fontsize: float = 15,
) -> None:
"""
Expand All @@ -950,8 +951,9 @@ def contingency_table(
Cube root is used by default.
colormap : Any (default = "Blues")
Colormap compatible with matplotlib/seaborn.
size_factor : float (default = 0.7)
size_factor : float or "auto"
Size of each cell in the table.
If "auto", the cell size is automatically adjusted so that the numbers fit in the cells.
fontsize : float (default = 15)
Size of the font for axis labels.
"""
Expand Down Expand Up @@ -980,6 +982,13 @@ def contingency_table(
annot_kws={"fontsize": fontsize},
square=True,
)
if size_factor == "auto":
n_digits_max = 1 + np.floor(np.log10(table.max().max()))
# Constants chosen empirically to make the numbers fit in the cells
size_factor = max(
0.72,
0.18 * n_digits_max,
)
ax.figure.set_size_inches(size_factor * len(table.columns), size_factor * len(table))
# Set y axis
ax.set_ylabel(ax.get_ylabel(), fontsize=fontsize)
Expand Down

0 comments on commit 1ae82b6

Please sign in to comment.