Skip to content

Commit

Permalink
More updates to the polar branch
Browse files Browse the repository at this point in the history
  • Loading branch information
cophus committed Jul 3, 2024
1 parent 99078f0 commit 084dd1b
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 31 deletions.
250 changes: 222 additions & 28 deletions py4DSTEM/process/polar/polar_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,7 @@ def cluster_grains(
threshold_add=0.05,
threshold_grow=0.2,
angle_tolerance_deg=5,
distance_tolerance_px = 1,
plot_grain_clusters=True,
returncalc=False,
**kwargs
Expand All @@ -1048,7 +1049,9 @@ def cluster_grains(
threshold_grow: float
Minimum signal required for a probe position to be added to a cluster.
angle_tolerance_deg: float
Rotation rolerance for clustering grains.
Angular rotation tolerance for clustering grains.
distance_tolerance_px: int
Distance tolerance for clustering grains.
plot_grain_clusters: bool
If True, plots clusters. **kwargs passed to `plot_grain_clusters`
return_calc: bool
Expand Down Expand Up @@ -1107,6 +1110,7 @@ def cluster_grains(
tol = np.deg2rad(angle_tolerance_deg)

# Main loop
vec_search = np.arange(-distance_tolerance_px, distance_tolerance_px+1, dtype="int")
search = True
while search is True:
inds_grain = np.argmax(sig)
Expand All @@ -1127,8 +1131,8 @@ def cluster_grains(
phi_cluster = phi[z]

# Neighbors to search
xr = np.clip(x + np.arange(-1, 2, dtype="int"), 0, sig.shape[0] - 1)
yr = np.clip(y + np.arange(-1, 2, dtype="int"), 0, sig.shape[1] - 1)
xr = np.clip(x + vec_search, 0, sig.shape[0] - 1)
yr = np.clip(y + vec_search, 0, sig.shape[1] - 1)
inds_cand = inds_all[xr[:, None], yr[None], :].ravel()
inds_cand = np.delete(inds_cand, mark.ravel()[inds_cand] == False)

Expand Down Expand Up @@ -1198,6 +1202,7 @@ def plot_clusters(
outline_grains=True,
outline_thickness=1,
fill_grains=0.25,
weight_by_area=False,
smooth_grains=1.0,
cmap="viridis",
figsize=(8, 8),
Expand All @@ -1206,24 +1211,26 @@ def plot_clusters(
returnfig=False,
):
"""
Plot the clusters as an image.
Plot the clusters / domains as an image.
Parameters
--------
area_min: int (optional)
Min cluster size to include, in units of probe positions.
outline_grains: bool (optional)
Set to True to draw grains with outlines
Set to True to draw domains with outlines
outline_thickness: int (optional)
Thickenss of the grain outline
Thickenss of the domain outline
fill_grains: float (optional)
Outlined grains are filled with this value in pixels.
Outlined domains are filled with this value in pixels.
weight_by_area: bool (optional)
Weight the domain fill and edges by each domain's area.
smooth_grains: float (optional)
Grain boundaries are smoothed by this value in pixels.
Domain boundaries are smoothed by this value in pixels.
figsize: tuple
Size of the figure panel
returncalc: bool
Return the grain image.
Return the domain image.
returnfig: bool, optional
Setting this to true returns the figure and axis handles
Expand Down Expand Up @@ -1252,10 +1259,15 @@ def plot_clusters(
dtype="bool",
)

if weight_by_area:
weights = self.cluster_sizes / np.max(self.cluster_sizes) * (1-fill_grains) + fill_grains
else:
weights = np.ones_like(self.cluster_sizes)

# make plotting image
for a0 in tqdmnd(
self.cluster_sizes.shape[0],
desc="Generating grain image",
desc="Generating domain image",
unit=" grains",
disable=not progress_bar,
):
Expand Down Expand Up @@ -1291,7 +1303,7 @@ def plot_clusters(
# im_add = 1 - np.exp(
# distance_transform_edt(im_grain)**2 \
# / (-2*outline_thickness**2))
im_plot += im_add
im_plot += im_add * weights[a0]
# im_plot = np.minimum(im_plot, im_add)
else:
# xg,yg = np.unravel_index(self.cluster_inds[a0], im_plot.shape)
Expand All @@ -1302,7 +1314,7 @@ def plot_clusters(
] = True
im_plot += gaussian_filter(
im_grain.astype("float"), sigma=smooth_grains, mode="nearest"
)
) * weights[a0]

# im_plot[
# self.cluster_inds[a0][0,:],
Expand Down Expand Up @@ -1331,7 +1343,7 @@ def plot_clusters(
return fig, ax


def plot_grain_clusters_area(
def plot_clusters_area(
self,
area_min=None,
area_max=None,
Expand Down Expand Up @@ -1423,14 +1435,16 @@ def plot_grain_clusters_area(
return fig, ax


def plot_grain_clusters_diameter(
def plot_clusters_diameter(
self,
cluster_sizes=None,
diameter_min=None,
diameter_max=None,
diameter_step=1,
weight_intensity=False,
weight_diameter=True,
pixel_area=1.0,
weight_diameter=False,
weight_area=False,
pixel_size=1.0,
pixel_area_units="px",
figsize=(8, 6),
returnfig=False,
Expand All @@ -1440,6 +1454,8 @@ def plot_grain_clusters_diameter(
Parameters
--------
cluster_sizes: np.array
Size in pixels^2 of all clusters.
diameter_min: int (optional)
Min area bin in pixels
diameter_max: int (optional)
Expand All @@ -1449,8 +1465,10 @@ def plot_grain_clusters_diameter(
weight_intensity: bool
Weight histogram by the peak intensity.
weight_diameter: bool
Weight histogram by the area of each bin.
pixel_area: float
Weight histogram by the diameter of each cluster.
weight_diameter: bool
Weight histogram by the area of each cluster.
pixel_size: float
Size of pixel area unit square
pixel_area_units: string
Units of the pixel area
Expand All @@ -1466,9 +1484,14 @@ def plot_grain_clusters_diameter(
"""

cluster_diam = 0.5 * np.sqrt(self.cluster_sizes.astype('float') / np.pi)
if cluster_sizes is None:
cluster_sizes = self.cluster_sizes.copy().astype('float')
cluster_diam = 0.5 * np.sqrt(cluster_sizes / np.pi)

# units
cluster_diam *= pixel_size

# subset
if diameter_max is None:
diameter_max = np.max(cluster_diam)
diameter = np.arange(0, diameter_max, diameter_step)
Expand All @@ -1479,12 +1502,26 @@ def plot_grain_clusters_diameter(
cluster_diam >= diameter_min,
cluster_diam < diameter_max,
)

# histogram
if weight_intensity:
hist = np.bincount(
np.round(cluster_diam[sub] / diameter_step).astype('int'),
weights=self.cluster_sig[sub],
minlength=diameter.shape[0],
)
elif weight_area:
hist = np.bincount(
np.round(cluster_diam[sub] / diameter_step).astype('int'),
weights=cluster_sizes[sub],
minlength=diameter.shape[0],
)
elif weight_diameter:
hist = np.bincount(
np.round(cluster_diam[sub] / diameter_step).astype('int'),
weights=cluster_diam[sub],
minlength=diameter.shape[0],
)
else:
hist = np.bincount(
np.round(cluster_diam[sub] / diameter_step).astype('int'),
Expand All @@ -1495,24 +1532,181 @@ def plot_grain_clusters_diameter(
fig, ax = plt.subplots(figsize=figsize)
if weight_diameter:
ax.bar(
diameter * pixel_area,
diameter,
hist * diameter,
width=0.8 * pixel_area * diameter_step,
width=0.8 * diameter_step,
)
else:
ax.bar(
diameter * pixel_area,
diameter,
hist,
width=0.8 * pixel_area * diameter_step,
width=0.8 * diameter_step,
)
ax.set_xlim((0, diameter_max * pixel_area))
ax.set_xlim((0, diameter_max))
ax.set_xlabel("Domain Size (" + pixel_area_units + ")")

if weight_intensity:
ax.set_ylabel("Total Signal (arb. units)")
# elif weight_area:
# ax.set_ylabel("Area-Weighted Signal (arb. units)")
ax.set_ylabel("Intensity-Weighted Domain Size")
elif weight_area:
ax.set_ylabel("Total Domain Area")
elif weight_diameter:
ax.set_ylabel("Total Domain Diameter")
else:
ax.set_ylabel("Number of Grains")
ax.set_ylabel("Number of Domains")

if weight_intensity or weight_area or weight_diameter:
ax.set_yticks([])

if returnfig:
return fig, ax



def plot_clusters_max_length(
self,
cluster_inds=None,
length_min=None,
length_max=None,
length_step=1,
weight_intensity=False,
weight_length=False,
weight_area=False,
pixel_size=1.0,
pixel_area_units="px",
figsize=(8, 6),
returnfig=False,
):
"""
Plot the histogram of the max length of each cluster (over all angles).
Parameters
--------
cluster_inds: list of np.array
List of clusters, with the indices given for each cluster.
length_min: int (optional)
Min area bin in pixels
length_max: int (optional)
Max area bin in pixels
length_step: int (optional)
Step size of the histogram bin
weight_intensity: bool
Weight histogram by the peak intensity.
weight_length: bool
Weight histogram by the length of each cluster.
weight_diameter: bool
Weight histogram by the area of each cluster.
pixel_size: float
Size of pixel area unit square
pixel_area_units: string
Units of the pixel area
figsize: tuple
Size of the figure panel
returnfig: bool
Setting this to true returns the figure and axis handles
Returns
--------
fig, ax (optional)
Figure and axes handles
"""

if cluster_inds is None:
cluster_inds = self.cluster_inds.copy().astype('float')

# init
num_clusters = len(cluster_inds)
cluster_sizes = np.zeros(num_clusters)
cluster_length = np.zeros(num_clusters)
t_all = np.linspace(0,np.pi,45,endpoint=False)
ct = np.cos(t_all)
st = np.sin(t_all)

# calculate size and max lengths of each cluster
for a0 in range(num_clusters):
cluster_sizes[a0] = cluster_inds[a0].shape[1]

x0 = cluster_inds[a0][0]
y0 = cluster_inds[a0][1]

length_current_max = 0
for a1 in range(t_all.size):
p = x0*ct[a1] + y0*st[a1]
length_current_max = np.maximum(
length_current_max,
np.max(p) - np.min(p) + 1,
)
cluster_length[a0] = length_current_max


# units
cluster_length *= pixel_size

# subset
if length_max is None:
length_max = np.max(cluster_length)
length = np.arange(0, length_max, length_step)
if length_min is None:
sub = cluster_length < length_max
else:
sub = np.logical_and(
cluster_length >= length_min,
cluster_length < length_max,
)

# histogram
if weight_intensity:
hist = np.bincount(
np.round(cluster_length[sub] / length_step).astype('int'),
weights=self.cluster_sig[sub],
minlength=length.shape[0],
)
elif weight_area:
hist = np.bincount(
np.round(cluster_length[sub] / length_step).astype('int'),
weights=cluster_sizes[sub],
minlength=length.shape[0],
)
elif weight_diameter:
hist = np.bincount(
np.round(cluster_length[sub] / length_step).astype('int'),
weights=cluster_diam[sub],
minlength=length.shape[0],
)
else:
hist = np.bincount(
np.round(cluster_length[sub] / length_step).astype('int'),
minlength=length.shape[0],
)

# plotting
fig, ax = plt.subplots(figsize=figsize)
if weight_length:
ax.bar(
length,
hist * length,
width=0.8 * length_step,
)
else:
ax.bar(
length,
hist,
width=0.8 * length_step,
)
ax.set_xlim((0, length_max))
ax.set_xlabel("Maximum Domain Length (" + pixel_area_units + ")")

if weight_intensity:
ax.set_ylabel("Intensity-Weighted Domain Lengths")
elif weight_area:
ax.set_ylabel("Total Domain Area")
elif weight_diameter:
ax.set_ylabel("Total Domain Diameter")
else:
ax.set_ylabel("Number of Domains")

if weight_intensity or weight_area or weight_length:
ax.set_yticks([])

if returnfig:
return fig, ax
Expand Down
Loading

0 comments on commit 084dd1b

Please sign in to comment.