Skip to content

Commit

Permalink
drop return type option from sim multi
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Sep 7, 2024
1 parent bd9c3ad commit 7348a50
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 65 deletions.
62 changes: 17 additions & 45 deletions neurodsp/sim/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
###################################################################################################
###################################################################################################

def sim_multiple(sim_func, sim_params, n_sims, return_type='object'):
def sim_multiple(sim_func, sim_params, n_sims):
"""Simulate multiple samples of a specified simulation.
Parameters
Expand All @@ -21,16 +21,12 @@ def sim_multiple(sim_func, sim_params, n_sims, return_type='object'):
The parameters for the simulated signal, passed into `sim_func`.
n_sims : int
Number of simulations to create.
return_type : {'object', 'array'}
Specifies the return type of the simulations.
If 'object', returns simulations and metadata in a 'Simulations' object.
If 'array', returns the simulations (no metadata) in an array.
Returns
-------
sims : Simulations or 2d array
Simulations, return type depends on `return_type` argument.
Simulated time series are organized as [n_sims, sig length].
sims : Simulations
Simulations object with simulated time series and metadata.
Simulated signals are in the 'signals' attribute with shape [n_sims, sig_length].
Examples
--------
Expand All @@ -45,13 +41,10 @@ def sim_multiple(sim_func, sim_params, n_sims, return_type='object'):
for ind, sig in enumerate(sig_yielder(sim_func, sim_params, n_sims)):
sims.add_signal(sig, index=ind)

if return_type == 'array':
sims = sims.signals

return sims


def sim_across_values(sim_func, sim_params, return_type='object'):
def sim_across_values(sim_func, sim_params):
"""Simulate signals across different parameter values.
Parameters
Expand All @@ -60,16 +53,12 @@ def sim_across_values(sim_func, sim_params, return_type='object'):
Function to create the simulated time series.
sim_params : ParamIter or iterable or list of dict
Simulation parameters for `sim_func`.
return_type : {'object', 'array'}
Specifies the return type of the simulations.
If 'object', returns simulations and metadata in a 'VariableSimulations' object.
If 'array', returns the simulations (no metadata) in an array.
Returns
-------
sims : VariableSimulations or array
Simulations, return type depends on `return_type` argument.
If array, signals are collected together as [n_sims, sig_length].
sims : VariableSimulations
Simulations object with simulated time series and metadata.
Simulated signals are in the 'signals' attribute with shape [n_sims, sig_length].
Examples
--------
Expand All @@ -95,13 +84,10 @@ def sim_across_values(sim_func, sim_params, return_type='object'):
for ind, cur_sim_params in enumerate(sim_params):
sims.add_signal(sim_func(**cur_sim_params), cur_sim_params, index=ind)

if return_type == 'array':
sims = sims.signals

return sims


def sim_multi_across_values(sim_func, sim_params, n_sims, return_type='object'):
def sim_multi_across_values(sim_func, sim_params, n_sims):
"""Simulate multiple signals across different parameter values.
Parameters
Expand All @@ -112,16 +98,12 @@ def sim_multi_across_values(sim_func, sim_params, n_sims, return_type='object'):
Simulation parameters for `sim_func`.
n_sims : int
Number of simulations to create per parameter definition.
return_type : {'object', 'array'}
Specifies the return type of the simulations.
If 'object', returns simulations and metadata in a 'MultiSimulations' object.
If 'array', returns the simulations (no metadata) in an array.
Returns
-------
sims : MultiSimulations or array
Simulations, return type depends on `return_type` argument.
If array, signals are collected together as [n_sets, n_sims, sig_length].
sims : MultiSimulations
Simulations object with simulated time series and metadata.
Simulated signals are in the 'signals' attribute with shape [n_sets, n_sims, sig_length].
Examples
--------
Expand All @@ -143,15 +125,12 @@ def sim_multi_across_values(sim_func, sim_params, n_sims, return_type='object'):
sims = MultiSimulations(update=getattr(sim_params, 'update', None),
component=getattr(sim_params, 'component', None))
for cur_sim_params in sim_params:
sims.add_signals(sim_multiple(sim_func, cur_sim_params, n_sims, 'object'))

if return_type == 'array':
sims = np.squeeze(np.array([el.signals for el in sims]))
sims.add_signals(sim_multiple(sim_func, cur_sim_params, n_sims))

return sims


def sim_from_sampler(sim_func, sim_sampler, n_sims, return_type='object'):
def sim_from_sampler(sim_func, sim_sampler, n_sims):
"""Simulate a set of signals from a parameter sampler.
Parameters
Expand All @@ -162,16 +141,12 @@ def sim_from_sampler(sim_func, sim_sampler, n_sims, return_type='object'):
Parameter definition to sample from.
n_sims : int
Number of simulations to create per parameter definition.
return_type : {'object', 'array'}
Specifies the return type of the simulations.
If 'object', returns simulations and metadata in a 'VariableSimulations' object.
If 'array', returns the simulations (no metadata) in an array.
Returns
-------
sims : VariableSimulations or 2d array
Simulations, return type depends on `return_type` argument.
If array, simulations are organized as [n_sims, sig length].
sims : VariableSimulations
Simulations object with simulated time series and metadata.
Simulated signals are in the 'signals' attribute with shape [n_sims, sig_length].
Examples
--------
Expand All @@ -189,7 +164,4 @@ def sim_from_sampler(sim_func, sim_sampler, n_sims, return_type='object'):
for ind, (sig, params) in enumerate(sig_sampler(sim_func, sim_sampler, True, n_sims)):
sims.add_signal(sim_func(**params), params, index=ind)

