From 5bef8676e7e095055da1cfdbfb0c318e716f8795 Mon Sep 17 00:00:00 2001 From: Jamie Gooding Date: Mon, 23 Oct 2023 12:19:00 +0200 Subject: [PATCH] Basic python implementation of core functionality --- environment.yml | 1 + src/whisk/_recipe.py | 38 ++++++++++++++------------ src/whisk/_table.py | 4 +-- src/whisk/_whisk.py | 48 +++++---------------------------- tests/test_recipe.py | 63 +++++++++++++++++++------------------------- tests/test_whisk.py | 57 ++++++++++++++++++--------------------- 6 files changed, 84 insertions(+), 127 deletions(-) diff --git a/environment.yml b/environment.yml index 688f9af..a37a719 100644 --- a/environment.yml +++ b/environment.yml @@ -4,6 +4,7 @@ channels: - conda-forge dependencies: - python=3.9 + - numba - numpy - typing - root=6.26.06 diff --git a/src/whisk/_recipe.py b/src/whisk/_recipe.py index dc9e2fa..4f11139 100644 --- a/src/whisk/_recipe.py +++ b/src/whisk/_recipe.py @@ -9,12 +9,13 @@ # or submit itself to any jurisdiction. # ############################################################################### -from itertools import product as iterprod -from itertools import permutations -from numpy import prod -from ROOT import RDataFrame -from typing import Dict, List, Union +import awkward as ak +import itertools as it +import numba as nb +import numpy as np +import os +from typing import Dict, List, Union # Categories should look like # [] @@ -23,14 +24,15 @@ class recipe: def __init__( self, - data: RDataFrame, + data, categories: Union[List[str], Dict[str, str]], totals: bool = False ): self.data = data self.categories = categories self.default_to_fractions = not(totals) - self.total_events = data.Count().GetValue() + self.total_events = len(data) + self.maximum = 0 self._recipe() @@ -49,26 +51,28 @@ def __getitem__( return sum(values) - def _parse_filter( + #def max( + # self + #): + # return self.maximum + + def _filter( self, key : List[str] ): assert len(self.categories.keys()) == len(key), "Category list and possible keys are different length" - - filters = [f'{category[0]} == "{sub_key}"' for category, sub_key in zip(self.categories.items(), key)] - - assert len(filters) > 0, "Filter not parsed correctly." - return " && ".join(filters) + return np.logical_and.reduce([self.data[category[0]] == sub_key for category, sub_key in zip(self.categories.items(), key)]) def _recipe(self): self.proportions = {} - keys = iterprod(*self.categories.values()) + keys = it.product(*self.categories.values()) for key in keys: - filt = self._parse_filter(key) - count = self.data.Filter(filt).Count().GetValue() - value = count / self.data.Count().GetValue() if self.default_to_fractions else count + filt = self._filter(key) + count = len(self.data[filt]) + value = count / self.total_events if self.default_to_fractions else count self.proportions[key] = value + if value > self.maximum: self.maximum = value def _convert_to_fractions(self): diff --git a/src/whisk/_table.py b/src/whisk/_table.py index cde3889..7080157 100644 --- a/src/whisk/_table.py +++ b/src/whisk/_table.py @@ -10,9 +10,9 @@ ############################################################################### def table( - data: RDataFrame, + data, categories: Union[List[str], Dict[str, str]], - absolute: + absolute, output: bool = False, totals: bool = False, diff --git a/src/whisk/_whisk.py b/src/whisk/_whisk.py index 7e47ea7..7fb3aa0 100644 --- a/src/whisk/_whisk.py +++ b/src/whisk/_whisk.py @@ -9,55 +9,21 @@ # or submit itself to any jurisdiction. # ############################################################################### +import awkward as ak +import numba as nb +import numpy as np import os - -from ROOT import RDataFrame -from tempfile import TemporaryDirectory from typing import Dict, List, Union from ._recipe import recipe def whisk( - reference_data: RDataFrame, - combining_data: Dict[Union[str, List[str]], RDataFrame], + reference_data, + combining_data,#: Dict[Union[str, List[str]], ], categories: Union[List[str], Dict[str, str]], ): reference_recipe = recipe(reference_data, categories) - cdfs = [] - with TemporaryDirectory() as tmpdir: - for n, (key, cdf) in enumerate(combining_data.items()): - cdf_path = os.path.join( - tmpdir, - f"combining_dataframe{n}.root" - ) - cdf.Snapshot("temp_tree", cdf_path) - cdfs += [cdf_path] - whisked_data = RDataFrame("temp_tree", cdfs) - - return whisked_data - -def _calculate_proportions( - data: RDataFrame, - categories: Union[List[str], Dict[str, str]] -): - proportions = {} - if type(categories) == list: - return - - for category, values in categories: - for value in values: - proportions.extend({ - category: { - "total" : data.Filter(f"{category} == {value}").Count().getValue() - } - }) - - return + combine_data = (data[np.random.rand(len(data)) < reference_recipe[categories] / reference_recipe.maximum] for categories, data in combining_data.items()) - { - "red" : { - "total" : 1, - "square" : {"total": 1} - } - } \ No newline at end of file + return ak.concatenate(combine_data) \ No newline at end of file diff --git a/tests/test_recipe.py b/tests/test_recipe.py index e1bf274..9638577 100644 --- a/tests/test_recipe.py +++ b/tests/test_recipe.py @@ -9,62 +9,52 @@ # or submit itself to any jurisdiction. # ############################################################################### -from ROOT import RDataFrame -from ROOT.Numba import Declare - -import whisk - +import awkward as ak +#import numba as nb +import numpy as np +from typing import List, Union -# To-do: -# - Fix instability in generated sample +import whisk +#@nb.njit +def generate_data(colours: Union[str, List[str]], + shapes: Union[str, List[str]], + size: int = 1000): -def generate_rdataframe(categories): - - rdf = RDataFrame(1000) - - colours = ["red", "yellow", "green", "blue"] - rdf = rdf.Define("colour_idx", f"std::floor(gRandom->Rndm() * {len(colours)})") - rdf = rdf.Define("colour", - f""" - std::vector colours = {{"{'","'.join(colours)}"}}; - return colours[colour_idx]; - """) - - shapes = ["triangle", "rectangle", "square"] - rdf = rdf.Define("shape_idx", f"std::floor(gRandom->Rndm() * {len(shapes)})") - rdf = rdf.Define("shape", - f""" - std::vector shapes = {{"{'","'.join(shapes)}"}}; - return shapes[shape_idx]; - """) - - return rdf + if type(colours) == str: colours = [colours] + if type(shapes) == str: shapes = [shapes] + data = {} + data["colour"] = np.random.choice(colours, size=size) + data["shape"] = np.random.choice(shapes, size=size) + return ak.Array(data) def test_recipe(): categories = { "colour" : ["red", "yellow", "green", "blue"], "shape" : ["triangle", "rectangle", "square"], } - rdf = generate_rdataframe(categories) - rec = whisk.recipe(rdf, categories) + data = generate_data(*categories.values()) + recipe = whisk.recipe(data, categories) - total = rdf.Count().GetValue() + total = len(data)#, axis=None) + print(total) - print(rec["red"]) - print(rec["triangle"]) - print(rec["red", "triangle"]) + print(f"{recipe['red']:.4f} +/- {np.sqrt(recipe['red'] / total):.4f}") + print(f"{recipe['triangle']:.4f} +/- {np.sqrt(recipe['triangle'] / total):.4f}") + print(f"{recipe['red', 'triangle']:.4f} +/- {np.sqrt(recipe['red', 'triangle'] / total):.4f}") #assert rec["red"] == rdf.Filter('colour == "red"').Count().GetValue() / total #assert rec["triangle"] == rdf.Filter('shape == "triangle"').Count().GetValue() / total #assert rec["red", "triangle"] == rdf.Filter('colour == "red" && shape == "triangle"').Count().GetValue() / total #assert rec["triangle", "red"] == rdf.Filter('colour == "red" && shape == "triangle"').Count().GetValue() / total #assert rec["blue", "triangle"] == rdf.Filter('colour == "blue" && shape == "triangle"').Count().GetValue() / total #assert rec["square", "red"] == rdf.Filter('colour == "red" && shape == "shape"').Count().GetValue() / total - + + +""" def dummy(rec): test_indices = (["red", "triangle"], ["triangle", "red"], @@ -74,4 +64,5 @@ def dummy(rec): for test_index in test_indices: # [[lst[i] for i in pattern] for lst in a] count = [[rec[i] for i in test_indices]] - print(f"{test_index}: {count}") \ No newline at end of file + print(f"{test_index}: {count}") +""" \ No newline at end of file diff --git a/tests/test_whisk.py b/tests/test_whisk.py index c043e01..37df2e8 100644 --- a/tests/test_whisk.py +++ b/tests/test_whisk.py @@ -9,58 +9,53 @@ # or submit itself to any jurisdiction. # ############################################################################### -from itertools import product as iterprod -from ROOT import RDataFrame -from ROOT.Numba import Declare +import awkward as ak +import itertools as it +#import numba as nb +import numpy as np + from typing import List, Union import whisk + # To-do: # - Fix instability in generated sample +def generate_data(colours: Union[str, List[str]], + shapes: Union[str, List[str]], + size: int = 1000): -def generate_rdataframe(colours: Union[str, List[str]], - shapes: Union[str, List[str]], - size: int = 1000): - - rdf = RDataFrame(size) - - rdf = rdf.Define("colour_idx", f"std::floor(gRandom->Rndm() * {len(colours)})") - rdf = rdf.Define("colour", - f""" - std::vector colours = {{"{'","'.join(colours)}"}}; - return colours[colour_idx]; - """) + if type(colours) == str: colours = [colours] + if type(shapes) == str: shapes = [shapes] - rdf = rdf.Define("shape_idx", f"std::floor(gRandom->Rndm() * {len(shapes)})") - rdf = rdf.Define("shape", - f""" - std::vector shapes = {{"{'","'.join(shapes)}"}}; - return shapes[shape_idx]; - """) + data = {} - return rdf + data["colour"] = np.random.choice(colours, size=size) + data["shape"] = np.random.choice(shapes, size=size) + return ak.Array(data) def test_whisk(): ref_categories = { "colour" : ["red", "red", "red", "yellow", "yellow", "green", "blue"], "shape" : ["triangle", "triangle", "rectangle", "square"], } - ref_rdf = generate_rdataframe(ref_categories["colour"], ref_categories["shape"], size = 1000) + ref_data = generate_data(ref_categories["colour"], ref_categories["shape"], size = 1000) raw_categories = { "colour" : ["red", "yellow", "green", "blue"], "shape" : ["triangle", "rectangle", "square"], } - raw_rdfs = {} - for (colour, shape) in iterprod(*raw_categories.values()): - raw_rdfs[(colour,shape)] = generate_rdataframe(colour, shape, size=10000) - print(raw_rdfs["red", "triangle"]) - print(raw_rdfs["red", "triangle"].Count().GetValue()) - whisked_rdf = whisk.whisk(ref_rdf, raw_rdfs, raw_categories) - print(whisked_rdf.Count().GetValue()) - print(whisked_rdf["red", "triangle"].Count().GetValue()) + raw_data = {} + for (colour, shape) in it.product(*raw_categories.values()): + rd = generate_data(colour, shape, size=10000) + raw_data[(colour,shape)] = rd + + print(len(raw_data["red", "triangle"])) + whisked_data = whisk.whisk(ref_data, raw_data, raw_categories) + print(len(whisked_data[whisked_data["colour"] == "red"])) + print(len(whisked_data[whisked_data["shape"] == "rectangle"])) + print(len(whisked_data[np.logical_and(whisked_data["colour"] == "red", whisked_data["shape"] == "rectangle")])) \ No newline at end of file