From 9fda2c2454d18507e659840cb34bff125251b892 Mon Sep 17 00:00:00 2001 From: amoodie Date: Thu, 7 May 2020 21:17:48 -0500 Subject: [PATCH] set rng in numba, and make all calls to random in numba jitted functions --- pyDeltaRCM/init_tools.py | 6 +++++- pyDeltaRCM/sed_tools.py | 8 ++++---- pyDeltaRCM/utils.py | 10 ++++++++++ tests/test_yaml_parsing.py | 11 ++++++----- 4 files changed, 25 insertions(+), 10 deletions(-) diff --git a/pyDeltaRCM/init_tools.py b/pyDeltaRCM/init_tools.py index 614e4bfc..bc7fd204 100644 --- a/pyDeltaRCM/init_tools.py +++ b/pyDeltaRCM/init_tools.py @@ -23,6 +23,10 @@ import logging import time import yaml + +from . import utils + + # tools for initiating deltaRCM model domain @@ -98,7 +102,7 @@ def import_files(self): if self.seed is not None: if self.verbose >= 2: print("setting random seed to %s " % str(self.seed)) - np.random.seed(self.seed) + utils.set_random_seed(self.seed) def set_constants(self): diff --git a/pyDeltaRCM/sed_tools.py b/pyDeltaRCM/sed_tools.py index 39931fe3..b2fd2bb7 100644 --- a/pyDeltaRCM/sed_tools.py +++ b/pyDeltaRCM/sed_tools.py @@ -240,8 +240,8 @@ def sand_route(self): theta_sed = self.theta_sand num_starts = int(self.Np_sed * self.f_bedload) - start_indices = [self.random_pick_inlet( - self.inlet) for x in range(num_starts)] + typed_inlet = self.inlet_typed + start_indices = [utils.random_pick_inlet(typed_inlet) for x in range(num_starts)] for np_sed in range(num_starts): @@ -282,8 +282,8 @@ def mud_route(self): theta_sed = self.theta_mud num_starts = int(self.Np_sed * (1 - self.f_bedload)) - start_indices = [self.random_pick_inlet( - self.inlet) for x in range(num_starts)] + typed_inlet = self.inlet_typed + start_indices = [utils.random_pick_inlet(typed_inlet) for x in range(num_starts)] for np_sed in range(num_starts): diff --git a/pyDeltaRCM/utils.py b/pyDeltaRCM/utils.py index 40c1bf42..7f539a96 100644 --- a/pyDeltaRCM/utils.py +++ b/pyDeltaRCM/utils.py @@ -7,6 +7,16 @@ # utilities used in various places in the model and docs +@numba.njit +def set_random_seed(_seed): + np.random.seed(_seed) + + +@numba.njit +def get_random_uniform(N): + return np.random.uniform(0, 1, N) + + @numba.njit def random_pick(probs): """ diff --git a/tests/test_yaml_parsing.py b/tests/test_yaml_parsing.py index 67b6a6da..9ce61180 100644 --- a/tests/test_yaml_parsing.py +++ b/tests/test_yaml_parsing.py @@ -6,6 +6,7 @@ from pyDeltaRCM.deltaRCM_driver import pyDeltaRCM +from pyDeltaRCM.utils import set_random_seed, get_random_uniform # utilities for file writing def create_temporary_file(tmp_path, file_name): @@ -105,13 +106,13 @@ def test_random_seed_settings_value(tmp_path): p, f = create_temporary_file(tmp_path, file_name) write_parameter_to_file(f, 'seed', 9999) f.close() - np.random.seed(9999) - _preval_same = np.random.uniform() - np.random.seed(5) - _preval_diff = np.random.uniform(1000) + set_random_seed(9999) + _preval_same = get_random_uniform(1) + set_random_seed(5) + _preval_diff = get_random_uniform(1000) delta = pyDeltaRCM(input_file=p) assert delta.seed == 9999 - _postval_same = np.random.uniform() + _postval_same = get_random_uniform(1) assert _preval_same == _postval_same