Skip to content

Commit

Permalink
Simplify and speed up duplicate rate checking (pynucastro#692)
Browse files Browse the repository at this point in the history
This speeds up the sensitivity analysis example code by ~10-15%.
  • Loading branch information
yut23 authored Nov 10, 2023
1 parent eff3c6f commit 57efb55
Showing 1 changed file with 17 additions and 34 deletions.
51 changes: 17 additions & 34 deletions pynucastro/rates/known_duplicates.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,25 @@
import collections

# there are some exceptions to the no-duplicate rates restriction. We
# list them here by class name and then fname
ALLOWED_DUPLICATES = [
(("ReacLibRate: p_p__d__weak__bet_pos_"),
("ReacLibRate: p_p__d__weak__electron_capture"))
{"ReacLibRate: p_p__d__weak__bet_pos_",
"ReacLibRate: p_p__d__weak__electron_capture"}
]


def find_duplicate_rates(rate_list):
"""given a list of rates, return a list of groups of duplicate
rates"""
"""given a list of rates, return a list of groups of duplicate rates"""

duplicates = []
# Group the rates into lists of potential duplicates, keyed by their
# reactants and products.
grouped_rates = collections.defaultdict(list)
for rate in rate_list:
same_links = [q for q in rate_list
if q != rate and
sorted(q.reactants) == sorted(rate.reactants) and
sorted(q.products) == sorted(rate.products)]

if same_links:
new_entry = [rate] + same_links
already_found = False
# we may have already found this pair
for dupe in duplicates:
if new_entry[0] in dupe:
already_found = True
break
if not already_found:
duplicates.append(new_entry)
grouped_rates[tuple(sorted(rate.reactants)),
tuple(sorted(rate.products))].append(rate)

# any entry in grouped_rates containing more than one rate is a duplicate
duplicates = [entry for entry in grouped_rates.values() if len(entry) > 1]

return duplicates

Expand All @@ -35,17 +28,7 @@ def is_allowed_dupe(rate_list):
"""rate_list is a list of rates that provide the same connection
in a network. Return True if this is an allowed duplicate"""

for allowed_dupe in ALLOWED_DUPLICATES:
found = 0
if len(rate_list) == len(allowed_dupe):
found = 1
for r in rate_list:
rate_key = f"{r.__class__.__name__}: {r.fname}"
if rate_key in allowed_dupe:
found *= 1
else:
found *= 0
if found:
return True

return False
# make rate_list into a set of strings in the same format as
# ALLOWED_DUPLICATES, then check if it matches any of the allowed sets
key_set = {f"{r.__class__.__name__}: {r.fname}" for r in rate_list}
return key_set in ALLOWED_DUPLICATES

0 comments on commit 57efb55

Please sign in to comment.