From dd9faa2dc61ca9a9aa721a7a9c3d11b33ada9121 Mon Sep 17 00:00:00 2001 From: Tara Prasad Mishra Date: Mon, 6 Feb 2023 10:19:41 -0800 Subject: [PATCH 01/25] changes to crystal phase --- .../io/datastructure/py4dstem/datacube_fns.py | 2 +- py4DSTEM/process/diffraction/crystal_phase.py | 144 +++++++++++++----- 2 files changed, 103 insertions(+), 43 deletions(-) diff --git a/py4DSTEM/io/datastructure/py4dstem/datacube_fns.py b/py4DSTEM/io/datastructure/py4dstem/datacube_fns.py index fbe2f38fb..98532093e 100644 --- a/py4DSTEM/io/datastructure/py4dstem/datacube_fns.py +++ b/py4DSTEM/io/datastructure/py4dstem/datacube_fns.py @@ -1075,7 +1075,7 @@ def find_Bragg_disks( ml_num_attempts = ml_num_attempts, ml_batch_size = ml_batch_size, - _qt_progress_bar = _qt_progress_bar, + # _qt_progress_bar = _qt_progress_bar, ) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index b51380cfa..44a97891d 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -59,8 +59,8 @@ def plot_all_phase_maps( phase_maps.append(self.orientation_maps[m].corr[:,:,index] / corr_sum) show_image_grid(lambda i:phase_maps[i], 1, len(phase_maps), cmap = 'inferno') return - - def plot_phase_map( + # Plot the correlation maps. Plots the maps with the best correlation scores. + def plot_correlation_map( self, index = 0, cmap = None @@ -159,6 +159,62 @@ def quantify_phase( return TypeError('pointlistarray must be of type pointlistarray.') return + def compare_intensitylists( + self, + masterpointlist, + masterintensitylist, + bragg_peaks_fit, + tolerance_distance, + intensity_power + ): + """ + Function to compare the exisiting point list enteries with the array. + """ + # Add a column of zeros in the master intensity list to make way for the new fitted intensity list + zeros = np.zeros((masterintensitylist.shape[0], 1)) + masterintensitylist = np.concatenate((masterintensitylist,zeros),axis=1) + + # Compare with the exisiting bragg_peaks_fit with the masterpointlist. + # Make a temporary intensity list to store the intensities of the the bragg_peaks_fit. + + if intensity_power == 0: + temporary_pl_intensities = np.ones(bragg_peaks_fit['intensity'].shape) + else: + temporary_pl_intensities = bragg_peaks_fit['intensity']**intensity_power + + + # Go through the bragg_peaks_fit to find if the master list has an entry or not. + for d in range(bragg_peaks_fit['qx'].shape[0]): + distances = [] + # Making a numpy array of the fitted bragg peak + bragg_peak_point=np.array([bragg_peaks_fit['qx'][d],bragg_peaks_fit['qy'][d]]) + for p in range(masterpointlist.shape[0]): + distances.append(np.linalg.norm(bragg_peak_point-masterpointlist[p]) + ) + ind = np.where(distances == np.min(distances))[0][0] + # Potentially loop over to find the best tolerance distance. + if distances[ind] <= tolerance_distance: + columns_masterintensitylist = len(masterintensitylist[0]) + masterintensitylist[ind][columns_masterintensitylist-1]=temporary_pl_intensities[d] + + else: + continue + ## The point list is not in the mega list of point list so the point list last row of masterpointlist + masterpointlist = np.vstack((masterpointlist,bragg_peak_point)) + ## Add a row to the intensity list such that all the remaining intensity lists should be 0 but only the new bragg intensity list is non zero but intensity power + new_intensity_list_row = np.zeros((1, masterintensitylist.shape[1]-1)) + new_intensity_list_row = np.append(new_intensity_list_row, [temporary_pl_intensities[d]]) + new_intensity_list_row = new_intensity_list_row.reshape((1,-1)) + masterintensitylist = np.concatenate((masterintensitylist,new_intensity_list_row),axis=0) + + + + + return masterpointlist,masterintensitylist + + + + def quantify_phase_pointlist( self, pointlistarray, @@ -191,8 +247,10 @@ def quantify_phase_pointlist( # Things to add: # 1. Better cost for distance from peaks in pointlists # 2. Iterate through multiple tolerance_distance values to find best value. Cost function residuals, or something else? + # 3. Make a flag variable for the experimental dataset which turns 1 if it is encountered in the simulated dataset. pointlist = pointlistarray.get_pointlist(position[0], position[1]) + ## Remove the central beam pl_mask = np.where((pointlist['qx'] == 0) & (pointlist['qy'] == 0), 1, 0) pointlist.remove(pl_mask) # False Negatives (exp peak with no match in crystal instances) will appear here, already coded in @@ -201,74 +259,76 @@ def quantify_phase_pointlist( pl_intensities = np.ones(pointlist['intensity'].shape) else: pl_intensities = pointlist['intensity']**intensity_power + #Prepare matches for modeling pointlist_peak_matches = [] crystal_identity = [] - + ## Initialize the megapointlist and master intensity list with the experimental intensity + masterpointlist = np.column_stack((pointlist['qx'],pointlist['qy'])) + masterintensitylist = pl_intensities + ## Convert masterintensitylist to a 2D array + masterintensitylist = np.array(masterintensitylist, ndmin=2).T + ## Loop over the number of crystals. for c in range(len(self.crystals)): + ## Loop over the number of num matches which is the number of orientation candidates. + # This value of num matches was supplied when the orientation map was created. for m in range(self.orientation_maps[c].num_matches): + # Set crystal identity crystal_identity.append([c,m]) - phase_peak_match_intensities = np.zeros((pointlist['intensity'].shape)) + # For a given crystal class generate a diffraction pattern given a orientation crystal and given num match bragg_peaks_fit = self.crystals[c].generate_diffraction_pattern( self.orientation_maps[c].get_orientation(position[0], position[1]), ind_orientation = m ) - #Find the best match peak within tolerance_distance and add value in the right position - for d in range(pointlist['qx'].shape[0]): - distances = [] - for p in range(bragg_peaks_fit['qx'].shape[0]): - distances.append( - np.sqrt((pointlist['qx'][d] - bragg_peaks_fit['qx'][p])**2 + - (pointlist['qy'][d]-bragg_peaks_fit['qy'][p])**2) - ) - ind = np.where(distances == np.min(distances))[0][0] - - #Potentially for-loop over multiple values for 'tolerance_distance' to find best tolerance_distance value - if distances[ind] <= tolerance_distance: - ## Somewhere in this if statement is probably where better distances from the peak should be coded in - if intensity_power == 0: #This could potentially be a different intensity_power arg - phase_peak_match_intensities[d] = 1**((tolerance_distance-distances[ind])/tolerance_distance) - else: - phase_peak_match_intensities[d] = bragg_peaks_fit['intensity'][ind]**((tolerance_distance-distances[ind])/tolerance_distance) - else: - ## This is probably where the false positives (peaks in crystal but not in experiment) should be handled - continue - - pointlist_peak_matches.append(phase_peak_match_intensities) - pointlist_peak_intensity_matches = np.dstack(pointlist_peak_matches) - pointlist_peak_intensity_matches = pointlist_peak_intensity_matches.reshape( - pl_intensities.shape[0], - pointlist_peak_intensity_matches.shape[-1] - ) - - if len(pointlist['qx']) > 0: + # Check if there are any experimental intensity observed at all. + if len(masterpointlist !=0): + # Send this bragg_peaks_fit to the compare function to be compared with mega point list and master intensity list. + masterpointlist,masterintensitylist=self.compare_intensitylists(masterpointlist,masterintensitylist,bragg_peaks_fit,tolerance_distance,intensity_power) + else: + continue + + + ### The intensity and point lists are accumulated in the masterintensitylist and masterpointlist. + # The first column of the intensity lists are the observed experimental intensities. + observed_intensities = masterintensitylist[:,0] + expected_intensities = masterintensitylist[:,1:] + + if len(observed_intensities) > 0: if mask_peaks is not None: for i in range(len(mask_peaks)): if mask_peaks[i] == None: continue - inds_mask = np.where(pointlist_peak_intensity_matches[:,mask_peaks[i]] != 0)[0] + inds_mask = np.where(expected_intensities[:,mask_peaks[i]] != 0)[0] for mask in range(len(inds_mask)): - pointlist_peak_intensity_matches[inds_mask[mask],i] = 0 - + expected_intensities[inds_mask[mask],i] = 0 if method == 'nnls': phase_weights, phase_residuals = nnls( - pointlist_peak_intensity_matches, - pl_intensities + expected_intensities, + observed_intensities ) elif method == 'lstsq': phase_weights, phase_residuals, rank, singluar_vals = lstsq( - pointlist_peak_intensity_matches, - pl_intensities, + expected_intensities, + observed_intensities, rcond = -1 ) phase_residuals = np.sum(phase_residuals) else: raise ValueError(method + ' Not yet implemented. Try nnls or lstsq.') else: - phase_weights = np.zeros((pointlist_peak_intensity_matches.shape[1],)) + # Find the number of expected phases + number_expected_phases=0 + for c in range(len(self.crystals)): + for m in range(self.orientation_maps[c].num_matches): + number_expected_phases+=1 + + # If there are no diffraction patterns + phase_weights = np.zeros(number_expected_phases) phase_residuals = np.NaN - return pointlist_peak_intensity_matches, phase_weights, phase_residuals, crystal_identity + + return expected_intensities, phase_weights, phase_residuals, crystal_identity + # def plot_peak_matches( # self, From 35556a88c3f37f5d8b6ef08961e2e6cd19ac99ad Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 6 Feb 2023 15:47:41 -0500 Subject: [PATCH 02/25] adds test_Crystal.py skeleton --- py4DSTEM/test/test_classes/test_crystal.py | 41 ++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 py4DSTEM/test/test_classes/test_crystal.py diff --git a/py4DSTEM/test/test_classes/test_crystal.py b/py4DSTEM/test/test_classes/test_crystal.py new file mode 100644 index 000000000..e8d644675 --- /dev/null +++ b/py4DSTEM/test/test_classes/test_crystal.py @@ -0,0 +1,41 @@ +import numpy as np +#from py4DSTEM.classes import ( +# Crystal +#) + + + + +class TestCrystal: + + def setup_cls(self): + pass + + def teardown_cls(self): + pass + + def setup_method(self): + pass + + def teardown_method(self): + pass + + + + def test_Crystal(self): + + #crystal = Crystal( **args ) + #assert(isinstance(crystal,Crystal)) + + pass + + + def test_Crystal2(self): + + #crystal = Crystal( **args ) + #assert(isinstance(crystal,Crystal)) + + pass + + + From 8132882a2102b69e452c1ef176f7c32e03674fa1 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 6 Feb 2023 18:07:03 -0500 Subject: [PATCH 03/25] allows mp-api >= 0.24.1 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d6b1558f9..560e98b61 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ extras_require={ 'ipyparallel': ['ipyparallel >= 6.2.4', 'dill >= 0.3.3'], 'cuda': ['cupy'], - 'acom': ['pymatgen >= 2022', 'mp-api == 0.24.1'], + 'acom': ['pymatgen >= 2022', 'mp-api >= 0.24.1'], 'aiml': ['tensorflow == 2.4.1','tensorflow-addons <= 0.14.0','crystal4D'], 'aiml-cuda': ['tensorflow == 2.4.1','tensorflow-addons <= 0.14.0','crystal4D','cupy'], 'numba': ['numba >= 0.49.1'] From 4028c0d967cfeb467fbb964fc8d8d34da82fa4c9 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 6 Feb 2023 18:26:40 -0500 Subject: [PATCH 04/25] add test data, test .from_CIF --- py4DSTEM/test/test_classes/test_crystal.py | 20 ++++++++++++------- .../test_nonnativefilereaders/test_dm.py | 17 ++++++++++++++++ 2 files changed, 30 insertions(+), 7 deletions(-) create mode 100644 py4DSTEM/test/test_io/test_nonnativefilereaders/test_dm.py diff --git a/py4DSTEM/test/test_classes/test_crystal.py b/py4DSTEM/test/test_classes/test_crystal.py index e8d644675..93ea745f7 100644 --- a/py4DSTEM/test/test_classes/test_crystal.py +++ b/py4DSTEM/test/test_classes/test_crystal.py @@ -1,7 +1,13 @@ -import numpy as np -#from py4DSTEM.classes import ( -# Crystal -#) +from py4DSTEM.process.diffraction import Crystal +from py4DSTEM import _TESTPATH +from os.path import join + + +# Set filepaths +filepath_braggpeaks = join(_TESTPATH, "crystal/braggpeaks_cali.h5") +filepath_cif1 = join(_TESTPATH, "crystal/LCO.cif") +filepath_cif2 = join(_TESTPATH, "crystal/Li2MnO3.cif") +filepath_cif3 = join(_TESTPATH, "crystal/LiMn2O4.cif") @@ -22,10 +28,10 @@ def teardown_method(self): - def test_Crystal(self): + def test_instantiation_from_cif(self): - #crystal = Crystal( **args ) - #assert(isinstance(crystal,Crystal)) + crystal = Crystal.from_CIF(filepath_cif1) + assert(isinstance(crystal,Crystal)) pass diff --git a/py4DSTEM/test/test_io/test_nonnativefilereaders/test_dm.py b/py4DSTEM/test/test_io/test_nonnativefilereaders/test_dm.py new file mode 100644 index 000000000..ffc9c0cb6 --- /dev/null +++ b/py4DSTEM/test/test_io/test_nonnativefilereaders/test_dm.py @@ -0,0 +1,17 @@ +import py4DSTEM +from os.path import join + + +# Set filepaths +filepath_dm = join(py4DSTEM._TESTPATH, "small_dm3.dm3") + + +def test_dmfile_3Darray(): + data = py4DSTEM.import_file( filepath_dm ) + assert isinstance(data, py4DSTEM.emd.Array) + + +# TODO +# def test_dmfile_4Darray(): +# def test_dmfile_multiple_datablocks(): + From 70e0af8317e76d2be15134d54744c1b2656254bd Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 6 Feb 2023 18:28:30 -0500 Subject: [PATCH 05/25] updates --- .../test_nonnativefilereaders/test_dm.py | 17 ----------------- 1 file changed, 17 deletions(-) delete mode 100644 py4DSTEM/test/test_io/test_nonnativefilereaders/test_dm.py diff --git a/py4DSTEM/test/test_io/test_nonnativefilereaders/test_dm.py b/py4DSTEM/test/test_io/test_nonnativefilereaders/test_dm.py deleted file mode 100644 index ffc9c0cb6..000000000 --- a/py4DSTEM/test/test_io/test_nonnativefilereaders/test_dm.py +++ /dev/null @@ -1,17 +0,0 @@ -import py4DSTEM -from os.path import join - - -# Set filepaths -filepath_dm = join(py4DSTEM._TESTPATH, "small_dm3.dm3") - - -def test_dmfile_3Darray(): - data = py4DSTEM.import_file( filepath_dm ) - assert isinstance(data, py4DSTEM.emd.Array) - - -# TODO -# def test_dmfile_4Darray(): -# def test_dmfile_multiple_datablocks(): - From 0e69796099989674c7ec545a40eb0d9176ad9ce9 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 6 Feb 2023 18:35:14 -0500 Subject: [PATCH 06/25] updates --- py4DSTEM/test/test_classes/test_crystal.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/py4DSTEM/test/test_classes/test_crystal.py b/py4DSTEM/test/test_classes/test_crystal.py index 93ea745f7..caf5ce258 100644 --- a/py4DSTEM/test/test_classes/test_crystal.py +++ b/py4DSTEM/test/test_classes/test_crystal.py @@ -15,6 +15,11 @@ class TestCrystal: def setup_cls(self): + self.braggpeaks = read( + filepath_braggpeaks, + data_id='braggpeaks_cal_raw' + ) + self.crystal = Crystal.from_CIF(filepath_cif1) pass def teardown_cls(self): @@ -33,15 +38,16 @@ def test_instantiation_from_cif(self): crystal = Crystal.from_CIF(filepath_cif1) assert(isinstance(crystal,Crystal)) - pass + def test_generate_diffraction_pattern(self): + + self.crystal.generate_diffraction_pattern( + zone_axis_lattice = [1,1,2], + sigma_excitation_error = 0.2 + ) - def test_Crystal2(self): - #crystal = Crystal( **args ) - #assert(isinstance(crystal,Crystal)) - pass From 5c55c5be3b9b2c470dce97e27a1179eb557e746a Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 6 Feb 2023 19:14:43 -0500 Subject: [PATCH 07/25] crystal + crystal phase tests written + passing --- py4DSTEM/test/test_classes/test_crystal.py | 150 ++++++++++++++++++++- 1 file changed, 143 insertions(+), 7 deletions(-) diff --git a/py4DSTEM/test/test_classes/test_crystal.py b/py4DSTEM/test/test_classes/test_crystal.py index caf5ce258..4f49c558b 100644 --- a/py4DSTEM/test/test_classes/test_crystal.py +++ b/py4DSTEM/test/test_classes/test_crystal.py @@ -1,28 +1,49 @@ from py4DSTEM.process.diffraction import Crystal -from py4DSTEM import _TESTPATH +from py4DSTEM.process.diffraction import Crystal_Phase as CrystalPhase +from py4DSTEM import _TESTPATH,read from os.path import join # Set filepaths -filepath_braggpeaks = join(_TESTPATH, "crystal/braggpeaks_cali.h5") +filepath_braggpeaks = join(_TESTPATH, "crystal/braggdisks_cali.h5") filepath_cif1 = join(_TESTPATH, "crystal/LCO.cif") -filepath_cif2 = join(_TESTPATH, "crystal/Li2MnO3.cif") -filepath_cif3 = join(_TESTPATH, "crystal/LiMn2O4.cif") +filepath_cif2 = join(_TESTPATH, "crystal/LiMn2O4.cif") class TestCrystal: - def setup_cls(self): + def setup_class(self): + + # get bragg peaks self.braggpeaks = read( filepath_braggpeaks, data_id='braggpeaks_cal_raw' ) + + # make a Crystal self.crystal = Crystal.from_CIF(filepath_cif1) - pass - def teardown_cls(self): + # get structure factors + self.q, self.inten = self.crystal.calculate_structure_factors( + k_max = 1.7, + tol_structure_factor = 0.1, + return_intensities = True + ) + + # set up the orientation plan + self.crystal.orientation_plan( + angle_step_zone_axis=1.0, + angle_step_in_plane=5.0, + accel_voltage=300e3, + zone_axis_range='fiber', + fiber_axis=self.crystal.hexagonal_to_lattice([1,0,-1,0]), + fiber_angles=[5,90], + intensity_power=2.5, + ) + + def teardown_class(self): pass def setup_method(self): @@ -46,8 +67,123 @@ def test_generate_diffraction_pattern(self): sigma_excitation_error = 0.2 ) + def test_match_single_pattern(self): + + xind,yind = 30,40 + + # match the pattern + orientation = self.crystal.match_single_pattern( + self.braggpeaks[xind,yind], + num_matches_return=1, + verbose=True + ) + + # compute the predicted peaks at this orientation + braggpeaks_fit = self.crystal.generate_diffraction_pattern( + orientation, + sigma_excitation_error=0.02 + ) + + + def test_match_orientations(self): + + orientation_map = self.crystal.match_orientations( + self.braggpeaks + ) + + + + + + +class TestPhaseMapping: + + def setup_class(self): + + # get bragg peaks + self.braggpeaks = read( + filepath_braggpeaks, + data_id='braggpeaks_cal_raw' + ) + + # make Crystals + self.crystal1 = Crystal.from_CIF(filepath_cif1) + self.crystal2 = Crystal.from_CIF(filepath_cif2) + + # get structure factors + self.crystal1.calculate_structure_factors( + k_max = 1.7, + tol_structure_factor = 0.1, + return_intensities = False + ) + self.crystal2.calculate_structure_factors( + k_max = 1.7, + tol_structure_factor = 0.1, + return_intensities = False + ) + + # set up orientation plans + self.crystal1.orientation_plan( + angle_step_zone_axis=1.0, + angle_step_in_plane=5.0, + accel_voltage=300e3, + zone_axis_range='fiber', + fiber_axis=self.crystal1.hexagonal_to_lattice([1,0,-1,0]), + fiber_angles=[5,90], + intensity_power=2.5, + ) + self.crystal2.orientation_plan( + angle_step_zone_axis=1.0, + angle_step_in_plane=5.0, + accel_voltage=300e3, + zone_axis_range='fiber', + fiber_axis=[1,1,2], + fiber_angles=[5,180], + intensity_power=2.5, + ) + + # get orientation maps + self.orientation_map1 = self.crystal1.match_orientations( + self.braggpeaks + ) + self.orientation_map2 = self.crystal2.match_orientations( + self.braggpeaks + ) + + def teardown_class(self): + pass + + def setup_method(self): + pass + + def teardown_method(self): + pass + def test_crystal_phase(self): + + # make the CrystalPhase instance + self.crystal_phase = CrystalPhase( + name = 'pristine_spinel_phases', + crystals = [ + self.crystal1, + self.crystal2 + ], + orientation_maps = [ + self.orientation_map1, + self.orientation_map2 + ] + + ) + + # quantify the phases + self.crystal_phase.quantify_phase( + self.braggpeaks, + tolerance_distance = 0.035, + method = 'nnls', + intensity_power = 0, + ) + From 54a33000773bd5d154df13fb1fbe9f800fd65354 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 6 Feb 2023 20:22:13 -0500 Subject: [PATCH 08/25] finishes moving orientation_map -> crystal. nbs pass, tests pass --- py4DSTEM/process/diffraction/crystal_ACOM.py | 7 ++- py4DSTEM/process/diffraction/crystal_phase.py | 35 +++++++-------- py4DSTEM/process/diffraction/crystal_viz.py | 44 +++++++++---------- py4DSTEM/test/test_classes/test_crystal.py | 7 +-- 4 files changed, 43 insertions(+), 50 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index af2d77c1d..766705683 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -763,11 +763,13 @@ def match_orientations( This function computes the orientation of any number of PointLists stored in a PointListArray, and returns an OrienationMap. ''' + # instantiate object orientation_map = OrientationMap( num_x=bragg_peaks_array.shape[0], num_y=bragg_peaks_array.shape[1], num_matches=num_matches_return) + # loop for rx, ry in tqdmnd( *bragg_peaks_array.shape, desc="Matching Orientations", @@ -775,6 +777,7 @@ def match_orientations( disable=not progress_bar, ): + # get orientation matches orientation = self.match_single_pattern( bragg_peaks_array.get_pointlist(rx, ry), num_matches_return=num_matches_return, @@ -786,8 +789,10 @@ def match_orientations( ) orientation_map.set_orientation(orientation,rx,ry) + + # assign and return self.orientation_map = orientation_map - + if return_orientation: return orientation_map else: diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index b51380cfa..1b6dc9c4a 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -8,7 +8,7 @@ from scipy.optimize import nnls from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern -class Crystal_Phase: +class CrystalPhase: """ A class storing multiple crystal structures, and associated diffraction data. Must be initialized after matching orientations to a pointlistarray??? @@ -17,29 +17,24 @@ class Crystal_Phase: def __init__( self, crystals, - orientation_maps, name, ): """ Args: crystals (list): List of crystal instances - orientation_maps (list): List of orientation maps name (str): Name of Crystal_Phase instance """ - if isinstance(crystals, list): - self.crystals = crystals - self.num_crystals = len(crystals) - else: - raise TypeError('crystals must be a list of crystal instances.') - if isinstance(orientation_maps, list): - if len(self.crystals) != len(orientation_maps): - raise ValueError('Orientation maps must have the same number of entries as crystals.') - self.orientation_maps = orientation_maps - else: - raise TypeError('orientation_maps must be a list of orientation maps.') - self.name = name - return - + # validate inputs + assert(isinstance(crystals,list)), '`crystals` must be a list of crystal instances' + for xtal in crystals: + assert(hasattr(xtal,'orientation_map')), '`crystals` elements must be Crystal instances with a .orientation_map - try running .match_orientations' + + # assign variables + self.num_crystals = len(crystals) + self.crystals = crystals + self.orientation_maps = [xtal.orientation_map for xtal in crystals] + + def plot_all_phase_maps( self, map_scale_values = None, @@ -48,7 +43,7 @@ def plot_all_phase_maps( """ Visualize phase maps of dataset. - Args: + Args: map_scale_values (float): Value to scale correlations by """ phase_maps = [] @@ -59,7 +54,7 @@ def plot_all_phase_maps( phase_maps.append(self.orientation_maps[m].corr[:,:,index] / corr_sum) show_image_grid(lambda i:phase_maps[i], 1, len(phase_maps), cmap = 'inferno') return - + def plot_phase_map( self, index = 0, @@ -300,4 +295,4 @@ def quantify_phase_pointlist( # fig, (ax1, ax2) = plt.subplots(2,1,figsize = figsize) # ax1 = plot_diffraction_pattern(pointlist,) # return - \ No newline at end of file + diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index f0ed95663..f393dfe7c 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -955,7 +955,6 @@ def overline(x): def plot_orientation_maps( self, - orientation_map, orientation_ind: int = 0, dir_in_plane_degrees: float = 0.0, corr_range: np.ndarray = np.array([0, 5]), @@ -975,7 +974,6 @@ def plot_orientation_maps( Plot the orientation maps. Args: - orientation_map (OrientationMap): Class containing orientation matrices, correlation values, etc. orientation_ind (int): Which orientation match to plot if num_matches > 1 dir_in_plane_degrees (float): In-plane angle to plot in degrees. Default is 0 / x-axis / vertical down. corr_range (np.ndarray): Correlation intensity range for the plot @@ -1002,6 +1000,8 @@ def plot_orientation_maps( """ + assert(hasattr(self,'orientation_map')), "No orientation map found - try running .match_orientations" + # Inputs # Legend size leg_size = np.array([300, 300], dtype="int") @@ -1051,17 +1051,17 @@ def plot_orientation_maps( dir_in_plane = np.deg2rad(dir_in_plane_degrees) ct = np.cos(dir_in_plane) st = np.sin(dir_in_plane) - basis_x = np.zeros((orientation_map.num_x, orientation_map.num_y, 3)) - basis_y = np.zeros((orientation_map.num_x, orientation_map.num_y, 3)) - basis_z = np.zeros((orientation_map.num_x, orientation_map.num_y, 3)) - rgb_x = np.zeros((orientation_map.num_x, orientation_map.num_y, 3)) - rgb_z = np.zeros((orientation_map.num_x, orientation_map.num_y, 3)) + basis_x = np.zeros((self.orientation_map.num_x, self.orientation_map.num_y, 3)) + basis_y = np.zeros((self.orientation_map.num_x, self.orientation_map.num_y, 3)) + basis_z = np.zeros((self.orientation_map.num_x, self.orientation_map.num_y, 3)) + rgb_x = np.zeros((self.orientation_map.num_x, self.orientation_map.num_y, 3)) + rgb_z = np.zeros((self.orientation_map.num_x, self.orientation_map.num_y, 3)) # Basis for fitting orientation projections A = np.linalg.inv(self.orientation_zone_axis_range).T # Correlation masking - corr = orientation_map.corr[:, :, orientation_ind] + corr = self.orientation_map.corr[:, :, orientation_ind] if corr_normalize: corr = corr / np.mean(corr) mask = (corr - corr_range[0]) / (corr_range[1] - corr_range[0]) @@ -1069,21 +1069,21 @@ def plot_orientation_maps( # Generate images for rx, ry in tqdmnd( - orientation_map.num_x, - orientation_map.num_y, + self.orientation_map.num_x, + self.orientation_map.num_y, desc="Generating orientation maps", unit=" PointList", disable=not progress_bar, ): if self.pymatgen_available: - basis_x[rx,ry,:] = A @ orientation_map.family[rx,ry,orientation_ind,:,0] - basis_y[rx,ry,:] = A @ orientation_map.family[rx,ry,orientation_ind,:,1] + basis_x[rx,ry,:] = A @ self.orientation_map.family[rx,ry,orientation_ind,:,0] + basis_y[rx,ry,:] = A @ self.orientation_map.family[rx,ry,orientation_ind,:,1] basis_x[rx,ry,:] = basis_x[rx,ry,:]*ct + basis_y[rx,ry,:]*st - basis_z[rx,ry,:] = A @ orientation_map.family[rx,ry,orientation_ind,:,2] + basis_z[rx,ry,:] = A @ self.orientation_map.family[rx,ry,orientation_ind,:,2] else: - basis_z[rx,ry,:] = A @ orientation_map.matrix[rx,ry,orientation_ind,:,2] + basis_z[rx,ry,:] = A @ self.orientation_map.matrix[rx,ry,orientation_ind,:,2] basis_x = np.clip(basis_x,0,1) basis_z = np.clip(basis_z,0,1) @@ -1390,8 +1390,8 @@ def plot_orientation_maps( plt.show() images_orientation = np.zeros(( - orientation_map.num_x, - orientation_map.num_y, + self.orientation_map.num_x, + self.orientation_map.num_y, 3,2)) if self.pymatgen_available: images_orientation[:,:,:,0] = rgb_x @@ -1407,7 +1407,6 @@ def plot_orientation_maps( def plot_fiber_orientation_maps( self, - orientation_map, orientation_ind: int = 0, symmetry_order: int = None, symmetry_mirror: bool = False, @@ -1426,7 +1425,6 @@ def plot_fiber_orientation_maps( Generate and plot the orientation maps from fiber texture plots. Args: - orientation_map (OrientationMap): Class containing orientation matrices, correlation values, etc. orientation_ind (int): Which orientation match to plot if num_matches > 1 dir_in_plane_degrees (float): Reference in-plane angle (degrees). Default is 0 / x-axis / vertical down. corr_range (np.ndarray): Correlation intensity range for the plot @@ -1459,7 +1457,7 @@ def plot_fiber_orientation_maps( ) # Correlation masking - corr = orientation_map.corr[:, :, orientation_ind] + corr = self.orientation_map.corr[:, :, orientation_ind] if corr_normalize: corr = corr / np.mean(corr) if medfilt_size is not None: @@ -1474,7 +1472,7 @@ def plot_fiber_orientation_maps( symmetry_order = 2 * symmetry_order # Generate out-of-plane orientation signal - ang_op = orientation_map.angles[:, :, orientation_ind, 1] + ang_op = self.orientation_map.angles[:, :, orientation_ind, 1] if self.orientation_fiber_angles[0] > 0: sig_op = ang_op / np.deg2rad(self.orientation_fiber_angles[0]) else: @@ -1484,8 +1482,8 @@ def plot_fiber_orientation_maps( # Generate in-plane orientation signal ang_ip = ( - orientation_map.angles[:, :, orientation_ind, 0] - + orientation_map.angles[:, :, orientation_ind, 2] + self.orientation_map.angles[:, :, orientation_ind, 0] + + self.orientation_map.angles[:, :, orientation_ind, 2] ) sig_ip = np.mod((symmetry_order / (2 * np.pi)) * ang_ip, 1.0) if symmetry_mirror: @@ -1609,7 +1607,7 @@ def plot_fiber_orientation_maps( else: ax_op_l.axis("off") - images_orientation = np.zeros((orientation_map.num_x, orientation_map.num_y, 3, 2)) + images_orientation = np.zeros((self.orientation_map.num_x, self.orientation_map.num_y, 3, 2)) images_orientation[:, :, :, 0] = im_ip images_orientation[:, :, :, 1] = im_op diff --git a/py4DSTEM/test/test_classes/test_crystal.py b/py4DSTEM/test/test_classes/test_crystal.py index 4f49c558b..f87eddbdc 100644 --- a/py4DSTEM/test/test_classes/test_crystal.py +++ b/py4DSTEM/test/test_classes/test_crystal.py @@ -1,5 +1,5 @@ from py4DSTEM.process.diffraction import Crystal -from py4DSTEM.process.diffraction import Crystal_Phase as CrystalPhase +from py4DSTEM.process.diffraction import CrystalPhase from py4DSTEM import _TESTPATH,read from os.path import join @@ -170,12 +170,7 @@ def test_crystal_phase(self): crystals = [ self.crystal1, self.crystal2 - ], - orientation_maps = [ - self.orientation_map1, - self.orientation_map2 ] - ) # quantify the phases From d558e946f4a5651bcf071da1b0cf24147d6c91dd Mon Sep 17 00:00:00 2001 From: Tara Prasad Mishra Date: Wed, 8 Feb 2023 14:23:58 -0800 Subject: [PATCH 09/25] change the continue --- py4DSTEM/process/diffraction/crystal_phase.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index 44a97891d..a1f2d3a0f 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -198,7 +198,6 @@ def compare_intensitylists( masterintensitylist[ind][columns_masterintensitylist-1]=temporary_pl_intensities[d] else: - continue ## The point list is not in the mega list of point list so the point list last row of masterpointlist masterpointlist = np.vstack((masterpointlist,bragg_peak_point)) ## Add a row to the intensity list such that all the remaining intensity lists should be 0 but only the new bragg intensity list is non zero but intensity power From 4c1dff12b1c7102c80dea98de13ef83c0bc5cf22 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 16 Feb 2023 16:40:02 -0800 Subject: [PATCH 10/25] refactor of phase matching --- py4DSTEM/process/diffraction/crystal_phase.py | 730 +++++++++++++----- py4DSTEM/process/diffraction/crystal_viz.py | 4 +- 2 files changed, 549 insertions(+), 185 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index 1b6dc9c4a..43746084a 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -1,28 +1,37 @@ import numpy as np +from numpy.linalg import lstsq +from scipy.optimize import nnls import matplotlib as mpl import matplotlib.pyplot as plt + from py4DSTEM.utils.tqdmnd import tqdmnd from py4DSTEM.visualize import show, show_image_grid -from py4DSTEM.io.datastructure.emd.pointlistarray import PointListArray -from numpy.linalg import lstsq -from scipy.optimize import nnls -from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern +# from py4DSTEM.io.datastructure.emd.pointlistarray import PointListArray +# from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern +from py4DSTEM.io.datastructure import PointList, PointListArray + +from dataclasses import dataclass, field +@dataclass class CrystalPhase: """ A class storing multiple crystal structures, and associated diffraction data. Must be initialized after matching orientations to a pointlistarray??? """ + + name: str + num_crystals: int + def __init__( self, crystals, - name, + names = None, ): """ Args: crystals (list): List of crystal instances - name (str): Name of Crystal_Phase instance + name (str): Name of CrystalPhase instance """ # validate inputs assert(isinstance(crystals,list)), '`crystals` must be a list of crystal instances' @@ -32,13 +41,365 @@ def __init__( # assign variables self.num_crystals = len(crystals) self.crystals = crystals - self.orientation_maps = [xtal.orientation_map for xtal in crystals] + # self.orientation_maps = [xtal.orientation_map for xtal in crystals] + + # Get some attributes from crystals + self.k_max = np.zeros(self.num_crystals, dtype='int') + self.num_matches = np.zeros(self.num_crystals, dtype='int') + self.crystal_identity = np.zeros((0,2), dtype='int') + for a0 in range(self.num_crystals): + self.k_max[a0] = self.crystals[a0].k_max + self.num_matches[a0] = self.crystals[a0].orientation_map.num_matches + for a1 in range(self.num_matches[a0]): + self.crystal_identity = np.append(self.crystal_identity,np.array((a0,a1),dtype='int')[None,:], axis=0) + + self.num_fits = np.sum(self.num_matches) + # for a0 in range(self.crystals): + + + if names is not None: + self.names = names + else: + self.names = ['crystal'] * self.num_crystals + + + + def quantify_single_pattern( + self, + pointlistarray: PointListArray, + xy_position = (0,0), + corr_kernel_size = 0.04, + include_false_positives = True, + sigma_excitation_error = 0.02, + power_experiment = 0.5, + power_calculated = 0.5, + plot_result = True, + scale_markers_experiment = 10, + scale_markers_calculated = 4000, + crystal_inds_plot = None, + figsize = (12,8), + returnfig = False, + ): + """ + Quantify the phase for a single diffraction pattern. + """ + + # tolerance + tol2 = 4e-4 + + # Experimental values + bragg_peaks = pointlistarray.get_pointlist(xy_position[0],xy_position[1]).copy() + keep = bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2 > tol2 + # ind_center_beam = np.argmin( + # bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2) + # mask = np.ones_like(bragg_peaks.data["qx"], dtype='bool') + # mask[ind_center_beam] = False + # bragg_peaks.remove(ind_center_beam) + qx = bragg_peaks.data["qx"][keep] + qy = bragg_peaks.data["qy"][keep] + qx0 = bragg_peaks.data["qx"][np.logical_not(keep)] + qy0 = bragg_peaks.data["qy"][np.logical_not(keep)] + if power_experiment == 0: + intensity = np.ones_like(qx) + intensity0 = np.ones_like(qx0) + else: + intensity = bragg_peaks.data["intensity"][keep]**power_experiment + intensity0 = bragg_peaks.data["intensity"][np.logical_not(keep)]**power_experiment + + # init basis array + if include_false_positives: + basis = np.zeros((intensity.shape[0], self.num_fits)) + unpaired_peaks = [] + else: + basis = np.zeros((intensity.shape[0], self.num_fits)) + + # kernel radius squared + radius_max_2 = corr_kernel_size**2 + + # init for plotting + if plot_result: + library_peaks = [] + library_int = [] + library_matches = [] + + # Generate point list data, match to experimental peaks + for a0 in range(self.num_fits): + c = self.crystal_identity[a0,0] + m = self.crystal_identity[a0,1] + # for c in range(self.num_crystals): + # for m in range(self.num_matches[c]): + # ind_match += 1 + + # Generate simulated peaks + bragg_peaks_fit = self.crystals[c].generate_diffraction_pattern( + self.crystals[c].orientation_map.get_orientation( + xy_position[0], xy_position[1] + ), + ind_orientation = m, + sigma_excitation_error = sigma_excitation_error, + ) + del_peak = bragg_peaks_fit.data["qx"]**2 \ + + bragg_peaks_fit.data["qy"]**2 < tol2 + bragg_peaks_fit.remove(del_peak) + + # peak intensities + if power_calculated == 0: + int_fit = np.ones_like(bragg_peaks_fit.data["qx"]) + else: + int_fit = bragg_peaks_fit.data['intensity']**power_calculated + + # Pair peaks to experiment + if plot_result: + matches = np.zeros((bragg_peaks_fit.data.shape[0]),dtype='bool') + + for a1 in range(bragg_peaks_fit.data.shape[0]): + dist2 = (bragg_peaks_fit.data['qx'][a1] - qx)**2 \ + + (bragg_peaks_fit.data['qy'][a1] - qy)**2 + ind_min = np.argmin(dist2) + val_min = dist2[ind_min] + + if val_min < radius_max_2: + weight = 1 - np.sqrt(dist2[ind_min]) / corr_kernel_size + basis[ind_min,a0] = weight * int_fit[a1] + if plot_result: + matches[a1] = True + elif include_false_positives: + unpaired_peaks.append([a0,int_fit[a1]]) + + if plot_result: + library_peaks.append(bragg_peaks_fit) + library_int.append(int_fit) + library_matches.append(matches) + + # If needed, augment basis and observations with false positives + if include_false_positives: + basis_aug = np.zeros((len(unpaired_peaks),self.num_fits)) + for a0 in range(len(unpaired_peaks)): + basis_aug[a0,unpaired_peaks[a0][0]] = unpaired_peaks[a0][1] + + basis = np.vstack((basis, basis_aug)) + obs = np.hstack((intensity, np.zeros(len(unpaired_peaks)))) + else: + obs = intensity + + # Solve for phase coefficients + phase_weights, phase_residual = nnls( + basis, + obs, + ) + + print(np.round(phase_weights,decimals=2)) + # print() + # print(np.array(unpaired_peaks)) + # print() + + # initialize matching array + + + # phase_peak_match_intensities = np.zeros((pointlist['intensity'].shape)) + # bragg_peaks_fit = self.crystals[c].generate_diffraction_pattern( + # self.orientation_maps[c].get_orientation(position[0], position[1]), + # ind_orientation = m + # ) + + + # Plotting + if plot_result: + # fig, ax = plt.subplots(figsize=figsize) + fig = plt.figure(figsize=figsize) + # if plot_layout == 0: + # ax_x = fig.add_axes( + # [0.0+figbound[0], 0.0, 0.4-2*+figbound[0], 1.0]) + ax = fig.add_axes([0.0, 0.0, 0.66, 1.0]) + ax_leg = fig.add_axes([0.68, 0.0, 0.3, 1.0]) + + # plot the experimental radii + t = np.linspace(0,2*np.pi,91,endpoint=True) + ct = np.cos(t) * corr_kernel_size + st = np.sin(t) * corr_kernel_size + for a0 in range(qx.shape[0]): + ax.plot( + qy[a0] + st, + qx[a0] + ct, + color = 'k', + linewidth = 1, + ) + + + # plot the experimental peaks + ax.scatter( + qy0, + qx0, + s = scale_markers_experiment * intensity0, + marker = "o", + facecolor = [0.0, 0.0, 0.0], + ) + ax.scatter( + qy, + qx, + s = scale_markers_experiment * intensity, + marker = "o", + facecolor = [0.0, 0.0, 0.0], + ) + # legend + k_max = np.max(self.k_max) + dx_leg = -0.05*k_max + dy_leg = 0.04*k_max + text_params = { + "va": "center", + "ha": "left", + "family": "sans-serif", + "fontweight": "normal", + "color": "k", + "size": 14, + } + ax_leg.plot( + 0 + st*0.5, + -dx_leg + ct*0.5, + color = 'k', + linewidth = 1, + ) + ax_leg.scatter( + 0, + 0, + s = 200, + marker = "o", + facecolor = [0.0, 0.0, 0.0], + ) + ax_leg.text( + dy_leg, + 0, + 'Experimental peaks', + **text_params) + ax_leg.text( + dy_leg, + -dx_leg, + 'Correlation radius', + **text_params) + + + + # plot calculated diffraction patterns + # Currently just hardcoded for 6 max phases + cvals = np.array(( + (1.0,0.0,0.0,1.0), + (0.0,0.8,1.0,1.0), + (0.0,0.6,0.0,1.0), + (1.0,0.0,1.0,1.0), + (0.0,0.2,1.0,1.0), + (1.0,0.8,0.0,1.0), + )) + uvals = np.array(( + (1.0,0.0,0.0,0.2), + (0.0,0.8,1.0,0.2), + (0.0,0.6,0.0,0.2), + (1.0,0.0,1.0,0.2), + (0.0,0.2,1.0,0.2), + (1.0,0.8,0.0,0.2), + )) + mvals = ['v','^','<','>','d','s',] + + for a0 in range(self.num_fits): + c = self.crystal_identity[a0,0] + m = self.crystal_identity[a0,1] + + if crystal_inds_plot == None or np.min(np.abs(c - crystal_inds_plot)) == 0: + + qx_fit = library_peaks[a0].data['qx'] + qy_fit = library_peaks[a0].data['qy'] + int_fit = library_int[a0] + matches_fit = library_matches[a0] + + if np.mod(m,2) == 0: + ax.scatter( + qy_fit[matches_fit], + qx_fit[matches_fit], + s = scale_markers_calculated * int_fit[matches_fit], + marker = mvals[c], + facecolor = cvals[c,:], + ) + ax.scatter( + qy_fit[np.logical_not(matches_fit)], + qx_fit[np.logical_not(matches_fit)], + s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], + marker = mvals[c], + facecolor = uvals[c,:], + ) + + # legend + ax_leg.scatter( + 0, + dx_leg*(a0+1), + s = 200, + marker = mvals[c], + facecolor = cvals[c,:], + ) + else: + ax.scatter( + qy_fit[matches_fit], + qx_fit[matches_fit], + s = scale_markers_calculated * int_fit[matches_fit], + marker = mvals[c], + edgecolors = cvals[c,:], + facecolors = (1,1,1,0.5), + linewidth = 2, + ) + ax.scatter( + qy_fit[np.logical_not(matches_fit)], + qx_fit[np.logical_not(matches_fit)], + s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], + marker = mvals[c], + edgecolors = uvals[c,:], + facecolors = (1,1,1,0.5), + linewidth = 2, + ) + + # legend + ax_leg.scatter( + 0, + dx_leg*(a0+1), + s = 200, + marker = mvals[c], + edgecolors = cvals[c,:], + facecolors = (1,1,1,0.5), + ) + + # legend text + ax_leg.text( + dy_leg, + (a0+1)*dx_leg, + self.names[c], + **text_params) + + + # appearance + ax.set_xlim((-k_max, k_max)) + ax.set_ylim((-k_max, k_max)) + + ax_leg.set_xlim((-0.1*k_max, 0.4*k_max)) + ax_leg.set_ylim((-0.5*k_max, 0.5*k_max)) + ax_leg.set_axis_off() + + + if returnfig: + return phase_weights, phase_residual, fig, ax + else: + return phase_weights, phase_residual + + def quantify_phase( + + + ): + + + def plot_all_phase_maps( self, map_scale_values = None, - index = 0 + index = 0, + layout = 0, ): """ Visualize phase maps of dataset. @@ -52,7 +413,10 @@ def plot_all_phase_maps( corr_sum = np.sum([(self.orientation_maps[m].corr[:,:,index] * map_scale_values[m]) for m in range(len(self.orientation_maps))]) for m in range(len(self.orientation_maps)): phase_maps.append(self.orientation_maps[m].corr[:,:,index] / corr_sum) - show_image_grid(lambda i:phase_maps[i], 1, len(phase_maps), cmap = 'inferno') + if layout == 0: + show_image_grid(lambda i:phase_maps[i], 1, len(phase_maps), cmap = 'inferno') + elif layout == 1: + show_image_grid(lambda i:phase_maps[i], len(phase_maps), 1, cmap = 'inferno') return def plot_phase_map( @@ -107,192 +471,192 @@ def plot_phase_map( # ): # return - def quantify_phase( - self, - pointlistarray, - tolerance_distance = 0.08, - method = 'nnls', - intensity_power = 0, - mask_peaks = None - ): - """ - Quantification of the phase of a crystal based on the crystal instances and the pointlistarray. + # def quantify_phase( + # self, + # pointlistarray, + # tolerance_distance = 0.08, + # method = 'nnls', + # intensity_power = 0, + # mask_peaks = None + # ): + # """ + # Quantification of the phase of a crystal based on the crystal instances and the pointlistarray. - Args: - pointlisarray (pointlistarray): Pointlistarray to quantify phase of - tolerance_distance (float): Distance allowed between a peak and match - method (str): Numerical method used to quantify phase - intensity_power (float): ... - mask_peaks (list, optional): A pointer of which positions to mask peaks from + # Args: + # pointlisarray (pointlistarray): Pointlistarray to quantify phase of + # tolerance_distance (float): Distance allowed between a peak and match + # method (str): Numerical method used to quantify phase + # intensity_power (float): ... + # mask_peaks (list, optional): A pointer of which positions to mask peaks from - Details: - """ - if isinstance(pointlistarray, PointListArray): - - phase_weights = np.zeros(( - pointlistarray.shape[0], - pointlistarray.shape[1], - np.sum([map.num_matches for map in self.orientation_maps]) - )) - phase_residuals = np.zeros(pointlistarray.shape) - for Rx, Ry in tqdmnd(pointlistarray.shape[0], pointlistarray.shape[1]): - _, phase_weight, phase_residual, crystal_identity = self.quantify_phase_pointlist( - pointlistarray, - position = [Rx, Ry], - tolerance_distance=tolerance_distance, - method = method, - intensity_power = intensity_power, - mask_peaks = mask_peaks - ) - phase_weights[Rx,Ry,:] = phase_weight - phase_residuals[Rx,Ry] = phase_residual - self.phase_weights = phase_weights - self.phase_residuals = phase_residuals - self.crystal_identity = crystal_identity - return - else: - return TypeError('pointlistarray must be of type pointlistarray.') - return + # Details: + # """ + # if isinstance(pointlistarray, PointListArray): + + # phase_weights = np.zeros(( + # pointlistarray.shape[0], + # pointlistarray.shape[1], + # np.sum([map.num_matches for map in self.orientation_maps]) + # )) + # phase_residuals = np.zeros(pointlistarray.shape) + # for Rx, Ry in tqdmnd(pointlistarray.shape[0], pointlistarray.shape[1]): + # _, phase_weight, phase_residual, crystal_identity = self.quantify_phase_pointlist( + # pointlistarray, + # position = [Rx, Ry], + # tolerance_distance=tolerance_distance, + # method = method, + # intensity_power = intensity_power, + # mask_peaks = mask_peaks + # ) + # phase_weights[Rx,Ry,:] = phase_weight + # phase_residuals[Rx,Ry] = phase_residual + # self.phase_weights = phase_weights + # self.phase_residuals = phase_residuals + # self.crystal_identity = crystal_identity + # return + # else: + # return TypeError('pointlistarray must be of type pointlistarray.') + # return - def quantify_phase_pointlist( - self, - pointlistarray, - position, - method = 'nnls', - tolerance_distance = 0.08, - intensity_power = 0, - mask_peaks = None - ): - """ - Args: - pointlisarray (pointlistarray): Pointlistarray to quantify phase of - position (tuple/list): Position of pointlist in pointlistarray - tolerance_distance (float): Distance allowed between a peak and match - method (str): Numerical method used to quantify phase - intensity_power (float): ... - mask_peaks (list, optional): A pointer of which positions to mask peaks from - - Returns: - pointlist_peak_intensity_matches (np.ndarray): Peak matches in the rows of array and the crystals in the columns - phase_weights (np.ndarray): Weights of each phase - phase_residuals (np.ndarray): Residuals - crystal_identity (list): List of lists, where the each entry represents the position in the - crystal and orientation match that is associated with the phase - weights. for example, if the output was [[0,0], [0,1], [1,0], [0,1]], - the first entry [0,0] in phase weights is associated with the first crystal - the first match within that crystal. [0,1] is the first crystal and the - second match within that crystal. - """ - # Things to add: - # 1. Better cost for distance from peaks in pointlists - # 2. Iterate through multiple tolerance_distance values to find best value. Cost function residuals, or something else? + # def quantify_phase_pointlist( + # self, + # pointlistarray, + # position, + # method = 'nnls', + # tolerance_distance = 0.08, + # intensity_power = 0, + # mask_peaks = None + # ): + # """ + # Args: + # pointlisarray (pointlistarray): Pointlistarray to quantify phase of + # position (tuple/list): Position of pointlist in pointlistarray + # tolerance_distance (float): Distance allowed between a peak and match + # method (str): Numerical method used to quantify phase + # intensity_power (float): ... + # mask_peaks (list, optional): A pointer of which positions to mask peaks from + + # Returns: + # pointlist_peak_intensity_matches (np.ndarray): Peak matches in the rows of array and the crystals in the columns + # phase_weights (np.ndarray): Weights of each phase + # phase_residuals (np.ndarray): Residuals + # crystal_identity (list): List of lists, where the each entry represents the position in the + # crystal and orientation match that is associated with the phase + # weights. for example, if the output was [[0,0], [0,1], [1,0], [0,1]], + # the first entry [0,0] in phase weights is associated with the first crystal + # the first match within that crystal. [0,1] is the first crystal and the + # second match within that crystal. + # """ + # # Things to add: + # # 1. Better cost for distance from peaks in pointlists + # # 2. Iterate through multiple tolerance_distance values to find best value. Cost function residuals, or something else? - pointlist = pointlistarray.get_pointlist(position[0], position[1]) - pl_mask = np.where((pointlist['qx'] == 0) & (pointlist['qy'] == 0), 1, 0) - pointlist.remove(pl_mask) - # False Negatives (exp peak with no match in crystal instances) will appear here, already coded in + # pointlist = pointlistarray.get_pointlist(position[0], position[1]) + # pl_mask = np.where((pointlist['qx'] == 0) & (pointlist['qy'] == 0), 1, 0) + # pointlist.remove(pl_mask) + # # False Negatives (exp peak with no match in crystal instances) will appear here, already coded in - if intensity_power == 0: - pl_intensities = np.ones(pointlist['intensity'].shape) - else: - pl_intensities = pointlist['intensity']**intensity_power - #Prepare matches for modeling - pointlist_peak_matches = [] - crystal_identity = [] + # if intensity_power == 0: + # pl_intensities = np.ones(pointlist['intensity'].shape) + # else: + # pl_intensities = pointlist['intensity']**intensity_power + # #Prepare matches for modeling + # pointlist_peak_matches = [] + # crystal_identity = [] - for c in range(len(self.crystals)): - for m in range(self.orientation_maps[c].num_matches): - crystal_identity.append([c,m]) - phase_peak_match_intensities = np.zeros((pointlist['intensity'].shape)) - bragg_peaks_fit = self.crystals[c].generate_diffraction_pattern( - self.orientation_maps[c].get_orientation(position[0], position[1]), - ind_orientation = m - ) - #Find the best match peak within tolerance_distance and add value in the right position - for d in range(pointlist['qx'].shape[0]): - distances = [] - for p in range(bragg_peaks_fit['qx'].shape[0]): - distances.append( - np.sqrt((pointlist['qx'][d] - bragg_peaks_fit['qx'][p])**2 + - (pointlist['qy'][d]-bragg_peaks_fit['qy'][p])**2) - ) - ind = np.where(distances == np.min(distances))[0][0] + # for c in range(len(self.crystals)): + # for m in range(self.orientation_maps[c].num_matches): + # crystal_identity.append([c,m]) + # phase_peak_match_intensities = np.zeros((pointlist['intensity'].shape)) + # bragg_peaks_fit = self.crystals[c].generate_diffraction_pattern( + # self.orientation_maps[c].get_orientation(position[0], position[1]), + # ind_orientation = m + # ) + # #Find the best match peak within tolerance_distance and add value in the right position + # for d in range(pointlist['qx'].shape[0]): + # distances = [] + # for p in range(bragg_peaks_fit['qx'].shape[0]): + # distances.append( + # np.sqrt((pointlist['qx'][d] - bragg_peaks_fit['qx'][p])**2 + + # (pointlist['qy'][d]-bragg_peaks_fit['qy'][p])**2) + # ) + # ind = np.where(distances == np.min(distances))[0][0] - #Potentially for-loop over multiple values for 'tolerance_distance' to find best tolerance_distance value - if distances[ind] <= tolerance_distance: - ## Somewhere in this if statement is probably where better distances from the peak should be coded in - if intensity_power == 0: #This could potentially be a different intensity_power arg - phase_peak_match_intensities[d] = 1**((tolerance_distance-distances[ind])/tolerance_distance) - else: - phase_peak_match_intensities[d] = bragg_peaks_fit['intensity'][ind]**((tolerance_distance-distances[ind])/tolerance_distance) - else: - ## This is probably where the false positives (peaks in crystal but not in experiment) should be handled - continue + # #Potentially for-loop over multiple values for 'tolerance_distance' to find best tolerance_distance value + # if distances[ind] <= tolerance_distance: + # ## Somewhere in this if statement is probably where better distances from the peak should be coded in + # if intensity_power == 0: #This could potentially be a different intensity_power arg + # phase_peak_match_intensities[d] = 1**((tolerance_distance-distances[ind])/tolerance_distance) + # else: + # phase_peak_match_intensities[d] = bragg_peaks_fit['intensity'][ind]**((tolerance_distance-distances[ind])/tolerance_distance) + # else: + # ## This is probably where the false positives (peaks in crystal but not in experiment) should be handled + # continue - pointlist_peak_matches.append(phase_peak_match_intensities) - pointlist_peak_intensity_matches = np.dstack(pointlist_peak_matches) - pointlist_peak_intensity_matches = pointlist_peak_intensity_matches.reshape( - pl_intensities.shape[0], - pointlist_peak_intensity_matches.shape[-1] - ) + # pointlist_peak_matches.append(phase_peak_match_intensities) + # pointlist_peak_intensity_matches = np.dstack(pointlist_peak_matches) + # pointlist_peak_intensity_matches = pointlist_peak_intensity_matches.reshape( + # pl_intensities.shape[0], + # pointlist_peak_intensity_matches.shape[-1] + # ) - if len(pointlist['qx']) > 0: - if mask_peaks is not None: - for i in range(len(mask_peaks)): - if mask_peaks[i] == None: - continue - inds_mask = np.where(pointlist_peak_intensity_matches[:,mask_peaks[i]] != 0)[0] - for mask in range(len(inds_mask)): - pointlist_peak_intensity_matches[inds_mask[mask],i] = 0 - - if method == 'nnls': - phase_weights, phase_residuals = nnls( - pointlist_peak_intensity_matches, - pl_intensities - ) + # if len(pointlist['qx']) > 0: + # if mask_peaks is not None: + # for i in range(len(mask_peaks)): + # if mask_peaks[i] == None: + # continue + # inds_mask = np.where(pointlist_peak_intensity_matches[:,mask_peaks[i]] != 0)[0] + # for mask in range(len(inds_mask)): + # pointlist_peak_intensity_matches[inds_mask[mask],i] = 0 + + # if method == 'nnls': + # phase_weights, phase_residuals = nnls( + # pointlist_peak_intensity_matches, + # pl_intensities + # ) - elif method == 'lstsq': - phase_weights, phase_residuals, rank, singluar_vals = lstsq( - pointlist_peak_intensity_matches, - pl_intensities, - rcond = -1 - ) - phase_residuals = np.sum(phase_residuals) - else: - raise ValueError(method + ' Not yet implemented. Try nnls or lstsq.') - else: - phase_weights = np.zeros((pointlist_peak_intensity_matches.shape[1],)) - phase_residuals = np.NaN - return pointlist_peak_intensity_matches, phase_weights, phase_residuals, crystal_identity + # elif method == 'lstsq': + # phase_weights, phase_residuals, rank, singluar_vals = lstsq( + # pointlist_peak_intensity_matches, + # pl_intensities, + # rcond = -1 + # ) + # phase_residuals = np.sum(phase_residuals) + # else: + # raise ValueError(method + ' Not yet implemented. Try nnls or lstsq.') + # else: + # phase_weights = np.zeros((pointlist_peak_intensity_matches.shape[1],)) + # phase_residuals = np.NaN + # return pointlist_peak_intensity_matches, phase_weights, phase_residuals, crystal_identity - # def plot_peak_matches( - # self, - # pointlistarray, - # position, - # tolerance_distance, - # ind_orientation, - # pointlist_peak_intensity_matches, - # ): - # """ - # A method to view how the tolerance distance impacts the peak matches associated with - # the quantify_phase_pointlist method. + # # def plot_peak_matches( + # # self, + # # pointlistarray, + # # position, + # # tolerance_distance, + # # ind_orientation, + # # pointlist_peak_intensity_matches, + # # ): + # # """ + # # A method to view how the tolerance distance impacts the peak matches associated with + # # the quantify_phase_pointlist method. - # Args: - # pointlistarray, - # position, - # tolerance_distance - # pointlist_peak_intensity_matches - # """ - # pointlist = pointlistarray.get_pointlist(position[0],position[1]) + # # Args: + # # pointlistarray, + # # position, + # # tolerance_distance + # # pointlist_peak_intensity_matches + # # """ + # # pointlist = pointlistarray.get_pointlist(position[0],position[1]) - # for m in range(pointlist_peak_intensity_matches.shape[1]): - # bragg_peaks_fit = self.crystals[m].generate_diffraction_pattern( - # self.orientation_maps[m].get_orientation(position[0], position[1]), - # ind_orientation = ind_orientation - # ) - # peak_inds = np.where(bragg_peaks_fit.data['intensity'] == pointlist_peak_intensity_matches[:,m]) + # # for m in range(pointlist_peak_intensity_matches.shape[1]): + # # bragg_peaks_fit = self.crystals[m].generate_diffraction_pattern( + # # self.orientation_maps[m].get_orientation(position[0], position[1]), + # # ind_orientation = ind_orientation + # # ) + # # peak_inds = np.where(bragg_peaks_fit.data['intensity'] == pointlist_peak_intensity_matches[:,m]) - # fig, (ax1, ax2) = plt.subplots(2,1,figsize = figsize) - # ax1 = plot_diffraction_pattern(pointlist,) - # return + # # fig, (ax1, ax2) = plt.subplots(2,1,figsize = figsize) + # # ax1 = plot_diffraction_pattern(pointlist,) + # # return diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index f393dfe7c..07243384b 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -645,7 +645,7 @@ def plot_orientation_zones( ax.axes.set_xlim3d(left=plot_limit[0], right=plot_limit[1]) ax.axes.set_ylim3d(bottom=plot_limit[0], top=plot_limit[1]) ax.axes.set_zlim3d(bottom=plot_limit[0], top=plot_limit[1]) - ax.set_box_aspect((1, 1, 1)) + # ax.set_box_aspect((1, 1, 1)) ax.set_axis_off() # ax.setxticklabels([]) # fig.subplots_adjust(left=0, right=1, bottom=0, top=1) @@ -1626,7 +1626,7 @@ def axisEqual3D(ax): r = maxsize / 2 for ctr, dim in zip(centers, "xyz"): getattr(ax, "set_{}lim".format(dim))(ctr - r, ctr + r) - ax.set_box_aspect((1, 1, 1)) + # ax.set_box_aspect((1, 1, 1)) def atomic_colors(Z, scheme="jmol"): From 2156d07f65df3eefca4e40813996884f7085ebdf Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 16 Feb 2023 20:55:32 -0800 Subject: [PATCH 11/25] Phase mapping plus plotting functions --- py4DSTEM/process/diffraction/crystal_phase.py | 591 +++++++++--------- 1 file changed, 287 insertions(+), 304 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index 43746084a..2e3921df1 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -44,7 +44,7 @@ def __init__( # self.orientation_maps = [xtal.orientation_map for xtal in crystals] # Get some attributes from crystals - self.k_max = np.zeros(self.num_crystals, dtype='int') + self.k_max = np.zeros(self.num_crystals) self.num_matches = np.zeros(self.num_crystals, dtype='int') self.crystal_identity = np.zeros((0,2), dtype='int') for a0 in range(self.num_crystals): @@ -74,10 +74,19 @@ def quantify_single_pattern( power_experiment = 0.5, power_calculated = 0.5, plot_result = True, - scale_markers_experiment = 10, + scale_markers_experiment = 4, scale_markers_calculated = 4000, crystal_inds_plot = None, + phase_colors = np.array(( + (1.0,0.0,0.0,1.0), + (0.0,0.8,1.0,1.0), + (0.0,0.6,0.0,1.0), + (1.0,0.0,1.0,1.0), + (0.0,0.2,1.0,1.0), + (1.0,0.8,0.0,1.0), + )), figsize = (12,8), + verbose = True, returnfig = False, ): """ @@ -152,6 +161,7 @@ def quantify_single_pattern( if plot_result: matches = np.zeros((bragg_peaks_fit.data.shape[0]),dtype='bool') + # Loop over all people for a1 in range(bragg_peaks_fit.data.shape[0]): dist2 = (bragg_peaks_fit.data['qx'][a1] - qx)**2 \ + (bragg_peaks_fit.data['qy'][a1] - qy)**2 @@ -183,25 +193,33 @@ def quantify_single_pattern( obs = intensity # Solve for phase coefficients - phase_weights, phase_residual = nnls( - basis, - obs, - ) - - print(np.round(phase_weights,decimals=2)) - # print() - # print(np.array(unpaired_peaks)) - # print() - - # initialize matching array - - - # phase_peak_match_intensities = np.zeros((pointlist['intensity'].shape)) - # bragg_peaks_fit = self.crystals[c].generate_diffraction_pattern( - # self.orientation_maps[c].get_orientation(position[0], position[1]), - # ind_orientation = m - # ) - + try: + phase_weights, phase_residual = nnls( + basis, + obs, + ) + except: + phase_weights = np.zeros(self.num_fits) + phase_residual = np.sqrt(np.sum(intensity**2)) + + if verbose: + ind_max = np.argmax(phase_weights) + # print() + print('\033[1m' + 'phase_weight or_ind name' + '\033[0m') + # print() + for a0 in range(self.num_fits): + c = self.crystal_identity[a0,0] + m = self.crystal_identity[a0,1] + line = '{:>12} {:>8} {:<12}'.format( + np.round(phase_weights[a0],decimals=2), + m, + self.names[c] + ) + if a0 == ind_max: + print('\033[1m' + line + '\033[0m') + else: + print(line) + # print() # Plotting if plot_result: @@ -280,23 +298,16 @@ def quantify_single_pattern( # plot calculated diffraction patterns - # Currently just hardcoded for 6 max phases - cvals = np.array(( - (1.0,0.0,0.0,1.0), - (0.0,0.8,1.0,1.0), - (0.0,0.6,0.0,1.0), - (1.0,0.0,1.0,1.0), - (0.0,0.2,1.0,1.0), - (1.0,0.8,0.0,1.0), - )) - uvals = np.array(( - (1.0,0.0,0.0,0.2), - (0.0,0.8,1.0,0.2), - (0.0,0.6,0.0,0.2), - (1.0,0.0,1.0,0.2), - (0.0,0.2,1.0,0.2), - (1.0,0.8,0.0,0.2), - )) + uvals = phase_colors.copy() + uvals[:,3] = 0.3 + # uvals = np.array(( + # (1.0,0.0,0.0,0.2), + # (0.0,0.8,1.0,0.2), + # (0.0,0.6,0.0,0.2), + # (1.0,0.0,1.0,0.2), + # (0.0,0.2,1.0,0.2), + # (1.0,0.8,0.0,0.2), + # )) mvals = ['v','^','<','>','d','s',] for a0 in range(self.num_fits): @@ -316,14 +327,14 @@ def quantify_single_pattern( qx_fit[matches_fit], s = scale_markers_calculated * int_fit[matches_fit], marker = mvals[c], - facecolor = cvals[c,:], + facecolor = phase_colors[c,:], ) ax.scatter( qy_fit[np.logical_not(matches_fit)], qx_fit[np.logical_not(matches_fit)], s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], marker = mvals[c], - facecolor = uvals[c,:], + facecolor = phase_colors[c,:], ) # legend @@ -332,7 +343,7 @@ def quantify_single_pattern( dx_leg*(a0+1), s = 200, marker = mvals[c], - facecolor = cvals[c,:], + facecolor = phase_colors[c,:], ) else: ax.scatter( @@ -340,7 +351,7 @@ def quantify_single_pattern( qx_fit[matches_fit], s = scale_markers_calculated * int_fit[matches_fit], marker = mvals[c], - edgecolors = cvals[c,:], + edgecolors = uvals[c,:], facecolors = (1,1,1,0.5), linewidth = 2, ) @@ -360,7 +371,7 @@ def quantify_single_pattern( dx_leg*(a0+1), s = 200, marker = mvals[c], - edgecolors = cvals[c,:], + edgecolors = uvals[c,:], facecolors = (1,1,1,0.5), ) @@ -387,276 +398,248 @@ def quantify_single_pattern( return phase_weights, phase_residual def quantify_phase( + self, + pointlistarray: PointListArray, + corr_kernel_size = 0.04, + include_false_positives = True, + sigma_excitation_error = 0.02, + power_experiment = 0.5, + power_calculated = 0.5, + progress_bar = True, + ): + """ + Quantify phase of all diffraction patterns. + """ + + # init results arrays + self.phase_weights = np.zeros(( + pointlistarray.shape[0], + pointlistarray.shape[1], + self.num_fits, + )) + self.phase_residuals = np.zeros(( + pointlistarray.shape[0], + pointlistarray.shape[1], + )) + + for rx, ry in tqdmnd( + *pointlistarray.shape, + desc="Matching Orientations", + unit=" PointList", + disable=not progress_bar, + ): + # calculate phase weights + phase_weights, phase_residual = self.quantify_single_pattern( + pointlistarray = pointlistarray, + xy_position = (rx,ry), + corr_kernel_size = corr_kernel_size, + include_false_positives = include_false_positives, + sigma_excitation_error = sigma_excitation_error, + power_experiment = power_experiment, + power_calculated = power_calculated, + plot_result = False, + verbose = False, + returnfig = False, + ) + self.phase_weights[rx,ry] = phase_weights + self.phase_residuals[rx,ry] = phase_residual + def plot_phase_weights( + self, + weight_range = (0.5,1,0), + weight_normalize = True, + cmap = 'gray', + show_ticks = False, + show_axes = True, + figsize = (6,6), + returnfig = False, ): + """ + Plot the individual phase weight maps and residuals. + """ + + # intensity range for plotting + if weight_normalize: + scale = np.median(np.max(self.phase_weights,axis=2)) + else: + scale = 1 + weight_range = np.array(weight_range) * scale + # plotting + fig,ax = plt.subplots( + self.num_crystals + 1, + 1, + figsize=(figsize[0],(self.num_fits+1)*figsize[1])) + for a0 in range(self.num_crystals): + sub = self.crystal_identity[:,0] == a0 + im = np.sum(self.phase_weights[:,:,sub],axis=2) + im = np.clip( + (im - weight_range[0]) / (weight_range[1] - weight_range[0]), + 0,1) + ax[a0].imshow( + im, + vmin = 0, + vmax = 1, + cmap = cmap, + ) + ax[a0].set_title( + self.names[a0], + fontsize = 16, + ) + if not show_ticks: + ax[a0].set_xticks([]) + ax[a0].set_yticks([]) + if not show_axes: + ax[a0].set_axis_off() + + # plot residuals + im = np.clip( + (self.phase_residuals - weight_range[0]) / (weight_range[1] - weight_range[0]), + 0,1) + ax[self.num_crystals].imshow( + im, + vmin = 0, + vmax = 1, + cmap = cmap, + ) + ax[self.num_crystals].set_title( + 'Residuals', + fontsize = 16, + ) + if not show_ticks: + ax[self.num_crystals].set_xticks([]) + ax[self.num_crystals].set_yticks([]) + if not show_axes: + ax[self.num_crystals].set_axis_off() + if returnfig: + return fig, ax - def plot_all_phase_maps( + def plot_phase_maps( self, - map_scale_values = None, - index = 0, - layout = 0, - ): + weight_threshold = 0.5, + weight_normalize = True, + plot_combine = False, + crystal_inds_plot = None, + phase_colors = np.array(( + (1.0,0.0,0.0), + (0.0,0.8,1.0), + (0.0,0.6,0.0), + (1.0,0.0,1.0), + (0.0,0.2,1.0), + (1.0,0.8,0.0), + )), + show_ticks = False, + show_axes = True, + figsize = (6,6), + return_phase_estimate = False, + return_rgb_images = False, + returnfig = False, + ): """ - Visualize phase maps of dataset. - - Args: - map_scale_values (float): Value to scale correlations by + Plot the individual phase weight maps and residuals. """ - phase_maps = [] - if map_scale_values == None: - map_scale_values = [1] * len(self.orientation_maps) - corr_sum = np.sum([(self.orientation_maps[m].corr[:,:,index] * map_scale_values[m]) for m in range(len(self.orientation_maps))]) - for m in range(len(self.orientation_maps)): - phase_maps.append(self.orientation_maps[m].corr[:,:,index] / corr_sum) - if layout == 0: - show_image_grid(lambda i:phase_maps[i], 1, len(phase_maps), cmap = 'inferno') - elif layout == 1: - show_image_grid(lambda i:phase_maps[i], len(phase_maps), 1, cmap = 'inferno') - return - - def plot_phase_map( - self, - index = 0, - cmap = None - - ): - corr_array = np.dstack([maps.corr[:,:,index] for maps in self.orientation_maps]) - best_corr_score = np.max(corr_array,axis=2) - best_match_phase = [np.where(corr_array[:,:,p] == best_corr_score, True,False) - for p in range(len(self.orientation_maps)) - ] - - if cmap == None: - cm = plt.get_cmap('rainbow') - cmap = [cm(1.*i/len(self.orientation_maps)) for i in range(len(self.orientation_maps))] - - fig, (ax) = plt.subplots(figsize = (6,6)) - ax.matshow(np.zeros((self.orientation_maps[0].num_x, self.orientation_maps[0].num_y)), cmap = 'gray') - ax.axis('off') - - for m in range(len(self.orientation_maps)): - c0, c1 = (cmap[m][0]*0.35,cmap[m][1]*0.35,cmap[m][2]*0.35,1), cmap[m] - cm = mpl.colors.LinearSegmentedColormap.from_list('cmap', [c0,c1], N = 10) - ax.matshow( - np.ma.array( - self.orientation_maps[m].corr[:,:,index], - mask = best_match_phase[m]), - cmap = cm) - plt.show() - - return - - # Potentially introduce a way to check best match out of all orientations in phase plan and plug into model - # to quantify phase - - # def phase_plan( - # self, - # method, - # zone_axis_range: np.ndarray = np.array([[0, 1, 1], [1, 1, 1]]), - # angle_step_zone_axis: float = 2.0, - # angle_coarse_zone_axis: float = None, - # angle_refine_range: float = None, - # angle_step_in_plane: float = 2.0, - # accel_voltage: float = 300e3, - # intensity_power: float = 0.25, - # tol_peak_delete=None, - # tol_distance: float = 0.01, - # fiber_axis = None, - # fiber_angles = None, - # ): - # return - - # def quantify_phase( - # self, - # pointlistarray, - # tolerance_distance = 0.08, - # method = 'nnls', - # intensity_power = 0, - # mask_peaks = None - # ): - # """ - # Quantification of the phase of a crystal based on the crystal instances and the pointlistarray. - - # Args: - # pointlisarray (pointlistarray): Pointlistarray to quantify phase of - # tolerance_distance (float): Distance allowed between a peak and match - # method (str): Numerical method used to quantify phase - # intensity_power (float): ... - # mask_peaks (list, optional): A pointer of which positions to mask peaks from - - # Details: - # """ - # if isinstance(pointlistarray, PointListArray): - - # phase_weights = np.zeros(( - # pointlistarray.shape[0], - # pointlistarray.shape[1], - # np.sum([map.num_matches for map in self.orientation_maps]) - # )) - # phase_residuals = np.zeros(pointlistarray.shape) - # for Rx, Ry in tqdmnd(pointlistarray.shape[0], pointlistarray.shape[1]): - # _, phase_weight, phase_residual, crystal_identity = self.quantify_phase_pointlist( - # pointlistarray, - # position = [Rx, Ry], - # tolerance_distance=tolerance_distance, - # method = method, - # intensity_power = intensity_power, - # mask_peaks = mask_peaks - # ) - # phase_weights[Rx,Ry,:] = phase_weight - # phase_residuals[Rx,Ry] = phase_residual - # self.phase_weights = phase_weights - # self.phase_residuals = phase_residuals - # self.crystal_identity = crystal_identity - # return - # else: - # return TypeError('pointlistarray must be of type pointlistarray.') - # return - - # def quantify_phase_pointlist( - # self, - # pointlistarray, - # position, - # method = 'nnls', - # tolerance_distance = 0.08, - # intensity_power = 0, - # mask_peaks = None - # ): - # """ - # Args: - # pointlisarray (pointlistarray): Pointlistarray to quantify phase of - # position (tuple/list): Position of pointlist in pointlistarray - # tolerance_distance (float): Distance allowed between a peak and match - # method (str): Numerical method used to quantify phase - # intensity_power (float): ... - # mask_peaks (list, optional): A pointer of which positions to mask peaks from - - # Returns: - # pointlist_peak_intensity_matches (np.ndarray): Peak matches in the rows of array and the crystals in the columns - # phase_weights (np.ndarray): Weights of each phase - # phase_residuals (np.ndarray): Residuals - # crystal_identity (list): List of lists, where the each entry represents the position in the - # crystal and orientation match that is associated with the phase - # weights. for example, if the output was [[0,0], [0,1], [1,0], [0,1]], - # the first entry [0,0] in phase weights is associated with the first crystal - # the first match within that crystal. [0,1] is the first crystal and the - # second match within that crystal. - # """ - # # Things to add: - # # 1. Better cost for distance from peaks in pointlists - # # 2. Iterate through multiple tolerance_distance values to find best value. Cost function residuals, or something else? - - # pointlist = pointlistarray.get_pointlist(position[0], position[1]) - # pl_mask = np.where((pointlist['qx'] == 0) & (pointlist['qy'] == 0), 1, 0) - # pointlist.remove(pl_mask) - # # False Negatives (exp peak with no match in crystal instances) will appear here, already coded in - - # if intensity_power == 0: - # pl_intensities = np.ones(pointlist['intensity'].shape) - # else: - # pl_intensities = pointlist['intensity']**intensity_power - # #Prepare matches for modeling - # pointlist_peak_matches = [] - # crystal_identity = [] - - # for c in range(len(self.crystals)): - # for m in range(self.orientation_maps[c].num_matches): - # crystal_identity.append([c,m]) - # phase_peak_match_intensities = np.zeros((pointlist['intensity'].shape)) - # bragg_peaks_fit = self.crystals[c].generate_diffraction_pattern( - # self.orientation_maps[c].get_orientation(position[0], position[1]), - # ind_orientation = m - # ) - # #Find the best match peak within tolerance_distance and add value in the right position - # for d in range(pointlist['qx'].shape[0]): - # distances = [] - # for p in range(bragg_peaks_fit['qx'].shape[0]): - # distances.append( - # np.sqrt((pointlist['qx'][d] - bragg_peaks_fit['qx'][p])**2 + - # (pointlist['qy'][d]-bragg_peaks_fit['qy'][p])**2) - # ) - # ind = np.where(distances == np.min(distances))[0][0] - - # #Potentially for-loop over multiple values for 'tolerance_distance' to find best tolerance_distance value - # if distances[ind] <= tolerance_distance: - # ## Somewhere in this if statement is probably where better distances from the peak should be coded in - # if intensity_power == 0: #This could potentially be a different intensity_power arg - # phase_peak_match_intensities[d] = 1**((tolerance_distance-distances[ind])/tolerance_distance) - # else: - # phase_peak_match_intensities[d] = bragg_peaks_fit['intensity'][ind]**((tolerance_distance-distances[ind])/tolerance_distance) - # else: - # ## This is probably where the false positives (peaks in crystal but not in experiment) should be handled - # continue + + # intensity range for plotting + if weight_normalize: + scale = np.median(np.max(self.phase_weights,axis=2)) + else: + scale = 1 + weight_threshold = weight_threshold * scale + + # init + im_all = np.zeros(( + self.num_crystals, + self.phase_weights.shape[0], + self.phase_weights.shape[1])) + im_rgb_all = np.zeros(( + self.num_crystals, + self.phase_weights.shape[0], + self.phase_weights.shape[1], + 3)) + + # phase weights over threshold + for a0 in range(self.num_crystals): + sub = self.crystal_identity[:,0] == a0 + im = np.sum(self.phase_weights[:,:,sub],axis=2) + im_all[a0] = np.maximum(im - weight_threshold, 0) + + # estimate compositions + im_sum = np.sum(im_all, axis = 0) + sub = im_sum > 0.0 + for a0 in range(self.num_crystals): + im_all[a0][sub] /= im_sum[sub] + + for a1 in range(3): + im_rgb_all[a0,:,:,a1] = im_all[a0] * phase_colors[a0,a1] + + if plot_combine: + if crystal_inds_plot is None: + im_rgb = np.sum(im_rgb_all, axis = 0) + else: + im_rgb = np.sum(im_rgb_all[np.array(crystal_inds_plot)], axis = 0) + + im_rgb = np.clip(im_rgb,0,1) + + fig,ax = plt.subplots(1,1,figsize=figsize) + ax.imshow( + im_rgb, + ) + ax.set_title( + 'Phase Maps', + fontsize = 16, + ) + if not show_ticks: + ax.set_xticks([]) + ax.set_yticks([]) + if not show_axes: + ax.set_axis_off() + + else: + # plotting + fig,ax = plt.subplots( + self.num_crystals, + 1, + figsize=(figsize[0],(self.num_fits+1)*figsize[1])) + + for a0 in range(self.num_crystals): - # pointlist_peak_matches.append(phase_peak_match_intensities) - # pointlist_peak_intensity_matches = np.dstack(pointlist_peak_matches) - # pointlist_peak_intensity_matches = pointlist_peak_intensity_matches.reshape( - # pl_intensities.shape[0], - # pointlist_peak_intensity_matches.shape[-1] - # ) - - # if len(pointlist['qx']) > 0: - # if mask_peaks is not None: - # for i in range(len(mask_peaks)): - # if mask_peaks[i] == None: - # continue - # inds_mask = np.where(pointlist_peak_intensity_matches[:,mask_peaks[i]] != 0)[0] - # for mask in range(len(inds_mask)): - # pointlist_peak_intensity_matches[inds_mask[mask],i] = 0 - - # if method == 'nnls': - # phase_weights, phase_residuals = nnls( - # pointlist_peak_intensity_matches, - # pl_intensities - # ) - - # elif method == 'lstsq': - # phase_weights, phase_residuals, rank, singluar_vals = lstsq( - # pointlist_peak_intensity_matches, - # pl_intensities, - # rcond = -1 - # ) - # phase_residuals = np.sum(phase_residuals) - # else: - # raise ValueError(method + ' Not yet implemented. Try nnls or lstsq.') - # else: - # phase_weights = np.zeros((pointlist_peak_intensity_matches.shape[1],)) - # phase_residuals = np.NaN - # return pointlist_peak_intensity_matches, phase_weights, phase_residuals, crystal_identity - - # # def plot_peak_matches( - # # self, - # # pointlistarray, - # # position, - # # tolerance_distance, - # # ind_orientation, - # # pointlist_peak_intensity_matches, - # # ): - # # """ - # # A method to view how the tolerance distance impacts the peak matches associated with - # # the quantify_phase_pointlist method. - - # # Args: - # # pointlistarray, - # # position, - # # tolerance_distance - # # pointlist_peak_intensity_matches - # # """ - # # pointlist = pointlistarray.get_pointlist(position[0],position[1]) - - # # for m in range(pointlist_peak_intensity_matches.shape[1]): - # # bragg_peaks_fit = self.crystals[m].generate_diffraction_pattern( - # # self.orientation_maps[m].get_orientation(position[0], position[1]), - # # ind_orientation = ind_orientation - # # ) - # # peak_inds = np.where(bragg_peaks_fit.data['intensity'] == pointlist_peak_intensity_matches[:,m]) - - # # fig, (ax1, ax2) = plt.subplots(2,1,figsize = figsize) - # # ax1 = plot_diffraction_pattern(pointlist,) - # # return - + ax[a0].imshow( + im_rgb_all[a0], + ) + ax[a0].set_title( + self.names[a0], + fontsize = 16, + ) + if not show_ticks: + ax[a0].set_xticks([]) + ax[a0].set_yticks([]) + if not show_axes: + ax[a0].set_axis_off() + + # All possible returns + if return_phase_estimate: + if returnfig: + return im_all, fig, ax + else: + return im_all + elif return_rgb_images: + if plot_combine: + if returnfig: + return im_rgb, fig, ax + else: + return im_rgb + else: + if returnfig: + return im_rgb_all, fig, ax + else: + return im_rgb_all + else: + if returnfig: + return fig, ax + + + \ No newline at end of file From a1f8a401d8b16cb378c7aa1193fbd8aa1f5b3779 Mon Sep 17 00:00:00 2001 From: Tara Prasad Mishra Date: Thu, 16 Feb 2023 22:13:31 -0800 Subject: [PATCH 12/25] sync with the crystal phase of py4dstem upstream repo --- .../io/datastructure/py4dstem/datacube_fns.py | 5 - py4DSTEM/process/diffraction/crystal_phase.py | 200 ------------------ 2 files changed, 205 deletions(-) diff --git a/py4DSTEM/io/datastructure/py4dstem/datacube_fns.py b/py4DSTEM/io/datastructure/py4dstem/datacube_fns.py index acae63a2d..45d9b1eea 100644 --- a/py4DSTEM/io/datastructure/py4dstem/datacube_fns.py +++ b/py4DSTEM/io/datastructure/py4dstem/datacube_fns.py @@ -1070,11 +1070,6 @@ def find_Bragg_disks( ml_model_path = ml_model_path, ml_num_attempts = ml_num_attempts, ml_batch_size = ml_batch_size, -<<<<<<< HEAD - - # _qt_progress_bar = _qt_progress_bar, -======= ->>>>>>> 8a78ac5ad29e0a99ee82a48032397a9cfe004922 ) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index 6b6f07060..64b9a4692 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -441,80 +441,11 @@ def quantify_phase( verbose = False, returnfig = False, ) -<<<<<<< HEAD - phase_weights[Rx,Ry,:] = phase_weight - phase_residuals[Rx,Ry] = phase_residual - self.phase_weights = phase_weights - self.phase_residuals = phase_residuals - self.crystal_identity = crystal_identity - return - else: - return TypeError('pointlistarray must be of type pointlistarray.') - return - - def compare_intensitylists( - self, - masterpointlist, - masterintensitylist, - bragg_peaks_fit, - tolerance_distance, - intensity_power - ): - """ - Function to compare the exisiting point list enteries with the array. - """ - # Add a column of zeros in the master intensity list to make way for the new fitted intensity list - zeros = np.zeros((masterintensitylist.shape[0], 1)) - masterintensitylist = np.concatenate((masterintensitylist,zeros),axis=1) - - # Compare with the exisiting bragg_peaks_fit with the masterpointlist. - # Make a temporary intensity list to store the intensities of the the bragg_peaks_fit. - - if intensity_power == 0: - temporary_pl_intensities = np.ones(bragg_peaks_fit['intensity'].shape) - else: - temporary_pl_intensities = bragg_peaks_fit['intensity']**intensity_power - - - # Go through the bragg_peaks_fit to find if the master list has an entry or not. - for d in range(bragg_peaks_fit['qx'].shape[0]): - distances = [] - # Making a numpy array of the fitted bragg peak - bragg_peak_point=np.array([bragg_peaks_fit['qx'][d],bragg_peaks_fit['qy'][d]]) - for p in range(masterpointlist.shape[0]): - distances.append(np.linalg.norm(bragg_peak_point-masterpointlist[p]) - ) - ind = np.where(distances == np.min(distances))[0][0] - # Potentially loop over to find the best tolerance distance. - if distances[ind] <= tolerance_distance: - columns_masterintensitylist = len(masterintensitylist[0]) - masterintensitylist[ind][columns_masterintensitylist-1]=temporary_pl_intensities[d] - - else: - ## The point list is not in the mega list of point list so the point list last row of masterpointlist - masterpointlist = np.vstack((masterpointlist,bragg_peak_point)) - ## Add a row to the intensity list such that all the remaining intensity lists should be 0 but only the new bragg intensity list is non zero but intensity power - new_intensity_list_row = np.zeros((1, masterintensitylist.shape[1]-1)) - new_intensity_list_row = np.append(new_intensity_list_row, [temporary_pl_intensities[d]]) - new_intensity_list_row = new_intensity_list_row.reshape((1,-1)) - masterintensitylist = np.concatenate((masterintensitylist,new_intensity_list_row),axis=0) - - - - - return masterpointlist,masterintensitylist - - - - - def quantify_phase_pointlist( -======= self.phase_weights[rx,ry] = phase_weights self.phase_residuals[rx,ry] = phase_residual def plot_phase_weights( ->>>>>>> 8a78ac5ad29e0a99ee82a48032397a9cfe004922 self, weight_range = (0.5,1,0), weight_normalize = True, @@ -528,89 +459,6 @@ def plot_phase_weights( Plot the individual phase weight maps and residuals. """ -<<<<<<< HEAD - Returns: - pointlist_peak_intensity_matches (np.ndarray): Peak matches in the rows of array and the crystals in the columns - phase_weights (np.ndarray): Weights of each phase - phase_residuals (np.ndarray): Residuals - crystal_identity (list): List of lists, where the each entry represents the position in the - crystal and orientation match that is associated with the phase - weights. for example, if the output was [[0,0], [0,1], [1,0], [0,1]], - the first entry [0,0] in phase weights is associated with the first crystal - the first match within that crystal. [0,1] is the first crystal and the - second match within that crystal. - """ - # Things to add: - # 1. Better cost for distance from peaks in pointlists - # 2. Iterate through multiple tolerance_distance values to find best value. Cost function residuals, or something else? - # 3. Make a flag variable for the experimental dataset which turns 1 if it is encountered in the simulated dataset. - - pointlist = pointlistarray.get_pointlist(position[0], position[1]) - ## Remove the central beam - pl_mask = np.where((pointlist['qx'] == 0) & (pointlist['qy'] == 0), 1, 0) - pointlist.remove(pl_mask) - # False Negatives (exp peak with no match in crystal instances) will appear here, already coded in - - if intensity_power == 0: - pl_intensities = np.ones(pointlist['intensity'].shape) - else: - pl_intensities = pointlist['intensity']**intensity_power - - #Prepare matches for modeling - pointlist_peak_matches = [] - crystal_identity = [] - ## Initialize the megapointlist and master intensity list with the experimental intensity - masterpointlist = np.column_stack((pointlist['qx'],pointlist['qy'])) - masterintensitylist = pl_intensities - ## Convert masterintensitylist to a 2D array - masterintensitylist = np.array(masterintensitylist, ndmin=2).T - ## Loop over the number of crystals. - for c in range(len(self.crystals)): - ## Loop over the number of num matches which is the number of orientation candidates. - # This value of num matches was supplied when the orientation map was created. - for m in range(self.orientation_maps[c].num_matches): - # Set crystal identity - crystal_identity.append([c,m]) - # For a given crystal class generate a diffraction pattern given a orientation crystal and given num match - bragg_peaks_fit = self.crystals[c].generate_diffraction_pattern( - self.orientation_maps[c].get_orientation(position[0], position[1]), - ind_orientation = m - ) - # Check if there are any experimental intensity observed at all. - if len(masterpointlist !=0): - # Send this bragg_peaks_fit to the compare function to be compared with mega point list and master intensity list. - masterpointlist,masterintensitylist=self.compare_intensitylists(masterpointlist,masterintensitylist,bragg_peaks_fit,tolerance_distance,intensity_power) - else: - continue - - - ### The intensity and point lists are accumulated in the masterintensitylist and masterpointlist. - # The first column of the intensity lists are the observed experimental intensities. - observed_intensities = masterintensitylist[:,0] - expected_intensities = masterintensitylist[:,1:] - - if len(observed_intensities) > 0: - if mask_peaks is not None: - for i in range(len(mask_peaks)): - if mask_peaks[i] == None: - continue - inds_mask = np.where(expected_intensities[:,mask_peaks[i]] != 0)[0] - for mask in range(len(inds_mask)): - expected_intensities[inds_mask[mask],i] = 0 - if method == 'nnls': - phase_weights, phase_residuals = nnls( - expected_intensities, - observed_intensities - ) - - elif method == 'lstsq': - phase_weights, phase_residuals, rank, singluar_vals = lstsq( - expected_intensities, - observed_intensities, - rcond = -1 - ) - phase_residuals = np.sum(phase_residuals) -======= # intensity range for plotting if weight_normalize: scale = np.median(np.max(self.phase_weights,axis=2)) @@ -731,7 +579,6 @@ def plot_phase_maps( if plot_combine: if crystal_inds_plot is None: im_rgb = np.sum(im_rgb_all, axis = 0) ->>>>>>> 8a78ac5ad29e0a99ee82a48032397a9cfe004922 else: im_rgb = np.sum(im_rgb_all[np.array(crystal_inds_plot)], axis = 0) @@ -752,52 +599,6 @@ def plot_phase_maps( ax.set_axis_off() else: -<<<<<<< HEAD - # Find the number of expected phases - number_expected_phases=0 - for c in range(len(self.crystals)): - for m in range(self.orientation_maps[c].num_matches): - number_expected_phases+=1 - - # If there are no diffraction patterns - phase_weights = np.zeros(number_expected_phases) - phase_residuals = np.NaN - - return expected_intensities, phase_weights, phase_residuals, crystal_identity - - - # def plot_peak_matches( - # self, - # pointlistarray, - # position, - # tolerance_distance, - # ind_orientation, - # pointlist_peak_intensity_matches, - # ): - # """ - # A method to view how the tolerance distance impacts the peak matches associated with - # the quantify_phase_pointlist method. - - # Args: - # pointlistarray, - # position, - # tolerance_distance - # pointlist_peak_intensity_matches - # """ - # pointlist = pointlistarray.get_pointlist(position[0],position[1]) - - # for m in range(pointlist_peak_intensity_matches.shape[1]): - # bragg_peaks_fit = self.crystals[m].generate_diffraction_pattern( - # self.orientation_maps[m].get_orientation(position[0], position[1]), - # ind_orientation = ind_orientation - # ) - # peak_inds = np.where(bragg_peaks_fit.data['intensity'] == pointlist_peak_intensity_matches[:,m]) - - # fig, (ax1, ax2) = plt.subplots(2,1,figsize = figsize) - # ax1 = plot_diffraction_pattern(pointlist,) - # return - -======= # plotting fig,ax = plt.subplots( self.num_crystals, @@ -842,4 +643,3 @@ def plot_phase_maps( ->>>>>>> 8a78ac5ad29e0a99ee82a48032397a9cfe004922 From ae91f2404eca4c85b093b617dd98082e0e343fcd Mon Sep 17 00:00:00 2001 From: Tara Prasad Mishra Date: Sun, 19 Feb 2023 08:36:24 -0800 Subject: [PATCH 13/25] change class variable name --- py4DSTEM/process/diffraction/crystal_phase.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index 64b9a4692..7909ec459 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -20,7 +20,7 @@ class CrystalPhase: """ - name: str + names: str num_crystals: int def __init__( From 73e495b9a977b37b823e3ea8eb8c71489509f3eb Mon Sep 17 00:00:00 2001 From: Colin Date: Tue, 7 Mar 2023 10:16:14 -0800 Subject: [PATCH 14/25] Testing normalization by total DF peak intensity Also adding plot layouts --- py4DSTEM/process/diffraction/crystal_phase.py | 76 ++++++++++++++----- 1 file changed, 58 insertions(+), 18 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index 7909ec459..0277543f9 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -114,6 +114,8 @@ def quantify_single_pattern( else: intensity = bragg_peaks.data["intensity"][keep]**power_experiment intensity0 = bragg_peaks.data["intensity"][np.logical_not(keep)]**power_experiment + int_total = np.sum(intensity) + # init basis array if include_false_positives: @@ -393,9 +395,9 @@ def quantify_single_pattern( if returnfig: - return phase_weights, phase_residual, fig, ax + return phase_weights, phase_residual, int_total, fig, ax else: - return phase_weights, phase_residual + return phase_weights, phase_residual, int_total def quantify_phase( self, @@ -421,6 +423,10 @@ def quantify_phase( pointlistarray.shape[0], pointlistarray.shape[1], )) + self.int_total = np.zeros(( + pointlistarray.shape[0], + pointlistarray.shape[1], + )) for rx, ry in tqdmnd( *pointlistarray.shape, @@ -429,7 +435,7 @@ def quantify_phase( disable=not progress_bar, ): # calculate phase weights - phase_weights, phase_residual = self.quantify_single_pattern( + phase_weights, phase_residual, int_peaks = self.quantify_single_pattern( pointlistarray = pointlistarray, xy_position = (rx,ry), corr_kernel_size = corr_kernel_size, @@ -443,15 +449,18 @@ def quantify_phase( ) self.phase_weights[rx,ry] = phase_weights self.phase_residuals[rx,ry] = phase_residual + self.int_total[rx,ry] = int_peaks def plot_phase_weights( self, weight_range = (0.5,1,0), - weight_normalize = True, + weight_normalize = False, + total_intensity_normalize = True, cmap = 'gray', show_ticks = False, show_axes = True, + layout = 0, figsize = (6,6), returnfig = False, ): @@ -459,22 +468,37 @@ def plot_phase_weights( Plot the individual phase weight maps and residuals. """ + # Normalization if required to total DF peak intensity + phase_weights = self.phase_weights.copy() + phase_residuals = self.phase_residuals.copy() + if total_intensity_normalize: + sub = self.int_total > 0.0 + for a0 in range(self.num_fits): + phase_weights[:,:,a0][sub] /= self.int_total[sub] + phase_residuals[sub] /= self.int_total[sub] + # intensity range for plotting if weight_normalize: - scale = np.median(np.max(self.phase_weights,axis=2)) + scale = np.median(np.max(phase_weights,axis=2)) else: scale = 1 weight_range = np.array(weight_range) * scale # plotting - fig,ax = plt.subplots( - self.num_crystals + 1, - 1, - figsize=(figsize[0],(self.num_fits+1)*figsize[1])) + if layout == 0: + fig,ax = plt.subplots( + 1, + self.num_crystals + 1, + figsize=(figsize[0],(self.num_fits+1)*figsize[1])) + elif layout == 1: + fig,ax = plt.subplots( + self.num_crystals + 1, + 1, + figsize=(figsize[0],(self.num_fits+1)*figsize[1])) for a0 in range(self.num_crystals): sub = self.crystal_identity[:,0] == a0 - im = np.sum(self.phase_weights[:,:,sub],axis=2) + im = np.sum(phase_weights[:,:,sub],axis=2) im = np.clip( (im - weight_range[0]) / (weight_range[1] - weight_range[0]), 0,1) @@ -496,7 +520,8 @@ def plot_phase_weights( # plot residuals im = np.clip( - (self.phase_residuals - weight_range[0]) / (weight_range[1] - weight_range[0]), + (phase_residuals - weight_range[0]) \ + / (weight_range[1] - weight_range[0]), 0,1) ax[self.num_crystals].imshow( im, @@ -522,6 +547,8 @@ def plot_phase_maps( self, weight_threshold = 0.5, weight_normalize = True, + total_intensity_normalize = True, + plot_combine = False, crystal_inds_plot = None, phase_colors = np.array(( @@ -534,6 +561,7 @@ def plot_phase_maps( )), show_ticks = False, show_axes = True, + layout = 0, figsize = (6,6), return_phase_estimate = False, return_rgb_images = False, @@ -543,9 +571,15 @@ def plot_phase_maps( Plot the individual phase weight maps and residuals. """ + phase_weights = self.phase_weights.copy() + if total_intensity_normalize: + sub = self.int_total > 0.0 + for a0 in range(self.num_fits): + phase_weights[:,:,a0][sub] /= self.int_total[sub] + # intensity range for plotting if weight_normalize: - scale = np.median(np.max(self.phase_weights,axis=2)) + scale = np.median(np.max(phase_weights,axis=2)) else: scale = 1 weight_threshold = weight_threshold * scale @@ -564,7 +598,7 @@ def plot_phase_maps( # phase weights over threshold for a0 in range(self.num_crystals): sub = self.crystal_identity[:,0] == a0 - im = np.sum(self.phase_weights[:,:,sub],axis=2) + im = np.sum(phase_weights[:,:,sub],axis=2) im_all[a0] = np.maximum(im - weight_threshold, 0) # estimate compositions @@ -600,11 +634,17 @@ def plot_phase_maps( else: # plotting - fig,ax = plt.subplots( - self.num_crystals, - 1, - figsize=(figsize[0],(self.num_fits+1)*figsize[1])) - + if layout == 0: + fig,ax = plt.subplots( + 1, + self.num_crystals, + figsize=(figsize[0],(self.num_fits+1)*figsize[1])) + elif layout == 1: + fig,ax = plt.subplots( + self.num_crystals, + 1, + figsize=(figsize[0],(self.num_fits+1)*figsize[1])) + for a0 in range(self.num_crystals): ax[a0].imshow( From e6b9f2eb1c136deb9d5ddb6b912872f1b9872ec4 Mon Sep 17 00:00:00 2001 From: Ben Savitzky Date: Mon, 13 Mar 2023 21:34:03 -0700 Subject: [PATCH 15/25] updates --- py4DSTEM/__init__.py | 14 +++++++++++--- py4DSTEM/process/diffraction/crystal_viz.py | 4 ++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/py4DSTEM/__init__.py b/py4DSTEM/__init__.py index 047b7d980..7855d83b6 100644 --- a/py4DSTEM/__init__.py +++ b/py4DSTEM/__init__.py @@ -1,6 +1,11 @@ from py4DSTEM.version import __version__ from py4DSTEM.utils.tqdmnd import tqdmnd +# test paths +from os.path import dirname,join +_TESTPATH = join(dirname(__file__), "test/unit_test_data") + + # submodules @@ -18,8 +23,11 @@ from py4DSTEM.utils.configuration_checker import check_config -# test paths -from os.path import dirname,join -_TESTPATH = join(dirname(__file__), "test/unit_test_data") + +# classes + +from py4DSTEM.process.diffraction import Crystal + + diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 07243384b..20838cd5b 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -298,6 +298,7 @@ def plot_scattering_intensity( bragg_intensity_power=1.0, bragg_k_broadening=0.005, figsize: Union[list, tuple, np.ndarray] = (12, 6), + title: Optional[str] = None, returnfig: bool = False, ): """ @@ -317,6 +318,7 @@ def plot_scattering_intensity( bragg_intensity_power (float): bragg_peaks scaled by intensities**bragg_intensity_power. bragg_k_broadening float): Broadening applied to bragg_peaks. figsize (list, tuple, np.ndarray): Figure size for plot. + title (str or None): Title returnfig (bool): Return figure and axes handles if this is True. Returns: @@ -392,6 +394,8 @@ def plot_scattering_intensity( ax.set_xlabel("Scattering Vector k [1/A]", fontsize=14) ax.set_yticks([]) ax.set_ylabel("Magnitude", fontsize=14) + if title is not None: + ax.set_title(title, fontsize=16) if returnfig: return fig, ax From 20bb94f51191d58d9a1f2d084ec000e9fc9f688b Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Fri, 14 Jul 2023 01:58:04 -0400 Subject: [PATCH 16/25] bugfix --- py4DSTEM/preprocess/preprocess.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/py4DSTEM/preprocess/preprocess.py b/py4DSTEM/preprocess/preprocess.py index 755d07ae4..f1052d72d 100644 --- a/py4DSTEM/preprocess/preprocess.py +++ b/py4DSTEM/preprocess/preprocess.py @@ -283,12 +283,6 @@ def bin_data_diffraction( # set calibration pixel size datacube.calibration.set_Q_pixel_size(Qpixsize) - # remake Cartesian coordinate system - datacube.qyy,datacube.qxx = np.meshgrid( - np.arange(0,datacube.Q_Ny), - np.arange(0,datacube.Q_Nx) - ) - # return return datacube From 7b9d330440f02bacbc75673b5a4f63a92c91b687 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Fri, 14 Jul 2023 11:36:39 -0400 Subject: [PATCH 17/25] bugfix --- py4DSTEM/datacube/virtualimage.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/py4DSTEM/datacube/virtualimage.py b/py4DSTEM/datacube/virtualimage.py index 6731c0fd0..d91b23906 100644 --- a/py4DSTEM/datacube/virtualimage.py +++ b/py4DSTEM/datacube/virtualimage.py @@ -325,7 +325,7 @@ def position_detector( shift_center = None, scan_position = None, invert = False, - color = 'r', + color = 'c', alpha = 0.7, **kwargs ): @@ -382,10 +382,10 @@ def position_detector( # data if data is None: keys = ['dp_mean','dp_max','dp_median'] + image = None for k in keys: - image = None try: - image = data.tree(k) + image = self.tree(k) break except: pass @@ -393,6 +393,7 @@ def position_detector( image = self[0,0] elif isinstance(data, np.ndarray): assert(data.shape == self.Qshape), f"Can't position a detector over an image with a shape that is different from diffraction space. Diffraction space in this dataset has shape {self.Qshape} but the image passed has shape {data.shape}" + image = data elif isinstance(data,tuple): rx,ry = data[:2] image = self[rx,ry] From e62642d866268f2fc15a85f876e022a50b500ea4 Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 9 Aug 2023 15:53:06 -0700 Subject: [PATCH 18/25] Fixing merge conflicts --- py4DSTEM/__init__.py | 8 ++--- py4DSTEM/process/diffraction/crystal_phase.py | 34 +++++-------------- py4DSTEM/process/diffraction/crystal_viz.py | 4 --- 3 files changed, 12 insertions(+), 34 deletions(-) diff --git a/py4DSTEM/__init__.py b/py4DSTEM/__init__.py index 6f362df93..aed475452 100644 --- a/py4DSTEM/__init__.py +++ b/py4DSTEM/__init__.py @@ -91,18 +91,18 @@ from py4DSTEM.utils.configuration_checker import check_config # TODO - config .toml -<<<<<<< HEAD +# <<<<<<< HEAD # classes -from py4DSTEM.process.diffraction import Crystal +# from py4DSTEM.process.diffraction import Crystal -======= +# ======= # testing from os.path import dirname,join _TESTPATH = join(dirname(__file__), "../test/unit_test_data") ->>>>>>> dev +# >>>>>>> dev diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index 86ed7fe6e..235d2c3e3 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -4,19 +4,19 @@ import matplotlib as mpl import matplotlib.pyplot as plt -<<<<<<< HEAD -from py4DSTEM.utils.tqdmnd import tqdmnd -from py4DSTEM.visualize import show, show_image_grid -# from py4DSTEM.io.datastructure.emd.pointlistarray import PointListArray -# from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern -from py4DSTEM.io.datastructure import PointList, PointListArray +# <<<<<<< HEAD +# from py4DSTEM.utils.tqdmnd import tqdmnd +# from py4DSTEM.visualize import show, show_image_grid +# # from py4DSTEM.io.datastructure.emd.pointlistarray import PointListArray +# # from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern +# from py4DSTEM.io.datastructure import PointList, PointListArray from dataclasses import dataclass, field -======= +# ======= from emdfile import tqdmnd, PointListArray from py4DSTEM.visualize import show, show_image_grid from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern ->>>>>>> dev +# >>>>>>> dev @dataclass class CrystalPhase: @@ -66,19 +66,10 @@ def __init__( if names is not None: self.names = names else: -<<<<<<< HEAD self.names = ['crystal'] * self.num_crystals - def quantify_single_pattern( -======= - raise TypeError('orientation_maps must be a list of orientation maps.') - self.name = name - return - - def plot_all_phase_maps( ->>>>>>> dev self, pointlistarray: PointListArray, xy_position = (0,0), @@ -695,12 +686,3 @@ def plot_phase_maps( if returnfig: return fig, ax - - -<<<<<<< HEAD -======= - # fig, (ax1, ax2) = plt.subplots(2,1,figsize = figsize) - # ax1 = plot_diffraction_pattern(pointlist,) - # return - ->>>>>>> dev diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index ebaac0b4a..582eb07ad 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -296,12 +296,8 @@ def plot_scattering_intensity( bragg_k_power=0.0, bragg_intensity_power=1.0, bragg_k_broadening=0.005, -<<<<<<< HEAD figsize: Union[list, tuple, np.ndarray] = (12, 6), title: Optional[str] = None, -======= - figsize: Union[list, tuple, np.ndarray] = (10, 4), ->>>>>>> dev returnfig: bool = False, ): """ From fa4736ffe9ea2535ed50c9cfe6f4250299f7f050 Mon Sep 17 00:00:00 2001 From: cophus Date: Wed, 9 Aug 2023 16:30:39 -0700 Subject: [PATCH 19/25] updates to polardata --- py4DSTEM/braggvectors/braggvector_methods.py | 243 ++----- py4DSTEM/braggvectors/braggvectors.py | 29 +- py4DSTEM/data/calibration.py | 10 +- py4DSTEM/datacube/datacube.py | 35 +- py4DSTEM/datacube/virtualimage.py | 17 +- py4DSTEM/io/filereaders/__init__.py | 3 +- py4DSTEM/io/filereaders/read_abTEM.py | 81 +++ py4DSTEM/io/filereaders/read_arina.py | 115 ++++ py4DSTEM/io/google_drive_downloader.py | 57 ++ py4DSTEM/io/importfile.py | 37 +- .../legacy/legacy13/v13_emd_classes/array.py | 6 +- py4DSTEM/io/parsefiletype.py | 84 ++- py4DSTEM/io/read.py | 124 ++-- py4DSTEM/preprocess/preprocess.py | 8 +- py4DSTEM/process/calibration/origin.py | 9 +- py4DSTEM/process/diffraction/crystal.py | 150 ++++- py4DSTEM/process/diffraction/crystal_ACOM.py | 48 +- .../process/diffraction/crystal_calibrate.py | 35 +- py4DSTEM/process/diffraction/crystal_viz.py | 80 ++- py4DSTEM/process/diffraction/flowlines.py | 24 +- py4DSTEM/process/fit/fit.py | 4 - py4DSTEM/process/latticevectors/fit.py | 15 +- py4DSTEM/process/latticevectors/index.py | 57 +- py4DSTEM/process/latticevectors/strain.py | 33 +- .../iterative_multislice_ptychography.py | 6 +- py4DSTEM/process/polar/__init__.py | 2 +- py4DSTEM/process/polar/polar_analysis.py | 46 ++ py4DSTEM/process/polar/polar_datacube.py | 16 +- py4DSTEM/process/polar/polar_fits.py | 51 +- py4DSTEM/process/polar/polar_peaks.py | 114 +++- py4DSTEM/process/rdf/amorph.py | 2 +- py4DSTEM/process/strain.py | 617 +++++++++++++++--- py4DSTEM/process/utils/utils.py | 4 +- py4DSTEM/version.py | 2 +- py4DSTEM/visualize/overlay.py | 22 +- py4DSTEM/visualize/show.py | 8 +- py4DSTEM/visualize/vis_special.py | 96 --- setup.py | 1 + test/gettestdata.py | 21 +- test/test_nonnative_io/test_arina.py | 19 + test/test_strain.py | 2 + 41 files changed, 1613 insertions(+), 720 deletions(-) create mode 100644 py4DSTEM/io/filereaders/read_abTEM.py create mode 100644 py4DSTEM/io/filereaders/read_arina.py create mode 100644 test/test_nonnative_io/test_arina.py diff --git a/py4DSTEM/braggvectors/braggvector_methods.py b/py4DSTEM/braggvectors/braggvector_methods.py index 53f800ed1..be766ad49 100644 --- a/py4DSTEM/braggvectors/braggvector_methods.py +++ b/py4DSTEM/braggvectors/braggvector_methods.py @@ -186,7 +186,11 @@ def get_virtual_image( mode = None, geometry = None, name = 'bragg_virtual_image', - returncalc = True + returncalc = True, + center = True, + ellipse = True, + pixel = True, + rotate = True, ): ''' Calculates a virtual image based on the values of the Braggvectors @@ -204,13 +208,22 @@ def get_virtual_image( - 'circle', 'circular': nested 2-tuple, ((qx,qy),radius) - 'annular' or 'annulus': nested 2-tuple, ((qx,qy),(radius_i,radius_o)) - All values are in pixels. Note that (qx,qy) can be skipped, which - assumes peaks centered at (0,0) + Values can be in pixels or calibrated units. Note that (qx,qy) + can be skipped, which assumes peaks centered at (0,0). + center: bool + Apply calibration - center coordinate. + ellipse: bool + Apply calibration - elliptical correction. + pixel: bool + Apply calibration - pixel size. + rotate: bool + Apply calibration - QR rotation. Returns ------- virtual_im : VirtualImage ''' + # parse inputs circle_modes = ['circular','circle'] annulus_modes = ['annular','annulus'] @@ -220,13 +233,13 @@ def get_virtual_image( # set geometry if mode is None: if geometry is None: - center = None + qxy_center = None radial_range = np.array((0,np.inf)) else: if len(geometry[0]) == 0: - center = None + qxy_center = None else: - center = np.array(geometry[0]) + qxy_center = np.array(geometry[0]) if isinstance(geometry[1], int) or isinstance(geometry[1], float): radial_range = np.array((0,geometry[1])) elif len(geometry[1]) == 0: @@ -236,30 +249,44 @@ def get_virtual_image( elif mode == 'circular' or mode == 'circle': radial_range = np.array((0,geometry[1])) if len(geometry[0]) == 0: - center = None + qxy_center = None else: - center = np.array(geometry[0]) + qxy_center = np.array(geometry[0]) elif mode == 'annular' or mode == 'annulus': radial_range = np.array(geometry[1]) if len(geometry[0]) == 0: - center = None + qxy_center = None else: - center = np.array(geometry[0]) + qxy_center = np.array(geometry[0]) # allocate space im_virtual = np.zeros(self.shape) # generate image - for rx,ry in tqdmnd(self.shape[0],self.shape[1]): - p = self.raw[rx,ry] + for rx,ry in tqdmnd( + self.shape[0], + self.shape[1], + ): + # Get user-specified Bragg vectors + p = self.get_vectors( + rx, + ry, + center = center, + ellipse = ellipse, + pixel = pixel, + rotate = rotate, + ) + if p.data.shape[0] > 0: if radial_range is None: im_virtual[rx,ry] = np.sum(p.I) else: - if center is None: + if qxy_center is None: qr = np.hypot(p.qx,p.qy) else: - qr = np.hypot(p.qx - center[0],p.qy - center[1]) + qr = np.hypot( + p.qx - qxy_center[0], + p.qy - qxy_center[1]) sub = np.logical_and( qr >= radial_range[0], qr < radial_range[1]) @@ -284,7 +311,7 @@ def get_virtual_image( } ) # attach to the tree - self.attach( ans) + self.attach(ans) # return if returncalc: @@ -634,192 +661,6 @@ def fit_p_ellipse( if returncalc: return p_ellipse - - # Deprecated?? - - def index_bragg_directions( - self, - x0 = None, - y0 = None, - plot = True, - bvm_vis_params = {}, - returncalc = False, - ): - """ - From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of - lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the - reciprocal lattice directions. - - Args: - x0 (float): x-coord of origin - y0 (float): y-coord of origin - Plot (bool): plot results - """ - - if x0 is None: - x0 = self.Qshape[0]/2 - if y0 is None: - y0 = self.Qshape[0]/2 - - from py4DSTEM.process.latticevectors import index_bragg_directions - _, _, braggdirections = index_bragg_directions( - x0, - y0, - self.g['x'], - self.g['y'], - self.g1, - self.g2 - ) - - self.braggdirections = braggdirections - - if plot: - from py4DSTEM.visualize import show_bragg_indexing - show_bragg_indexing( - self.bvm_centered, - **bvm_vis_params, - braggdirections = braggdirections, - points = True - ) - - if returncalc: - return braggdirections - - - - def add_indices_to_braggpeaks( - self, - maxPeakSpacing, - mask = None, - returncalc = False, - ): - """ - Using the peak positions (qx,qy) and indices (h,k) in the PointList lattice, - identify the indices for each peak in the PointListArray braggpeaks. - Return a new braggpeaks_indexed PointListArray, containing a copy of braggpeaks plus - three additional data columns -- 'h','k', and 'index_mask' -- specifying the peak - indices with the ints (h,k) and indicating whether the peak was successfully indexed - or not with the bool index_mask. If `mask` is specified, only the locations where - mask is True are indexed. - - Args: - maxPeakSpacing (float): Maximum distance from the ideal lattice points - to include a peak for indexing - qx_shift,qy_shift (number): the shift of the origin in the `lattice` PointList - relative to the `braggpeaks` PointListArray - mask (bool): Boolean mask, same shape as the pointlistarray, indicating which - locations should be indexed. This can be used to index different regions of - the scan with different lattices - """ - from py4DSTEM.process.latticevectors import add_indices_to_braggpeaks - - bragg_peaks_indexed = add_indices_to_braggpeaks( - self.vectors, - self.braggdirections, - maxPeakSpacing = maxPeakSpacing, - qx_shift = self.Qshape[0]/2, - qy_shift = self.Qshape[1]/2, - ) - - self.bragg_peaks_indexed = bragg_peaks_indexed - - if returncalc: - return bragg_peaks_indexed - - - def fit_lattice_vectors_all_DPs(self, returncalc = False): - """ - Fits lattice vectors g1,g2 to each diffraction pattern in braggpeaks, given some - known (h,k) indexing. - - - """ - - from py4DSTEM.process.latticevectors import fit_lattice_vectors_all_DPs - g1g2_map = fit_lattice_vectors_all_DPs(self.bragg_peaks_indexed) - self.g1g2_map = g1g2_map - if returncalc: - return g1g2_map - - def get_strain_from_reference_region(self, mask, returncalc = False): - """ - Gets a strain map from the reference region of real space specified by mask and the - lattice vector map g1g2_map. - - Args: - mask (ndarray of bools): use lattice vectors from g1g2_map scan positions - wherever mask==True - - """ - from py4DSTEM.process.latticevectors import get_strain_from_reference_region - - strainmap_median_g1g2 = get_strain_from_reference_region( - self.g1g2_map, - mask = mask, - ) - - self.strainmap_median_g1g2 = strainmap_median_g1g2 - - if returncalc: - return strainmap_median_g1g2 - - - def get_strain_from_reference_g1g2(self, mask, returncalc = False): - """ - Gets a strain map from the reference lattice vectors g1,g2 and lattice vector map - g1g2_map. - - - Args: - mask (ndarray of bools): use lattice vectors from g1g2_map scan positions - wherever mask==True - - """ - from py4DSTEM.process.latticevectors import get_reference_g1g2 - g1_ref,g2_ref = get_reference_g1g2(self.g1g2_map, mask) - - from py4DSTEM.process.latticevectors import get_strain_from_reference_g1g2 - strainmap_reference_g1g2 = get_strain_from_reference_g1g2(self.g1g2_map, g1_ref, g2_ref) - - self.strainmap_reference_g1g2 = strainmap_reference_g1g2 - - if returncalc: - return strainmap_reference_g1g2 - - def get_rotated_strain_map(self, mode, g_reference = None, returncalc = True, flip_theta = False): - """ - Starting from a strain map defined with respect to the xy coordinate system of - diffraction space, i.e. where exx and eyy are the compression/tension along the Qx - and Qy directions, respectively, get a strain map defined with respect to some other - right-handed coordinate system, in which the x-axis is oriented along (xaxis_x, - xaxis_y). - - Args: - g_referencce (tupe): reference coordinate system for xaxis_x and xaxis_y - """ - - assert mode in ("median","reference") - if g_reference is None: - g_reference = np.subtract(self.g1, self.g2) - - from py4DSTEM.process.latticevectors import get_rotated_strain_map - - if mode == "median": - strainmap_raw = self.strainmap_median_g1g2 - elif mode == "reference": - strainmap_raw = self.strainmap_reference_g1g2 - - strainmap = get_rotated_strain_map( - strainmap_raw, - xaxis_x = g_reference[0], - xaxis_y = g_reference[1], - flip_theta = flip_theta - ) - - if returncalc: - return strainmap - - def mask_in_Q( self, mask, diff --git a/py4DSTEM/braggvectors/braggvectors.py b/py4DSTEM/braggvectors/braggvectors.py index 14b89fd98..f1ff406d0 100644 --- a/py4DSTEM/braggvectors/braggvectors.py +++ b/py4DSTEM/braggvectors/braggvectors.py @@ -65,7 +65,7 @@ def __init__( Rshape, Qshape, name = 'braggvectors', - verbose = True, + verbose = False, calibration = None ): Custom.__init__(self,name=name) @@ -236,9 +236,16 @@ def setcal( "rotate" : rotate, } if self.verbose: - print('current calstate: ', self.calstate) + print('current calibration state: ', self.calstate) pass + def calibrate(self): + """ + Autoupdate the calstate when relevant calibrations are set + """ + self.setcal() + + # vector getter method @@ -250,7 +257,7 @@ def get_vectors( ellipse, pixel, rotate - ): + ): """ Returns the bragg vectors at the specified scan position with the specified calibration state. @@ -268,6 +275,7 @@ def get_vectors( ------- vectors : BVects """ + ans = self._v_uncal[scan_x,scan_y].data ans = self.cal._transform( data = ans, @@ -282,17 +290,16 @@ def get_vectors( # copy - def copy(self, name=None): name = name if name is not None else self.name+"_copy" braggvector_copy = BraggVectors( - self.Rshape, - self.Qshape, - name=name, + self.Rshape, + self.Qshape, + name=name, calibration = self.calibration.copy() ) - - braggvector_copy._v_uncal = self._v_uncal.copy() + + braggvector_copy.set_raw_vectors( self._v_uncal.copy() ) for k in self.metadata.keys(): braggvector_copy.metadata = self.metadata[k].copy() return braggvector_copy @@ -526,6 +533,4 @@ def _transform( # return - return ans - - + return ans \ No newline at end of file diff --git a/py4DSTEM/data/calibration.py b/py4DSTEM/data/calibration.py index 6077864f3..50ec8f6f9 100644 --- a/py4DSTEM/data/calibration.py +++ b/py4DSTEM/data/calibration.py @@ -64,8 +64,8 @@ class Calibration(Metadata): theta, * p_ellipse, * ellipse, * - QR_rotation_degrees, - QR_flip, + QR_rotation_degrees, * + QR_flip, * QR_rotflip, * probe_semiangle, probe_param, @@ -598,11 +598,13 @@ def ellipse(self,x): # Q/R-space rotation and flip + @call_calibrate def set_QR_rotation_degrees(self,x): self._params['QR_rotation_degrees'] = x def get_QR_rotation_degrees(self): return self._get_value('QR_rotation_degrees') + @call_calibrate def set_QR_flip(self,x): self._params['QR_flip'] = x def get_QR_flip(self): @@ -617,8 +619,8 @@ def set_QR_rotflip(self, rot_flip): flip (bool): True indicates a Q/R axes flip """ rot,flip = rot_flip - self.set_QR_rotation_degrees(rot) - self.set_QR_flip(flip) + self._params['QR_rotation_degrees'] = rot + self._params['QR_flip'] = flip def get_QR_rotflip(self): rot = self.get_QR_rotation_degrees() flip = self.get_QR_flip() diff --git a/py4DSTEM/datacube/datacube.py b/py4DSTEM/datacube/datacube.py index 81db0ed9b..ae3a82a36 100644 --- a/py4DSTEM/datacube/datacube.py +++ b/py4DSTEM/datacube/datacube.py @@ -2,8 +2,8 @@ import numpy as np from scipy.interpolate import interp1d -from scipy.ndimage import (binary_opening, binary_dilation,distance_transform_edt, - binary_fill_holes, gaussian_filter1d,gaussian_filter) +from scipy.ndimage import (binary_opening, binary_dilation, + distance_transform_edt, binary_fill_holes, gaussian_filter1d, gaussian_filter) from typing import Optional,Union from emdfile import Array, Metadata, Node, Root, tqdmnd @@ -125,6 +125,9 @@ def calibrate(self): self._qxx,self._qyy = np.meshgrid( dim_qx,dim_qy ) self._rxx,self._ryy = np.meshgrid( dim_rx,dim_ry ) + self._qyy_raw,self._qxx_raw = np.meshgrid( np.arange(self.Q_Ny),np.arange(self.Q_Nx) ) + self._ryy_raw,self._rxx_raw = np.meshgrid( np.arange(self.R_Ny),np.arange(self.R_Nx) ) + # coordinate meshgrids @@ -140,6 +143,18 @@ def qxx(self): @property def qyy(self): return self._qyy + @property + def rxx_raw(self): + return self._rxx_raw + @property + def ryy_raw(self): + return self._ryy_raw + @property + def qxx_raw(self): + return self._qxx_raw + @property + def qyy_raw(self): + return self._qyy_raw # coordinate meshgrids with shifted origin def qxxs(self,rx,ry): @@ -1061,11 +1076,12 @@ def get_beamstop_mask( # im = self.tree["dp_max"].data.astype('float') if not "dp_max" in self._branch.keys(): self.get_dp_max(); - im = self.tree("dp_max").data.astype('float') + im = self.tree("dp_max").data.copy().astype('float') else: if not "dp_mean" in self._branch.keys(): self.get_dp_mean(); - im = self.tree("dp_mean").data + im = self.tree("dp_mean").data.copy() + # if not "dp_mean" in self.tree.keys(): # self.get_dp_mean(); # im = self.tree["dp_mean"].data.astype('float') @@ -1119,7 +1135,7 @@ def get_beamstop_mask( ) # Add to tree - self.attach( x ) + self.tree(x) # return if returncalc: @@ -1170,7 +1186,7 @@ def get_radial_bkgrnd( # define the 2D cartesian coordinate system origin = self.calibration.get_origin() origin = origin[0][rx,ry],origin[1][rx,ry] - qxx,qyy = self.qxx-origin[0], self.qyy-origin[1] + qxx,qyy = self.qxx_raw-origin[0], self.qyy_raw-origin[1] # get distance qr in polar-elliptical coords ellipse = self.calibration.get_ellipse() @@ -1455,9 +1471,6 @@ def get_braggmask( vects = braggvectors.raw[rx,ry] # loop for idx in range(len(vects.data)): - qr = np.hypot(self.qxx-vects.qx[idx], self.qyy-vects.qy[idx]) + qr = np.hypot(self.qxx_raw-vects.qx[idx], self.qyy_raw-vects.qy[idx]) mask = np.logical_and(mask, qr>radius) - return mask - - - + return mask \ No newline at end of file diff --git a/py4DSTEM/datacube/virtualimage.py b/py4DSTEM/datacube/virtualimage.py index d91b23906..5e2681eb6 100644 --- a/py4DSTEM/datacube/virtualimage.py +++ b/py4DSTEM/datacube/virtualimage.py @@ -10,7 +10,7 @@ import inspect from emdfile import tqdmnd,Metadata -from py4DSTEM.data import Calibration, RealSlice, Data +from py4DSTEM.data import Calibration, RealSlice, Data, DiffractionSlice from py4DSTEM.visualize.show import show @@ -322,7 +322,7 @@ def position_detector( data = None, centered = None, calibrated = None, - shift_center = None, + shift_center = False, scan_position = None, invert = False, color = 'c', @@ -381,6 +381,7 @@ def position_detector( # data if data is None: + image = None keys = ['dp_mean','dp_max','dp_median'] image = None for k in keys: @@ -389,11 +390,14 @@ def position_detector( break except: pass - if image is None: - image = self[0,0] + if image is None: + image = self[0,0] elif isinstance(data, np.ndarray): assert(data.shape == self.Qshape), f"Can't position a detector over an image with a shape that is different from diffraction space. Diffraction space in this dataset has shape {self.Qshape} but the image passed has shape {data.shape}" image = data + elif isinstance(data, DiffractionSlice): + assert(data.shape == self.Qshape), f"Can't position a detector over an image with a shape that is different from diffraction space. Diffraction space in this dataset has shape {self.Qshape} but the image passed has shape {data.shape}" + image = data.data elif isinstance(data,tuple): rx,ry = data[:2] image = self[rx,ry] @@ -402,10 +406,7 @@ def position_detector( # shift center if shift_center is None: - if isinstance(data,tuple): - shift_center = True - else: - shift_center = False + shift_center = False elif shift_center == True: assert(isinstance(data,tuple)), "If shift_center is set to True, `data` should be a 2-tuple (rx,ry). To shift the detector mask while using some other input for `data`, set `shift_center` to a 2-tuple (rx,ry)" elif isinstance(shift_center,tuple): diff --git a/py4DSTEM/io/filereaders/__init__.py b/py4DSTEM/io/filereaders/__init__.py index d256334a8..b6f4eb0a2 100644 --- a/py4DSTEM/io/filereaders/__init__.py +++ b/py4DSTEM/io/filereaders/__init__.py @@ -2,4 +2,5 @@ from py4DSTEM.io.filereaders.read_K2 import read_gatan_K2_bin from py4DSTEM.io.filereaders.empad import read_empad from py4DSTEM.io.filereaders.read_mib import load_mib - +from py4DSTEM.io.filereaders.read_arina import read_arina +from py4DSTEM.io.filereaders.read_abTEM import read_abTEM diff --git a/py4DSTEM/io/filereaders/read_abTEM.py b/py4DSTEM/io/filereaders/read_abTEM.py new file mode 100644 index 000000000..1fec9e73e --- /dev/null +++ b/py4DSTEM/io/filereaders/read_abTEM.py @@ -0,0 +1,81 @@ +import h5py +from py4DSTEM.data import DiffractionSlice, RealSlice +from py4DSTEM.datacube import DataCube + +def read_abTEM( + filename, + mem="RAM", + binfactor: int = 1, +): + """ + File reader for abTEM datasets + Args: + filename: str with path to file + mem (str): Must be "RAM" or "MEMMAP". Specifies how the data is + loaded; "RAM" transfer the data from storage to RAM, while "MEMMAP" + leaves the data in storage and creates a memory map which points to + the diffraction patterns, allowing them to be retrieved individually + from storage. + binfactor (int): Diffraction space binning factor for bin-on-load. + + Returns: + DataCube + """ + assert mem == "RAM", "read_abTEM does not support memory mapping" + assert binfactor == 1, "abTEM files can only be read at full resolution" + + with h5py.File(filename, "r") as f: + datasets = {} + for key in f.keys(): + datasets[key] = f.get(key)[()] + + data = datasets["array"] + + sampling = datasets["sampling"] + units = datasets["units"] + + assert len(data.shape) in (2, 4), "abtem reader supports only 4D and 2D data" + + if len(data.shape) == 4: + + datacube = DataCube(data=data) + + datacube.calibration.set_R_pixel_size(sampling[0]) + if sampling[0] != sampling[1]: + print( + "Warning: py4DSTEM currently only handles uniform x,y sampling. Setting sampling with x calibration" + ) + datacube.calibration.set_Q_pixel_size(sampling[2]) + if sampling[2] != sampling[3]: + print( + "Warning: py4DSTEM currently only handles uniform qx,qy sampling. Setting sampling with qx calibration" + ) + + if units[0] == b"\xc3\x85": + datacube.calibration.set_R_pixel_units("A") + else: + datacube.calibration.set_R_pixel_units(units[0].decode("utf-8")) + + datacube.calibration.set_Q_pixel_units(units[2].decode("utf-8")) + + return datacube + + else: + if units[0] == b"mrad": + diffraction = DiffractionSlice(data=data) + if sampling[0] != sampling[1]: + print( + "Warning: py4DSTEM currently only handles uniform qx,qy sampling. Setting sampling with x calibration" + ) + diffraction.calibration.set_Q_pixel_units(units[0].decode("utf-8")) + diffraction.calibration.set_Q_pixel_size(sampling[0]) + return diffraction + else: + image = RealSlice(data=data) + if sampling[0] != sampling[1]: + print( + "Warning: py4DSTEM currently only handles uniform x,y sampling. Setting sampling with x calibration" + ) + image.calibration.set_Q_pixel_units("A") + image.calibration.set_Q_pixel_size(sampling[0]) + return image diff --git a/py4DSTEM/io/filereaders/read_arina.py b/py4DSTEM/io/filereaders/read_arina.py new file mode 100644 index 000000000..323b5643f --- /dev/null +++ b/py4DSTEM/io/filereaders/read_arina.py @@ -0,0 +1,115 @@ +import h5py +import hdf5plugin +import numpy as np +from py4DSTEM.datacube import DataCube +from py4DSTEM.preprocess.utils import bin2D + + +def read_arina( + filename, + scan_width=1, + mem="RAM", + binfactor: int = 1, + dtype_bin: float = None, + flatfield: np.ndarray = None, +): + + """ + File reader for arina 4D-STEM datasets + Args: + filename: str with path to master file + scan_width: x dimension of scan + mem (str): Must be "RAM" or "MEMMAP". Specifies how the data is + loaded; "RAM" transfer the data from storage to RAM, while "MEMMAP" + leaves the data in storage and creates a memory map which points to + the diffraction patterns, allowing them to be retrieved individually + from storage. + binfactor (int): Diffraction space binning factor for bin-on-load. + dtype_bin(float): specify datatype for bin on load if need something + other than uint16 + flatfield (np.ndarray): + flatfield forcorrection factors + + Returns: + DataCube + """ + assert mem == "RAM", "read_arina does not support memory mapping" + + f = h5py.File(filename, "r") + nimages = 0 + + # Count the number of images in all datasets + for dset in f["entry"]["data"]: + nimages = nimages + f["entry"]["data"][dset].shape[0] + height = f["entry"]["data"][dset].shape[1] + width = f["entry"]["data"][dset].shape[2] + dtype = f["entry"]["data"][dset].dtype + + width = width // binfactor + height = height // binfactor + + assert ( + nimages % scan_width < 1e-6 + ), "scan_width must be integer multiple of x*y size" + + if dtype.type is np.uint32: + print("Dataset is uint32 but will be converted to uint16") + dtype = np.dtype(np.uint16) + + if dtype_bin: + array_3D = np.empty((nimages, width, height), dtype=dtype_bin) + else: + array_3D = np.empty((nimages, width, height), dtype=dtype) + + image_index = 0 + + if flatfield is None: + correction_factors = 1 + else: + # Avoid div by 0 errors -> pixel with value 0 will be set to meadian + flatfield[flatfield == 0] = 1 + correction_factors = np.median(flatfield) / flatfield + + for dset in f["entry"]["data"]: + image_index = _processDataSet( + f["entry"]["data"][dset], + image_index, + array_3D, + binfactor, + correction_factors, + ) + + if f.__bool__(): + f.close() + + scan_height = int(nimages / scan_width) + + datacube = DataCube( + np.flip( + array_3D.reshape( + scan_width, scan_height, array_3D.data.shape[1], array_3D.data.shape[2] + ), + 0, + ) + ) + + return datacube + + +def _processDataSet(dset, start_index, array_3D, binfactor, correction_factors): + image_index = start_index + nimages_dset = dset.shape[0] + + for i in range(nimages_dset): + if binfactor == 1: + array_3D[image_index] = np.multiply( + dset[i].astype(array_3D.dtype), correction_factors + ) + else: + array_3D[image_index] = bin2D( + np.multiply(dset[i].astype(array_3D.dtype), correction_factors), + binfactor, + ) + + image_index = image_index + 1 + return image_index diff --git a/py4DSTEM/io/google_drive_downloader.py b/py4DSTEM/io/google_drive_downloader.py index 51e3a70d7..86ad1a9f4 100644 --- a/py4DSTEM/io/google_drive_downloader.py +++ b/py4DSTEM/io/google_drive_downloader.py @@ -83,6 +83,50 @@ 'test_realslice_io.h5', '1siH80-eRJwG5R6AnU4vkoqGWByrrEz1y' ), + 'test_arina_master' : ( + 'STO_STEM_bench_20us_master.h5', + '1q_4IjFuWRkw5VM84NhxrNTdIq4563BOC' + ), + 'test_arina_01' : ( + 'STO_STEM_bench_20us_data_000001.h5', + '1_3Dbm22-hV58iffwK9x-3vqJUsEXZBFQ' + ), + 'test_arina_02' : ( + 'STO_STEM_bench_20us_data_000002.h5', + '1x29RzHLnCzP0qthLhA1kdlUQ09ENViR8' + ), + 'test_arina_03' : ( + 'STO_STEM_bench_20us_data_000003.h5', + '1qsbzdEVD8gt4DYKnpwjfoS_Mg4ggObAA' + ), + 'test_arina_04' : ( + 'STO_STEM_bench_20us_data_000004.h5', + '1Lcswld0Y9fNBk4-__C9iJbc854BuHq-h' + ), + 'test_arina_05' : ( + 'STO_STEM_bench_20us_data_000005.h5', + '13YTO2ABsTK5nObEr7RjOZYCV3sEk3gt9' + ), + 'test_arina_06' : ( + 'STO_STEM_bench_20us_data_000006.h5', + '1RywPXt6HRbCvjgjSuYFf60QHWlOPYXwy' + ), + 'test_arina_07' : ( + 'STO_STEM_bench_20us_data_000007.h5', + '1GRoBecCvAUeSIujzsPywv1vXKSIsNyoT' + ), + 'test_arina_08' : ( + 'STO_STEM_bench_20us_data_000008.h5', + '1sTFuuvgKbTjZz1lVUfkZbbTDTQmwqhuU' + ), + 'test_arina_09' : ( + 'STO_STEM_bench_20us_data_000009.h5', + '1JmBiMg16iMVfZ5wz8z_QqcNPVRym1Ezh' + ), + 'test_arina_10' : ( + 'STO_STEM_bench_20us_data_000010.h5', + '1_90xAfclNVwMWwQ-YKxNNwBbfR1nfHoB' + ), 'test_strain' : ( 'downsample_Si_SiGe_analysis_braggdisks_cal.h5', '1bYgDdAlnWHyFmY-SwN3KVpMutWBI5MhP' @@ -112,6 +156,19 @@ 'legacy_v0.14', 'test_realslice_io', ), + 'test_arina' : ( + 'test_arina_master', + 'test_arina_01', + 'test_arina_02', + 'test_arina_03', + 'test_arina_04', + 'test_arina_05', + 'test_arina_06', + 'test_arina_07', + 'test_arina_08', + 'test_arina_09', + 'test_arina_10', + ), 'test_braggvectors' : ( 'Au_sim', ), diff --git a/py4DSTEM/io/importfile.py b/py4DSTEM/io/importfile.py index 17b052601..20a3759a2 100644 --- a/py4DSTEM/io/importfile.py +++ b/py4DSTEM/io/importfile.py @@ -1,18 +1,18 @@ # Reader functions for non-native file types import pathlib -from os.path import exists, splitext -from typing import Union, Optional +from os.path import exists +from typing import Optional, Union -from py4DSTEM.io.parsefiletype import _parse_filetype from py4DSTEM.io.filereaders import ( - read_empad, + load_mib, + read_abTEM, + read_arina, read_dm, + read_empad, read_gatan_K2_bin, - load_mib ) - - +from py4DSTEM.io.parsefiletype import _parse_filetype def import_file( @@ -37,6 +37,7 @@ def import_file( from storage. binfactor (int): Diffraction space binning factor for bin-on-load. filetype (str): Used to override automatic filetype detection. + options include "dm", "empad", "gatan_K2_bin", "mib", "arina", "abTEM" **kwargs: any additional kwargs are passed to the downstream reader - refer to the individual filetype reader function call signatures and docstrings for more details. @@ -55,9 +56,7 @@ def import_file( "RAM", "MEMMAP", ], 'Error: argument mem must be either "RAM" or "MEMMAP"' - assert isinstance( - binfactor, int - ), "Error: argument binfactor must be an integer" + assert isinstance(binfactor, int), "Error: argument binfactor must be an integer" assert binfactor >= 1, "Error: binfactor must be >= 1" if binfactor > 1: assert ( @@ -66,13 +65,17 @@ def import_file( filetype = _parse_filetype(filepath) if filetype is None else filetype - if filetype == 'EMD': - raise Exception("EMD file detected - use py4DSTEM.read, not py4DSTEM.import_file!") + if filetype in ("emd", "legacy"): + raise Exception( + "EMD file or py4DSTEM detected - use py4DSTEM.read, not py4DSTEM.import_file!" + ) assert filetype in [ "dm", "empad", "gatan_K2_bin", - "mib" + "mib", + "arina", + "abTEM" # "kitware_counted", ], "Error: filetype not recognized" @@ -85,10 +88,12 @@ def import_file( # elif filetype == "kitware_counted": # data = read_kitware_counted(filepath, mem, binfactor, metadata=metadata, **kwargs) elif filetype == "mib": - data = load_mib(filepath, mem=mem, binfactor=binfactor,**kwargs) + data = load_mib(filepath, mem=mem, binfactor=binfactor, **kwargs) + elif filetype == "arina": + data = read_arina(filepath, mem=mem, binfactor=binfactor, **kwargs) + elif filetype == "abTEM": + data = read_abTEM(filepath, mem=mem, binfactor=binfactor, **kwargs) else: raise Exception("Bad filetype!") return data - - diff --git a/py4DSTEM/io/legacy/legacy13/v13_emd_classes/array.py b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/array.py index 4b385a694..8b20779f8 100644 --- a/py4DSTEM/io/legacy/legacy13/v13_emd_classes/array.py +++ b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/array.py @@ -322,8 +322,10 @@ def set_dim( values for the n'th dim vector. Accepts: n (int): specifies which dim vector - dim (list or array): length must be either 2, or equal to the - length of the n'th axis of the data array + dim (list or array): length must be either 1 or 2, or equal to the + length of the n'th axis of the data array. If length is 1 specifies step + size of dim vector and starts at 0. If length is 2, specifies start + and step of dim vector. units (Optional, str): name: (Optional, str): """ diff --git a/py4DSTEM/io/parsefiletype.py b/py4DSTEM/io/parsefiletype.py index 5903ce814..1838f89b6 100644 --- a/py4DSTEM/io/parsefiletype.py +++ b/py4DSTEM/io/parsefiletype.py @@ -1,9 +1,18 @@ # File parser utility from os.path import splitext +import py4DSTEM.io.legacy as legacy +import emdfile as emd +import h5py + +import emdfile as emd +import h5py +import py4DSTEM.io.legacy as legacy + def _parse_filetype(fp): - """ Accepts a path to a data file, and returns the file type as a string. + """ + Accepts a path to a data file, and returns the file type as a string. """ _, fext = splitext(fp) fext = fext.lower() @@ -13,7 +22,20 @@ def _parse_filetype(fp): ".py4dstem", ".emd", ]: - return "H5" + if emd._is_EMD_file(fp): + return "emd" + + elif legacy.is_py4DSTEM_file(fp): + return "legacy" + + elif _is_arina(fp): + return "arina" + + elif _is_abTEM(fp): + return "abTEM" + else: + raise Exception("not supported `h5` data type") + elif fext in [ ".dm", ".dm3", @@ -21,17 +43,67 @@ def _parse_filetype(fp): ]: return "dm" elif fext in [".raw"]: - return "empad" + return "empad" elif fext in [".mrc"]: - return "mrc_relativity" + return "mrc_relativity" elif fext in [".gtg", ".bin"]: - return "gatan_K2_bin" + return "gatan_K2_bin" elif fext in [".kitware_counted"]: - return "kitware_counted" + return "kitware_counted" elif fext in [".mib", ".MIB"]: return "mib" else: raise Exception(f"Unrecognized file extension {fext}.") +def _is_arina(filepath): + """ + Check if an h5 file is an Arina file. + """ + with h5py.File(filepath,'r') as f: + try: + assert("entry" in f.keys()) + except AssertionError: + return False + try: + assert("NX_class" in f["entry"].attrs.keys()) + except AssertionError: + return False + return True + +def _is_abTEM(filepath): + """ + Check if an h5 file is an abTEM file. + """ + with h5py.File(filepath,'r') as f: + try: + assert("array" in f.keys()) + except AssertionError: + return False + return True + +def _is_arina(filepath): + """ + Check if an h5 file is an Arina file. + """ + with h5py.File(filepath, "r") as f: + try: + assert "entry" in f.keys() + except AssertionError: + return False + try: + assert "NX_class" in f["entry"].attrs.keys() + except AssertionError: + return False + return True +def _is_abTEM(filepath): + """ + Check if an h5 file is an abTEM file. + """ + with h5py.File(filepath, "r") as f: + try: + assert "array" in f.keys() + except AssertionError: + return False + return True diff --git a/py4DSTEM/io/read.py b/py4DSTEM/io/read.py index 79291fe86..bab555eaf 100644 --- a/py4DSTEM/io/read.py +++ b/py4DSTEM/io/read.py @@ -1,25 +1,23 @@ # Reader for native files -from pathlib import Path -from os.path import exists -from typing import Optional,Union import warnings +from os.path import exists +from pathlib import Path +from typing import Optional, Union -import py4DSTEM import emdfile as emd -from py4DSTEM.io.parsefiletype import _parse_filetype import py4DSTEM.io.legacy as legacy - - +from py4DSTEM.data import Data +from py4DSTEM.io.parsefiletype import _parse_filetype def read( - filepath: Union[str,Path], + filepath: Union[str, Path], datapath: Optional[str] = None, - tree: Optional[Union[bool,str]] = True, + tree: Optional[Union[bool, str]] = True, verbose: Optional[bool] = False, **kwargs, - ): +): """ A file reader for native py4DSTEM / EMD files. To read non-native formats, use `py4DSTEM.import_file`. @@ -66,53 +64,61 @@ def read( # parse filetype er1 = f"filepath must be a string or Path, not {type(filepath)}" er2 = f"specified filepath '{filepath}' does not exist" - assert(isinstance(filepath, (str,Path) )), er1 - assert(exists(filepath)), er2 + assert isinstance(filepath, (str, Path)), er1 + assert exists(filepath), er2 filetype = _parse_filetype(filepath) - assert filetype == "H5", f"`py4DSTEM.read` loads native HDF5 formatted files, but a file of type {filetype} was detected. Try loading it using py4DSTEM.import_file" + assert filetype in ( + "emd", + "legacy", + ), f"`py4DSTEM.read` loads native HDF5 formatted files, but a file of type {filetype} was detected. Try loading it using py4DSTEM.import_file" # support older `root` input if datapath is None: - if 'root' in kwargs: - datapath = kwargs['root'] + if "root" in kwargs: + datapath = kwargs["root"] # EMD 1.0 formatted files (py4DSTEM v0.14+) - if emd._is_EMD_file(filepath): + if filetype == "emd": + + # check version version = emd._get_EMD_version(filepath) - if verbose: print(f"EMD version {version[0]}.{version[1]}.{version[2]} detected. Reading...") - assert emd._version_is_geq(version,(1,0,0)), f"EMD version {version} detected. Expected version >= 1.0.0" + if verbose: + print( + f"EMD version {version[0]}.{version[1]}.{version[2]} detected. Reading..." + ) + assert emd._version_is_geq( + version, (1, 0, 0) + ), f"EMD version {version} detected. Expected version >= 1.0.0" # read - data = emd.read( - filepath, - emdpath = datapath, - tree = tree - ) + data = emd.read(filepath, emdpath=datapath, tree=tree) + if verbose: + print("Data was read from file. Adding calibration links...") # add calibration links - if isinstance(data,py4DSTEM.Data): + if isinstance(data, Data): with warnings.catch_warnings(): - warnings.simplefilter('ignore') + warnings.simplefilter("ignore") cal = data.calibration - elif isinstance(data,py4DSTEM.Root): + elif isinstance(data, emd.Root): try: - cal = data.metadata['calibration'] + cal = data.metadata["calibration"] except KeyError: cal = None else: cal = None if cal is not None: try: - root_treepath = cal['_root_treepath'] - target_paths = cal['_target_paths'] - del(cal._params['_target_paths']) + root_treepath = cal["_root_treepath"] + target_paths = cal["_target_paths"] + del cal._params["_target_paths"] for p in target_paths: try: - p = p.replace(root_treepath,'') + p = p.replace(root_treepath, "") d = data.root.tree(p) - cal.register_target( d ) - if hasattr(d,'setcal'): + cal.register_target(d) + if hasattr(d, "setcal"): d.setcal() except AssertionError: pass @@ -121,68 +127,70 @@ def read( cal.calibrate() # return - if verbose: print("Done.") + if verbose: + print("Done.") return data - # legacy py4DSTEM files (v <= 0.13) else: - assert legacy.is_py4DSTEM_file(filepath), "path points to an H5 file which is neither an EMD 1.0+ file, nor a recognized legacy py4DSTEM file." - + assert ( + filetype == "legacy" + ), "path points to an H5 file which is neither an EMD 1.0+ file, nor a recognized legacy py4DSTEM file." # read v13 if legacy.is_py4DSTEM_version13(filepath): # load the data - if verbose: print(f"Legacy py4DSTEM version 13 file detected. Reading...") - kwargs['root'] = datapath - kwargs['tree'] = tree + if verbose: + print("Legacy py4DSTEM version 13 file detected. Reading...") + kwargs["root"] = datapath + kwargs["tree"] = tree data = legacy.read_legacy13( filepath=filepath, **kwargs, ) - if verbose: print("Done.") + if verbose: + print("Done.") return data - # read <= v12 else: # parse the root/data_id from the datapath arg if datapath is not None: - datapath = datapath.split('/') + datapath = datapath.split("/") try: - datapath.remove('') + datapath.remove("") except ValueError: pass rootgroup = datapath[0] - if len(datapath)>1: - datapath = '/'.join(rootgroup[1:]) + if len(datapath) > 1: + datapath = "/".join(rootgroup[1:]) else: datapath = None else: rootgroups = legacy.get_py4DSTEM_topgroups(filepath) - if len(rootgroups)>1: - print('multiple root groups in a legacy file found - returning list of root names; please pass one as `datapath`') + if len(rootgroups) > 1: + print( + "multiple root groups in a legacy file found - returning list of root names; please pass one as `datapath`" + ) return rootgroups - elif len(rootgroups)==0: - raise Exception('No rootgroups found') + elif len(rootgroups) == 0: + raise Exception("No rootgroups found") else: rootgroup = rootgroups[0] datapath = None - # load the data - if verbose: print(f"Legacy py4DSTEM version <= 12 file detected. Reading...") - kwargs['topgroup'] = rootgroup + if verbose: + print("Legacy py4DSTEM version <= 12 file detected. Reading...") + kwargs["topgroup"] = rootgroup if datapath is not None: - kwargs['data_id'] = datapath + kwargs["data_id"] = datapath data = legacy.read_legacy12( filepath=filepath, **kwargs, ) - if verbose: print("Done.") + if verbose: + print("Done.") return data - - - diff --git a/py4DSTEM/preprocess/preprocess.py b/py4DSTEM/preprocess/preprocess.py index f1052d72d..4001f80cb 100644 --- a/py4DSTEM/preprocess/preprocess.py +++ b/py4DSTEM/preprocess/preprocess.py @@ -283,6 +283,7 @@ def bin_data_diffraction( # set calibration pixel size datacube.calibration.set_Q_pixel_size(Qpixsize) + # return return datacube @@ -647,14 +648,9 @@ def resample_data_diffraction( datacube.data = fourier_resample( datacube.data, scale=resampling_factor, output_size=output_size ) - - if not resampling_factor: - resampling_factor = old_size[2] / output_size[0] - if datacube.calibration.get_Q_pixel_size() is not None: - datacube.calibration.set_Q_pixel_size(datacube.calibration.get_Q_pixel_size() / resampling_factor) if not resampling_factor: - resampling_factor = old_size[2] / output_size[0] + resampling_factor = output_size[0] / old_size[2] if datacube.calibration.get_Q_pixel_size() is not None: datacube.calibration.set_Q_pixel_size(datacube.calibration.get_Q_pixel_size() / resampling_factor) diff --git a/py4DSTEM/process/calibration/origin.py b/py4DSTEM/process/calibration/origin.py index 8821b43d0..19d2f0c55 100644 --- a/py4DSTEM/process/calibration/origin.py +++ b/py4DSTEM/process/calibration/origin.py @@ -134,14 +134,14 @@ def fit_origin( # Fit data if mask is None: - popt_x, pcov_x, qx0_fit = fit_2D( + popt_x, pcov_x, qx0_fit, _ = fit_2D( f, qx0_meas, robust=robust, robust_steps=robust_steps, robust_thresh=robust_thresh, ) - popt_y, pcov_y, qy0_fit = fit_2D( + popt_y, pcov_y, qy0_fit, _ = fit_2D( f, qy0_meas, robust=robust, @@ -150,7 +150,7 @@ def fit_origin( ) else: - popt_x, pcov_x, qx0_fit = fit_2D( + popt_x, pcov_x, qx0_fit, _ = fit_2D( f, qx0_meas, robust=robust, @@ -158,7 +158,7 @@ def fit_origin( robust_thresh=robust_thresh, data_mask=mask == True, ) - popt_y, pcov_y, qy0_fit = fit_2D( + popt_y, pcov_y, qy0_fit, _ = fit_2D( f, qy0_meas, robust=robust, @@ -359,4 +359,3 @@ def get_origin_beamstop(datacube: DataCube, mask: np.ndarray, **kwargs): return qx0, qy0 - diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 228602692..4d4d4a248 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -4,7 +4,6 @@ import matplotlib.pyplot as plt from fractions import Fraction from typing import Union, Optional -from copy import deepcopy from scipy.optimize import curve_fit import sys @@ -78,6 +77,7 @@ def __init__( 1 number: the lattice parameter for a cubic cell 3 numbers: the three lattice parameters for an orthorhombic cell 6 numbers: the a,b,c lattice parameters and É‘,β,É£ angles for any cell + 3x3 array: row vectors containing the (u,v,w) lattice vectors. """ # Initialize Crystal @@ -92,7 +92,10 @@ def __init__( else: raise Exception("Number of positions and atomic numbers do not match") - # unit cell, as either [a a a 90 90 90], [a b c 90 90 90], or [a b c alpha beta gamma] + # unit cell, as one of: + # [a a a 90 90 90] + # [a b c 90 90 90] + # [a b c alpha beta gamma] cell = np.asarray(cell, dtype="float_") if np.size(cell) == 1: self.cell = np.hstack([cell, cell, cell, 90, 90, 90]) @@ -100,34 +103,48 @@ def __init__( self.cell = np.hstack([cell, 90, 90, 90]) elif np.size(cell) == 6: self.cell = cell + elif np.shape(cell)[0] == 3 and np.shape(cell)[1] == 3: + self.lat_real = np.array(cell) + a = np.linalg.norm(self.lat_real[0,:]) + b = np.linalg.norm(self.lat_real[1,:]) + c = np.linalg.norm(self.lat_real[2,:]) + alpha = np.rad2deg(np.arccos(np.clip(np.sum( + self.lat_real[1,:]*self.lat_real[2,:])/b/c,-1,1))) + beta = np.rad2deg(np.arccos(np.clip(np.sum( + self.lat_real[0,:]*self.lat_real[2,:])/a/c,-1,1))) + gamma = np.rad2deg(np.arccos(np.clip(np.sum( + self.lat_real[0,:]*self.lat_real[1,:])/a/b,-1,1))) + self.cell = (a,b,c,alpha,beta,gamma) else: - raise Exception("Cell cannot contain " + np.size(cell) + " elements") + raise Exception("Cell cannot contain " + np.size(cell) + " entries") # pymatgen flag self.pymatgen_available = False # Calculate lattice parameters self.calculate_lattice() - + def calculate_lattice(self): - # calculate unit cell lattice vectors - a = self.cell[0] - b = self.cell[1] - c = self.cell[2] - alpha = np.deg2rad(self.cell[3]) - beta = np.deg2rad(self.cell[4]) - gamma = np.deg2rad(self.cell[5]) - f = np.cos(beta) * np.cos(gamma) - np.cos(alpha) - vol = a*b*c*np.sqrt(1 \ - + 2*np.cos(alpha)*np.cos(beta)*np.cos(gamma) \ - - np.cos(alpha)**2 - np.cos(beta)**2 - np.cos(gamma)**2) - self.lat_real = np.array( - [ - [a, 0, 0], - [b*np.cos(gamma), b*np.sin(gamma), 0], - [c*np.cos(beta), -c*f/np.sin(gamma), vol/(a*b*np.sin(gamma))], - ] - ) + + if not hasattr(self, 'lat_real'): + # calculate unit cell lattice vectors + a = self.cell[0] + b = self.cell[1] + c = self.cell[2] + alpha = np.deg2rad(self.cell[3]) + beta = np.deg2rad(self.cell[4]) + gamma = np.deg2rad(self.cell[5]) + f = np.cos(beta) * np.cos(gamma) - np.cos(alpha) + vol = a*b*c*np.sqrt(1 \ + + 2*np.cos(alpha)*np.cos(beta)*np.cos(gamma) \ + - np.cos(alpha)**2 - np.cos(beta)**2 - np.cos(gamma)**2) + self.lat_real = np.array( + [ + [a, 0, 0], + [b*np.cos(gamma), b*np.sin(gamma), 0], + [c*np.cos(beta), -c*f/np.sin(gamma), vol/(a*b*np.sin(gamma))], + ] + ) # Inverse lattice, metric tensors self.metric_real = self.lat_real @ self.lat_real.T @@ -139,6 +156,49 @@ def calculate_lattice(self): self.pymatgen_available = True else: self.pymatgen_available = False + + def get_strained_crystal( + self, + exx = 0.0, + eyy = 0.0, + ezz = 0.0, + exy = 0.0, + exz = 0.0, + eyz = 0.0, + deformation_matrix = None, + return_deformation_matrix = False, + ): + """ + This method returns new Crystal class with strain applied. The directions of (x,y,z) + are with respect to the default Crystal orientation, which can be checked with + print(Crystal.lat_real) applied to the original Crystal. + + Strains are given in fractional values, so exx = 0.01 is 1% strain along the x direction. + """ + + # deformation matrix + if deformation_matrix is None: + deformation_matrix = np.array([ + [1.0+exx, 1.0*exy, 1.0*exz], + [1.0*exy, 1.0+eyy, 1.0*eyz], + [1.0*exz, 1.0*eyz, 1.0+ezz], + ]) + + # new unit cell + lat_new = self.lat_real @ deformation_matrix + + # make new crystal class + from py4DSTEM.process.diffraction import Crystal + crystal_strained = Crystal( + positions = self.positions.copy(), + numbers = self.numbers.copy(), + cell = lat_new, + ) + + if return_deformation_matrix: + return crystal_strained, deformation_matrix + else: + return crystal_strained def from_CIF(CIF, conventional_standard_structure=True): @@ -386,13 +446,28 @@ def calculate_structure_factors( k_max: float = 2.0, tol_structure_factor: float = 1e-4, return_intensities: bool = False, - ): + ): + + """ Calculate structure factors for all hkl indices up to max scattering vector k_max - Args: - k_max (numpy float): max scattering vector to include (1/Angstroms) - tol_structure_factor (numpy float): tolerance for removing low-valued structure factors + Parameters + -------- + + k_max: float + max scattering vector to include (1/Angstroms) + tol_structure_factor: float + tolerance for removing low-valued structure factors + return_intensities: bool + return the intensities and positions of all structure factor peaks. + + Returns + -------- + (q_SF, I_SF) + Tuple of the q vectors and intensities of each structure factor. + + """ # Store k_max @@ -425,7 +500,7 @@ def calculate_structure_factors( hkl = np.vstack([xa.ravel(), ya.ravel(), za.ravel()]) # g_vec_all = self.lat_inv @ hkl g_vec_all = (hkl.T @ self.lat_inv).T - + # Delete lattice vectors outside of k_max keep = np.linalg.norm(g_vec_all, axis=0) <= self.k_max self.hkl = hkl[:, keep] @@ -898,12 +973,25 @@ def calculate_bragg_peak_histogram( k = np.arange(k_min, k_max + k_step, k_step) k_num = k.shape[0] - # experimental data histogram + # set rotate and ellipse based on their availability + rotate = bragg_peaks.calibration.get_QR_rotation_degrees() + ellipse = bragg_peaks.calibration.get_ellipse() + rotate = False if rotate is None else True + ellipse = False if ellipse is None else True + + # concatenate all peaks bigpl = np.concatenate( [ - bragg_peaks.cal[i, j].data - for i in range(bragg_peaks.shape[0]) - for j in range(bragg_peaks.shape[1]) + bragg_peaks.get_vectors( + rx, + ry, + center = True, + ellipse = ellipse, + pixel = True, + rotate = rotate, + ).data + for rx in range(bragg_peaks.shape[0]) + for ry in range(bragg_peaks.shape[1]) ] ) qr = np.sqrt(bigpl["qx"] ** 2 + bigpl["qy"] ** 2) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 4d02dcb0b..bffa5b620 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -8,6 +8,8 @@ from py4DSTEM.process.diffraction.utils import Orientation, OrientationMap, axisEqual3D from py4DSTEM.process.utils import electron_wavelength_angstrom +from warnings import warn + from numpy.linalg import lstsq try: import cupy as cp @@ -767,6 +769,18 @@ def match_orientations( num_x=bragg_peaks_array.shape[0], num_y=bragg_peaks_array.shape[1], num_matches=num_matches_return) + + #check cal state + if bragg_peaks_array.calstate['ellipse'] == False: + ellipse = False + warn('Warning: bragg peaks not elliptically calibrated') + else: + ellipse = True + if bragg_peaks_array.calstate['rotate'] == False: + rotate = False + warn('bragg peaks not rotationally calibrated') + else: + rotate = True for rx, ry in tqdmnd( *bragg_peaks_array.shape, @@ -774,9 +788,17 @@ def match_orientations( unit=" PointList", disable=not progress_bar, ): + vectors = bragg_peaks_array.get_vectors( + scan_x=rx, + scan_y=ry, + center=True, + ellipse=ellipse, + pixel=True, + rotate=rotate + ) orientation = self.match_single_pattern( - bragg_peaks_array.cal[rx, ry], + bragg_peaks=vectors, num_matches_return=num_matches_return, min_number_peaks=min_number_peaks, inversion_symmetry=inversion_symmetry, @@ -1639,6 +1661,18 @@ def calculate_strain( corr_kernel_size = self.orientation_kernel_size radius_max_2 = corr_kernel_size**2 + #check cal state + if bragg_peaks_array.calstate['ellipse'] == False: + ellipse = False + warn('bragg peaks not elliptically calibrated') + else: + ellipse = True + if bragg_peaks_array.calstate['rotate'] == False: + rotate = False + warn('bragg peaks not rotationally calibrated') + else: + rotate = True + # Loop over all probe positions for rx, ry in tqdmnd( *bragg_peaks_array.shape, @@ -1647,7 +1681,14 @@ def calculate_strain( disable=not progress_bar, ): # Get bragg peaks from experiment and reference - p = bragg_peaks_array.cal[rx,ry] + p = bragg_peaks_array.get_vectors( + scan_x=rx, + scan_y=ry, + center=True, + ellipse=ellipse, + pixel=True, + rotate=rotate + ) if p.data.shape[0] >= min_num_peaks: p_ref = self.generate_diffraction_pattern( @@ -2070,5 +2111,4 @@ def symmetry_reduce_directions( } # "-3m": ["fiber", [0, 0, 1], [90.0, 60.0]], - # "-3m": ["fiber", [0, 0, 1], [180.0, 30.0]], - + # "-3m": ["fiber", [0, 0, 1], [180.0, 30.0]], \ No newline at end of file diff --git a/py4DSTEM/process/diffraction/crystal_calibrate.py b/py4DSTEM/process/diffraction/crystal_calibrate.py index 2d08cd03c..1b65480f5 100644 --- a/py4DSTEM/process/diffraction/crystal_calibrate.py +++ b/py4DSTEM/process/diffraction/crystal_calibrate.py @@ -24,7 +24,7 @@ def calibrate_pixel_size( k_step = 0.002, k_broadening = 0.002, fit_all_intensities = True, - set_calibration = True, + set_calibration_in_place = False, verbose = True, plot_result = False, figsize: Union[list, tuple, np.ndarray] = (12, 6), @@ -60,8 +60,13 @@ def calibrate_pixel_size( figsize (list, tuple, np.ndarray): Figure size of the plot. returnfig (bool): Return handles figure and axis - Returns: - fig, ax (handles): Optional figure and axis handles, if returnfig=True. + Returns + _______ + + + + fig, ax: handles, optional + Figure and axis handles, if returnfig=True. """ @@ -112,17 +117,21 @@ def fit_profile(k, *coefs): # Get the answer pix_size_prev = bragg_peaks.calibration.get_Q_pixel_size() - ans = pix_size_prev / scale_pixel_size + pixel_size_new = pix_size_prev / scale_pixel_size - # if requested, apply calibrations - if set_calibration: - bragg_peaks.calibration.set_Q_pixel_size( ans ) + # if requested, apply calibrations in place + if set_calibration_in_place: + bragg_peaks.calibration.set_Q_pixel_size( pixel_size_new ) bragg_peaks.calibration.set_Q_pixel_units('A^-1') - bragg_peaks.setcal() - # Output + # Output calibrated Bragg peaks + bragg_peaks_cali = bragg_peaks.copy() + bragg_peaks_cali.calibration.set_Q_pixel_size( pixel_size_new ) + bragg_peaks_cali.calibration.set_Q_pixel_units('A^-1') + + # Output pixel size if verbose: - print(f"Calibrated pixel size = {np.round(ans, decimals=8)} A^-1") + print(f"Calibrated pixel size = {np.round(pixel_size_new, decimals=8)} A^-1") # Plotting if plot_result: @@ -163,9 +172,9 @@ def fit_profile(k, *coefs): # return if returnfig and plot_result: - return ans, (fig,ax) + return bragg_peaks_cali, (fig,ax) else: - return ans + return bragg_peaks_cali @@ -463,4 +472,4 @@ def fitfun(self, k, *coefs_fit): "432": [[0, 0, 0, 3, 3, 3], [True, True, True, False, False, False]], #Cubic "-43m": [[0, 0, 0, 3, 3, 3], [True, True, True, False, False, False]], #Cubic "m-3m": [[0, 0, 0, 3, 3, 3], [True, True, True, False, False, False]], #Cubic - } + } \ No newline at end of file diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 9c1f5b667..a9420fee4 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -296,30 +296,48 @@ def plot_scattering_intensity( bragg_k_power=0.0, bragg_intensity_power=1.0, bragg_k_broadening=0.005, - figsize: Union[list, tuple, np.ndarray] = (12, 6), + figsize: Union[list, tuple, np.ndarray] = (10, 4), returnfig: bool = False, ): """ 1D plot of the structure factors - Args: - k_min (float): min k value for profile range. - k_max (float): max k value for profile range. - k_step (float): step size of k in profile range. - k_broadening (float): Broadening of simulated pattern. - k_power_scale (float): Scale SF intensities by k**k_power_scale. - int_power_scale (float): Scale SF intensities**int_power_scale. - int_scale (float): Scale output profile by this value. - remove_origin (bool): Remove origin from plot. - bragg_peaks (BraggVectors): Passed in bragg_peaks for comparison with simulated pattern. - bragg_k_power (float): bragg_peaks scaled by k**bragg_k_power. - bragg_intensity_power (float): bragg_peaks scaled by intensities**bragg_intensity_power. - bragg_k_broadening float): Broadening applied to bragg_peaks. - figsize (list, tuple, np.ndarray): Figure size for plot. - returnfig (bool): Return figure and axes handles if this is True. - - Returns: - fig, ax (optional) figure and axes handles + Parameters + -------- + + k_min: float + min k value for profile range. + k_max: float + max k value for profile range. + k_step: float + Step size of k in profile range. + k_broadening: float + Broadening of simulated pattern. + k_power_scale: float + Scale SF intensities by k**k_power_scale. + int_power_scale: float + Scale SF intensities**int_power_scale. + int_scale: float + Scale output profile by this value. + remove_origin: bool + Remove origin from plot. + bragg_peaks: BraggVectors + Passed in bragg_peaks for comparison with simulated pattern. + bragg_k_power: float + bragg_peaks scaled by k**bragg_k_power. + bragg_intensity_power: float + bragg_peaks scaled by intensities**bragg_intensity_power. + bragg_k_broadening: float + Broadening applied to bragg_peaks. + figsize: list, tuple, np.ndarray + Figure size for plot. + returnfig (bool): + Return figure and axes handles if this is True. + + Returns + -------- + fig, ax (optional) + figure and axes handles """ # k coordinates @@ -342,12 +360,25 @@ def plot_scattering_intensity( # If Bragg peaks are passed in, compute 1D integral if bragg_peaks is not None: + # set rotate and ellipse based on their availability + rotate = bragg_peaks.calibration.get_QR_rotation_degrees() + ellipse = bragg_peaks.calibration.get_ellipse() + rotate = False if rotate is None else True + ellipse = False if ellipse is None else True + # concatenate all peaks bigpl = np.concatenate( [ - bragg_peaks.cal[i, j].data - for i in range(bragg_peaks.shape[0]) - for j in range(bragg_peaks.shape[1]) + bragg_peaks.get_vectors( + rx, + ry, + center = True, + ellipse = ellipse, + pixel = True, + rotate = rotate, + ).data + for rx in range(bragg_peaks.shape[0]) + for ry in range(bragg_peaks.shape[1]) ] ) @@ -903,6 +934,9 @@ def plot_diffraction_pattern( ax.set_ylabel("$q_x$ [Ă…$^{-1}$]") if plot_range_kx_ky is not None: + plot_range_kx_ky = np.array(plot_range_kx_ky) + if plot_range_kx_ky.ndim == 0: + plot_range_kx_ky = np.array((plot_range_kx_ky,plot_range_kx_ky)) ax.set_xlim((-plot_range_kx_ky[0], plot_range_kx_ky[0])) ax.set_ylim((-plot_range_kx_ky[1], plot_range_kx_ky[1])) else: @@ -1846,4 +1880,4 @@ def plot_ring_pattern( plt.show() if returnfig: - return fig, ax + return fig, ax \ No newline at end of file diff --git a/py4DSTEM/process/diffraction/flowlines.py b/py4DSTEM/process/diffraction/flowlines.py index 27d4f9381..cf84f69f5 100644 --- a/py4DSTEM/process/diffraction/flowlines.py +++ b/py4DSTEM/process/diffraction/flowlines.py @@ -519,6 +519,7 @@ def make_flowline_rainbow_image( power_scaling = 1.0, sum_radial_bins = False, plot_images = True, + figsize = None, ): """ Generate RGB output images from the flowline arrays. @@ -535,6 +536,7 @@ def make_flowline_rainbow_image( power_scaling (float): Power law scaling for flowline intensity output. sum_radial_bins (bool): Sum all radial bins (alternative is to output separate images). plot_images (bool): Plot the outputs for quick visualization. + figsize (2-tuple): Size of output figure. Returns: im_flowline (array): 3D or 4D array containing flowline images @@ -613,7 +615,14 @@ def make_flowline_rainbow_image( im_flowline = np.min(im_flowline,axis=0)[None,:,:,:] if plot_images is True: - fig,ax = plt.subplots(im_flowline.shape[0],1,figsize=(10,im_flowline.shape[0]*10)) + if figsize is None: + fig,ax = plt.subplots( + im_flowline.shape[0],1, + figsize=(10,im_flowline.shape[0]*10)) + else: + fig,ax = plt.subplots( + im_flowline.shape[0],1, + figsize=figsize) if im_flowline.shape[0] > 1: for a0 in range(im_flowline.shape[0]): @@ -729,6 +738,7 @@ def make_flowline_combined_image( power_scaling = 1.0, sum_radial_bins = True, plot_images = True, + figsize = None, ): """ Generate RGB output images from the flowline arrays. @@ -742,6 +752,7 @@ def make_flowline_combined_image( power_scaling (float): Power law scaling for flowline intensities. sum_radial_bins (bool): Sum outputs over radial bins. plot_images (bool): Plot the output images for quick visualization. + figsize (2-tuple): Size of output figure. Returns: im_flowline (array): flowline images @@ -787,7 +798,14 @@ def make_flowline_combined_image( if plot_images is True: - fig,ax = plt.subplots(im_flowline.shape[0],1,figsize=(10,im_flowline.shape[0]*10)) + if figsize is None: + fig,ax = plt.subplots( + im_flowline.shape[0],1, + figsize=(10,im_flowline.shape[0]*10)) + else: + fig,ax = plt.subplots( + im_flowline.shape[0],1, + figsize=figsize) if im_flowline.shape[0] > 1: for a0 in range(im_flowline.shape[0]): @@ -1143,4 +1161,4 @@ def set_intensity(orient,xy_t_int): mode=['clip','clip','wrap']) orient.ravel()[inds_1D] = orient.ravel()[inds_1D] + xy_t_int[:,3]*( dx)*( dy)*( dt) - return orient + return orient \ No newline at end of file diff --git a/py4DSTEM/process/fit/fit.py b/py4DSTEM/process/fit/fit.py index 9abb713f7..32809ddb1 100644 --- a/py4DSTEM/process/fit/fit.py +++ b/py4DSTEM/process/fit/fit.py @@ -269,7 +269,3 @@ def fit_2D_polar_gaussian( robust_steps = robust_steps, robust_thresh = robust_thresh ) - - - - diff --git a/py4DSTEM/process/latticevectors/fit.py b/py4DSTEM/process/latticevectors/fit.py index 822ffdea5..fef72aca3 100644 --- a/py4DSTEM/process/latticevectors/fit.py +++ b/py4DSTEM/process/latticevectors/fit.py @@ -104,13 +104,22 @@ def fit_lattice_vectors_all_DPs(braggpeaks, x0=0, y0=0, minNumPeaks=5): # Make RealSlice to contain outputs slicelabels = ('x0','y0','g1x','g1y','g2x','g2y','error','mask') - g1g2_map = RealSlice(data=np.zeros((braggpeaks.shape[0],braggpeaks.shape[1],8)), - slicelabels=slicelabels, name='g1g2_map') + g1g2_map = RealSlice( + data=np.zeros( + (8, braggpeaks.shape[0],braggpeaks.shape[1]) + ), + slicelabels=slicelabels, name='g1g2_map' + ) # Fit lattice vectors for (Rx, Ry) in tqdmnd(braggpeaks.shape[0],braggpeaks.shape[1]): braggpeaks_curr = braggpeaks.get_pointlist(Rx,Ry) - qx0,qy0,g1x,g1y,g2x,g2y,error = fit_lattice_vectors(braggpeaks_curr, x0, y0, minNumPeaks) + qx0,qy0,g1x,g1y,g2x,g2y,error = fit_lattice_vectors( + braggpeaks_curr, + x0, + y0, + minNumPeaks + ) # Store data if g1x is not None: g1g2_map.get_slice('x0').data[Rx,Ry] = qx0 diff --git a/py4DSTEM/process/latticevectors/index.py b/py4DSTEM/process/latticevectors/index.py index cdf6b00fd..189e7f10f 100644 --- a/py4DSTEM/process/latticevectors/index.py +++ b/py4DSTEM/process/latticevectors/index.py @@ -80,6 +80,9 @@ def index_bragg_directions(x0, y0, gx, gy, g1, g2): temp_array = np.zeros([], dtype = coords) bragg_directions = PointList(data = temp_array) bragg_directions.add_data_by_field((gx,gy,h,k)) + mask = np.zeros(bragg_directions['qx'].shape[0]) + mask[0] = 1 + bragg_directions.remove(mask) return h,k, bragg_directions @@ -152,8 +155,14 @@ def generate_lattice(ux,uy,vx,vy,x0,y0,Q_Nx,Q_Ny,h_max=None,k_max=None): return ideal_lattice -def add_indices_to_braggpeaks(braggpeaks, lattice, maxPeakSpacing, qx_shift=0, - qy_shift=0, mask=None): +def add_indices_to_braggvectors( + braggpeaks, + lattice, + maxPeakSpacing, + qx_shift=0, + qy_shift=0, + mask=None + ): """ Using the peak positions (qx,qy) and indices (h,k) in the PointList lattice, identify the indices for each peak in the PointListArray braggpeaks. @@ -181,43 +190,41 @@ def add_indices_to_braggpeaks(braggpeaks, lattice, maxPeakSpacing, qx_shift=0, 'h', 'k', containing the indices of each indexable peak. """ - assert isinstance(braggpeaks,PointListArray) - assert np.all([name in braggpeaks.dtype.names for name in ('qx','qy','intensity')]) - assert isinstance(lattice, PointList) - assert np.all([name in lattice.dtype.names for name in ('qx','qy','h','k')]) + # assert isinstance(braggpeaks,BraggVectors) + # assert isinstance(lattice, PointList) + # assert np.all([name in lattice.dtype.names for name in ('qx','qy','h','k')]) if mask is None: - mask = np.ones(braggpeaks.shape,dtype=bool) + mask = np.ones(braggpeaks.Rshape,dtype=bool) - assert mask.shape == braggpeaks.shape, 'mask must have same shape as pointlistarray' + assert mask.shape == braggpeaks.Rshape, 'mask must have same shape as pointlistarray' assert mask.dtype == bool, 'mask must be boolean' - indexed_braggpeaks = braggpeaks.copy() - # add the coordinates if they don't exist - if not ('h' in braggpeaks.dtype.names): - indexed_braggpeaks = indexed_braggpeaks.add_fields([('h',int)]) - if not ('k' in braggpeaks.dtype.names): - indexed_braggpeaks = indexed_braggpeaks.add_fields([('k',int)]) + coords = [('qx',float),('qy',float),('intensity',float),('h',int),('k',int)] + + indexed_braggpeaks = PointListArray( + dtype = coords, + shape = braggpeaks.Rshape, + ) # loop over all the scan positions for Rx, Ry in tqdmnd(mask.shape[0],mask.shape[1]): - if mask[Rx,Ry]: - pl = indexed_braggpeaks.get_pointlist(Rx,Ry) - rm_peak_mask = np.zeros(pl.length,dtype=bool) - - for i in range(pl.length): + if mask[Rx,Ry]: + pl = braggpeaks.cal[Rx,Ry] + for i in range(pl.data.shape[0]): r2 = (pl.data['qx'][i]-lattice.data['qx'] + qx_shift)**2 + \ (pl.data['qy'][i]-lattice.data['qy'] + qy_shift)**2 ind = np.argmin(r2) if r2[ind] <= maxPeakSpacing**2: - pl.data['h'][i] = lattice.data['h'][ind] - pl.data['k'][i] = lattice.data['k'][ind] - else: - rm_peak_mask[i] = True - pl.remove(rm_peak_mask) + indexed_braggpeaks[Rx,Ry].add_data_by_field(( + pl.data['qx'][i], + pl.data['qy'][i], + pl.data['intensity'][i], + lattice.data['h'][ind], + lattice.data['k'][ind] + )) - indexed_braggpeaks.name = braggpeaks.name + "_indexed" return indexed_braggpeaks diff --git a/py4DSTEM/process/latticevectors/strain.py b/py4DSTEM/process/latticevectors/strain.py index 50b9bddc9..7a586bd69 100644 --- a/py4DSTEM/process/latticevectors/strain.py +++ b/py4DSTEM/process/latticevectors/strain.py @@ -71,9 +71,11 @@ def get_strain_from_reference_g1g2(g1g2_map, g1, g2): # Get RealSlice for output storage R_Nx,R_Ny = g1g2_map.get_slice('g1x').shape - strain_map = RealSlice(data=np.zeros((R_Nx,R_Ny,5)), - slicelabels=('e_xx','e_yy','e_xy','theta','mask'), - name='strain_map') + strain_map = RealSlice( + data=np.zeros((5, R_Nx, R_Ny)), + slicelabels=('e_xx','e_yy','e_xy','theta','mask'), + name='strain_map' + ) # Get reference lattice matrix g1x,g1y = g1 @@ -130,7 +132,8 @@ def get_strain_from_reference_region(g1g2_map, mask): Note 1: the strain matrix has been symmetrized, so e_xy and e_yx are identical """ assert isinstance(g1g2_map, RealSlice) - assert np.all([name in g1g2_map.slicelabels for name in ('g1x','g1y','g2x','g2y','mask')]) + assert np.all( + [name in g1g2_map.slicelabels for name in ('g1x','g1y','g2x','g2y','mask')]) assert mask.dtype == bool g1,g2 = get_reference_g1g2(g1g2_map,mask) @@ -169,18 +172,20 @@ def get_rotated_strain_map(unrotated_strain_map, xaxis_x, xaxis_y, flip_theta): sint2 = sint**2 Rx,Ry = unrotated_strain_map.get_slice('e_xx').data.shape - rotated_strain_map = RealSlice(data=np.zeros((Rx,Ry,5)), - slicelabels=['e_xx','e_xy','e_yy','theta','mask'], - name=unrotated_strain_map.name+"_rotated".format(np.degrees(theta))) - - rotated_strain_map.data[:,:,0] = cost2*unrotated_strain_map.get_slice('e_xx').data - 2*cost*sint*unrotated_strain_map.get_slice('e_xy').data + sint2*unrotated_strain_map.get_slice('e_yy').data - rotated_strain_map.data[:,:,1] = cost*sint*(unrotated_strain_map.get_slice('e_xx').data-unrotated_strain_map.get_slice('e_yy').data) + (cost2-sint2)*unrotated_strain_map.get_slice('e_xy').data - rotated_strain_map.data[:,:,2] = sint2*unrotated_strain_map.get_slice('e_xx').data + 2*cost*sint*unrotated_strain_map.get_slice('e_xy').data + cost2*unrotated_strain_map.get_slice('e_yy').data + rotated_strain_map = RealSlice( + data=np.zeros((5, Rx,Ry)), + slicelabels=['e_xx','e_xy','e_yy','theta','mask'], + name=unrotated_strain_map.name+"_rotated".format(np.degrees(theta)) + ) + + rotated_strain_map.data[0,:,:] = cost2*unrotated_strain_map.get_slice('e_xx').data - 2*cost*sint*unrotated_strain_map.get_slice('e_xy').data + sint2*unrotated_strain_map.get_slice('e_yy').data + rotated_strain_map.data[1,:,:] = cost*sint*(unrotated_strain_map.get_slice('e_xx').data-unrotated_strain_map.get_slice('e_yy').data) + (cost2-sint2)*unrotated_strain_map.get_slice('e_xy').data + rotated_strain_map.data[2,:,:] = sint2*unrotated_strain_map.get_slice('e_xx').data + 2*cost*sint*unrotated_strain_map.get_slice('e_xy').data + cost2*unrotated_strain_map.get_slice('e_yy').data if flip_theta == True: - rotated_strain_map.data[:,:,3] = -unrotated_strain_map.get_slice('theta').data + rotated_strain_map.data[3,:,:] = -unrotated_strain_map.get_slice('theta').data else: - rotated_strain_map.data[:,:,3] = unrotated_strain_map.get_slice('theta').data - rotated_strain_map.data[:,:,4] = unrotated_strain_map.get_slice('mask').data + rotated_strain_map.data[3,:,:] = unrotated_strain_map.get_slice('theta').data + rotated_strain_map.data[4,:,:] = unrotated_strain_map.get_slice('mask').data return rotated_strain_map diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 1e9bd2cbb..92f8c0bf3 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -974,7 +974,7 @@ def _gradient_descent_adjoint( ) # back-transmit - exit_waves *= xp.conj(obj) / xp.abs(obj) ** 2 + exit_waves *= xp.conj(obj) #/ xp.abs(obj) ** 2 if s > 0: # back-propagate @@ -1076,7 +1076,7 @@ def _projection_sets_adjoint( ) # back-transmit - exit_waves_copy *= xp.conj(obj) / xp.abs(obj) ** 2 + exit_waves_copy *= xp.conj(obj) # / xp.abs(obj) ** 2 if s > 0: # back-propagate @@ -3067,4 +3067,4 @@ def _return_object_fft( obj = np.angle(obj) obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) - return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) + return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) \ No newline at end of file diff --git a/py4DSTEM/process/polar/__init__.py b/py4DSTEM/process/polar/__init__.py index ddf0a9e50..79e13a054 100644 --- a/py4DSTEM/process/polar/__init__.py +++ b/py4DSTEM/process/polar/__init__.py @@ -1,3 +1,3 @@ from py4DSTEM.process.polar.polar_datacube import PolarDatacube from py4DSTEM.process.polar.polar_fits import fit_amorphous_ring, plot_amorphous_ring -from py4DSTEM.process.polar.polar_peaks import find_peaks_single_pattern, find_peaks, refine_peaks, plot_radial_peaks, plot_radial_background, make_orientation_histogram +from py4DSTEM.process.polar.polar_peaks import find_peaks_single_pattern, find_peaks, refine_peaks, plot_radial_peaks, plot_radial_background, make_orientation_histogram \ No newline at end of file diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 0a6089f4e..fa6a40a4f 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -6,10 +6,17 @@ from emdfile import tqdmnd +<<<<<<< Updated upstream def calculate_FEM_global( self, use_median_local = False, use_median_global = False, +======= +def calculate_radial_statistics( + self, + median_local = False, + median_global = False, +>>>>>>> Stashed changes plot_results = False, figsize = (8,4), returnval = False, @@ -42,22 +49,48 @@ def calculate_FEM_global( self.scattering_vector = self.radial_bins * self.qstep * self.calibration.get_Q_pixel_size() self.scattering_vector_units = self.calibration.get_Q_pixel_units() +<<<<<<< Updated upstream # init radial data array +======= + # init radial data arrays +>>>>>>> Stashed changes self.radial_all = np.zeros(( self._datacube.shape[0], self._datacube.shape[1], self.polar_shape[1], )) +<<<<<<< Updated upstream +======= + self.radial_all_std = np.zeros(( + self._datacube.shape[0], + self._datacube.shape[1], + self.polar_shape[1], + )) + +>>>>>>> Stashed changes # Compute the radial mean for each probe position for rx, ry in tqdmnd( self._datacube.shape[0], self._datacube.shape[1], +<<<<<<< Updated upstream desc="Global FEM", unit=" probe positions", disable=not progress_bar): self.radial_all[rx,ry] = np.mean(self.data[rx,ry],axis=0) +======= + desc="Radial statistics", + unit=" probe positions", + disable=not progress_bar): + + self.radial_all[rx,ry] = np.mean( + self.data[rx,ry], + axis=0) + self.radial_all_std[rx,ry] = np.sqrt(np.mean( + (self.data[rx,ry] - self.radial_all[rx,ry][None])**2, + axis=0)) +>>>>>>> Stashed changes self.radial_avg = np.mean(self.radial_all, axis=(0,1)) self.radial_var = np.mean( @@ -138,5 +171,18 @@ def calculate_FEM_local( """ +<<<<<<< Updated upstream 1+1 +======= + pass + + +# def radial_average( +# self, +# figsize = (8,6), +# returnfig = False, +# ): + + +>>>>>>> Stashed changes diff --git a/py4DSTEM/process/polar/polar_datacube.py b/py4DSTEM/process/polar/polar_datacube.py index 3f3db0eca..b6ef4ee66 100644 --- a/py4DSTEM/process/polar/polar_datacube.py +++ b/py4DSTEM/process/polar/polar_datacube.py @@ -19,7 +19,7 @@ def __init__( n_annular = 180, qscale = None, mask = None, - mask_thresh = 0.25, + mask_thresh = 0.1, ellipse = True, two_fold_symmetry = False, ): @@ -95,7 +95,12 @@ def __init__( pass from py4DSTEM.process.polar.polar_analysis import ( +<<<<<<< Updated upstream calculate_FEM_global, +======= + # calculate_FEM_global, + calculate_radial_statistics, +>>>>>>> Stashed changes plot_FEM_global, calculate_FEM_local, ) @@ -127,9 +132,9 @@ def set_radial_bins( self._qmax, self._qstep ) - self.qscale = self._qscale self._radial_step = self._datacube.calibration.get_Q_pixel_size() * self._qstep self.set_polar_shape() + self.qscale = self._qscale @property def qmin(self): @@ -241,7 +246,7 @@ def qscale(self): def qscale(self,x): self._qscale = x if x is not None: - self._qscale_ar = np.arange(self.polar_shape[1])**x + self._qscale_ar = (self.qq / self.qq[-1])**x # expose raw data @@ -453,7 +458,7 @@ def _transform( ) # scale the normalization array by the bin density - norm_array = ans_norm*self._polarcube._annular_bin_step[np.newaxis] + norm_array = ans_norm * self._polarcube._annular_bin_step[np.newaxis] mask_bool = norm_array < mask_thresh # apply normalization @@ -588,5 +593,4 @@ def __repr__(self): space = ' '*len(self.__class__.__name__)+' ' string = f"{self.__class__.__name__}( " string += "Retrieves the diffraction pattern at scan position (x,y) in polar coordinates when sliced with [x,y]." - return string - + return string \ No newline at end of file diff --git a/py4DSTEM/process/polar/polar_fits.py b/py4DSTEM/process/polar/polar_fits.py index 82085d6fb..e231dda07 100644 --- a/py4DSTEM/process/polar/polar_fits.py +++ b/py4DSTEM/process/polar/polar_fits.py @@ -7,11 +7,12 @@ def fit_amorphous_ring( im, - center, - radial_range, + center = None, + radial_range = None, coefs = None, mask_dp = None, show_fit_mask = False, + maxfev = None, verbose = False, plot_result = True, plot_log_scale = False, @@ -28,15 +29,19 @@ def fit_amorphous_ring( im: np.array 2D image array to perform fitting on center: np.array - (x,y) center coordinates for fitting mask + (x,y) center coordinates for fitting mask. If not specified + by the user, we will assume the center coordinate is (im.shape-1)/2. radial_range: np.array - (radius_inner, radius_outer) radial range to perform fitting over + (radius_inner, radius_outer) radial range to perform fitting over. + If not specified by the user, we will assume (im.shape[0]/4,im.shape[0]/2). coefs: np.array (optional) Array containing initial fitting coefficients for the amorphous fit. mask_dp: np.array Dark field mask for fitting, in addition to the radial range specified above. show_fit_mask: bool Set to true to preview the fitting mask and initial guess for the ellipse params + maxfev: int + Max number of fitting evaluations for curve_fit. verbose: bool Print fit results plot_result: bool @@ -58,6 +63,14 @@ def fit_amorphous_ring( 11 parameter elliptic fit coefficients """ + # Default values + if center is None: + center = np.array(( + (im.shape[0]-1)/2, + (im.shape[1]-1)/2)) + if radial_range is None: + radial_range = (im.shape[0]/4, im.shape[0]/2) + # coordinates xa,ya = np.meshgrid( np.arange(im.shape[0]), @@ -149,14 +162,26 @@ def fit_amorphous_ring( else: # Perform elliptic fitting int_mean = np.mean(vals) - coefs = curve_fit( - amorphous_model, - basis, - vals / int_mean, - p0=coefs, - xtol = 1e-12, - bounds = (lb,ub), - )[0] + + if maxfev is None: + coefs = curve_fit( + amorphous_model, + basis, + vals / int_mean, + p0=coefs, + xtol = 1e-8, + bounds = (lb,ub), + )[0] + else: + coefs = curve_fit( + amorphous_model, + basis, + vals / int_mean, + p0=coefs, + xtol = 1e-8, + bounds = (lb,ub), + maxfev = maxfev, + )[0] coefs[4] = np.mod(coefs[4],2*np.pi) coefs[5:8] *= int_mean # bounds=bounds @@ -356,4 +381,4 @@ def amorphous_model(basis, *coefs): sub = np.logical_not(sub) int_model[sub] += int12*np.exp(dr2[sub]/(-2*sigma2**2)) - return int_model + return int_model \ No newline at end of file diff --git a/py4DSTEM/process/polar/polar_peaks.py b/py4DSTEM/process/polar/polar_peaks.py index 3f367b398..6a6e0860a 100644 --- a/py4DSTEM/process/polar/polar_peaks.py +++ b/py4DSTEM/process/polar/polar_peaks.py @@ -5,7 +5,7 @@ from scipy.ndimage import gaussian_filter, gaussian_filter1d from scipy.signal import peak_prominences from skimage.feature import peak_local_max -from scipy.optimize import curve_fit +from scipy.optimize import curve_fit, leastsq import warnings # from emdfile import tqdmnd, PointList, PointListArray @@ -34,7 +34,8 @@ def find_peaks_single_pattern( return_background = False, plot_result = True, plot_power_scale = 1.0, - plot_scale_size = 100.0, + plot_scale_size = 10.0, + figsize = (12,6), returnfig = False, ): """ @@ -62,10 +63,41 @@ def find_peaks_single_pattern( radial_background_thresh: float Relative order of sorted values to use as background estimate. Setting to 0.5 is equivalent to median, 0.0 is min value. - + num_peaks_max = 100 + Max number of peaks to return. + threshold_abs: float + Absolute image intensity threshold for peaks. + threshold_prom_annular: float + Threshold for prominance, along annular direction. + threshold_prom_radial: float + Threshold for prominance, along radial direction. + remove_masked_peaks: bool + Delete peaks that are in the region masked by "mask" + scale_sigma_annular: float + Scaling of the estimated annular standard deviation. + scale_sigma_radial: float + Scaling of the estimated radial standard deviation. + return_background: bool + Return the background signal. + plot_result: + Plot the detector peaks + plot_power_scale: float + Image intensity power law scaling. + plot_scale_size: float + Marker scaling in the plot. + figsize: 2-tuple + Size of the result plotting figure. + returnfig: bool + Return the figure and axes handles. + Returns -------- + peaks_polar : pointlist + The detected peaks + fig, ax : (optional) + Figure and axes handles + """ # if needed, generate mask from Bragg peaks @@ -151,7 +183,7 @@ def find_peaks_single_pattern( trace_annular, annular_ind_center, ) - sigma_annular = scale_sigma_annular * np.maximum( + sigma_annular = scale_sigma_annular * np.minimum( annular_ind_center - p_annular[1], p_annular[2] - annular_ind_center) @@ -161,7 +193,7 @@ def find_peaks_single_pattern( trace_radial, np.atleast_1d(peaks[a0,1]), ) - sigma_radial = scale_sigma_radial * np.maximum( + sigma_radial = scale_sigma_radial * np.minimum( peaks[a0,1] - p_radial[1], p_radial[2] - peaks[a0,1]) @@ -266,7 +298,7 @@ def find_peaks_single_pattern( st = np.sin(t) - fig,ax = plt.subplots(figsize=(12,6)) + fig,ax = plt.subplots(figsize=figsize) ax.imshow( im_plot, @@ -685,6 +717,7 @@ def model_radial_background( ring_int = None, refine_model = True, plot_result = True, + figsize = (8,4), ): """ User provided radial background model, of the form: @@ -751,6 +784,8 @@ def model_radial_background( self.background_coefs[3*a0+3] = ring_int[a0] self.background_coefs[3*a0+4] = ring_sigma[a0] self.background_coefs[3*a0+5] = ring_position[a0] + lb = np.zeros_like(self.background_coefs) + ub = np.ones_like(self.background_coefs) * np.inf # Create background model def background_model(q, *coefs): @@ -776,7 +811,7 @@ def background_model(q, *coefs): self.background_radial_mean[self.background_mask], p0 = self.background_coefs, xtol = 1e-12, - # bounds = (lb,ub), + bounds = (lb,ub), )[0] # plotting @@ -784,6 +819,7 @@ def background_model(q, *coefs): self.plot_radial_background( q_pixel_units = False, plot_background_model = True, + figsize = figsize, ) @@ -794,6 +830,7 @@ def refine_peaks( # reset_fits_to_init_positions = False, scale_sigma_estimate = 0.5, min_num_pixels_fit = 10, + maxfev = None, progress_bar = True, ): """ @@ -816,6 +853,8 @@ def refine_peaks( Factor to reduce sigma of peaks by, to prevent fit from running away. min_num_pixels_fit: int Minimum number of pixels to perform fitting + maxfev: int + Maximum number of iterations in fit. Set to a low number for a fast fit. progress_bar: bool Enable progress bar @@ -896,6 +935,11 @@ def refine_peaks( s_radial * scale_sigma_estimate, )) + # bounds + lb = np.zeros_like(coefs_all) + ub = np.ones_like(coefs_all) * np.inf + + # Construct fitting model def fit_image(basis, *coefs): coefs = np.squeeze(np.array(coefs)) @@ -928,14 +972,25 @@ def fit_image(basis, *coefs): try: with warnings.catch_warnings(): warnings.simplefilter('ignore') - coefs_all = curve_fit( - fit_image, - basis[mask_bool.ravel(),:], - im_polar[mask_bool], - p0 = coefs_all, - xtol = 1e-12, - # bounds = (lb,ub), - )[0] + if maxfev is None: + coefs_all = curve_fit( + fit_image, + basis[mask_bool.ravel(),:], + im_polar[mask_bool], + p0 = coefs_all, + xtol = 1e-12, + bounds = (lb,ub), + )[0] + else: + coefs_all = curve_fit( + fit_image, + basis[mask_bool.ravel(),:], + im_polar[mask_bool], + p0 = coefs_all, + xtol = 1e-12, + maxfev = maxfev, + bounds = (lb,ub), + )[0] # Output refined peak parameters coefs_peaks = np.reshape( @@ -951,9 +1006,24 @@ def fit_image(basis, *coefs): ]), name = 'peaks_polar') except: - # if fitting has failed, we will output the mean background signal, - # but none of the peaks. - pass + # if fitting has failed, we will still output the last iteration + # TODO - add a flag for unconverged fits + coefs_peaks = np.reshape( + coefs_all[(3*num_rings+3):], + (5,num_peaks)).T + self.peaks_refine[rx,ry] = PointList( + coefs_peaks.ravel().view([ + ('qt', float), + ('qr', float), + ('intensity', float), + ('sigma_annular', float), + ('sigma_radial', float), + ]), + name = 'peaks_polar') + + # mean background signal, + # # but none of the peaks. + # pass # Output refined parameters for background coefs_bg = coefs_all[:(3*num_rings+3)] @@ -1154,6 +1224,9 @@ def make_orientation_histogram( v_sigma = np.linspace(-2,2,2*peak_sigma_samples+1) w_sigma = np.exp(-v_sigma**2/2) + if use_refined_peaks is False: + warnings.warn("Orientation histogram is using non-refined peak positions") + # Loop over all probe positions for a0 in range(num_radii): t = "Generating histogram " + str(a0) @@ -1199,7 +1272,10 @@ def make_orientation_histogram( # If needed, expand signal using peak sigma to write into multiple bins if use_peak_sigma: - theta_std = self.peaks_refine[rx,ry]['sigma_annular'][sub] / dtheta + if use_refined_peaks: + theta_std = self.peaks_refine[rx,ry]['sigma_annular'][sub] / dtheta + else: + theta_std = self.peaks[rx,ry]['sigma_annular'][sub] / dtheta t = (t[:,None] + theta_std[:,None]*v_sigma[None,:]).ravel() intensity = (intensity[:,None] * w_sigma[None,:]).ravel() diff --git a/py4DSTEM/process/rdf/amorph.py b/py4DSTEM/process/rdf/amorph.py index 9c80a2807..a537896b9 100644 --- a/py4DSTEM/process/rdf/amorph.py +++ b/py4DSTEM/process/rdf/amorph.py @@ -111,7 +111,7 @@ def plot_strains(strains, cmap="RdBu_r", vmin=None, vmax=None, mask=None): cmap, vmin, vmax: imshow parameters mask: real space mask of values not to show (black) """ - cmap = matplotlib.cm.get_cmap(cmap) + cmap = plt.get_cmap(cmap) if vmin is None: vmin = np.min(strains) if vmax is None: diff --git a/py4DSTEM/process/strain.py b/py4DSTEM/process/strain.py index e999c02d1..db252f75b 100644 --- a/py4DSTEM/process/strain.py +++ b/py4DSTEM/process/strain.py @@ -1,13 +1,17 @@ # Defines the Strain class -import numpy as np from typing import Optional -from py4DSTEM.data import RealSlice, Data -from py4DSTEM.braggvectors import BraggVectors +import matplotlib.pyplot as plt +import numpy as np +from py4DSTEM import PointList +from py4DSTEM.braggvectors import BraggVectors +from py4DSTEM.data import Data, RealSlice +from py4DSTEM.preprocess.utils import get_maxima_2D +from py4DSTEM.visualize import add_bragg_index_labels, add_pointlabels, add_vector, show -class StrainMap(RealSlice,Data): +class StrainMap(RealSlice, Data): """ Stores strain map. @@ -15,64 +19,80 @@ class StrainMap(RealSlice,Data): """ - def __init__( - self, - braggvectors: BraggVectors, - name: Optional[str] = 'strainmap' - ): + def __init__(self, braggvectors: BraggVectors, name: Optional[str] = "strainmap"): """ TODO """ - assert(isinstance(braggvectors,BraggVectors)), f"braggvectors myst be BraggVectors, not type {type(braggvectors)}" + assert isinstance( + braggvectors, BraggVectors + ), f"braggvectors must be BraggVectors, not type {type(braggvectors)}" # initialize as a RealSlice RealSlice.__init__( self, - name = name, - data = np.empty(( - 6, - braggvectors.Rshape[0], - braggvectors.Rshape[1], - )), - slicelabels = [ - 'exx', - 'eyy', - 'exy', - 'theta', - 'mask', - 'error' - ] + name=name, + data=np.empty( + ( + 6, + braggvectors.Rshape[0], + braggvectors.Rshape[1], + ) + ), + slicelabels=["exx", "eyy", "exy", "theta", "mask", "error"], ) # set up braggvectors + # this assigns the bvs, ensures the origin is calibrated, + # and adds the strainmap to the bvs' tree self.braggvectors = braggvectors - # TODO - how to handle changes to braggvectors - # option: register with calibrations and add a .calibrate method - # which {{does something}} when origin changes - # TODO - include ellipse cal or no? - - assert(self.root is not None) # initialize as Data - Data.__init__( - self, - calibration = self.braggvectors.calibration - ) - + Data.__init__(self) + + # set calstate + # this property is used only to check to make sure that + # the braggvectors being used throughout a workflow are + # the same. The state of calibration of the vectors is noted + # here, and then checked each time the vectors are used - + # if they differ, an error message and instructions for + # re-calibration are issued + self.calstate = self.braggvectors.calstate + assert self.calstate["center"], "braggvectors must be centered" + # get the BVM + # a new BVM using the current calstate is computed + self.bvm = self.braggvectors.histogram(mode="cal") # braggvector properties @property def braggvectors(self): return self._braggvectors + @braggvectors.setter - def braggvectors(self,x): - assert(isinstance(x,BraggVectors)), f".braggvectors must be BraggVectors, not type {type(x)}" - assert(x.calibration.origin is not None), f"braggvectors must have a calibrated origin" + def braggvectors(self, x): + assert isinstance( + x, BraggVectors + ), f".braggvectors must be BraggVectors, not type {type(x)}" + assert ( + x.calibration.origin is not None + ), f"braggvectors must have a calibrated origin" self._braggvectors = x - self._braggvectors.tree(self,force=True) - + self._braggvectors.tree(self, force=True) + def reset_calstate(self): + """ + Resets the calibration state. This recomputes the BVM, and removes any computations + this StrainMap instance has stored, which will need to be recomputed. + """ + for attr in ( + "g0", + "g1", + "g2", + ): + if hasattr(self, attr): + delattr(self, attr) + self.calstate = self.braggvectors.calstate + pass # Class methods @@ -81,10 +101,8 @@ def choose_lattice_vectors( index_g0, index_g1, index_g2, - mode = 'centered', - plot = True, - subpixel = 'multicorr', - upsample_factor = 16, + subpixel="multicorr", + upsample_factor=16, sigma=0, minAbsoluteIntensity=0, minRelativeIntensity=0, @@ -92,95 +110,492 @@ def choose_lattice_vectors( minSpacing=0, edgeBoundary=1, maxNumPeaks=10, - bvm_vis_params = {}, - returncalc = False, - ): + figsize=(12, 6), + c_indices="lightblue", + c0="g", + c1="r", + c2="r", + c_vectors="r", + c_vectorlabels="w", + size_indices=20, + width_vectors=1, + size_vectorlabels=20, + vis_params={}, + returncalc=False, + returnfig=False, + ): """ Choose which lattice vectors to use for strain mapping. - Args: - index_g0 (int): origin - index_g1 (int): second point of vector 1 - index_g2 (int): second point of vector 2 - mode (str): centered or raw bragg map - plot (bool): plot bragg vector maps and vectors - subpixel (str): specifies the subpixel resolution algorithm to use. - must be in ('pixel','poly','multicorr'), which correspond - to pixel resolution, subpixel resolution by fitting a - parabola, and subpixel resultion by Fourier upsampling. - upsample_factor: the upsampling factor for the 'multicorr' - algorithm - sigma: if >0, applies a gaussian filter - maxNumPeaks: the maximum number of maxima to return - minAbsoluteIntensity, minRelativeIntensity, relativeToPeak, - minSpacing, edgeBoundary, maxNumPeaks: filtering applied - after maximum detection and before subpixel refinement + Overlays the bvm with the points detected via local 2D + maxima detection, plus an index for each point. User selects + 3 points using the overlaid indices, which are identified as + the origin and the termini of the lattice vectors g1 and g2. + + Parameters + ---------- + index_g0 : int + selected index for the origin + index_g1 : int + selected index for g1 + index_g2 :int + selected index for g2 + subpixel : str in ('pixel','poly','multicorr') + See the docstring for py4DSTEM.preprocess.get_maxima_2D + upsample_factor : int + See the py4DSTEM.preprocess.get_maxima_2D docstring + sigma : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + minAbsoluteIntensity : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + minRelativeIntensity : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + relativeToPeak : int + See the py4DSTEM.preprocess.get_maxima_2D docstring + minSpacing : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + edgeBoundary : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + maxNumPeaks : int + See the py4DSTEM.preprocess.get_maxima_2D docstring + figsize : 2-tuple + the size of the figure + c_indices : color + color of the maxima + c0 : color + color of the origin + c1 : color + color of g1 point + c2 : color + color of g2 point + c_vectors : color + color of the g1/g2 vectors + c_vectorlabels : color + color of the vector labels + size_indices : number + size of the indices + width_vectors : number + width of the vectors + size_vectorlabels : number + size of the vector labels + vis_params : dict + additional visualization parameters passed to `show` + returncalc : bool + toggles returning the answer + returnfig : bool + toggles returning the figure + + Returns + ------- + (optional) : None or (g0,g1,g2) or (fig,(ax1,ax2)) or both of the latter """ - from py4DSTEM.process.utils import get_maxima_2D - - if mode == "centered": - bvm = self.bvm_centered - else: - bvm = self.bvm_raw - + # validate inputs + for i in (index_g0, index_g1, index_g2): + assert isinstance(i, (int, np.integer)), "indices must be integers!" + # check the calstate + assert ( + self.calstate == self.braggvectors.calstate + ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." + + # find the maxima g = get_maxima_2D( - bvm, - subpixel = subpixel, - upsample_factor = upsample_factor, - sigma = sigma, - minAbsoluteIntensity = minAbsoluteIntensity, - minRelativeIntensity = minRelativeIntensity, - relativeToPeak = relativeToPeak, - minSpacing = minSpacing, - edgeBoundary = edgeBoundary, - maxNumPeaks = maxNumPeaks, + self.bvm.data, + subpixel=subpixel, + upsample_factor=upsample_factor, + sigma=sigma, + minAbsoluteIntensity=minAbsoluteIntensity, + minRelativeIntensity=minRelativeIntensity, + relativeToPeak=relativeToPeak, + minSpacing=minSpacing, + edgeBoundary=edgeBoundary, + maxNumPeaks=maxNumPeaks, ) + # get the lattice vectors + gx, gy = g["x"], g["y"] + g0 = gx[index_g0], gy[index_g0] + g1x = gx[index_g1] - g0[0] + g1y = gy[index_g1] - g0[1] + g2x = gx[index_g2] - g0[0] + g2y = gy[index_g2] - g0[1] + g1, g2 = (g1x, g1y), (g2x, g2y) + + # make the figure + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) + show(self.bvm.data, figax=(fig, ax1), **vis_params) + show(self.bvm.data, figax=(fig, ax2), **vis_params) + + # Add indices to left panel + d = {"x": gx, "y": gy, "size": size_indices, "color": c_indices} + d0 = { + "x": gx[index_g0], + "y": gy[index_g0], + "size": size_indices, + "color": c0, + "fontweight": "bold", + "labels": [str(index_g0)], + } + d1 = { + "x": gx[index_g1], + "y": gy[index_g1], + "size": size_indices, + "color": c1, + "fontweight": "bold", + "labels": [str(index_g1)], + } + d2 = { + "x": gx[index_g2], + "y": gy[index_g2], + "size": size_indices, + "color": c2, + "fontweight": "bold", + "labels": [str(index_g2)], + } + add_pointlabels(ax1, d) + add_pointlabels(ax1, d0) + add_pointlabels(ax1, d1) + add_pointlabels(ax1, d2) + + # Add vectors to right panel + dg1 = { + "x0": gx[index_g0], + "y0": gy[index_g0], + "vx": g1[0], + "vy": g1[1], + "width": width_vectors, + "color": c_vectors, + "label": r"$g_1$", + "labelsize": size_vectorlabels, + "labelcolor": c_vectorlabels, + } + dg2 = { + "x0": gx[index_g0], + "y0": gy[index_g0], + "vx": g2[0], + "vy": g2[1], + "width": width_vectors, + "color": c_vectors, + "label": r"$g_2$", + "labelsize": size_vectorlabels, + "labelcolor": c_vectorlabels, + } + add_vector(ax2, dg1) + add_vector(ax2, dg2) + + # store vectors self.g = g + self.g0 = g0 + self.g1 = g1 + self.g2 = g2 + + # return + if returncalc and returnfig: + return (g0, g1, g2), (fig, (ax1, ax2)) + elif returncalc: + return (g0, g1, g2) + elif returnfig: + return (fig, (ax1, ax2)) + else: + return + + def fit_lattice_vectors( + self, + x0=None, + y0=None, + max_peak_spacing=2, + mask=None, + plot=True, + vis_params={}, + returncalc=False, + ): + """ + From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of + lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the + reciprocal lattice directions. + + Args: + x0 : floagt + x-coord of origin + y0 : float + y-coord of origin + max_peak_spacing: float + Maximum distance from the ideal lattice points + to include a peak for indexing + mask: bool + Boolean mask, same shape as the pointlistarray, indicating which + locations should be indexed. This can be used to index different regions of + the scan with different lattices + plot:bool + plot results if tru + vis_params : dict + additional visualization parameters passed to `show` + returncalc : bool + if True, returns bragg_directions, bragg_vectors_indexed, g1g2_map + """ + # check the calstate + assert ( + self.calstate == self.braggvectors.calstate + ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." - from py4DSTEM.visualize import select_lattice_vectors - g1,g2 = select_lattice_vectors( - bvm, - gx = g['x'], - gy = g['y'], - i0 = index_g0, - i1 = index_g1, - i2 = index_g2, - **bvm_vis_params, + if x0 is None: + x0 = self.braggvectors.Qshape[0] / 2 + if y0 is None: + y0 = self.braggvectors.Qshape[0] / 2 + + # index braggvectors + from py4DSTEM.process.latticevectors import index_bragg_directions + + _, _, braggdirections = index_bragg_directions( + x0, y0, self.g["x"], self.g["y"], self.g1, self.g2 ) - self.g1 = g1 - self.g2 = g2 + self.braggdirections = braggdirections + + if plot: + self.show_bragg_indexing( + self.bvm, + bragg_directions=braggdirections, + points=True, + **vis_params, + ) + + # add indicies to braggvectors + from py4DSTEM.process.latticevectors import add_indices_to_braggvectors + + bragg_vectors_indexed = add_indices_to_braggvectors( + self.braggvectors, + self.braggdirections, + maxPeakSpacing=max_peak_spacing, + qx_shift=self.braggvectors.Qshape[0] / 2, + qy_shift=self.braggvectors.Qshape[1] / 2, + mask=mask, + ) + + self.bragg_vectors_indexed = bragg_vectors_indexed + + # fit bragg vectors + from py4DSTEM.process.latticevectors import fit_lattice_vectors_all_DPs + + g1g2_map = fit_lattice_vectors_all_DPs(self.bragg_vectors_indexed) + self.g1g2_map = g1g2_map if returncalc: - return g1, g2 + braggdirections, bragg_vectors_indexed, g1g2_map + + def get_strain( + self, mask=None, g_reference=None, flip_theta=False, returncalc=False, **kwargs + ): + """ + mask: nd.array (bool) + Use lattice vectors from g1g2_map scan positions + wherever mask==True. If mask is None gets median strain + map from entire field of view. If mask is not None, gets + reference g1 and g2 from region and then calculates strain. + g_reference: nd.array of form [x,y] + G_reference (tupe): reference coordinate system for + xaxis_x and xaxis_y + flip_theta: bool + If True, flips rotation coordinate system + returncal: bool + It True, returns rotated map + """ + # check the calstate + assert ( + self.calstate == self.braggvectors.calstate + ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." + if mask is None: + mask = np.ones(self.g1g2_map.shape, dtype="bool") + from py4DSTEM.process.latticevectors import get_strain_from_reference_region + strainmap_g1g2 = get_strain_from_reference_region( + self.g1g2_map, + mask=mask, + ) + else: + from py4DSTEM.process.latticevectors import get_reference_g1g2 + g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, mask) + from py4DSTEM.process.latticevectors import get_strain_from_reference_g1g2 + strainmap_g1g2 = get_strain_from_reference_g1g2( + self.g1g2_map, g1_ref, g2_ref + ) + self.strainmap_g1g2 = strainmap_g1g2 + if g_reference is None: + g_reference = np.subtract(self.g1, self.g2) + from py4DSTEM.process.latticevectors import get_rotated_strain_map + strainmap_rotated = get_rotated_strain_map( + self.strainmap_g1g2, + xaxis_x=g_reference[0], + xaxis_y=g_reference[1], + flip_theta=flip_theta, + ) - # IO methods + self.strainmap_rotated = strainmap_rotated + + from py4DSTEM.visualize import show_strain + + figsize = kwargs.pop("figsize", (14, 4)) + vrange_exx = kwargs.pop("vrange_exx", [-2.0, 2.0]) + vrange_theta = kwargs.pop("vrange_theta", [-2.0, 2.0]) + ticknumber = kwargs.pop("ticknumber", 3) + bkgrd = kwargs.pop("bkgrd", False) + axes_plots = kwargs.pop("axes_plots", ()) + + fig, ax = show_strain( + self.strainmap_rotated, + vrange_exx=vrange_exx, + vrange_theta=vrange_theta, + ticknumber=ticknumber, + axes_plots=axes_plots, + bkgrd=bkgrd, + figsize=figsize, + **kwargs, + returnfig=True, + ) + + if not np.all(mask == True): + ax[0][0].imshow(mask, alpha=0.2, cmap="binary") + ax[0][1].imshow(mask, alpha=0.2, cmap="binary") + ax[1][0].imshow(mask, alpha=0.2, cmap="binary") + ax[1][1].imshow(mask, alpha=0.2, cmap="binary") + + if returncalc: + return self.strainmap_rotated + + def show_lattice_vectors( + ar, + x0, + y0, + g1, + g2, + color="r", + width=1, + labelsize=20, + labelcolor="w", + returnfig=False, + **kwargs, + ): + """Adds the vectors g1,g2 to an image, with tail positions at (x0,y0). g1 and g2 are 2-tuples (gx,gy).""" + fig, ax = show(ar, returnfig=True, **kwargs) + + # Add vectors + dg1 = { + "x0": x0, + "y0": y0, + "vx": g1[0], + "vy": g1[1], + "width": width, + "color": color, + "label": r"$g_1$", + "labelsize": labelsize, + "labelcolor": labelcolor, + } + dg2 = { + "x0": x0, + "y0": y0, + "vx": g2[0], + "vy": g2[1], + "width": width, + "color": color, + "label": r"$g_2$", + "labelsize": labelsize, + "labelcolor": labelcolor, + } + add_vector(ax, dg1) + add_vector(ax, dg2) + + if returnfig: + return fig, ax + else: + plt.show() + return + + def show_bragg_indexing( + self, + ar, + bragg_directions, + voffset=5, + hoffset=0, + color="w", + size=20, + points=True, + pointcolor="r", + pointsize=50, + returnfig=False, + **kwargs, + ): + """ + Shows an array with an overlay describing the Bragg directions + + Accepts: + ar (arrray) the image + bragg_directions (PointList) the bragg scattering directions; must have coordinates + 'qx','qy','h', and 'k'. Optionally may also have 'l'. + """ + assert isinstance(bragg_directions, PointList) + for k in ("qx", "qy", "h", "k"): + assert k in bragg_directions.data.dtype.fields + + fig, ax = show(ar, returnfig=True, **kwargs) + d = { + "bragg_directions": bragg_directions, + "voffset": voffset, + "hoffset": hoffset, + "color": color, + "size": size, + "points": points, + "pointsize": pointsize, + "pointcolor": pointcolor, + } + add_bragg_index_labels(ax, d) - # TODO - copy method + if returnfig: + return fig, ax + else: + plt.show() + return + + def copy(self, name=None): + name = name if name is not None else self.name + "_copy" + strainmap_copy = StrainMap(self.braggvectors) + for attr in ( + "g", + "g0", + "g1", + "g2", + "calstate", + "bragg_directions", + "bragg_vectors_indexed", + "g1g2_map", + "strainmap_g1g2", + "strainmap_rotated", + ): + if hasattr(self, attr): + setattr(strainmap_copy, attr, getattr(self, attr)) + + for k in self.metadata.keys(): + strainmap_copy.metadata = self.metadata[k].copy() + return strainmap_copy + + # IO methods # read @classmethod - def _get_constructor_args(cls,group): + def _get_constructor_args(cls, group): """ Returns a dictionary of args/values to pass to the class constructor """ ar_constr_args = RealSlice._get_constructor_args(group) args = { - 'data' : ar_constr_args['data'], - 'name' : ar_constr_args['name'], + "data": ar_constr_args["data"], + "name": ar_constr_args["name"], } return args - - - diff --git a/py4DSTEM/process/utils/utils.py b/py4DSTEM/process/utils/utils.py index 1df4e78c5..86257b4dc 100644 --- a/py4DSTEM/process/utils/utils.py +++ b/py4DSTEM/process/utils/utils.py @@ -59,7 +59,7 @@ def radial_reduction( def plot(img, title='Image', savePath=None, cmap='inferno', show=True, vmax=None, figsize=(10, 10), scale=None): fig, ax = plt.subplots(figsize=figsize) - im = ax.imshow(img, interpolation='nearest', cmap=plt.cm.get_cmap(cmap), vmax=vmax) + im = ax.imshow(img, interpolation='nearest', cmap=plt.get_cmap(cmap), vmax=vmax) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im, cax=cax) @@ -636,7 +636,7 @@ def fourier_resample( #def plot(img, title='Image', savePath=None, cmap='inferno', show=True, vmax=None, # figsize=(10, 10), scale=None): # fig, ax = plt.subplots(figsize=figsize) -# im = ax.imshow(img, interpolation='nearest', cmap=plt.cm.get_cmap(cmap), vmax=vmax) +# im = ax.imshow(img, interpolation='nearest', cmap=plt.get_cmap(cmap), vmax=vmax) # divider = make_axes_locatable(ax) # cax = divider.append_axes("right", size="5%", pad=0.05) # plt.colorbar(im, cax=cax) diff --git a/py4DSTEM/version.py b/py4DSTEM/version.py index 4009e43c9..9df5075b8 100644 --- a/py4DSTEM/version.py +++ b/py4DSTEM/version.py @@ -1,2 +1,2 @@ -__version__='0.14.2' +__version__='0.14.3' diff --git a/py4DSTEM/visualize/overlay.py b/py4DSTEM/visualize/overlay.py index e0c87a427..7e7147a15 100644 --- a/py4DSTEM/visualize/overlay.py +++ b/py4DSTEM/visualize/overlay.py @@ -437,7 +437,7 @@ def add_bragg_index_labels(ax,d): Adds labels for indexed bragg directions to a plot, using the parameters in dict d. The dictionary d has required and optional parameters as follows: - braggdirections (req'd) (PointList) the Bragg directions. This PointList must have + bragg_directions (req'd) (PointList) the Bragg directions. This PointList must have the fields 'qx','qy','h', and 'k', and may optionally have 'l' voffset (number) vertical offset for the labels hoffset (number) horizontal offset for the labels @@ -450,12 +450,12 @@ def add_bragg_index_labels(ax,d): # handle inputs assert isinstance(ax,Axes) # bragg directions - assert('braggdirections' in d.keys()) - braggdirections = d['braggdirections'] - assert isinstance(braggdirections,PointList) + assert('bragg_directions' in d.keys()) + bragg_directions = d['bragg_directions'] + assert isinstance(bragg_directions,PointList) for k in ('qx','qy','h','k'): - assert k in braggdirections.data.dtype.fields - include_l = True if 'l' in braggdirections.data.dtype.fields else False + assert k in bragg_directions.data.dtype.fields + include_l = True if 'l' in bragg_directions.data.dtype.fields else False # offsets hoffset = d['hoffset'] if 'hoffset' in d.keys() else 0 voffset = d['voffset'] if 'voffset' in d.keys() else 5 @@ -474,20 +474,20 @@ def add_bragg_index_labels(ax,d): # add the points if points: - ax.scatter(braggdirections.data['qy'],braggdirections.data['qx'], + ax.scatter(bragg_directions.data['qy'],bragg_directions.data['qx'], color=pointcolor,s=pointsize) # add index labels - for i in range(braggdirections.length): - x,y = braggdirections.data['qx'][i],braggdirections.data['qy'][i] + for i in range(bragg_directions.length): + x,y = bragg_directions.data['qx'][i],bragg_directions.data['qy'][i] x -= voffset y += hoffset - h,k = braggdirections.data['h'][i],braggdirections.data['k'][i] + h,k = bragg_directions.data['h'][i],bragg_directions.data['k'][i] h = str(h) if h>=0 else r'$\overline{{{}}}$'.format(np.abs(h)) k = str(k) if k>=0 else r'$\overline{{{}}}$'.format(np.abs(k)) s = h+','+k if include_l: - l = braggdirections.data['l'][i] + l = bragg_directions.data['l'][i] l = str(l) if l>=0 else r'$\overline{{{}}}$'.format(np.abs(l)) s += l ax.text(y,x,s,color=color,size=size,ha='center',va='bottom') diff --git a/py4DSTEM/visualize/show.py b/py4DSTEM/visualize/show.py index f63bda993..3b9d99e43 100644 --- a/py4DSTEM/visualize/show.py +++ b/py4DSTEM/visualize/show.py @@ -567,12 +567,8 @@ def show( ax.matshow(mask_display,cmap=cmap,alpha=mask_alpha,vmin=vmin,vmax=vmax) # ...or, plot its histogram else: - # hist,bin_edges = np.histogram( - # _ar, - # bins=np.linspace(np.min(_ar),np.max(_ar),num=n_bins)) - hist,bin_edges = np.histogram( - _ar, - bins=np.linspace(vmin,vmax,num=n_bins)) + hist,bin_edges = np.histogram(_ar,bins=np.linspace(np.min(_ar), + np.max(_ar),num=n_bins)) w = bin_edges[1]-bin_edges[0] x = bin_edges[:-1]+w/2. ax.bar(x,hist,width=w) diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index b487048e2..43cf7fff8 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -590,102 +590,6 @@ def select_point(ar,x,y,i,color='lightblue',color_selected='r',size=20,returnfig return -def select_lattice_vectors(ar,gx,gy,i0,i1,i2, - c_indices='lightblue',c0='g',c1='r',c2='r',c_vectors='r',c_vectorlabels='w', - size_indices=20,width_vectors=1,size_vectorlabels=20, - figsize=(12,6),returnfig=False,**kwargs): - """ - This function accepts a set of reciprocal lattice points (gx,gy) and three indices - (i0,i1,i2). Using those indices as, respectively, the origin, the endpoint of g1, and - the endpoint of g2, this function computes the basis lattice vectors g1,g2, visualizes - them, and returns them. To compute these vectors without visualizing, use - latticevectors.get_selected_lattice_vectors(). - - Returns: - if returnfig==False: g1,g2 - if returnfig==True g1,g2,fig,ax - """ - from py4DSTEM.process.latticevectors import get_selected_lattice_vectors - - # Make the figure - fig,(ax1,ax2) = plt.subplots(1,2,figsize=figsize) - show(ar,figax=(fig,ax1),**kwargs) - show(ar,figax=(fig,ax2),**kwargs) - - # Add indices to left panel - d = {'x':gx,'y':gy,'size':size_indices,'color':c_indices} - d0 = {'x':gx[i0],'y':gy[i0],'size':size_indices,'color':c0,'fontweight':'bold','labels':[str(i0)]} - d1 = {'x':gx[i1],'y':gy[i1],'size':size_indices,'color':c1,'fontweight':'bold','labels':[str(i1)]} - d2 = {'x':gx[i2],'y':gy[i2],'size':size_indices,'color':c2,'fontweight':'bold','labels':[str(i2)]} - add_pointlabels(ax1,d) - add_pointlabels(ax1,d0) - add_pointlabels(ax1,d1) - add_pointlabels(ax1,d2) - - # Compute vectors - g1,g2 = get_selected_lattice_vectors(gx,gy,i0,i1,i2) - - # Add vectors to right panel - dg1 = {'x0':gx[i0],'y0':gy[i0],'vx':g1[0],'vy':g1[1],'width':width_vectors, - 'color':c_vectors,'label':r'$g_1$','labelsize':size_vectorlabels,'labelcolor':c_vectorlabels} - dg2 = {'x0':gx[i0],'y0':gy[i0],'vx':g2[0],'vy':g2[1],'width':width_vectors, - 'color':c_vectors,'label':r'$g_2$','labelsize':size_vectorlabels,'labelcolor':c_vectorlabels} - add_vector(ax2,dg1) - add_vector(ax2,dg2) - - if returnfig: - return g1,g2,fig,(ax1,ax2) - else: - plt.show() - return g1,g2 - - -def show_lattice_vectors(ar,x0,y0,g1,g2,color='r',width=1,labelsize=20,labelcolor='w',returnfig=False,**kwargs): - """ Adds the vectors g1,g2 to an image, with tail positions at (x0,y0). g1 and g2 are 2-tuples (gx,gy). - """ - fig,ax = show(ar,returnfig=True,**kwargs) - - # Add vectors - dg1 = {'x0':x0,'y0':y0,'vx':g1[0],'vy':g1[1],'width':width, - 'color':color,'label':r'$g_1$','labelsize':labelsize,'labelcolor':labelcolor} - dg2 = {'x0':x0,'y0':y0,'vx':g2[0],'vy':g2[1],'width':width, - 'color':color,'label':r'$g_2$','labelsize':labelsize,'labelcolor':labelcolor} - add_vector(ax,dg1) - add_vector(ax,dg2) - - if returnfig: - return fig,ax - else: - plt.show() - return - - -def show_bragg_indexing(ar,braggdirections,voffset=5,hoffset=0,color='w',size=20, - points=True,pointcolor='r',pointsize=50,returnfig=False,**kwargs): - """ - Shows an array with an overlay describing the Bragg directions - - Accepts: - ar (arrray) the image - bragg_directions (PointList) the bragg scattering directions; must have coordinates - 'qx','qy','h', and 'k'. Optionally may also have 'l'. - """ - assert isinstance(braggdirections,PointList) - for k in ('qx','qy','h','k'): - assert k in braggdirections.data.dtype.fields - - fig,ax = show(ar,returnfig=True,**kwargs) - d = {'braggdirections':braggdirections,'voffset':voffset,'hoffset':hoffset,'color':color, - 'size':size,'points':points,'pointsize':pointsize,'pointcolor':pointcolor} - add_bragg_index_labels(ax,d) - - if returnfig: - return fig,ax - else: - plt.show() - return - - def show_max_peak_spacing(ar,spacing,braggdirections,color='g',lw=2,returnfig=False,**kwargs): """ Show a circle of radius `spacing` about each Bragg direction """ diff --git a/setup.py b/setup.py index cb9da8169..b0c7fa081 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,7 @@ 'numpy >= 1.19', 'scipy >= 1.5.2', 'h5py >= 3.2.0', + 'hdf5plugin >= 4.1.3', 'ncempy >= 1.8.1', 'matplotlib >= 3.2.2', 'scikit-image >= 0.17.2', diff --git a/test/gettestdata.py b/test/gettestdata.py index b3d8a0a40..a84e5b9b3 100644 --- a/test/gettestdata.py +++ b/test/gettestdata.py @@ -53,23 +53,24 @@ # Set data collection key if args.data == 'tutorials': - data = 'tutorials' + data = ['tutorials'] elif args.data == 'io': - data = 'test_io' + data = ['test_io','test_arina'] elif args.data == 'basic': - data = 'small_datacube' + data = ['small_datacube'] elif args.data == 'strain': - data = 'strain' + data = ['strain'] else: raise Exception(f"invalid data choice, {parser.data}") # Download data -download( - data, - destination = testpath, - overwrite = args.overwrite, - verbose = args.verbose -) +for d in data: + download( + d, + destination = testpath, + overwrite = args.overwrite, + verbose = args.verbose + ) # Always download the basic datacube if args.data != 'basic': diff --git a/test/test_nonnative_io/test_arina.py b/test/test_nonnative_io/test_arina.py new file mode 100644 index 000000000..c27cb8ef5 --- /dev/null +++ b/test/test_nonnative_io/test_arina.py @@ -0,0 +1,19 @@ +import py4DSTEM +import emdfile +from os.path import join + + +# Set filepaths +filepath = join(py4DSTEM._TESTPATH, "test_arina/STO_STEM_bench_20us_master.h5") + + +def test_read_arina(): + + # read + data = py4DSTEM.import_file( filepath ) + + # check imported data + assert isinstance(data, emdfile.Array) + assert isinstance(data, py4DSTEM.DataCube) + + diff --git a/test/test_strain.py b/test/test_strain.py index 5bfa0efd3..bc9b8b58c 100644 --- a/test/test_strain.py +++ b/test/test_strain.py @@ -27,5 +27,7 @@ def test_strainmap_instantiation(self): ) assert(isinstance(strainmap, StrainMap)) + assert(strainmap.calibration is not None) + assert(strainmap.calibration is strainmap.braggvectors.calibration) From 4e1be96c7f58d267abf7c85ffd83d82687f40fb6 Mon Sep 17 00:00:00 2001 From: cophus Date: Wed, 9 Aug 2023 16:31:53 -0700 Subject: [PATCH 20/25] Fixing merge conflicts --- py4DSTEM/process/polar/polar_analysis.py | 30 +----------------------- 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index fa6a40a4f..447baaf1d 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -6,17 +6,10 @@ from emdfile import tqdmnd -<<<<<<< Updated upstream -def calculate_FEM_global( - self, - use_median_local = False, - use_median_global = False, -======= def calculate_radial_statistics( self, median_local = False, median_global = False, ->>>>>>> Stashed changes plot_results = False, figsize = (8,4), returnval = False, @@ -49,37 +42,22 @@ def calculate_radial_statistics( self.scattering_vector = self.radial_bins * self.qstep * self.calibration.get_Q_pixel_size() self.scattering_vector_units = self.calibration.get_Q_pixel_units() -<<<<<<< Updated upstream - # init radial data array -======= # init radial data arrays ->>>>>>> Stashed changes self.radial_all = np.zeros(( self._datacube.shape[0], self._datacube.shape[1], self.polar_shape[1], )) -<<<<<<< Updated upstream -======= self.radial_all_std = np.zeros(( self._datacube.shape[0], self._datacube.shape[1], self.polar_shape[1], )) ->>>>>>> Stashed changes - # Compute the radial mean for each probe position for rx, ry in tqdmnd( self._datacube.shape[0], self._datacube.shape[1], -<<<<<<< Updated upstream - desc="Global FEM", - unit=" probe positions", - disable=not progress_bar): - - self.radial_all[rx,ry] = np.mean(self.data[rx,ry],axis=0) -======= desc="Radial statistics", unit=" probe positions", disable=not progress_bar): @@ -90,7 +68,6 @@ def calculate_radial_statistics( self.radial_all_std[rx,ry] = np.sqrt(np.mean( (self.data[rx,ry] - self.radial_all[rx,ry][None])**2, axis=0)) ->>>>>>> Stashed changes self.radial_avg = np.mean(self.radial_all, axis=(0,1)) self.radial_var = np.mean( @@ -171,9 +148,7 @@ def calculate_FEM_local( """ -<<<<<<< Updated upstream - 1+1 -======= + pass @@ -183,6 +158,3 @@ def calculate_FEM_local( # returnfig = False, # ): - ->>>>>>> Stashed changes - From 46effbb395e253e573e35d0b3cfa3165dac492a9 Mon Sep 17 00:00:00 2001 From: cophus Date: Wed, 9 Aug 2023 16:32:31 -0700 Subject: [PATCH 21/25] Merge conflict fixing --- py4DSTEM/process/polar/polar_datacube.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/py4DSTEM/process/polar/polar_datacube.py b/py4DSTEM/process/polar/polar_datacube.py index b6ef4ee66..c0b8871f9 100644 --- a/py4DSTEM/process/polar/polar_datacube.py +++ b/py4DSTEM/process/polar/polar_datacube.py @@ -95,12 +95,8 @@ def __init__( pass from py4DSTEM.process.polar.polar_analysis import ( -<<<<<<< Updated upstream - calculate_FEM_global, -======= # calculate_FEM_global, calculate_radial_statistics, ->>>>>>> Stashed changes plot_FEM_global, calculate_FEM_local, ) From a9c941d8d84cf0c411a5680d94fd834ea7c3cdbc Mon Sep 17 00:00:00 2001 From: cophus Date: Wed, 9 Aug 2023 20:57:35 -0700 Subject: [PATCH 22/25] Fixing the phase mapping module --- py4DSTEM/process/diffraction/crystal_phase.py | 173 +++++++++++------- 1 file changed, 108 insertions(+), 65 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index 235d2c3e3..b25c0cd78 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -73,13 +73,16 @@ def quantify_single_pattern( self, pointlistarray: PointListArray, xy_position = (0,0), - corr_kernel_size = 0.04, + corr_kernel_size = 0.02, + corr_distance_scale = 1.0, include_false_positives = True, + max_number_phases = 3, sigma_excitation_error = 0.02, power_experiment = 0.5, power_calculated = 0.5, plot_result = True, - scale_markers_experiment = 4, + plot_only_nonzero_phases = True, + scale_markers_experiment = 10, scale_markers_calculated = 4000, crystal_inds_plot = None, phase_colors = np.array(( @@ -101,8 +104,26 @@ def quantify_single_pattern( # tolerance tol2 = 4e-4 + # calibrations + center = pointlistarray.calstate['center'] + ellipse = pointlistarray.calstate['ellipse'] + pixel = pointlistarray.calstate['pixel'] + rotate = pointlistarray.calstate['rotate'] + if center is False: + raise ValueError('Bragg peaks must be center calibration') + if pixel is False: + raise ValueError('Bragg peaks must have pixel size calibration') + # TODO - potentially warn the user if ellipse / rotate calibration not available + # Experimental values - bragg_peaks = pointlistarray.get_pointlist(xy_position[0],xy_position[1]).copy() + bragg_peaks = pointlistarray.get_vectors( + xy_position[0], + xy_position[1], + center = center, + ellipse = ellipse, + pixel = pixel, + rotate = rotate) + # bragg_peaks = pointlistarray.get_pointlist(xy_position[0],xy_position[1]).copy() keep = bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2 > tol2 # ind_center_beam = np.argmin( # bragg_peaks.data["qx"]**2 + bragg_peaks.data["qy"]**2) @@ -176,12 +197,14 @@ def quantify_single_pattern( val_min = dist2[ind_min] if val_min < radius_max_2: - weight = 1 - np.sqrt(dist2[ind_min]) / corr_kernel_size + # weight = 1 - np.sqrt(dist2[ind_min]) / corr_kernel_size + weight = 1 + corr_distance_scale * \ + np.sqrt(dist2[ind_min]) / corr_kernel_size basis[ind_min,a0] = weight * int_fit[a1] if plot_result: matches[a1] = True elif include_false_positives: - unpaired_peaks.append([a0,int_fit[a1]]) + unpaired_peaks.append([a0,int_fit[a1]*(1 + corr_distance_scale)]) if plot_result: library_peaks.append(bragg_peaks_fit) @@ -196,15 +219,30 @@ def quantify_single_pattern( basis = np.vstack((basis, basis_aug)) obs = np.hstack((intensity, np.zeros(len(unpaired_peaks)))) + else: obs = intensity # Solve for phase coefficients try: - phase_weights, phase_residual = nnls( - basis, - obs, - ) + phase_weights = np.zeros(self.num_fits) + inds_solve = np.ones(self.num_fits,dtype='bool') + + search = True + while search is True: + phase_weights_cand, phase_residual_cand = nnls( + basis[:,inds_solve], + obs, + ) + + if np.count_nonzero(phase_weights_cand > 0.0) <= max_number_phases: + phase_weights[inds_solve] = phase_weights_cand + phase_residual = phase_residual_cand + search = False + else: + inds = np.where(inds_solve)[0] + inds_solve[inds[np.argmin(phase_weights_cand)]] = False + except: phase_weights = np.zeros(self.num_fits) phase_residual = np.sqrt(np.sum(intensity**2)) @@ -303,7 +341,6 @@ def quantify_single_pattern( **text_params) - # plot calculated diffraction patterns uvals = phase_colors.copy() uvals[:,3] = 0.3 @@ -328,59 +365,61 @@ def quantify_single_pattern( int_fit = library_int[a0] matches_fit = library_matches[a0] - if np.mod(m,2) == 0: - ax.scatter( - qy_fit[matches_fit], - qx_fit[matches_fit], - s = scale_markers_calculated * int_fit[matches_fit], - marker = mvals[c], - facecolor = phase_colors[c,:], - ) - ax.scatter( - qy_fit[np.logical_not(matches_fit)], - qx_fit[np.logical_not(matches_fit)], - s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], - marker = mvals[c], - facecolor = phase_colors[c,:], - ) - - # legend - ax_leg.scatter( - 0, - dx_leg*(a0+1), - s = 200, - marker = mvals[c], - facecolor = phase_colors[c,:], - ) - else: - ax.scatter( - qy_fit[matches_fit], - qx_fit[matches_fit], - s = scale_markers_calculated * int_fit[matches_fit], - marker = mvals[c], - edgecolors = uvals[c,:], - facecolors = (1,1,1,0.5), - linewidth = 2, - ) - ax.scatter( - qy_fit[np.logical_not(matches_fit)], - qx_fit[np.logical_not(matches_fit)], - s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], - marker = mvals[c], - edgecolors = uvals[c,:], - facecolors = (1,1,1,0.5), - linewidth = 2, - ) - - # legend - ax_leg.scatter( - 0, - dx_leg*(a0+1), - s = 200, - marker = mvals[c], - edgecolors = uvals[c,:], - facecolors = (1,1,1,0.5), - ) + if plot_only_nonzero_phases is False or phase_weights[a0] > 0: + + if np.mod(m,2) == 0: + ax.scatter( + qy_fit[matches_fit], + qx_fit[matches_fit], + s = scale_markers_calculated * int_fit[matches_fit], + marker = mvals[c], + facecolor = phase_colors[c,:], + ) + ax.scatter( + qy_fit[np.logical_not(matches_fit)], + qx_fit[np.logical_not(matches_fit)], + s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], + marker = mvals[c], + facecolor = phase_colors[c,:], + ) + + # legend + ax_leg.scatter( + 0, + dx_leg*(a0+1), + s = 200, + marker = mvals[c], + facecolor = phase_colors[c,:], + ) + else: + ax.scatter( + qy_fit[matches_fit], + qx_fit[matches_fit], + s = scale_markers_calculated * int_fit[matches_fit], + marker = mvals[c], + edgecolors = uvals[c,:], + facecolors = (1,1,1,0.5), + linewidth = 2, + ) + ax.scatter( + qy_fit[np.logical_not(matches_fit)], + qx_fit[np.logical_not(matches_fit)], + s = scale_markers_calculated * int_fit[np.logical_not(matches_fit)], + marker = mvals[c], + edgecolors = uvals[c,:], + facecolors = (1,1,1,0.5), + linewidth = 2, + ) + + # legend + ax_leg.scatter( + 0, + dx_leg*(a0+1), + s = 200, + marker = mvals[c], + edgecolors = uvals[c,:], + facecolors = (1,1,1,0.5), + ) # legend text ax_leg.text( @@ -407,8 +446,10 @@ def quantify_single_pattern( def quantify_phase( self, pointlistarray: PointListArray, - corr_kernel_size = 0.04, + corr_kernel_size = 0.02, + corr_distance_scale = 1.0, include_false_positives = True, + max_number_phases = 3, sigma_excitation_error = 0.02, power_experiment = 0.5, power_calculated = 0.5, @@ -444,7 +485,9 @@ def quantify_phase( pointlistarray = pointlistarray, xy_position = (rx,ry), corr_kernel_size = corr_kernel_size, + corr_distance_scale = corr_distance_scale, include_false_positives = include_false_positives, + max_number_phases = max_number_phases, sigma_excitation_error = sigma_excitation_error, power_experiment = power_experiment, power_calculated = power_calculated, @@ -459,7 +502,7 @@ def quantify_phase( def plot_phase_weights( self, - weight_range = (0.5,1,0), + weight_range = (0.5,1.0), weight_normalize = False, total_intensity_normalize = True, cmap = 'gray', From 0ae8d72dcde562117c1f6ec8a0879f75b5681b79 Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 14 Aug 2023 11:53:20 -0700 Subject: [PATCH 23/25] Cleaning up --- py4DSTEM/process/diffraction/crystal_phase.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index b25c0cd78..4fac83281 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -12,11 +12,9 @@ # from py4DSTEM.io.datastructure import PointList, PointListArray from dataclasses import dataclass, field -# ======= from emdfile import tqdmnd, PointListArray from py4DSTEM.visualize import show, show_image_grid from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern -# >>>>>>> dev @dataclass class CrystalPhase: From fe797121bb8c36c8a03bc589ad759ecb0d028a7e Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 14 Aug 2023 17:22:12 -0700 Subject: [PATCH 24/25] Fixing CUDA disk detection --- py4DSTEM/braggvectors/diskdetection.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/py4DSTEM/braggvectors/diskdetection.py b/py4DSTEM/braggvectors/diskdetection.py index fb7755349..6ce488555 100644 --- a/py4DSTEM/braggvectors/diskdetection.py +++ b/py4DSTEM/braggvectors/diskdetection.py @@ -591,8 +591,12 @@ def _find_Bragg_disks_CUDA_unbatched( batching=False) # Populate a BraggVectors instance and return - braggvectors = BraggVectors( datacube.Rshape, datacube.Qshape ) - braggvectors._v_uncal = peaks + braggvectors = BraggVectors( + datacube.Rshape, + datacube.Qshape, + name = peaks.name) + braggvectors.set_raw_vectors(peaks) + return braggvectors @@ -637,12 +641,13 @@ def _find_Bragg_disks_CUDA_batched( batching=True) # Populate a BraggVectors instance and return - braggvectors = BraggVectors( datacube.Rshape, datacube.Qshape ) - braggvectors._v_uncal = peaks - return braggvectors - - + braggvectors = BraggVectors( + datacube.Rshape, + datacube.Qshape, + name = peaks.name) + braggvectors.set_raw_vectors(peaks) + return braggvectors # Distributed - ipyparallel From 6e9817b53826342ad419553a81073d3fac70a6f7 Mon Sep 17 00:00:00 2001 From: cophus Date: Tue, 29 Aug 2023 15:42:00 +1000 Subject: [PATCH 25/25] Adding moire lattice generation and plotting --- py4DSTEM/process/diffraction/crystal.py | 454 +++++++++++++++++- py4DSTEM/process/diffraction/crystal_phase.py | 7 - 2 files changed, 447 insertions(+), 14 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 856a9a028..14c7e32a7 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -2,6 +2,7 @@ import numpy as np import matplotlib.pyplot as plt +from matplotlib.patches import Circle from fractions import Fraction from typing import Union, Optional from scipy.optimize import curve_fit @@ -718,13 +719,12 @@ def generate_diffraction_pattern( np.array([],dtype=pl_dtype) ) if np.any(keep_int): - bragg_peaks.add_data_by_field( - [ - gx_proj, - gy_proj, - gz_proj, - g_int[keep_int], - h,k,l]) + bragg_peaks.add_data_by_field([ + gx_proj, + gy_proj, + gz_proj, + g_int[keep_int], + h,k,l]) else: pl_dtype = np.dtype([ ("qx", "float64"), @@ -1048,3 +1048,443 @@ def calculate_bragg_peak_histogram( int_exp /= np.max(int_exp) return k, int_exp + + +def generate_moire( + bragg_peaks_0, + bragg_peaks_1, + thresh_0 = 0.0002, + thresh_1 = 0.0002, + int_range = (0,5e-3), + exx_1 = 0.0, + eyy_1 = 0.0, + exy_1 = 0.0, + phi_1 = 0.0, + power = 2.0, + k_max = 1.0, + plot_result = True, + plot_subpixel = True, + labels = None, + marker_size_parent = 16, + marker_size_moire = 4, + text_size_parent = 10, + text_size_moire = 6, + add_labels_parent = False, + add_labels_moire = False, + dist_labels = 0.03, + dist_check = 0.06, + sep_labels = 0.03, + figsize = (8,6), + return_moire = False, + returnfig = False, + ): + """ + Calculate a Moire lattice from 2 parent diffraction patterns. + + Parameters + -------- + bragg_peaks_0: BraggVector + Bragg vectors for parent lattice 0. + bragg_peaks_1: BraggVector + Bragg vectors for parent lattice 1. + thresh_0: float + thresh_1: float + int_range: (float, float) + exx_1: float + eyy_1: float + exy_1: float + phi_1: float + power: float + k_max: float + plot_result: bool + plot_subpixel: bool + labels: list + List of text labels for parent lattices + marker_size_parent: float + marker_size_moire: float + text_size_parent: float + text_size_moire: float + add_labels_parent: bool + add_labels_moire: bool + dist_labels: float + dist_check: float + sep_labels: float + figsize: (float,float) + return_moire: bool + returnfig: bool + + Returns + -------- + bragg_peaksMoire: BraggVector (optjonal) + Bragg vectors for moire lattice. + fig, ax: matplotlib handles (optional) + Figure and axes handles for the moire plot. + + """ + + # peak labels + if labels is None: + labels = ('crystal 0', 'crystal 1') + + # get intenties of all peaks + int0 = bragg_peaks_0['intensity']**(power/2.0) + int1 = bragg_peaks_1['intensity']**(power/2.0) + + # peaks above threshold + sub0 = int0 >= thresh_0 + sub1 = int1 >= thresh_1 + + # Remove origin (assuming brightest peak) + ind0_or = np.argmax(bragg_peaks_0['intensity']) + ind1_or = np.argmax(bragg_peaks_1['intensity']) + sub0[ind0_or] = False + sub1[ind1_or] = False + int0_sub = int0[sub0] + int1_sub = int1[sub1] + + # Get peaks + qx0 = bragg_peaks_0['qx'][sub0] + qy0 = bragg_peaks_0['qy'][sub0] + qx1_init = bragg_peaks_1['qx'][sub1] + qy1_init = bragg_peaks_1['qy'][sub1] + + # peak labels + if add_labels_parent or add_labels_moire or return_moire: + def overline(x): + return str(x) if x >= 0 else (r"\overline{" + str(np.abs(x)) + "}") + + h0 = bragg_peaks_0['h'][sub0] + k0 = bragg_peaks_0['k'][sub0] + l0 = bragg_peaks_0['l'][sub0] + h1 = bragg_peaks_1['h'][sub1] + k1 = bragg_peaks_1['k'][sub1] + l1 = bragg_peaks_1['l'][sub1] + + # apply strain tensor to lattice 1 + # infinitesimal + # m = np.array([ + # [1 + exx_1, (exy_1 - phi_1)*0.5], + # [(exy_1 _ phi_1)*0.5, 1 + eyy_1], + # ]) + # finite rotation + m = np.array([ + [np.cos(phi_1), -np.sin(phi_1)], + [np.sin(phi_1), np.cos(phi_1)], + ]) @ np.array([ + [1 + exx_1, exy_1*0.5], + [exy_1*0.5, 1 + eyy_1], + ]) + qx1 = m[0,0] * qx1_init + m[0,1] * qy1_init + qy1 = m[1,0] * qx1_init + m[1,1] * qy1_init + + # Generate moire lattice + ind0, ind1 = np.meshgrid( + np.arange(np.sum(sub0)), + np.arange(np.sum(sub1)), + indexing = 'ij', + ) + # ind0 = ind0.ravel() + # ind1 = ind1.ravel() + qx = qx0[ind0] + qx1[ind1] + qy = qy0[ind0] + qy1[ind1] + # int_moire = int0_sub[ind0] + int1_sub[ind1] + int_moire = (int0_sub[ind0] * int1_sub[ind1]) ** 0.5 + + # moire labels + if add_labels_moire or return_moire: + m_h0 = h0[ind0] + m_k0 = k0[ind0] + m_l0 = l0[ind0] + m_h1 = h1[ind1] + m_k1 = k1[ind1] + m_l1 = l1[ind1] + + # If needed, convert moire peaks to BraggVector class + if return_moire: + pl_dtype = np.dtype([ + ("qx", "float"), + ("qy", "float"), + ("intensity", "float"), + ("h0", "int"), + ("k0", "int"), + ("l0", "int"), + ("h1", "int"), + ("k1", "int"), + ("l1", "int"), + ]) + bragg_moire = PointList( + np.array([],dtype=pl_dtype) + ) + bragg_moire.add_data_by_field([ + qx.ravel(), + qy.ravel(), + int_moire.ravel(), + m_h0.ravel(),m_k0.ravel(),m_l0.ravel(), + m_h1.ravel(),m_k1.ravel(),m_l1.ravel(), + ]) + + + # plot outputs + if plot_result: + fig = plt.figure(figsize = figsize) + ax = fig.add_axes([0.09,0.09,0.65,0.9]) + ax_labels = fig.add_axes([0.75,0,0.25,1]) + + + text_params_parent = { + "ha": "center", + "va": "center", + "family": "sans-serif", + "fontweight": "normal", + "size": text_size_parent, + } + text_params_moire = { + "ha": "center", + "va": "center", + "family": "sans-serif", + "fontweight": "normal", + "size": text_size_moire, + } + + + if plot_subpixel is False: + + # moire + ax.scatter( + qy, + qx, + # color = (0,0,0,1), + c = int_moire, + s = marker_size_moire, + cmap = 'gray_r', + vmin = int_range[0], + vmax = int_range[1], + antialiased=True, + ) + + # parent lattices + ax.scatter( + qy0, + qx0, + color = (1,0,0,1), + s = marker_size_parent, + antialiased=True, + ) + ax.scatter( + qy1, + qx1, + color = (0,0.7,1,1), + s = marker_size_parent, + antialiased=True, + ) + + # origin + ax.scatter( + 0, + 0, + color = (0,0,0,1), + s = marker_size_parent, + antialiased=True, + ) + + else: + # moire peaks + int_all = np.clip( + (int_moire - int_range[0]) / (int_range[1] - int_range[0]), + 0,1) + keep = np.logical_and.reduce(( + qx >= -k_max, + qx <= k_max, + qy >= -k_max, + qy <= k_max + )) + for x, y, int_marker in zip(qx[keep], qy[keep], int_all[keep]): + ax.add_artist(Circle( + xy=(y, x), + radius = np.sqrt(marker_size_moire)/800.0, + color = (1-int_marker,1-int_marker,1-int_marker), + )) + if add_labels_moire: + for a0 in range(qx.size): + if keep.ravel()[a0]: + x0 = qx.ravel()[a0] + y0 = qy.ravel()[a0] + d2 = (qx.ravel()-x0)**2 + (qy.ravel()-y0)**2 + sub = d2 < dist_check**2 + xc = np.mean(qx.ravel()[sub]) + yc = np.mean(qy.ravel()[sub]) + xp = x0 - xc + yp = y0 - yc + if xp == 0 and yp == 0.0: + xp = x0 - dist_labels + yp = y0 + else: + leng = np.linalg.norm((xp,yp)) + xp = x0 + xp * dist_labels / leng + yp = y0 + yp * dist_labels / leng + + ax.text( + yp, + xp - sep_labels, + "$" + overline(m_h0.ravel()[a0]) \ + + overline(m_k0.ravel()[a0]) \ + + overline(m_l0.ravel()[a0]) + "$", + c = 'r', + **text_params_moire, + ) + ax.text( + yp, + xp, + "$" + overline(m_h1.ravel()[a0]) \ + + overline(m_k1.ravel()[a0]) \ + + overline(m_l1.ravel()[a0]) + "$", + c = (0,0.7,1.0), + **text_params_moire, + ) + + + keep = np.logical_and.reduce(( + qx0 >= -k_max, + qx0 <= k_max, + qy0 >= -k_max, + qy0 <= k_max + )) + for x, y in zip(qx0[keep], qy0[keep]): + ax.add_artist(Circle( + xy=(y, x), + radius = np.sqrt(marker_size_parent)/800.0, + color = (1,0,0), + )) + if add_labels_parent: + for a0 in range(qx0.size): + if keep.ravel()[a0]: + xp = qx0.ravel()[a0] - dist_labels + yp = qy0.ravel()[a0] + ax.text( + yp, + xp, + "$" + overline(h0.ravel()[a0]) \ + + overline(k0.ravel()[a0]) \ + + overline(l0.ravel()[a0]) + "$", + c = 'k', + **text_params_parent, + ) + + keep = np.logical_and.reduce(( + qx1 >= -k_max, + qx1 <= k_max, + qy1 >= -k_max, + qy1 <= k_max + )) + for x, y in zip(qx1[keep], qy1[keep]): + ax.add_artist(Circle( + xy=(y, x), + radius = np.sqrt(marker_size_parent)/800.0, + color = (0,0.7,1), + )) + if add_labels_parent: + for a0 in range(qx1.size): + if keep.ravel()[a0]: + xp = qx1.ravel()[a0] - dist_labels + yp = qy1.ravel()[a0] + ax.text( + yp, + xp, + "$" + overline(h1.ravel()[a0]) \ + + overline(k1.ravel()[a0]) \ + + overline(l1.ravel()[a0]) + "$", + c = 'k', + **text_params_parent, + ) + + # origin + ax.add_artist(Circle( + xy=(0, 0), + radius = np.sqrt(marker_size_parent)/800.0, + color = (0,0,0), + )) + + ax.set_xlim((-k_max,k_max)) + ax.set_ylim((-k_max,k_max)) + ax.set_ylabel('$q_x$ (1/A)') + ax.set_xlabel('$q_y$ (1/A)') + ax.invert_yaxis() + + # labels + ax_labels.scatter( + 0, + 0, + color = (1,0,0,1), + s = marker_size_parent, + ) + ax_labels.scatter( + 0, + -1, + color = (0,0.7,1,1), + s = marker_size_parent, + ) + ax_labels.scatter( + 0, + -2, + color = (0,0,0,1), + s = marker_size_moire, + ) + ax_labels.text( + 0.4, + -0.2, + labels[0], + fontsize = 14, + ) + ax_labels.text( + 0.4, + -1.2, + labels[1], + fontsize = 14, + ) + ax_labels.text( + 0.4, + -2.2, + 'MoirĂ© lattice', + fontsize = 14, + ) + + ax_labels.text( + 0, + -4.2, + labels[1] + ' $\epsilon_{xx}$ = ' + str(np.round(exx_1*100,2)) + '%', + fontsize = 14, + ) + ax_labels.text( + 0, + -5.2, + labels[1] + ' $\epsilon_{yy}$ = ' + str(np.round(eyy_1*100,2)) + '%', + fontsize = 14, + ) + ax_labels.text( + 0, + -6.2, + labels[1] + ' $\epsilon_{xy}$ = ' + str(np.round(exy_1*100,2)) + '%', + fontsize = 14, + ) + ax_labels.text( + 0, + -7.2, + labels[1] + ' $\phi$ = ' + str(np.round(phi_1*180/np.pi,2)) + '$^\circ$', + fontsize = 14, + + ) + + ax_labels.set_xlim((-1,4)) + ax_labels.set_ylim((-21,1)) + + ax_labels.axis('off') + + if return_moire: + if returnfig: + return bragg_moire, fig, ax + else: + return bragg_moire + if returnfig: + return fig, ax + + diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index 4fac83281..5b29ee4fb 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -4,13 +4,6 @@ import matplotlib as mpl import matplotlib.pyplot as plt -# <<<<<<< HEAD -# from py4DSTEM.utils.tqdmnd import tqdmnd -# from py4DSTEM.visualize import show, show_image_grid -# # from py4DSTEM.io.datastructure.emd.pointlistarray import PointListArray -# # from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern -# from py4DSTEM.io.datastructure import PointList, PointListArray - from dataclasses import dataclass, field from emdfile import tqdmnd, PointListArray from py4DSTEM.visualize import show, show_image_grid