Skip to content

Commit

Permalink
Fix MAB 86 Reporting (westpa#435)
Browse files Browse the repository at this point in the history
* make mab reporting more accurate

* clean up a little bit

* fix docstring for mab

* add tests for MABBinMapper
  • Loading branch information
jeremyleung521 authored Jul 9, 2024
1 parent a2d390e commit 579af04
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/westpa/core/binning/mab.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def determine_total_bins(self, nbins_per_dim, direction, skip, bottleneck, **kwa
Parameters
----------
nbins_per_dim : int
Number of total bins in each direction.
nbins_per_dim : list of int
Number of total bins in each dimension within the linear portion.
direction : list of int
Direction in each dimension. See __init__ for more information.
skip : list of int
Expand All @@ -111,14 +111,19 @@ def determine_total_bins(self, nbins_per_dim, direction, skip, bottleneck, **kwa
Number of total bins.
"""
n_total_bins = np.prod(nbins_per_dim)
n_total_bins = np.prod(nbins_per_dim) # Number of Bins in the linear portion
ndim = len(nbins_per_dim)
for i in range(ndim):
if skip[i] == 0:
if direction[i] != 0:
n_total_bins += 1 + 1 * bottleneck
else:
if direction[i] == 0:
# Both directions (leading/trailing) + bottlenecks if enabled
n_total_bins += 2 + 2 * bottleneck
elif direction[i] == 86:
# No leading/trailing, 2 bottlenecks if enabled
n_total_bins += 2 * bottleneck
else:
# direction[i] == -1 or 1, Just one leading/trailing + 1 bottleneck
n_total_bins += 1 + 1 * bottleneck
else:
n_total_bins -= nbins_per_dim[i] - 1
n_total_bins += 1 * ndim # or else it will be one bin short
Expand Down
11 changes: 11 additions & 0 deletions tests/test_binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
RecursiveBinMapper,
)
from westpa.core.binning.assign import coord_dtype
from westpa.core.binning.mab import MABBinMapper


class TestRectilinearBinMapper:
Expand Down Expand Up @@ -324,3 +325,13 @@ def test2dRectilinearRecursion(self):
print('EXPECTED', expected)
print('OUTPUT ', assignments)
assert (assignments == expected).all()


class TestMABBinMapper:
def test_init(self):
mab = MABBinMapper([5])
assert mab.nbins == 9

def test_determine_total_bins(self):
mab = MABBinMapper([5])
assert mab.determine_total_bins(nbins_per_dim=[5, 1], direction=[1, 86], skip=[0, 0], bottleneck=True) == 9

0 comments on commit 579af04

Please sign in to comment.