From fa519b1b3c846ec1375a965192d6f45ca092b3c4 Mon Sep 17 00:00:00 2001 From: Ricardo Righetto Date: Sun, 19 Feb 2017 13:11:23 +0100 Subject: [PATCH] Progress on CTF correction and filtering routines --- scripts/proc/SPR_ExtractParticles.py.new | 9 +-- scripts/proc/focus_ctf.py | 37 ++++++------ scripts/proc/focus_utilities.py | 71 ++++++++++++++++-------- 3 files changed, 73 insertions(+), 44 deletions(-) diff --git a/scripts/proc/SPR_ExtractParticles.py.new b/scripts/proc/SPR_ExtractParticles.py.new index 4a0cd07c..fb59a895 100644 --- a/scripts/proc/SPR_ExtractParticles.py.new +++ b/scripts/proc/SPR_ExtractParticles.py.new @@ -20,7 +20,8 @@ import matplotlib.cm as cm import matplotlib.patches as patches # import EMAN2 as e2 import ioMRC -import focus_utilities +import focus_utilities as util +import focus_ctf as CTF def main(): @@ -348,7 +349,7 @@ def main(): if normalize_box: # box = NormalizeStack([box], sigma)[0] - box = focus_utilities.NormalizeImg( box, std=sigma ) + box = util.NormalizeImg( box, std=sigma ) if calculate_defocus_tilted and not ctfcor: @@ -407,7 +408,7 @@ def main(): if normalize_box: # boxctfcor = NormalizeStack([boxctfcor], sigma)[0] - boxctfcor = focus_utilities.NormalizeImg( boxctfcor, std=sigma ) + boxctfcor = util.NormalizeImg( boxctfcor, std=sigma ) # Write image to the particle stack: # if idx == 0: @@ -431,7 +432,7 @@ def main(): if normalize_box: # boxctfcor = NormalizeStack([boxctfcor], sigma)[0] - boxctfcor = focus_utilities.NormalizeImg( boxctfcor, std=sigma ) + boxctfcor = util.NormalizeImg( boxctfcor, std=sigma ) # Write image to the particle stack: # if idx == 0: diff --git a/scripts/proc/focus_ctf.py b/scripts/proc/focus_ctf.py index 03100eb5..a5e104c8 100644 --- a/scripts/proc/focus_ctf.py +++ b/scripts/proc/focus_ctf.py @@ -21,28 +21,30 @@ def CTF( imsize = [100, 100], DF1 = 1000.0, DF2 = None, AST = 0.0, WGH = 0.10, C DF2 = DF1 - # NOTATION BELOW IS INVERTED DUE TO NUMPY CONVENTION: - df1 = DF2 - df2 = DF1 - ast = AST * np.pi / 180.0 + AST *= np.pi / 180.0 WL = ElectronWavelength( kV ) w1 = np.sqrt( 1 - WGH*WGH ) w2 = WGH - rmesh,amesh = focus_utilities.RadialIndices( imsize, rounding=True ) + rmesh,amesh = focus_utilities.RadialIndices( imsize, normalize=True ) + + rmesh = rmesh / apix - rmesh = rmesh / ( np.min( imsize ) * apix ) + rmesh2 = rmesh*rmesh + # NOTATION BELOW IS INVERTED DUE TO NUMPY CONVENTION: + DF = 0.5 * (DF1 + DF2 + (DF2 - DF1) * np.cos( 2.0 * (amesh - AST) ) ) - ast = np.radians( ast ) + import warnings + with warnings.catch_warnings(): + warnings.filterwarnings( "ignore", category=RuntimeWarning ) - df = 0.5 * (df1 + df2 + (df1 - df2) * np.cos( 2 * (amesh - ast) ) ) + Xr = np.pi * WL * rmesh2 * ( DF - 1 / (2 * WL*WL * rmesh2 * Cs) ) - Xr = np.pi * WL * rmesh*rmesh * ( df - 1 / (2 * WL*WL * rmesh*rmesh * Cs) ) Xr = np.nan_to_num( Xr ) CTFim = -w1 * np.sin( Xr ) - w2 * np.cos( Xr ) - CTFim = CTFim * np.exp( -B * ( rmesh*rmesh ) / 4 ) + CTFim = CTFim * np.exp( -B * ( rmesh2 ) / 4 ) return CTFim @@ -51,7 +53,7 @@ def ElectronWavelength( kV = 300.0 ): kV *= 1000.0 # ensure Kilovolts for below formula return 12.26 / np.sqrt( kV + 0.9785 * kV*kV / ( 10.0**6.0 ) ) -def CorrectCTF( img, DF1 = 1000.0, DF2 = None, AST = 0.0, WGH = 0.10, invert = False, Cs = 2.7, kV = 300.0, apix = 1.0, B = 0.0, ctftype = 0, C = 1.0, return_ctf = False ): +def CorrectCTF( img, DF1 = 1000.0, DF2 = None, AST = 0.0, WGH = 0.10, invert_contrast = False, Cs = 2.7, kV = 300.0, apix = 1.0, B = 0.0, ctftype = 0, C = 1.0, return_ctf = False ): # Applies CTF correction to image # Type can be one of the following: # 0 - Phase-flipping only @@ -59,13 +61,14 @@ def CorrectCTF( img, DF1 = 1000.0, DF2 = None, AST = 0.0, WGH = 0.10, invert = F # 2 - Wiener filtering with Wiener constant C # By default image should have dark proteins and bright background, otherwise set invert=True. - if invert: + imsize = img.shape - img *= -1.0 + # Direct CTF correction would invert the image contrast. By default we don't do that, hence the negative sign: + CTFim = -CTF( imsize, DF1, DF2, AST, WGH, Cs, kV, apix, B ) - imsize = img.shape + if invert_contrast: - CTFim = CTF( imsize, DF1, DF2, AST, WGH, Cs, kV, apix, B ) + CTFim *= -1.0 FT = fft.fftshift( fft.fftn( img ) ) @@ -91,11 +94,11 @@ def CorrectCTF( img, DF1 = 1000.0, DF2 = None, AST = 0.0, WGH = 0.10, invert = F if return_ctf: - return CTFcor, CTFim + return CTFcor.real, CTFim else: - return CTFcor + return CTFcor.real diff --git a/scripts/proc/focus_utilities.py b/scripts/proc/focus_utilities.py index 0edca15b..757c0a11 100644 --- a/scripts/proc/focus_utilities.py +++ b/scripts/proc/focus_utilities.py @@ -11,12 +11,13 @@ import numpy.fft as fft -def RadialIndices( imsize = [100, 100], rounding=True ): +def RadialIndices( imsize = [100, 100], rounding=False, normalize=False ): # Returns radius and angles for each pixel (or voxel) in a 2D image or 3D volume of shape = imsize # For 2D returns the angle with the horizontal x- axis # For 3D returns the angle with the horizontal x,y plane # If imsize is a scalar, will default to 2D. -# Rounding is to ensure "perfect" radial symmetry, desirable in most applications. +# Rounding is to ensure "perfect" radial symmetry, desirable for applications in real space. +# Norm will normalize the radius to values between 0 and 1. if np.isscalar(imsize): @@ -26,36 +27,58 @@ def RadialIndices( imsize = [100, 100], rounding=True ): raise ValueError ( "Object should not have dimensions larger than 3: len(imsize) = %d " % len(imsize)) - if len(imsize) == 2: + import warnings + with warnings.catch_warnings(): + warnings.filterwarnings( "ignore", category=RuntimeWarning ) - [xmesh, ymesh] = np.mgrid[-imsize[0]/2:imsize[0]/2, -imsize[1]/2:imsize[1]/2] + if len(imsize) == 2: - rmesh = np.sqrt( xmesh*xmesh + ymesh*ymesh ) - amesh = np.arctan( ymesh / xmesh ) - amesh = np.nan_to_num( amesh ) + if normalize: - return rmesh, amesh + [xmesh, ymesh] = np.mgrid[-imsize[0]/2:imsize[0]/2, -imsize[1]/2:imsize[1]/2].astype(np.float) - else: + xmesh /= imsize[0] + ymesh /= imsize[1] + + else: + + [xmesh, ymesh] = np.mgrid[-imsize[0]/2:imsize[0]/2, -imsize[1]/2:imsize[1]/2] + # xmesh += 1 + # ymesh += 1 + + rmesh = np.sqrt( xmesh*xmesh + ymesh*ymesh ) + amesh = np.arctan( ymesh / xmesh ) + + else: + + if normalize: + + [xmesh, ymesh, zmesh] = np.mgrid[-imsize[0]/2:imsize[0]/2, -imsize[1]/2:imsize[1]/2, -imsize[2]/2:imsize[2]/2].astype(np.float) + + xmesh /= imsize[0] + ymesh /= imsize[1] + zmesh /= imsize[2] + + else: - [xmesh, ymesh, zmesh] = np.mgrid[-imsize[0]/2:imsize[0]/2, -imsize[1]/2:imsize[1]/2, -imsize[2]/2:imsize[2]/2] + [xmesh, ymesh, zmesh] = np.mgrid[-imsize[0]/2:imsize[0]/2, -imsize[1]/2:imsize[1]/2, -imsize[2]/2:imsize[2]/2] - rmesh = np.sqrt( xmesh*xmesh + ymesh*ymesh + zmesh*zmesh ) - amesh = np.arccos( zmesh / rmesh ) - amesh[imsize[0]/2, imsize[1]/2, imsize[2]/2] = 0.0 + rmesh = np.sqrt( xmesh*xmesh + ymesh*ymesh + zmesh*zmesh ) + amesh = np.arccos( zmesh / rmesh ) + # amesh[imsize[0]/2, imsize[1]/2, imsize[2]/2] = 0.0 - if rounding: + if rounding and not normalize: - return np.round(rmesh), amesh + return np.round( rmesh ), np.nan_to_num( amesh ) else: - return rmesh,amesh + return rmesh, np.nan_to_num( amesh ) def RotationalAverage( img ): # Compute the rotational average of a 2D image or 3D volume - rmesh = RadialIndices( img.shape )[0] + rmesh = RadialIndices( img.shape, rounding=True )[0] rotavg = np.zeros( img.shape ) @@ -99,7 +122,7 @@ def SoftMask( imsize = [100, 100], radius = 0.5, width = 6.0 ): rii = radius + width/2 rih = radius - width/2 - rmesh = RadialIndices( imsize )[0] + rmesh = RadialIndices( imsize, rounding=True )[0] mask = np.zeros( imsize ) @@ -117,7 +140,8 @@ def SoftMask( imsize = [100, 100], radius = 0.5, width = 6.0 ): def FilterGauss( img, apix=1.0, lp=-1, hp=-1, return_filter=False ): # Gaussian band-pass filtering of images. - rmesh = RadialIndices( img.shape )[0] / ( np.min( img.shape ) * apix ) + rmesh = RadialIndices( img.shape, normalize=True )[0] / apix + rmesh2 = rmesh*rmesh if lp <= 0.0: @@ -125,7 +149,7 @@ def FilterGauss( img, apix=1.0, lp=-1, hp=-1, return_filter=False ): else: - lowpass = np.exp( - lp ** 2 * rmesh ** 2 / 2 ) + lowpass = np.exp( - lp ** 2 * rmesh2 / 2 ) if hp <= 0.0: @@ -133,7 +157,7 @@ def FilterGauss( img, apix=1.0, lp=-1, hp=-1, return_filter=False ): else: - highpass = 1.0 - np.exp( - hp ** 2 * rmesh ** 2 / 2 ) + highpass = 1.0 - np.exp( - hp ** 2 * rmesh2 / 2 ) bandpass = lowpass * highpass @@ -152,9 +176,10 @@ def FilterGauss( img, apix=1.0, lp=-1, hp=-1, return_filter=False ): def FilterBfactor( img, apix=1.0, B=0.0, return_filter=False ): # Applies a B-factor to images. B can be positive or negative. - rmesh = RadialIndices( img.shape )[0] / ( np.sqrt( np.prod( img.shape )) * apix ) + rmesh = RadialIndices( img.shape, normalize=True )[0] / apix + rmesh2 = rmesh*rmesh - bfac = np.exp( - (B * rmesh ** 2 ) / 4 ) + bfac = np.exp( - (B * rmesh2 ) / 4 ) ft = fft.fftshift( fft.fftn( img ) )