diff --git a/fast_plotter/v1/__init__.py b/fast_plotter/v1/__init__.py new file mode 100644 index 0000000..0939319 --- /dev/null +++ b/fast_plotter/v1/__init__.py @@ -0,0 +1,65 @@ +""" Features for version 1 """ +from typing import Any, Dict, List +import uproot + +from .hist_collections import EfficiencyHistCollection + + +def create_collection(name, config, style): + LOOKUP = { + "efficiency": EfficiencyHistCollection, + } + collection_class = LOOKUP[config["type"]] +# return collection_class(**config) + return collection_class( + name=name, + title=config["title"], + style=style, + ) + + +def _workaround_uproot_issue38(): + # workaround for issue reading TEfficiency + # https://github.com/scikit-hep/uproot5/issues/38 + import skhep_testdata + with uproot.open(skhep_testdata.data_path("uproot-issue38c.root")) as fp: + hist = fp["TEfficiencyName"] + # now all TEfficiency objects should be readable + return hist + + +def read_histogram_file(input_file, histname): + with uproot.open(input_file) as fp: + hist = fp[histname] + # TODO: use filter_dict to get > 1 hist + return hist + + +def make_plots(plot_config: Dict[str, Any], input_files: List[str], output_dir: str): + _workaround_uproot_issue38() + input_file = input_files[0] + + plotter_version = plot_config.pop("plotter-version", "0") + styles = plot_config.pop("styles", {}) + collections = plot_config.pop("collections", {}) + named_styles = {} + for style in styles: + named_styles[style["name"]] = style + + for name, config in collections.items(): + # TODO: needs to me safer + style = named_styles[config.pop("style")] + collection = create_collection(name, config, style) + sources = config.pop("sources") + for source in sources: + label = source.pop("label") + path = source.pop("path") + hist = read_histogram_file(input_file, path) + collection.add_hist( + name=label, + numerator=hist.members["fPassedHistogram"].to_numpy()[0], + denominator=hist.members["fTotalHistogram"].to_numpy()[0], + ) + + collection.plot() + collection.save(output_dir) diff --git a/fast_plotter/v1/hist_collections.py b/fast_plotter/v1/hist_collections.py new file mode 100644 index 0000000..cd1ccad --- /dev/null +++ b/fast_plotter/v1/hist_collections.py @@ -0,0 +1,71 @@ +import os +import numpy as np + +from hist.intervals import ratio_uncertainty +import mplhep as hep +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + + +class EfficiencyHistCollection(): + + def __init__(self, name, title, style): + self.name = name + self.title = title + self.style = style + self.hists = [] + + def add_hist(self, name, numerator, denominator, **kwargs): + self.hists.append(Efficiency(name, numerator, denominator, **kwargs)) + + def plot(self, **kwargs): + hep.style.use("CMS") + fig = plt.figure(figsize=(10, 8)) + for hist in self.hists: + hep.histplot(hist.eff, yerr=hist.eff_err, **kwargs) + + def save(self, output_dir): + output_file = os.path.join(output_dir, f"{self.name}.png") + print(f"Saving {output_file}") + plt.savefig(output_file) + + +class Efficiency: + num: np.ndarray + den: np.ndarray + name: str + + def __init__(self, name, num, den): + self.name = name + self.num = num + self.den = den + self._eff = None + self._eff_err = None + + @property + def eff(self): + if self._eff is None: + old_settings = np.seterr() + np.seterr(divide='ignore', invalid='ignore') + self._eff = np.divide(self.num, self.den, dtype=np.float64) + np.seterr(**old_settings) + self._eff[np.isnan(self._eff)] = 0.0 + + return self._eff + + @property + def eff_err(self): + if self._eff_err is None: + if np.any(self.num > self.den): + raise ValueError( + "Found numerator larger than denominator while calculating binomial uncertainty" + ) + self._eff_err = ratio_uncertainty(self.num, self.den, uncertainty_type="efficiency") + return self._eff_err + + def plot(self, **kwargs): + hep.histplot(self.eff, yerr=self.eff_err, **kwargs) + + def __repr__(self): + return f"EfficiencyHist(num={self.num}, den={self.den})" diff --git a/fast_plotter/version.py b/fast_plotter/version.py index b8fe69f..6c7023a 100644 --- a/fast_plotter/version.py +++ b/fast_plotter/version.py @@ -12,5 +12,5 @@ def split_version(version): return tuple(result) -__version__ = '0.10.4' +__version__ = '1.0.0a0' version_info = split_version(__version__) # noqa diff --git a/setup.cfg b/setup.cfg index dbde0d2..a972315 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.10.4 +current_version = 1.0.0a0 commit = True tag = False diff --git a/setup.py b/setup.py index d19ac0d..eee0eb9 100644 --- a/setup.py +++ b/setup.py @@ -37,11 +37,12 @@ def get_version(): 'Intended Audience :: Science/Research', 'License :: OSI Approved :: Apache Software License', 'Natural Language :: English', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", ], description="F.A.S.T. plotter package", entry_points={