diff --git a/spaceprime/demography.py b/spaceprime/demography.py index c08b056..b88e590 100644 --- a/spaceprime/demography.py +++ b/spaceprime/demography.py @@ -5,8 +5,7 @@ import numpy as np from typing import Union, List, Optional import msprime -from . import utilities as ut -from .utilities import split_landscape_by_pop, calc_migration_matrix +from .utilities import calc_migration_matrix ## not sure if I should add a subclass or not. There are only two new functions, so might not be worth it. @@ -82,6 +81,8 @@ def stepping_stone_2d( if len(d.shape) == 3: model = add_landscape_change(model, d, timestep, rate, scale) + model.sort_events() + return model @@ -196,6 +197,9 @@ def add_landscape_change( dest=f"deme_{i + di}_{j + dj}", ) + # sort the events to make sure everything runs in order + model.sort_events() + return model @@ -208,7 +212,7 @@ def add_landscape_change( def add_ancestral_populations( model: msprime.Demography, anc_sizes: List[float], - merge_time: float, + merge_time: Union[float, int], anc_id: Optional[np.ndarray] = None, anc_merge_times: Optional[List[float]] = None, anc_merge_sizes: Optional[List[float]] = None, @@ -220,7 +224,7 @@ def add_ancestral_populations( Parameters: model (msprime.Demography): The demographic model to which ancestral populations will be added. anc_sizes (List[float]): A list of ancestral population sizes. - merge_time (float): The time at which all demes in the spatial simulation merge into one or more ancestral populations. + merge_time (Union[float, int]): The time at which all demes in the spatial simulation merge into one or more ancestral populations. anc_id (Optional[np.ndarray], optional): An array of ancestral population IDs- the output of [split_landscape_by_pop][utilities.split_landscape_by_pop]. Defaults to None. anc_merge_times (Optional[List[float]], optional): A list of merge times for ancestral populations. Defaults to None. @@ -251,6 +255,11 @@ def add_ancestral_populations( ) # Rest of the code... + if isinstance(merge_time, list): + raise ValueError("merge_time should be a single float or int, not a list.") + elif not isinstance(merge_time, (float, int)): + raise TypeError("merge_time should be a float or int.") + if anc_id is None: # add an ancestral population model.add_population(name="ANC_1", initial_size=anc_sizes[0]) @@ -343,4 +352,7 @@ def add_ancestral_populations( rate=migration_rate, ) + # sort the events to make sure everything runs in order + model.sort_events() + return model