diff --git a/src/ert/gui/plottery/plots/histogram.py b/src/ert/gui/plottery/plots/histogram.py index f7e7ad0634d..51e5be44b12 100644 --- a/src/ert/gui/plottery/plots/histogram.py +++ b/src/ert/gui/plottery/plots/histogram.py @@ -1,7 +1,7 @@ from __future__ import annotations from math import ceil, floor, log10, sqrt -from typing import TYPE_CHECKING, Any, Dict, List +from typing import TYPE_CHECKING, Dict, List, Optional import numpy import pandas as pd @@ -15,7 +15,7 @@ from matplotlib.axes import Axes from matplotlib.figure import Figure - from ert.gui.plottery import PlotConfig, PlotContext + from ert.gui.plottery import PlotContext, PlotStyle class HistogramPlot: @@ -27,17 +27,17 @@ def plot( figure: Figure, plot_context: PlotContext, ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame], - _observation_data: Any, - std_dev_images: Any, + observation_data: pd.DataFrame, + std_dev_images: Dict[str, bytes], ) -> None: - plotHistogram(figure, plot_context, ensemble_to_data_map, _observation_data) + plotHistogram(figure, plot_context, ensemble_to_data_map, observation_data) def plotHistogram( figure: Figure, plot_context: PlotContext, ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame], - _observation_data: Any, + observation_data: pd.DataFrame, ) -> None: config = plot_context.plotConfig() @@ -58,18 +58,22 @@ def plotHistogram( plot_context.x_axis = plot_context.VALUE_AXIS plot_context.y_axis = plot_context.COUNT_AXIS - if config.xLabel() is None: - config.setXLabel("Value") + x_label = config.xLabel() + if x_label is None: + x_label = "Value" + config.setXLabel(x_label) - if config.yLabel() is None: - config.setYLabel("Count") + y_label = config.yLabel() + if y_label is None: + y_label = "Count" + config.setYLabel(y_label) use_log_scale = plot_context.log_scale data = {} minimum = None maximum = None - categories = set() + categories: set[str] = set() max_element_count = 0 categorical = False for ensemble, datas in ensemble_to_data_map.items(): @@ -101,7 +105,6 @@ def plotHistogram( maximum = current_max if maximum is None else max(maximum, current_max) max_element_count = max(max_element_count, len(data[ensemble.name].index)) - categories = sorted(categories) bin_count = int(ceil(sqrt(max_element_count))) axes = {} @@ -111,29 +114,35 @@ def plotHistogram( axes[ensemble.name].set_title( f"{config.title()} ({ensemble.experiment_name} : {ensemble.name})" ) + axes[ensemble.name].set_xlabel(x_label) + axes[ensemble.name].set_ylabel(y_label) if use_log_scale: axes[ensemble.name].set_xscale("log") if not data[ensemble.name].empty: if categorical: - _plotCategoricalHistogram( - axes[ensemble.name], - config, - data[ensemble.name], + config.addLegendItem( ensemble.name, - categories, + _plotCategoricalHistogram( + axes[ensemble.name], + config.histogramStyle(), + data[ensemble.name], + sorted(categories), + ), ) else: - _plotHistogram( - axes[ensemble.name], - config, - data[ensemble.name], + config.addLegendItem( ensemble.name, - bin_count, - use_log_scale, - minimum, - maximum, + _plotHistogram( + axes[ensemble.name], + config.histogramStyle(), + data[ensemble.name], + bin_count, + use_log_scale, + minimum, + maximum, + ), ) config.nextColor() @@ -157,16 +166,10 @@ def plotHistogram( def _plotCategoricalHistogram( axes: "Axes", - plot_config: "PlotConfig", + style: PlotStyle, data: pd.DataFrame, - label: str, categories: List[str], -): - axes.set_xlabel(plot_config.xLabel()) - axes.set_ylabel(plot_config.yLabel()) - - style = plot_config.histogramStyle() - +) -> Rectangle: counts = data.value_counts() freq = [counts.get(category, 0) for category in categories] pos = numpy.arange(len(categories)) @@ -176,27 +179,20 @@ def _plotCategoricalHistogram( axes.bar(pos, freq, alpha=style.alpha, color=style.color, width=width) - rectangle = Rectangle( + return Rectangle( (0, 0), 1, 1, color=style.color ) # creates rectangle patch for legend use. - plot_config.addLegendItem(label, rectangle) def _plotHistogram( axes: "Axes", - plot_config: "PlotConfig", + style: PlotStyle, data: pd.DataFrame, - label: str, - bin_count, - use_log_scale=False, - minimum=None, - maximum=None, -): - axes.set_xlabel(plot_config.xLabel()) - axes.set_ylabel(plot_config.yLabel()) - - style = plot_config.histogramStyle() - + bin_count: int, + use_log_scale: float = False, + minimum: Optional[float] = None, + maximum: Optional[float] = None, +) -> Rectangle: if minimum is not None and maximum is not None: if use_log_scale: bins = _histogramLogBins(bin_count, minimum, maximum) @@ -213,10 +209,9 @@ def _plotHistogram( axes.set_xlim(minimum, maximum) - rectangle = Rectangle( + return Rectangle( (0, 0), 1, 1, color=style.color ) # creates rectangle patch for legend use.' - plot_config.addLegendItem(label, rectangle) def _histogramLogBins(bin_count: int, minimum: float, maximum: float):