if return_type == 'array':
sims = sims.signals

return sims
24 changes: 4 additions & 20 deletions neurodsp/tests/sim/test_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,23 @@ def test_sim_multiple():
n_sims = 2
params = {'n_seconds' : 2, 'fs' : 250, 'exponent' : -1}

sims_obj = sim_multiple(sim_powerlaw, params, n_sims, 'object')
sims_obj = sim_multiple(sim_powerlaw, params, n_sims)
assert isinstance(sims_obj, Simulations)
assert sims_obj.signals.shape[0] == n_sims
assert sims_obj.params == params

sims_arr = sim_multiple(sim_powerlaw, params, n_sims, 'array')
assert isinstance(sims_arr, np.ndarray)
assert sims_arr.shape[0] == n_sims

def test_sim_across_values(tsim_iters):

params = [{'n_seconds' : 2, 'fs' : 250, 'exponent' : -2},
{'n_seconds' : 2, 'fs' : 250, 'exponent' : -1}]

sims_obj = sim_across_values(sim_powerlaw, params, 'object')
sims_obj = sim_across_values(sim_powerlaw, params)
assert isinstance(sims_obj, VariableSimulations)
assert len(sims_obj) == len(params)
for csim, cparams, oparams in zip(sims_obj, sims_obj.params, params):
assert isinstance(csim, np.ndarray)
assert cparams == oparams

sims_arr = sim_across_values(sim_powerlaw, params, 'array')
assert isinstance(sims_arr, np.ndarray)
assert sims_arr.shape[0] == len(params)

# Test with ParamIter input
siter = tsim_iters['pl_exp']
sims_iter = sim_across_values(sim_powerlaw, siter)
Expand All @@ -54,17 +46,13 @@ def test_sim_multi_across_values(tsim_iters):
params = [{'n_seconds' : 2, 'fs' : 250, 'exponent' : -2},
{'n_seconds' : 2, 'fs' : 250, 'exponent' : -1}]

sims_obj = sim_multi_across_values(sim_powerlaw, params, n_sims, 'object')
sims_obj = sim_multi_across_values(sim_powerlaw, params, n_sims)
assert isinstance(sims_obj, MultiSimulations)
for sims, cparams in zip(sims_obj, params):
assert isinstance(sims, Simulations)
assert len(sims) == n_sims
assert sims.params == cparams

sims_arr = sim_multi_across_values(sim_powerlaw, params, n_sims, 'array')
assert isinstance(sims_arr, np.ndarray)
assert sims_arr.shape[0:2] == (len(params), n_sims)

# Test with ParamIter input
siter = tsim_iters['pl_exp']
sims_iter = sim_multi_across_values(sim_powerlaw, siter, n_sims)
Expand All @@ -79,11 +67,7 @@ def test_sim_from_sampler():
samplers = {create_updater('exponent') : create_sampler([-2, -1, 0])}
psampler = ParamSampler(params, samplers)

sims_obj = sim_from_sampler(sim_powerlaw, psampler, n_sims, 'object')
sims_obj = sim_from_sampler(sim_powerlaw, psampler, n_sims)
assert isinstance(sims_obj, VariableSimulations)
assert sims_obj.signals.shape[0] == n_sims
assert len(sims_obj.params) == n_sims

sims_arr = sim_from_sampler(sim_powerlaw, psampler, n_sims, 'array')
assert isinstance(sims_arr, np.ndarray)
assert sims_arr.shape[0] == n_sims

0 comments on commit 7348a50

Please sign in to comment.