Skip to content

Commit

Permalink
Merge pull request #72 from kreczko/kreczko-fast-plotter-1
Browse files Browse the repository at this point in the history
Towards 1.0.0 alpha 0
  • Loading branch information
kreczko authored Oct 21, 2022
2 parents 2f41a89 + 73d2cc0 commit 5529cab
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 7 deletions.
65 changes: 65 additions & 0 deletions fast_plotter/v1/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
""" Features for version 1 """
from typing import Any, Dict, List
import uproot

from .hist_collections import EfficiencyHistCollection


def create_collection(name, config, style):
LOOKUP = {
"efficiency": EfficiencyHistCollection,
}
collection_class = LOOKUP[config["type"]]
# return collection_class(**config)
return collection_class(
name=name,
title=config["title"],
style=style,
)


def _workaround_uproot_issue38():
# workaround for issue reading TEfficiency
# https://github.com/scikit-hep/uproot5/issues/38
import skhep_testdata
with uproot.open(skhep_testdata.data_path("uproot-issue38c.root")) as fp:
hist = fp["TEfficiencyName"]
# now all TEfficiency objects should be readable
return hist


def read_histogram_file(input_file, histname):
with uproot.open(input_file) as fp:
hist = fp[histname]
# TODO: use filter_dict to get > 1 hist
return hist


def make_plots(plot_config: Dict[str, Any], input_files: List[str], output_dir: str):
_workaround_uproot_issue38()
input_file = input_files[0]

plotter_version = plot_config.pop("plotter-version", "0")
styles = plot_config.pop("styles", {})
collections = plot_config.pop("collections", {})
named_styles = {}
for style in styles:
named_styles[style["name"]] = style

for name, config in collections.items():
# TODO: needs to me safer
style = named_styles[config.pop("style")]
collection = create_collection(name, config, style)
sources = config.pop("sources")
for source in sources:
label = source.pop("label")
path = source.pop("path")
hist = read_histogram_file(input_file, path)
collection.add_hist(
name=label,
numerator=hist.members["fPassedHistogram"].to_numpy()[0],
denominator=hist.members["fTotalHistogram"].to_numpy()[0],
)

collection.plot()
collection.save(output_dir)
71 changes: 71 additions & 0 deletions fast_plotter/v1/hist_collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os
import numpy as np

from hist.intervals import ratio_uncertainty
import mplhep as hep
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt


class EfficiencyHistCollection():

def __init__(self, name, title, style):
self.name = name
self.title = title
self.style = style
self.hists = []

def add_hist(self, name, numerator, denominator, **kwargs):
self.hists.append(Efficiency(name, numerator, denominator, **kwargs))

def plot(self, **kwargs):
hep.style.use("CMS")
fig = plt.figure(figsize=(10, 8))
for hist in self.hists:
hep.histplot(hist.eff, yerr=hist.eff_err, **kwargs)

def save(self, output_dir):
output_file = os.path.join(output_dir, f"{self.name}.png")
print(f"Saving {output_file}")
plt.savefig(output_file)


class Efficiency:
num: np.ndarray
den: np.ndarray
name: str

def __init__(self, name, num, den):
self.name = name
self.num = num
self.den = den
self._eff = None
self._eff_err = None

@property
def eff(self):
if self._eff is None:
old_settings = np.seterr()
np.seterr(divide='ignore', invalid='ignore')
self._eff = np.divide(self.num, self.den, dtype=np.float64)
np.seterr(**old_settings)
self._eff[np.isnan(self._eff)] = 0.0

return self._eff

@property
def eff_err(self):
if self._eff_err is None:
if np.any(self.num > self.den):
raise ValueError(
"Found numerator larger than denominator while calculating binomial uncertainty"
)
self._eff_err = ratio_uncertainty(self.num, self.den, uncertainty_type="efficiency")
return self._eff_err

def plot(self, **kwargs):
hep.histplot(self.eff, yerr=self.eff_err, **kwargs)

def __repr__(self):
return f"EfficiencyHist(num={self.num}, den={self.den})"
2 changes: 1 addition & 1 deletion fast_plotter/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ def split_version(version):
return tuple(result)


__version__ = '0.10.4'
__version__ = '1.0.0a0'
version_info = split_version(__version__) # noqa
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.10.4
current_version = 1.0.0a0
commit = True
tag = False

Expand Down
11 changes: 6 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ def get_version():
'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License',
'Natural Language :: English',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
],
description="F.A.S.T. plotter package",
entry_points={
Expand Down

0 comments on commit 5529cab

Please sign in to comment.