Skip to content

Commit

Permalink
use SampledSims object in sim multi
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Aug 25, 2024
1 parent 9406855 commit e9534a1
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 19 deletions.
19 changes: 10 additions & 9 deletions neurodsp/sim/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from neurodsp.utils.core import counter
from neurodsp.sim.sims import Simulations
from neurodsp.sim.sims import Simulations, SampledSimulations

###################################################################################################
###################################################################################################
Expand Down Expand Up @@ -167,7 +167,7 @@ def sim_across_values(sim_func, sim_params, n_sims, output='dict'):
return sims


def sim_from_sampler(sim_func, sim_sampler, n_sims, return_params=False):
def sim_from_sampler(sim_func, sim_sampler, n_sims, return_type='object', return_params=False):
"""Simulate a set of signals from a parameter sampler.
Parameters
Expand All @@ -178,8 +178,11 @@ def sim_from_sampler(sim_func, sim_sampler, n_sims, return_params=False):
Parameter definition to sample from.
n_sims : int
Number of simulations to create per parameter definition.
return_params : bool, default: False
Whether to collect and return the parameters of all the generated simulations.
return_type : {'object', 'array'}
XX
#return_params : bool, default: False
# Whether to collect and return the parameters of all the generated simulations.
Returns
-------
Expand All @@ -205,11 +208,9 @@ def sim_from_sampler(sim_func, sim_sampler, n_sims, return_params=False):
sigs = np.zeros([n_sims, sim_sampler.params['n_seconds'] * sim_sampler.params['fs']])
for ind, (sig, params) in enumerate(sig_sampler(sim_func, sim_sampler, True, n_sims)):
sigs[ind, :] = sig
all_params[ind] = params

if return_params:
all_params[ind] = params

if return_params:
return sigs, all_params
if return_type == 'object':
return SampledSimulations(sigs, sim_func, all_params)
else:
return sigs
28 changes: 18 additions & 10 deletions neurodsp/tests/sim/test_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from neurodsp.sim.aperiodic import sim_powerlaw
from neurodsp.sim.sims import Simulations
from neurodsp.sim.sims import Simulations, SampledSimulations
from neurodsp.sim.update import create_updater, create_sampler, ParamSampler

from neurodsp.sim.multi import *
Expand Down Expand Up @@ -32,15 +32,17 @@ def test_sig_sampler():

def test_sim_multiple():

n_sims = 2
params = {'n_seconds' : 2, 'fs' : 250, 'exponent' : -1}

sims = sim_multiple(sim_powerlaw, params, 2, 'object')
assert isinstance(sims, Simulations)
assert sims.signals.shape[0] == 2
assert sims.params == params
sims_obj = sim_multiple(sim_powerlaw, params, n_sims, 'object')
assert isinstance(sims_obj, Simulations)
assert sims_obj.signals.shape[0] == n_sims
assert sims_obj.params == params

sigs = sim_multiple(sim_powerlaw, params, 2, 'array')
assert sigs.shape[0] == 2
sims_arr = sim_multiple(sim_powerlaw, params, n_sims, 'array')
assert isinstance(sims_arr, np.ndarray)
assert sims_arr.shape[0] == n_sims

def test_sim_across_values():

Expand All @@ -57,10 +59,16 @@ def test_sim_across_values():

def test_sim_from_sampler():

n_sims = 2
params = {'n_seconds' : 10, 'fs' : 250, 'exponent' : None}
samplers = {create_updater('exponent') : create_sampler([-2, -1, 0])}
psampler = ParamSampler(params, samplers)

sigs = sim_from_sampler(sim_powerlaw, psampler, 2)
assert isinstance(sigs, np.ndarray)
assert sigs.shape[0] == 2
sims_obj = sim_from_sampler(sim_powerlaw, psampler, n_sims, 'object')
assert isinstance(sims_obj, SampledSimulations)
assert sims_obj.signals.shape[0] == n_sims
assert len(sims_obj.params) == n_sims

sims_arr = sim_from_sampler(sim_powerlaw, psampler, n_sims, 'array')
assert isinstance(sims_arr, np.ndarray)
assert sims_arr.shape[0] == n_sims

0 comments on commit e9534a1

Please sign in to comment.