Skip to content

Commit

Permalink
Merge pull request #3 from SMARTHEP/basic-implementation
Browse files Browse the repository at this point in the history
Basic implementation of python combinations
  • Loading branch information
GoodingJamie authored Oct 23, 2023
2 parents 132cffc + 5778e2d commit b888804
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 127 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
- conda-forge
dependencies:
- python=3.9
- numba
- numpy
- typing
- root=6.26.06
Expand Down
38 changes: 21 additions & 17 deletions src/whisk/_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
# or submit itself to any jurisdiction. #
###############################################################################

from itertools import product as iterprod
from itertools import permutations
from numpy import prod
from ROOT import RDataFrame
from typing import Dict, List, Union
import awkward as ak
import itertools as it
import numba as nb
import numpy as np
import os

from typing import Dict, List, Union

# Categories should look like
# []
Expand All @@ -23,14 +24,15 @@
class recipe:
def __init__(
self,
data: RDataFrame,
data,
categories: Union[List[str], Dict[str, str]],
totals: bool = False
):
self.data = data
self.categories = categories
self.default_to_fractions = not(totals)
self.total_events = data.Count().GetValue()
self.total_events = len(data)
self.maximum = 0

self._recipe()

Expand All @@ -49,26 +51,28 @@ def __getitem__(

return sum(values)

def _parse_filter(
#def max(
# self
#):
# return self.maximum

def _filter(
self,
key : List[str]
):
assert len(self.categories.keys()) == len(key), "Category list and possible keys are different length"

filters = [f'{category[0]} == "{sub_key}"' for category, sub_key in zip(self.categories.items(), key)]

assert len(filters) > 0, "Filter not parsed correctly."
return " && ".join(filters)
return np.logical_and.reduce([self.data[category[0]] == sub_key for category, sub_key in zip(self.categories.items(), key)])

def _recipe(self):

self.proportions = {}
keys = iterprod(*self.categories.values())
keys = it.product(*self.categories.values())
for key in keys:
filt = self._parse_filter(key)
count = self.data.Filter(filt).Count().GetValue()
value = count / self.data.Count().GetValue() if self.default_to_fractions else count
filt = self._filter(key)
count = len(self.data[filt])
value = count / self.total_events if self.default_to_fractions else count
self.proportions[key] = value
if value > self.maximum: self.maximum = value


def _convert_to_fractions(self):
Expand Down
4 changes: 2 additions & 2 deletions src/whisk/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
###############################################################################

def table(
data: RDataFrame,
data,
categories: Union[List[str], Dict[str, str]],
absolute:
absolute,
output: bool = False,
totals: bool = False,

Expand Down
48 changes: 7 additions & 41 deletions src/whisk/_whisk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,55 +9,21 @@
# or submit itself to any jurisdiction. #
###############################################################################

import awkward as ak
import numba as nb
import numpy as np
import os

from ROOT import RDataFrame
from tempfile import TemporaryDirectory
from typing import Dict, List, Union

from ._recipe import recipe

def whisk(
reference_data: RDataFrame,
combining_data: Dict[Union[str, List[str]], RDataFrame],
reference_data,
combining_data,#: Dict[Union[str, List[str]], ],
categories: Union[List[str], Dict[str, str]],
):

reference_recipe = recipe(reference_data, categories)
cdfs = []
with TemporaryDirectory() as tmpdir:
for n, (key, cdf) in enumerate(combining_data.items()):
cdf_path = os.path.join(
tmpdir,
f"combining_dataframe{n}.root"
)
cdf.Snapshot("temp_tree", cdf_path)
cdfs += [cdf_path]
whisked_data = RDataFrame("temp_tree", cdfs)

return whisked_data

def _calculate_proportions(
data: RDataFrame,
categories: Union[List[str], Dict[str, str]]
):
proportions = {}
if type(categories) == list:
return

for category, values in categories:
for value in values:
proportions.extend({
category: {
"total" : data.Filter(f"{category} == {value}").Count().getValue()
}
})

return
combine_data = (data[np.random.rand(len(data)) < reference_recipe[categories] / reference_recipe.maximum] for categories, data in combining_data.items())

{
"red" : {
"total" : 1,
"square" : {"total": 1}
}
}
return ak.concatenate(combine_data)
63 changes: 27 additions & 36 deletions tests/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,62 +9,52 @@
# or submit itself to any jurisdiction. #
###############################################################################

from ROOT import RDataFrame
from ROOT.Numba import Declare

import whisk

import awkward as ak
#import numba as nb
import numpy as np

from typing import List, Union

# To-do:
# - Fix instability in generated sample
import whisk

#@nb.njit
def generate_data(colours: Union[str, List[str]],
shapes: Union[str, List[str]],
size: int = 1000):

def generate_rdataframe(categories):

rdf = RDataFrame(1000)

colours = ["red", "yellow", "green", "blue"]
rdf = rdf.Define("colour_idx", f"std::floor(gRandom->Rndm() * {len(colours)})")
rdf = rdf.Define("colour",
f"""
std::vector<std::string> colours = {{"{'","'.join(colours)}"}};
return colours[colour_idx];
""")

shapes = ["triangle", "rectangle", "square"]
rdf = rdf.Define("shape_idx", f"std::floor(gRandom->Rndm() * {len(shapes)})")
rdf = rdf.Define("shape",
f"""
std::vector<std::string> shapes = {{"{'","'.join(shapes)}"}};
return shapes[shape_idx];
""")

return rdf
if type(colours) == str: colours = [colours]
if type(shapes) == str: shapes = [shapes]

data = {}

data["colour"] = np.random.choice(colours, size=size)
data["shape"] = np.random.choice(shapes, size=size)

return ak.Array(data)

def test_recipe():
categories = {
"colour" : ["red", "yellow", "green", "blue"],
"shape" : ["triangle", "rectangle", "square"],
}
rdf = generate_rdataframe(categories)
rec = whisk.recipe(rdf, categories)
data = generate_data(*categories.values())
recipe = whisk.recipe(data, categories)

total = rdf.Count().GetValue()
total = len(data)#, axis=None)
print(total)

print(rec["red"])
print(rec["triangle"])
print(rec["red", "triangle"])
print(f"{recipe['red']:.4f} +/- {np.sqrt(recipe['red'] / total):.4f}")
print(f"{recipe['triangle']:.4f} +/- {np.sqrt(recipe['triangle'] / total):.4f}")
print(f"{recipe['red', 'triangle']:.4f} +/- {np.sqrt(recipe['red', 'triangle'] / total):.4f}")
#assert rec["red"] == rdf.Filter('colour == "red"').Count().GetValue() / total
#assert rec["triangle"] == rdf.Filter('shape == "triangle"').Count().GetValue() / total
#assert rec["red", "triangle"] == rdf.Filter('colour == "red" && shape == "triangle"').Count().GetValue() / total
#assert rec["triangle", "red"] == rdf.Filter('colour == "red" && shape == "triangle"').Count().GetValue() / total
#assert rec["blue", "triangle"] == rdf.Filter('colour == "blue" && shape == "triangle"').Count().GetValue() / total
#assert rec["square", "red"] == rdf.Filter('colour == "red" && shape == "shape"').Count().GetValue() / total



"""
def dummy(rec):
test_indices = (["red", "triangle"],
["triangle", "red"],
Expand All @@ -74,4 +64,5 @@ def dummy(rec):
for test_index in test_indices:
# [[lst[i] for i in pattern] for lst in a]
count = [[rec[i] for i in test_indices]]
print(f"{test_index}: {count}")
print(f"{test_index}: {count}")
"""
57 changes: 26 additions & 31 deletions tests/test_whisk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,58 +9,53 @@
# or submit itself to any jurisdiction. #
###############################################################################

from itertools import product as iterprod
from ROOT import RDataFrame
from ROOT.Numba import Declare
import awkward as ak
import itertools as it
#import numba as nb
import numpy as np

from typing import List, Union

import whisk



# To-do:
# - Fix instability in generated sample

def generate_data(colours: Union[str, List[str]],
shapes: Union[str, List[str]],
size: int = 1000):

def generate_rdataframe(colours: Union[str, List[str]],
shapes: Union[str, List[str]],
size: int = 1000):

rdf = RDataFrame(size)

rdf = rdf.Define("colour_idx", f"std::floor(gRandom->Rndm() * {len(colours)})")
rdf = rdf.Define("colour",
f"""
std::vector<std::string> colours = {{"{'","'.join(colours)}"}};
return colours[colour_idx];
""")
if type(colours) == str: colours = [colours]
if type(shapes) == str: shapes = [shapes]

rdf = rdf.Define("shape_idx", f"std::floor(gRandom->Rndm() * {len(shapes)})")
rdf = rdf.Define("shape",
f"""
std::vector<std::string> shapes = {{"{'","'.join(shapes)}"}};
return shapes[shape_idx];
""")
data = {}

return rdf
data["colour"] = np.random.choice(colours, size=size)
data["shape"] = np.random.choice(shapes, size=size)

return ak.Array(data)

def test_whisk():
ref_categories = {
"colour" : ["red", "red", "red", "yellow", "yellow", "green", "blue"],
"shape" : ["triangle", "triangle", "rectangle", "square"],
}
ref_rdf = generate_rdataframe(ref_categories["colour"], ref_categories["shape"], size = 1000)
ref_data = generate_data(ref_categories["colour"], ref_categories["shape"], size = 1000)

raw_categories = {
"colour" : ["red", "yellow", "green", "blue"],
"shape" : ["triangle", "rectangle", "square"],
}

raw_rdfs = {}
for (colour, shape) in iterprod(*raw_categories.values()):
raw_rdfs[(colour,shape)] = generate_rdataframe(colour, shape, size=10000)
print(raw_rdfs["red", "triangle"])
print(raw_rdfs["red", "triangle"].Count().GetValue())
whisked_rdf = whisk.whisk(ref_rdf, raw_rdfs, raw_categories)
print(whisked_rdf.Count().GetValue())
print(whisked_rdf["red", "triangle"].Count().GetValue())
raw_data = {}
for (colour, shape) in it.product(*raw_categories.values()):
rd = generate_data(colour, shape, size=10000)
raw_data[(colour,shape)] = rd

print(len(raw_data["red", "triangle"]))
whisked_data = whisk.whisk(ref_data, raw_data, raw_categories)
print(len(whisked_data[whisked_data["colour"] == "red"]))
print(len(whisked_data[whisked_data["shape"] == "rectangle"]))
print(len(whisked_data[np.logical_and(whisked_data["colour"] == "red", whisked_data["shape"] == "rectangle")]))

0 comments on commit b888804

Please sign in to comment.