diff --git a/lshmm/api.py b/lshmm/api.py index 7b32f38..2bb73f5 100644 --- a/lshmm/api.py +++ b/lshmm/api.py @@ -105,14 +105,17 @@ def check_inputs( err_msg = "Mutation probability is not a scalar or an array of expected length." raise ValueError(err_msg) - # Check the recombination probability. - if not ( - isinstance(prob_recombination, (int, float)) - or ( - isinstance(prob_recombination, np.ndarray) - and len(prob_recombination) == num_sites - ) +# Check the recombination probability. + if isinstance(prob_recombination, (int, float)): + pass + elif ( + isinstance(prob_recombination, np.ndarray) + and len(prob_recombination) == num_sites ): + if prob_recombination[0] != 0: + err_msg = "First value in the recombination probability array must be zero." + raise ValueError(err_msg) + else: err_msg = ( "Recombination probability is not a scalar or an array of expected length." )