diff --git a/lshmm/api.py b/lshmm/api.py index b5c0e10..9d2db7b 100644 --- a/lshmm/api.py +++ b/lshmm/api.py @@ -106,13 +106,16 @@ def check_inputs( raise ValueError(err_msg) # Check the recombination probability. - if not ( - isinstance(prob_recombination, (int, float)) - or ( - isinstance(prob_recombination, np.ndarray) - and len(prob_recombination) == num_sites - ) + if isinstance(prob_recombination, (int, float)): + pass + elif ( + isinstance(prob_recombination, np.ndarray) + and len(prob_recombination) == num_sites ): + if prob_recombination[0] != 0: + err_msg = "First value in the recombination probability array must be zero." + raise ValueError(err_msg) + else: err_msg = ( "Recombination probability is not a scalar or an array of expected length." )