Skip to content

Commit

Permalink
Merge branch 'outliers'
Browse files Browse the repository at this point in the history
  • Loading branch information
ahillsley committed Feb 23, 2024
2 parents f3deb03 + 2a781fe commit 44a2b37
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 66 deletions.
10 changes: 7 additions & 3 deletions blinx/estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from .utils import find_maximum


def estimate_y(traces, max_y, parameter_ranges=None, hyper_parameters=None, initial_parameters=None):
def estimate_y(
traces, max_y, parameter_ranges=None, hyper_parameters=None, initial_parameters=None
):
"""Infer the most likely number of fluorophores for the given traces.
Args:
Expand Down Expand Up @@ -89,7 +91,9 @@ def estimate_y(traces, max_y, parameter_ranges=None, hyper_parameters=None, init
return max_likelihood_y[0], all_parameters, all_log_likelihoods, all_log_evidences


def estimate_parameters(traces, y, parameter_ranges, hyper_parameters, initial_parameters):
def estimate_parameters(
traces, y, parameter_ranges, hyper_parameters, initial_parameters
):
"""Fit the fluorescence and trace model to the given traces, assuming that
`y` fluorophores are present in each trace.
Expand All @@ -110,7 +114,7 @@ def estimate_parameters(traces, y, parameter_ranges, hyper_parameters, initial_p
hyper_parameters (:class:`HyperParameters`):
The hyper-parameters used for the maximum likelihood estimation.
initial_parameters (:class: `Parameters`):
Initial guesses for the parameters, if None guess them from a grid search over parameter_ranges
Expand Down
8 changes: 5 additions & 3 deletions blinx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def extract_traces(image_file_path, pick_file_path, drift_file_path, spot_size=0
background_size = int(4 * spot_size)

# define image ROI to trim any ROIs that extend beyond bounds of frame
image_roi = fg.Roi((0,0), (image_sequence.shape[0], image_sequence.shape[1]))
image_roi = fg.Roi((0, 0), (image_sequence.shape[0], image_sequence.shape[1]))

# all ROIs centered at (0, 0)
spot_roi = fg.Roi((-spot_size, -spot_size), (2 * spot_size + 1, 2 * spot_size + 1))
Expand All @@ -73,11 +73,13 @@ def extract_traces(image_file_path, pick_file_path, drift_file_path, spot_size=0
# spot_data: rows of (frame, x_coordinate, y_coordinate, ...)
spot_data = picked_spots[picked_spots["group"] == spot_num]

detected_frames = np.asarray(spot_data['frame']).astype(np.int32)
detected_frames = np.asarray(spot_data["frame"]).astype(np.int32)
displacements = drifts[detected_frames, :]

# keep only coordinates and correct for drift
spot_locations = np.asarray(spot_data[['x', 'y']]).astype(np.int32) + displacements
spot_locations = (
np.asarray(spot_data[["x", "y"]]).astype(np.int32) + displacements
)
# (x, y) -> (y, x)
spot_locations = spot_locations[:, ::-1]

Expand Down
88 changes: 46 additions & 42 deletions blinx/hyper_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,31 @@ class HyperParameters:
p_outlier (float, default=0.1):
a weight to account for outlier, or out of distribution intensity
a weight to account for outliers, or out of distribution intensity
measurements. Occasional measurements contain extreme noise and
this sets a minimum possible probability
num_outliers (int, default=20):
the number of outlier intensities to assign constant likelihoods,
removing them from contributing towards the difference in likelihoods between counts.
i.e. the 20 frames with the highest intensities will be omited
delta_t (float, defaul=200):
the exposure time of a single frame in ms
r_e_loc / r_bg_loc / g_loc / mu_loc /sigma_loc (float, default=None):
Mean (loc) of the prior distribution on each of the fittable parameters. If None a uniform prior is assumed.
If either loc or scale is given for a parameter, the other must be given as well.
r_e_scale / r_bg_scale / g_scale / mu_scale /sigma_scale (float, default=None):
Variance (scale) of the prior distribution on each of the fittable parameters. If None a uniform prior is assumed.
If either loc or scale is given for a parameter, the other must be given as well.
"""

def __init__(
Expand All @@ -84,8 +102,9 @@ def __init__(
max_x=None,
num_x_bins=1024,
p_outlier=0.1,
num_outliers=20,
delta_t=200.0,
param_min_max_scale=5, # how many sigmas away from mean should the model consider
num_traces=None,
r_e_loc=None,
r_e_scale=None,
r_bg_loc=None,
Expand All @@ -107,28 +126,15 @@ def __init__(
self.max_x = max_x
self.num_x_bins = num_x_bins
self.p_outlier = p_outlier
self.num_outliers = num_outliers
self.delta_t = delta_t

# priors
self.prior_locs = Parameters(
r_e=r_e_loc,
r_bg=r_bg_loc,
mu_ro=mu_loc,
sigma_ro=sigma_loc,
gain=g_loc,
p_on=None,
p_off=None,
probs_are_logits=True,
self.prior_locs = self.reshape_priors(
r_e_loc, r_bg_loc, g_loc, mu_loc, sigma_loc, num_traces
)
self.prior_scales = Parameters(
r_e=r_e_scale,
r_bg=r_bg_scale,
mu_ro=mu_scale,
gain=g_scale,
sigma_ro=sigma_scale,
p_on=None,
p_off=None,
probs_are_logits=True,
self.prior_scales = self.reshape_priors(
r_e_scale, r_bg_scale, g_scale, mu_scale, sigma_scale, num_traces
)

if sum([r_e_loc is None, r_e_scale is None]) == 1:
Expand All @@ -144,30 +150,28 @@ def __init__(

# below is experimental
# ------------------------------------------------
def check_length(self, val, target_length):

def _reshape(self, val, target):
if val is None:
return val
elif len(val) == 1:
return jnp.repeat(val, target_length)
elif len(val) > 1 and len(val) != target_length:
else:
val = jnp.asarray(val)
if val.size == 1:
return jnp.repeat(val, target)
elif val.size != target:
raise RuntimeError("not enough prior values provided")
elif val.size == target:
return val

def check_prior_shapes(self, target_length):
self.prior_locs = Parameters(
r_e=self.check_length(self.prior_locs.r_e, target_length),
r_bg=self.check_length(self.prior_locs.r_bg, target_length),
mu_ro=self.check_length(self.prior_locs.mu_ro, target_length),
sigma_ro=self.check_length(self.prior_locs.sigma_ro, target_length),
gain=self.check_length(self.prior_locs.gain, target_length),
p_on=self.check_length(self.prior_locs.p_on, target_length),
p_off=self.check_length(self.prior_locs.p_off, target_length),
)
self.prior_scales = Parameters(
r_e=self.check_length(self.prior_scales.r_e, target_length),
r_bg=self.check_length(self.prior_scales.r_bg, target_length),
mu_ro=self.check_length(self.prior_scales.mu_ro, target_length),
sigma_ro=self.check_length(self.prior_scales.sigma_ro, target_length),
gain=self.check_length(self.prior_scales.gain, target_length),
p_on=self.check_length(self.prior_scales.p_on, target_length),
p_off=self.check_length(self.prior_scales.p_off, target_length),
def reshape_priors(self, r_e, r_bg, g, mu, sigma, num_traces):
# reshape each prior to size (num_traces,)
return Parameters(
r_e=self._reshape(r_e, num_traces),
r_bg=self._reshape(r_bg, num_traces),
gain=self._reshape(g, num_traces),
mu_ro=self._reshape(mu, num_traces),
sigma_ro=self._reshape(sigma, num_traces),
p_on=None,
p_off=None,
probs_are_logits=True,
)
29 changes: 11 additions & 18 deletions blinx/trace_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)


def log_p_parameters(parameters, hyper_parameters, locs, scales):
def log_p_parameters(parameters, locs, scales):
"""
the prior distribution p(parameters)
"""
Expand All @@ -41,22 +41,7 @@ def log_p_x_parameters(trace, y, parameters, hyper_parameters, locs, scales):
"""
return get_trace_log_likelihood(
trace, y, parameters, hyper_parameters
) + log_p_parameters(
parameters,
hyper_parameters,
locs,
scales
# r_e_loc,
# r_e_scale,
# r_bg_loc,
# r_bg_scale,
# g_loc,
# g_scale,
# mu_loc,
# mu_scale,
# sigma_loc,
# sigma_scale,
)
) + log_p_parameters(parameters, locs, scales)


def get_trace_log_likelihood(trace, y, parameters, hyper_parameters):
Expand Down Expand Up @@ -114,7 +99,15 @@ def get_trace_log_likelihood(trace, y, parameters, hyper_parameters):
in_axes=(None, None, 0, None, None, None, None, None, None),
)(x_left, x_right, zs, r_e, r_bg, mu_ro, sigma_ro, gain, hyper_parameters)

return get_measurement_log_likelihood(p_measurement.T, p_initial, p_transition)
# assign constant likelihood to "outlier" frames
outliers = jax.lax.top_k(x_right, hyper_parameters.num_outliers)
trimmed_p_measurement = p_measurement.at[:, outliers[1]].set(
1 / hyper_parameters.num_x_bins
)

return get_measurement_log_likelihood(
trimmed_p_measurement.T, p_initial, p_transition
)


def single_optimal_trace(trace, y, parameters, hyper_parameters):
Expand Down
36 changes: 36 additions & 0 deletions tests/test_priors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import jax.numpy as jnp
import numpy as np
import pytest
from blinx.hyper_parameters import HyperParameters


def test_priors():
hps = HyperParameters(
r_e_loc=None, # No priors for that parameter
r_e_scale=None, # if either loc or scale is None other must be as well
r_bg_loc=1, # does it convert an int to the correct shape jax array
r_bg_scale=1,
g_loc=[1], # does it convert a list to the correct shape jax array
g_scale=[1],
mu_loc=0.75, # does it convert a float to the correct shape jax array
mu_scale=0.1,
sigma_loc=jnp.array(1), # if the input is already a jax array
sigma_scale=1,
num_traces=5,
)

# check that None priors remain None after re-shaping
np.testing.assert_equal(hps.prior_locs.r_e, None)
np.testing.assert_equal(hps.prior_scales.r_e, None)

# check that priors have the correct shape
np.testing.assert_equal(hps.prior_locs.r_bg.shape, (5,))
np.testing.assert_equal(hps.prior_scales.r_bg.shape, (5,))
np.testing.assert_equal(hps.prior_locs.gain.shape, (5,))
np.testing.assert_equal(hps.prior_scales.gain.shape, (5,))
np.testing.assert_equal(hps.prior_locs.mu_ro.shape, (5,))
np.testing.assert_equal(hps.prior_scales.mu_ro.shape, (5,))
np.testing.assert_equal(hps.prior_locs.sigma_ro.shape, (5,))
np.testing.assert_equal(hps.prior_scales.sigma_ro.shape, (5,))

return

0 comments on commit 44a2b37

Please sign in to comment.