diff --git a/neurodsp/sim/multi.py b/neurodsp/sim/multi.py index 741665d6..8cfa303f 100644 --- a/neurodsp/sim/multi.py +++ b/neurodsp/sim/multi.py @@ -5,7 +5,7 @@ import numpy as np from neurodsp.utils.core import counter -from neurodsp.sim.sims import Simulations +from neurodsp.sim.sims import Simulations, SampledSimulations ################################################################################################### ################################################################################################### @@ -167,7 +167,7 @@ def sim_across_values(sim_func, sim_params, n_sims, output='dict'): return sims -def sim_from_sampler(sim_func, sim_sampler, n_sims, return_params=False): +def sim_from_sampler(sim_func, sim_sampler, n_sims, return_type='object', return_params=False): """Simulate a set of signals from a parameter sampler. Parameters @@ -178,8 +178,11 @@ def sim_from_sampler(sim_func, sim_sampler, n_sims, return_params=False): Parameter definition to sample from. n_sims : int Number of simulations to create per parameter definition. - return_params : bool, default: False - Whether to collect and return the parameters of all the generated simulations. + + return_type : {'object', 'array'} + XX + #return_params : bool, default: False + # Whether to collect and return the parameters of all the generated simulations. Returns ------- @@ -205,11 +208,9 @@ def sim_from_sampler(sim_func, sim_sampler, n_sims, return_params=False): sigs = np.zeros([n_sims, sim_sampler.params['n_seconds'] * sim_sampler.params['fs']]) for ind, (sig, params) in enumerate(sig_sampler(sim_func, sim_sampler, True, n_sims)): sigs[ind, :] = sig + all_params[ind] = params - if return_params: - all_params[ind] = params - - if return_params: - return sigs, all_params + if return_type == 'object': + return SampledSimulations(sigs, sim_func, all_params) else: return sigs diff --git a/neurodsp/tests/sim/test_multi.py b/neurodsp/tests/sim/test_multi.py index 0bc34506..d7c63c99 100644 --- a/neurodsp/tests/sim/test_multi.py +++ b/neurodsp/tests/sim/test_multi.py @@ -3,7 +3,7 @@ import numpy as np from neurodsp.sim.aperiodic import sim_powerlaw -from neurodsp.sim.sims import Simulations +from neurodsp.sim.sims import Simulations, SampledSimulations from neurodsp.sim.update import create_updater, create_sampler, ParamSampler from neurodsp.sim.multi import * @@ -32,15 +32,17 @@ def test_sig_sampler(): def test_sim_multiple(): + n_sims = 2 params = {'n_seconds' : 2, 'fs' : 250, 'exponent' : -1} - sims = sim_multiple(sim_powerlaw, params, 2, 'object') - assert isinstance(sims, Simulations) - assert sims.signals.shape[0] == 2 - assert sims.params == params + sims_obj = sim_multiple(sim_powerlaw, params, n_sims, 'object') + assert isinstance(sims_obj, Simulations) + assert sims_obj.signals.shape[0] == n_sims + assert sims_obj.params == params - sigs = sim_multiple(sim_powerlaw, params, 2, 'array') - assert sigs.shape[0] == 2 + sims_arr = sim_multiple(sim_powerlaw, params, n_sims, 'array') + assert isinstance(sims_arr, np.ndarray) + assert sims_arr.shape[0] == n_sims def test_sim_across_values(): @@ -57,10 +59,16 @@ def test_sim_across_values(): def test_sim_from_sampler(): + n_sims = 2 params = {'n_seconds' : 10, 'fs' : 250, 'exponent' : None} samplers = {create_updater('exponent') : create_sampler([-2, -1, 0])} psampler = ParamSampler(params, samplers) - sigs = sim_from_sampler(sim_powerlaw, psampler, 2) - assert isinstance(sigs, np.ndarray) - assert sigs.shape[0] == 2 + sims_obj = sim_from_sampler(sim_powerlaw, psampler, n_sims, 'object') + assert isinstance(sims_obj, SampledSimulations) + assert sims_obj.signals.shape[0] == n_sims + assert len(sims_obj.params) == n_sims + + sims_arr = sim_from_sampler(sim_powerlaw, psampler, n_sims, 'array') + assert isinstance(sims_arr, np.ndarray) + assert sims_arr.shape[0] == n_sims