Skip to content

Commit

Permalink
Bring filter parsing into recipe class
Browse files Browse the repository at this point in the history
  • Loading branch information
GoodingJamie committed Sep 28, 2023
1 parent aa70613 commit e6c3ed7
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
**.pyc
*.pyc
**/__pycache__/

build/
Expand Down
74 changes: 45 additions & 29 deletions src/whisk/_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,22 @@

class recipe:
def __init__(
self,
data: RDataFrame,
categories: Union[List[str], Dict[str, str]],
totals: bool = False
self,
data: RDataFrame,
categories: Union[List[str], Dict[str, str]],
totals: bool = False
):
self.data = data
self.categories = categories
self.totals = totals
self.default_to_fractions = not(totals)
self.total_events = data.Count().GetValue()

self._recipe()

def __getitem__(self,
key: Union[str, List[str]]):
def __getitem__(
self,
key: Union[str, List[str]]
):

values = []
if type(key) == str:
Expand All @@ -45,41 +48,54 @@ def __getitem__(self,
for proportions_key in self.proportions.keys():
if all([each_key in proportions_key for each_key in key]):
values += [self.proportions[proportions_key]]

for proportions_key in self.proportions.keys():
print(f"{key} | {proportions_key} | {key in proportions_key}")

print(values)

return sum(values)



def _parse_filter(
self,
key : List[str]
):
assert len(self.categories.keys()) == len(key), "Category list and possible keys are different length"

filters = [f'{category[0]} == "{sub_key}"' for category, sub_key in zip(self.categories.items(), key)]

assert len(filters) > 0, "Filter not parsed correctly."
return " && ".join(filters)

def _recipe(self):

self.proportions = defaultdict(dict)
keys = iterprod(*self.categories.values())
for key in keys:
filt = _parse_filter(self.categories, key)
assert filt != "", "Filter not parsed correctly."
filt = self._parse_filter(key)
count = self.data.Filter(filt).Count().GetValue()
value = count if self.totals else count / self.data.Count().GetValue()
value = count / self.data.Count().GetValue() if self.default_to_fractions else count
self.proportions[key] = value

return

def _parse_filter(
categories: Dict[str, str],
key : List[str]
):
assert len(categories.keys()) == len(key), "Category list and possible keys are different length"
def _convert_to_fractions(self):

#assert all([sub_key in category for category, sub_key in zip(categories.items(), key)]), "Category list and possible keys do not match"
filters = []
for n, (category, sub_key) in enumerate(zip(categories.items(), key)):
filters += [f'{category[0]} == "{sub_key}"']
for key in self.proportions.keys():
self.proportions[key] /= self.total_events

def default_to_fractions(
self,
value: bool = True
):

self.default_to_fractions = value
_convert_to_fractions()

return " && ".join(filters)

def _convert_to_totals(self):

for key in self.proportions.keys():
self.proportions[key] *= self.total_events

def default_to_totals(
self,
value: bool = True
):

self.default_to_fractions = not(value)
_convert_to_totals()

0 comments on commit e6c3ed7

Please sign in to comment.