Skip to content

Commit

Permalink
Merge pull request #6 from gyoge0/vmap_generate_trace
Browse files Browse the repository at this point in the history
add vmapped version of generate_trace
  • Loading branch information
ahillsley authored Jun 20, 2024
2 parents 27a6200 + 74ff46e commit 54b21f8
Showing 1 changed file with 51 additions and 2 deletions.
53 changes: 51 additions & 2 deletions blinx/trace_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,13 @@ def log_p_parameters(parameters, locs, scales):
if locs.sigma_ro is not None:
log_p += jnp.log(norm.pdf(parameters.sigma_ro, locs.sigma_ro, scales.sigma_ro))
if locs._p_on_logit is not None:
log_p += jnp.log(norm.pdf(parameters.p_on, locs._p_on_logit, scales._p_on_logit))
log_p += jnp.log(
norm.pdf(parameters.p_on, locs._p_on_logit, scales._p_on_logit)
)
if locs._p_off_logit is not None:
log_p += jnp.log(norm.pdf(parameters.p_off, locs._p_off_logit, scales._p_off_logit))
log_p += jnp.log(
norm.pdf(parameters.p_off, locs._p_off_logit, scales._p_off_logit)
)

return log_p

Expand Down Expand Up @@ -281,6 +285,51 @@ def sample_next_z(z, p_transition, key):
return z


def vmap_generate_trace(
num_traces, y, parameters, num_frames, hyper_parameters, seed=None
):
"""Create several simulated intensity traces.
Args:
num_traces (int):
- the number of traces to simulate
y (int):
- the total number of fluorescent emitters
parameters (:class:'Parameters'):
- the parameters of the fluoresent and trace model
num_frames (int):
- the number of observations to simulate
hyper_parameters (:class:`HyperParameters`):
- hypxer-parameters with `delta_t` set for the time between frames in the traces
seed (int, optional):
- random seed for the jax psudo rendom number generator
Returns:
trace (array):
- a num_traces x num_frames array containing traces with intensity values for each frame
states (array):
- array the same shape as trace, containing the number of 'on' emitters in each frame
"""

if seed is None:
seed = time.time_ns()
key = random.PRNGKey(seed)
subkeys = random.split(key, num_traces)
seeds = subkeys[:, 0]
mapped = jax.vmap(
generate_trace,
in_axes=(None, None, None, None, 0),
)
trace, zs = mapped(y, parameters, num_frames, hyper_parameters, seeds)
return jnp.squeeze(trace), jnp.squeeze(zs)


def create_transition_matrix(y, p_on, p_off):
"""Create a transition matrix for the number of active elements, given that
elements can randomly turn on and off.
Expand Down

0 comments on commit 54b21f8

Please sign in to comment.