diff --git a/lshmm/api.py b/lshmm/api.py index ab76286..e5e2b0b 100644 --- a/lshmm/api.py +++ b/lshmm/api.py @@ -108,12 +108,9 @@ def check_inputs( if isinstance(prob_recombination, (int, float)): prob_recombination = np.zeros(num_sites, dtype=np.float64) + prob_recombination prob_recombination[0] = 0.0 - elif ( - isinstance(prob_recombination, np.ndarray) - and len(prob_recombination) == num_sites - ): - if prob_recombination[0] != 0: - err_msg = "First value in the recombination probability array must be zero." + elif isinstance(prob_recombination, np.ndarray): + if len(prob_recombination) != num_sites: + err_msg = "Recombination probability is an array of unexpected length." raise ValueError(err_msg) else: err_msg = (