diff --git a/src/westpa/core/binning/mab.py b/src/westpa/core/binning/mab.py index e298213c..42caccf6 100644 --- a/src/westpa/core/binning/mab.py +++ b/src/westpa/core/binning/mab.py @@ -94,8 +94,8 @@ def determine_total_bins(self, nbins_per_dim, direction, skip, bottleneck, **kwa Parameters ---------- - nbins_per_dim : int - Number of total bins in each direction. + nbins_per_dim : list of int + Number of total bins in each dimension within the linear portion. direction : list of int Direction in each dimension. See __init__ for more information. skip : list of int @@ -111,14 +111,19 @@ def determine_total_bins(self, nbins_per_dim, direction, skip, bottleneck, **kwa Number of total bins. """ - n_total_bins = np.prod(nbins_per_dim) + n_total_bins = np.prod(nbins_per_dim) # Number of Bins in the linear portion ndim = len(nbins_per_dim) for i in range(ndim): if skip[i] == 0: - if direction[i] != 0: - n_total_bins += 1 + 1 * bottleneck - else: + if direction[i] == 0: + # Both directions (leading/trailing) + bottlenecks if enabled n_total_bins += 2 + 2 * bottleneck + elif direction[i] == 86: + # No leading/trailing, 2 bottlenecks if enabled + n_total_bins += 2 * bottleneck + else: + # direction[i] == -1 or 1, Just one leading/trailing + 1 bottleneck + n_total_bins += 1 + 1 * bottleneck else: n_total_bins -= nbins_per_dim[i] - 1 n_total_bins += 1 * ndim # or else it will be one bin short diff --git a/tests/test_binning.py b/tests/test_binning.py index 8120f4e0..77712205 100644 --- a/tests/test_binning.py +++ b/tests/test_binning.py @@ -12,6 +12,7 @@ RecursiveBinMapper, ) from westpa.core.binning.assign import coord_dtype +from westpa.core.binning.mab import MABBinMapper class TestRectilinearBinMapper: @@ -324,3 +325,13 @@ def test2dRectilinearRecursion(self): print('EXPECTED', expected) print('OUTPUT ', assignments) assert (assignments == expected).all() + + +class TestMABBinMapper: + def test_init(self): + mab = MABBinMapper([5]) + assert mab.nbins == 9 + + def test_determine_total_bins(self): + mab = MABBinMapper([5]) + assert mab.determine_total_bins(nbins_per_dim=[5, 1], direction=[1, 86], skip=[0, 0], bottleneck=True) == 9