Skip to content

Commit

Permalink
Add more type checking to histogram.py
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren committed Jun 11, 2024
1 parent 69e42b8 commit c260c4f
Showing 1 changed file with 44 additions and 49 deletions.
93 changes: 44 additions & 49 deletions src/ert/gui/plottery/plots/histogram.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from math import ceil, floor, log10, sqrt
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Dict, List, Optional

import numpy
import pandas as pd
Expand All @@ -15,7 +15,7 @@
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from ert.gui.plottery import PlotConfig, PlotContext
from ert.gui.plottery import PlotContext, PlotStyle


class HistogramPlot:
Expand All @@ -27,17 +27,17 @@ def plot(
figure: Figure,
plot_context: PlotContext,
ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame],
_observation_data: Any,
std_dev_images: Any,
observation_data: pd.DataFrame,
std_dev_images: Dict[str, bytes],
) -> None:
plotHistogram(figure, plot_context, ensemble_to_data_map, _observation_data)
plotHistogram(figure, plot_context, ensemble_to_data_map, observation_data)


def plotHistogram(
figure: Figure,
plot_context: PlotContext,
ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame],
_observation_data: Any,
observation_data: pd.DataFrame,
) -> None:
config = plot_context.plotConfig()

Expand All @@ -58,18 +58,22 @@ def plotHistogram(
plot_context.x_axis = plot_context.VALUE_AXIS
plot_context.y_axis = plot_context.COUNT_AXIS

if config.xLabel() is None:
config.setXLabel("Value")
x_label = config.xLabel()
if x_label is None:
x_label = "Value"
config.setXLabel(x_label)

if config.yLabel() is None:
config.setYLabel("Count")
y_label = config.yLabel()
if y_label is None:
y_label = "Count"
config.setYLabel(y_label)

use_log_scale = plot_context.log_scale

data = {}
minimum = None
maximum = None
categories = set()
categories: set[str] = set()
max_element_count = 0
categorical = False
for ensemble, datas in ensemble_to_data_map.items():
Expand Down Expand Up @@ -101,7 +105,6 @@ def plotHistogram(
maximum = current_max if maximum is None else max(maximum, current_max)
max_element_count = max(max_element_count, len(data[ensemble.name].index))

categories = sorted(categories)
bin_count = int(ceil(sqrt(max_element_count)))

axes = {}
Expand All @@ -111,29 +114,35 @@ def plotHistogram(
axes[ensemble.name].set_title(
f"{config.title()} ({ensemble.experiment_name} : {ensemble.name})"
)
axes[ensemble.name].set_xlabel(x_label)
axes[ensemble.name].set_ylabel(y_label)

if use_log_scale:
axes[ensemble.name].set_xscale("log")

if not data[ensemble.name].empty:
if categorical:
_plotCategoricalHistogram(
axes[ensemble.name],
config,
data[ensemble.name],
config.addLegendItem(
ensemble.name,
categories,
_plotCategoricalHistogram(
axes[ensemble.name],
config.histogramStyle(),
data[ensemble.name],
sorted(categories),
),
)
else:
_plotHistogram(
axes[ensemble.name],
config,
data[ensemble.name],
config.addLegendItem(
ensemble.name,
bin_count,
use_log_scale,
minimum,
maximum,
_plotHistogram(
axes[ensemble.name],
config.histogramStyle(),
data[ensemble.name],
bin_count,
use_log_scale,
minimum,
maximum,
),
)

config.nextColor()
Expand All @@ -157,16 +166,10 @@ def plotHistogram(

def _plotCategoricalHistogram(
axes: "Axes",
plot_config: "PlotConfig",
style: PlotStyle,
data: pd.DataFrame,
label: str,
categories: List[str],
):
axes.set_xlabel(plot_config.xLabel())
axes.set_ylabel(plot_config.yLabel())

style = plot_config.histogramStyle()

) -> Rectangle:
counts = data.value_counts()
freq = [counts.get(category, 0) for category in categories]
pos = numpy.arange(len(categories))
Expand All @@ -176,27 +179,20 @@ def _plotCategoricalHistogram(

axes.bar(pos, freq, alpha=style.alpha, color=style.color, width=width)

rectangle = Rectangle(
return Rectangle(
(0, 0), 1, 1, color=style.color
) # creates rectangle patch for legend use.
plot_config.addLegendItem(label, rectangle)


def _plotHistogram(
axes: "Axes",
plot_config: "PlotConfig",
style: PlotStyle,
data: pd.DataFrame,
label: str,
bin_count,
use_log_scale=False,
minimum=None,
maximum=None,
):
axes.set_xlabel(plot_config.xLabel())
axes.set_ylabel(plot_config.yLabel())

style = plot_config.histogramStyle()

bin_count: int,
use_log_scale: float = False,
minimum: Optional[float] = None,
maximum: Optional[float] = None,
) -> Rectangle:
if minimum is not None and maximum is not None:
if use_log_scale:
bins = _histogramLogBins(bin_count, minimum, maximum)
Expand All @@ -213,10 +209,9 @@ def _plotHistogram(

axes.set_xlim(minimum, maximum)

rectangle = Rectangle(
return Rectangle(
(0, 0), 1, 1, color=style.color
) # creates rectangle patch for legend use.'
plot_config.addLegendItem(label, rectangle)


def _histogramLogBins(bin_count: int, minimum: float, maximum: float):
Expand Down

0 comments on commit c260c4f

Please sign in to comment.