diff --git a/blinx/estimate.py b/blinx/estimate.py index c7d952e..cf24eac 100644 --- a/blinx/estimate.py +++ b/blinx/estimate.py @@ -15,7 +15,9 @@ from .utils import find_maximum -def estimate_y(traces, max_y, parameter_ranges=None, hyper_parameters=None, initial_parameters=None): +def estimate_y( + traces, max_y, parameter_ranges=None, hyper_parameters=None, initial_parameters=None +): """Infer the most likely number of fluorophores for the given traces. Args: @@ -89,7 +91,9 @@ def estimate_y(traces, max_y, parameter_ranges=None, hyper_parameters=None, init return max_likelihood_y[0], all_parameters, all_log_likelihoods, all_log_evidences -def estimate_parameters(traces, y, parameter_ranges, hyper_parameters, initial_parameters): +def estimate_parameters( + traces, y, parameter_ranges, hyper_parameters, initial_parameters +): """Fit the fluorescence and trace model to the given traces, assuming that `y` fluorophores are present in each trace. @@ -110,7 +114,7 @@ def estimate_parameters(traces, y, parameter_ranges, hyper_parameters, initial_p hyper_parameters (:class:`HyperParameters`): The hyper-parameters used for the maximum likelihood estimation. - + initial_parameters (:class: `Parameters`): Initial guesses for the parameters, if None guess them from a grid search over parameter_ranges diff --git a/blinx/extract.py b/blinx/extract.py index 54b9d80..f024d5c 100644 --- a/blinx/extract.py +++ b/blinx/extract.py @@ -52,7 +52,7 @@ def extract_traces(image_file_path, pick_file_path, drift_file_path, spot_size=0 background_size = int(4 * spot_size) # define image ROI to trim any ROIs that extend beyond bounds of frame - image_roi = fg.Roi((0,0), (image_sequence.shape[0], image_sequence.shape[1])) + image_roi = fg.Roi((0, 0), (image_sequence.shape[0], image_sequence.shape[1])) # all ROIs centered at (0, 0) spot_roi = fg.Roi((-spot_size, -spot_size), (2 * spot_size + 1, 2 * spot_size + 1)) @@ -73,11 +73,13 @@ def extract_traces(image_file_path, pick_file_path, drift_file_path, spot_size=0 # spot_data: rows of (frame, x_coordinate, y_coordinate, ...) spot_data = picked_spots[picked_spots["group"] == spot_num] - detected_frames = np.asarray(spot_data['frame']).astype(np.int32) + detected_frames = np.asarray(spot_data["frame"]).astype(np.int32) displacements = drifts[detected_frames, :] # keep only coordinates and correct for drift - spot_locations = np.asarray(spot_data[['x', 'y']]).astype(np.int32) + displacements + spot_locations = ( + np.asarray(spot_data[["x", "y"]]).astype(np.int32) + displacements + ) # (x, y) -> (y, x) spot_locations = spot_locations[:, ::-1] diff --git a/blinx/hyper_parameters.py b/blinx/hyper_parameters.py index d7d4324..eb38e14 100644 --- a/blinx/hyper_parameters.py +++ b/blinx/hyper_parameters.py @@ -61,13 +61,31 @@ class HyperParameters: p_outlier (float, default=0.1): - a weight to account for outlier, or out of distribution intensity + a weight to account for outliers, or out of distribution intensity measurements. Occasional measurements contain extreme noise and this sets a minimum possible probability + num_outliers (int, default=20): + + the number of outlier intensities to assign constant likelihoods, + removing them from contributing towards the difference in likelihoods between counts. + i.e. the 20 frames with the highest intensities will be omited + delta_t (float, defaul=200): the exposure time of a single frame in ms + + r_e_loc / r_bg_loc / g_loc / mu_loc /sigma_loc (float, default=None): + + Mean (loc) of the prior distribution on each of the fittable parameters. If None a uniform prior is assumed. + If either loc or scale is given for a parameter, the other must be given as well. + + r_e_scale / r_bg_scale / g_scale / mu_scale /sigma_scale (float, default=None): + + Variance (scale) of the prior distribution on each of the fittable parameters. If None a uniform prior is assumed. + If either loc or scale is given for a parameter, the other must be given as well. + + """ def __init__( @@ -84,8 +102,9 @@ def __init__( max_x=None, num_x_bins=1024, p_outlier=0.1, + num_outliers=20, delta_t=200.0, - param_min_max_scale=5, # how many sigmas away from mean should the model consider + num_traces=None, r_e_loc=None, r_e_scale=None, r_bg_loc=None, @@ -107,28 +126,15 @@ def __init__( self.max_x = max_x self.num_x_bins = num_x_bins self.p_outlier = p_outlier + self.num_outliers = num_outliers self.delta_t = delta_t # priors - self.prior_locs = Parameters( - r_e=r_e_loc, - r_bg=r_bg_loc, - mu_ro=mu_loc, - sigma_ro=sigma_loc, - gain=g_loc, - p_on=None, - p_off=None, - probs_are_logits=True, + self.prior_locs = self.reshape_priors( + r_e_loc, r_bg_loc, g_loc, mu_loc, sigma_loc, num_traces ) - self.prior_scales = Parameters( - r_e=r_e_scale, - r_bg=r_bg_scale, - mu_ro=mu_scale, - gain=g_scale, - sigma_ro=sigma_scale, - p_on=None, - p_off=None, - probs_are_logits=True, + self.prior_scales = self.reshape_priors( + r_e_scale, r_bg_scale, g_scale, mu_scale, sigma_scale, num_traces ) if sum([r_e_loc is None, r_e_scale is None]) == 1: @@ -144,30 +150,28 @@ def __init__( # below is experimental # ------------------------------------------------ - def check_length(self, val, target_length): + + def _reshape(self, val, target): if val is None: return val - elif len(val) == 1: - return jnp.repeat(val, target_length) - elif len(val) > 1 and len(val) != target_length: + else: + val = jnp.asarray(val) + if val.size == 1: + return jnp.repeat(val, target) + elif val.size != target: raise RuntimeError("not enough prior values provided") + elif val.size == target: + return val - def check_prior_shapes(self, target_length): - self.prior_locs = Parameters( - r_e=self.check_length(self.prior_locs.r_e, target_length), - r_bg=self.check_length(self.prior_locs.r_bg, target_length), - mu_ro=self.check_length(self.prior_locs.mu_ro, target_length), - sigma_ro=self.check_length(self.prior_locs.sigma_ro, target_length), - gain=self.check_length(self.prior_locs.gain, target_length), - p_on=self.check_length(self.prior_locs.p_on, target_length), - p_off=self.check_length(self.prior_locs.p_off, target_length), - ) - self.prior_scales = Parameters( - r_e=self.check_length(self.prior_scales.r_e, target_length), - r_bg=self.check_length(self.prior_scales.r_bg, target_length), - mu_ro=self.check_length(self.prior_scales.mu_ro, target_length), - sigma_ro=self.check_length(self.prior_scales.sigma_ro, target_length), - gain=self.check_length(self.prior_scales.gain, target_length), - p_on=self.check_length(self.prior_scales.p_on, target_length), - p_off=self.check_length(self.prior_scales.p_off, target_length), + def reshape_priors(self, r_e, r_bg, g, mu, sigma, num_traces): + # reshape each prior to size (num_traces,) + return Parameters( + r_e=self._reshape(r_e, num_traces), + r_bg=self._reshape(r_bg, num_traces), + gain=self._reshape(g, num_traces), + mu_ro=self._reshape(mu, num_traces), + sigma_ro=self._reshape(sigma, num_traces), + p_on=None, + p_off=None, + probs_are_logits=True, ) diff --git a/blinx/trace_model.py b/blinx/trace_model.py index ff0e267..d0f6970 100644 --- a/blinx/trace_model.py +++ b/blinx/trace_model.py @@ -14,7 +14,7 @@ ) -def log_p_parameters(parameters, hyper_parameters, locs, scales): +def log_p_parameters(parameters, locs, scales): """ the prior distribution p(parameters) """ @@ -41,22 +41,7 @@ def log_p_x_parameters(trace, y, parameters, hyper_parameters, locs, scales): """ return get_trace_log_likelihood( trace, y, parameters, hyper_parameters - ) + log_p_parameters( - parameters, - hyper_parameters, - locs, - scales - # r_e_loc, - # r_e_scale, - # r_bg_loc, - # r_bg_scale, - # g_loc, - # g_scale, - # mu_loc, - # mu_scale, - # sigma_loc, - # sigma_scale, - ) + ) + log_p_parameters(parameters, locs, scales) def get_trace_log_likelihood(trace, y, parameters, hyper_parameters): @@ -114,7 +99,15 @@ def get_trace_log_likelihood(trace, y, parameters, hyper_parameters): in_axes=(None, None, 0, None, None, None, None, None, None), )(x_left, x_right, zs, r_e, r_bg, mu_ro, sigma_ro, gain, hyper_parameters) - return get_measurement_log_likelihood(p_measurement.T, p_initial, p_transition) + # assign constant likelihood to "outlier" frames + outliers = jax.lax.top_k(x_right, hyper_parameters.num_outliers) + trimmed_p_measurement = p_measurement.at[:, outliers[1]].set( + 1 / hyper_parameters.num_x_bins + ) + + return get_measurement_log_likelihood( + trimmed_p_measurement.T, p_initial, p_transition + ) def single_optimal_trace(trace, y, parameters, hyper_parameters): diff --git a/tests/test_priors.py b/tests/test_priors.py new file mode 100644 index 0000000..539a47b --- /dev/null +++ b/tests/test_priors.py @@ -0,0 +1,36 @@ +import jax.numpy as jnp +import numpy as np +import pytest +from blinx.hyper_parameters import HyperParameters + + +def test_priors(): + hps = HyperParameters( + r_e_loc=None, # No priors for that parameter + r_e_scale=None, # if either loc or scale is None other must be as well + r_bg_loc=1, # does it convert an int to the correct shape jax array + r_bg_scale=1, + g_loc=[1], # does it convert a list to the correct shape jax array + g_scale=[1], + mu_loc=0.75, # does it convert a float to the correct shape jax array + mu_scale=0.1, + sigma_loc=jnp.array(1), # if the input is already a jax array + sigma_scale=1, + num_traces=5, + ) + + # check that None priors remain None after re-shaping + np.testing.assert_equal(hps.prior_locs.r_e, None) + np.testing.assert_equal(hps.prior_scales.r_e, None) + + # check that priors have the correct shape + np.testing.assert_equal(hps.prior_locs.r_bg.shape, (5,)) + np.testing.assert_equal(hps.prior_scales.r_bg.shape, (5,)) + np.testing.assert_equal(hps.prior_locs.gain.shape, (5,)) + np.testing.assert_equal(hps.prior_scales.gain.shape, (5,)) + np.testing.assert_equal(hps.prior_locs.mu_ro.shape, (5,)) + np.testing.assert_equal(hps.prior_scales.mu_ro.shape, (5,)) + np.testing.assert_equal(hps.prior_locs.sigma_ro.shape, (5,)) + np.testing.assert_equal(hps.prior_scales.sigma_ro.shape, (5,)) + + return