forked from qnano/drift-estimation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdme_example.py
93 lines (73 loc) · 3.1 KB
/
dme_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# -*- coding: utf-8 -*-
"""
3D drift estimation example
Units for X,Y,Z are pixels, pixels, and microns resp.
"""
import numpy as np
import matplotlib.pyplot as plt
from dme.dme import dme_estimate
from dme.rcc import rcc3D
from dme.native_api import NativeAPI
# Need to have CUDA installed
use_cuda=True
# Simulate an SMLM dataset in 3D with blinking molecules
def smlm_simulation(
drift_trace,
fov_width, # field of view size in pixels
loc_error, # localization error XYZ
n_sites, # number of locations where molecules blink on and off
n_frames,
on_prob = 0.1, # probability of a binding site generating a localization in a frame
):
"""
localization error is set to 20nm XY and 50nm Z precision
(assumping Z coordinates are in um and XY are in pixels)
"""
# typical 2D acquisition with small Z range and large XY range
binding_sites = np.random.uniform([0,0,-1], [fov_width,fov_width,1], size=(n_sites,3))
localizations = []
framenum = []
for i in range(n_frames):
on = np.random.binomial(1, on_prob, size=n_sites).astype(np.bool)
locs = binding_sites[on]*1
# add localization error
locs += drift_trace[i] + np.random.normal(0, loc_error, size=locs.shape)
framenum.append(np.ones(len(locs),dtype=np.int32)*i)
localizations.append(locs)
return np.concatenate(localizations), np.concatenate(framenum)
n_frames = 2000
fov_width = 200
drift_mean = (0.001,0,0)
drift_stdev = (0.02,0.02,0.02)
loc_error = np.array((0.1,0.1,0.03)) # pixel, pixel, um
# Ground truth drift trace
drift_trace = np.cumsum(np.random.normal(drift_mean, drift_stdev, size=(n_frames,3)), 0)
drift_trace -= drift_trace.mean(0)
localizations, framenum = smlm_simulation(drift_trace, fov_width, loc_error,
n_sites=200,
n_frames=n_frames)
print(f"Total localizations: {len(localizations)}")
crlb = np.ones(localizations.shape) * np.array(loc_error)[None]
estimated_drift,_ = dme_estimate(localizations, framenum,
crlb,
framesperbin = 1, # note that small frames per bin use many more iterations
imgshape=[fov_width, fov_width],
coarseFramesPerBin=200,
coarseSigma=[0.2,0.2,0.2], # run a coarse drift correction with large Z sigma
useCuda=use_cuda,
useDebugLibrary=False)
with NativeAPI(use_cuda) as dll:
estimated_drift_rcc = rcc3D(localizations, framenum, timebins=10, zoom=1, dll=dll)
rmsd = np.sqrt(np.mean((estimated_drift-drift_trace)**2, 0))
print(f"RMSD of drift estimate compared to true drift: {rmsd}")
fig,ax=plt.subplots(3, figsize=(7,6))
for i in range(3):
ax[i].plot(drift_trace[:,i],label='True drift')
ax[i].plot(estimated_drift[:,i]+0.2,label='Estimated drift (DME)')
ax[i].plot(estimated_drift_rcc[:,i]-0.2,label='Estimated drift (RCC)')
ax[i].set_title(['x', 'y', 'z'][i])
unit = ['px', 'px', 'um'][i]
ax[i].set_ylabel(f'Drift [{unit}]')
ax[0].legend()
plt.tight_layout()
plt.show()