Skip to content

Commit

Permalink
Fix bug in altitude bins and curve fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
miquelmassot committed Feb 14, 2022
1 parent 5781ae8 commit 06a4fe6
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 18 deletions.
32 changes: 18 additions & 14 deletions src/correct_images/corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,39 +557,43 @@ def generate_attenuation_correction_parameters(self):

self.bin_band = 0.1
hist_bins = np.arange(self.altitude_min, self.altitude_max, self.bin_band)
# Watch out: need to substract 1 to get the correct number of bins
# because the last bin is not included in the range

images_fn, images_map = open_memmap(
shape=(
len(hist_bins),
len(hist_bins) - 1,
self.image_height * self.image_width,
self.image_channels,
),
dtype=np.float32,
)

distances_fn, distances_map = open_memmap(
shape=(len(hist_bins), self.image_height * self.image_width),
shape=(len(hist_bins) - 1, self.image_height * self.image_width),
dtype=np.float32,
)

# Fill with NaN not to use empty bins
# images_map.fill(np.nan)
# distances_map.fill(np.nan)

distance_vector = None

if self.depth_map_list is not None:
Console.info("Computing depth map histogram with", hist_bins.size, " bins")
Console.info("Computing depth map histogram with", hist_bins.size - 1, " bins")
distance_vector = np.zeros((len(self.depth_map_list), 1))
for i, dm_file in enumerate(self.depth_map_list):
dm_np = depth_map.loader(dm_file, self.image_width, self.image_height)
distance_vector[i] = dm_np.mean(axis=1)
elif self.altitude_list is not None:
Console.info("Computing altitude histogram with", hist_bins.size, " bins")
Console.info("Computing altitude histogram with", hist_bins.size - 1, " bins:")
distance_vector = np.array(self.altitude_list)

if distance_vector is not None:
idxs = np.digitize(distance_vector, hist_bins)
idxs = np.digitize(distance_vector, hist_bins) - 1

# Display histogram in console
for idx_bin in range(hist_bins.size - 1):
tmp_idxs = np.where(idxs == idx_bin)[0]
Console.info(" Bin", idx_bin, "(", hist_bins[idx_bin],
"m < x <", hist_bins[idx_bin + 1], "m):", len(tmp_idxs), "images")

with tqdm_joblib(
tqdm(desc="Computing altitude histogram", total=hist_bins.size - 1,)
):
Expand All @@ -603,7 +607,7 @@ def generate_attenuation_correction_parameters(self):
max_bin_size_gb,
distance_vector,
)
for idx_bin in range(1, hist_bins.size)
for idx_bin in range(hist_bins.size - 1)
)

# calculate attenuation parameters per channel
Expand Down Expand Up @@ -764,7 +768,7 @@ def compute_distance_bin(
):
dimensions = [self.image_height, self.image_width, self.image_channels]
tmp_idxs = np.where(idxs == idx_bin)[0]
# Console.info("In bin", idx_bin, "there are", len(tmp_idxs), "images")

if len(tmp_idxs) > 2:
bin_images = [self.camera_image_list[i] for i in tmp_idxs]
bin_distances_sample = None
Expand Down Expand Up @@ -818,7 +822,7 @@ def compute_distance_bin(
plt.imshow(bin_images_sample)
plt.colorbar()
plt.title("Image bin " + str(idx_bin))
fig_name = base_path / ("bin_images_sample_" + str(idx_bin) + ".png")
fig_name = base_path / ("bin_images_sample_" + str(distance_bin_sample) + "m.png")
#Console.info("Saved figure at", fig_name)
plt.savefig(fig_name, dpi=600)
plt.close(fig)
Expand All @@ -827,7 +831,7 @@ def compute_distance_bin(
plt.imshow(bin_distances_sample)
plt.colorbar()
plt.title("Distance bin " + str(idx_bin))
fig_name = base_path / ("bin_distances_sample_" + str(idx_bin) + ".png")
fig_name = base_path / ("bin_distances_sample_" + str(distance_bin_sample) + "m.png")
#Console.info("Saved figure at", fig_name)
plt.savefig(fig_name, dpi=600)
plt.close(fig)
Expand Down
36 changes: 32 additions & 4 deletions src/correct_images/tools/curve_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,41 @@ def curve_fitting(altitudes: np.ndarray, intensities: np.ndarray) -> np.ndarray:
altitudes = altitudes[np.isfinite(altitudes)]
intensities = intensities[np.isfinite(intensities)]

if altitudes.size == 0:
print("---------\naltitudes: ", altitudes, "\nintensities: ", intensities)
print("ERROR: Empty non-nan altitudes in curve fitting")
return np.array([1, 0, 0])
if intensities.size == 0:
print("---------\naltitudes: ", altitudes, "\nintensities: ", intensities)
print("ERROR: Empty non-nan intensities in curve fitting")
return np.array([1, 0, 0])

altitudes_filt = []
intensities_filt = []
for x, y in zip(altitudes, intensities):
if x > 0 and y > 0:
if x >= 0 and y >= 0:
altitudes_filt.append(x)
intensities_filt.append(y)

if not altitudes_filt or not intensities_filt:
if not altitudes_filt:
print("---------\naltitudes: ", altitudes, "\nintensities: ", intensities, "\naltitudes_filt: ", altitudes_filt,
"\nintensities_filt: ", intensities_filt)
print("ERROR: Altitudes are negative in curve fitting")
if not intensities_filt:
print("---------\naltitudes: ", altitudes, "\nintensities: ", intensities, "\naltitudes_filt: ", altitudes_filt,
"\nintensities_filt: ", intensities_filt)
print("ERROR: Intensities are negative in curve fitting")
return np.array([1, 0, 0])

altitudes_filt = np.array(altitudes_filt)
intensities_filt = np.array(intensities_filt)

c_upper_bound = intensities_filt.min()
try:
c_upper_bound = intensities_filt.min()
except ValueError: # raised if it is empty.
c_upper_bound = np.finfo(float).eps

if c_upper_bound <= 0:
# c should be slightly greater than zero to avoid error
# 'Each lower bound must be strictly less than each upper bound.'
Expand All @@ -109,7 +133,11 @@ def curve_fitting(altitudes: np.ndarray, intensities: np.ndarray) -> np.ndarray:

# Avoid zero divisions
b = 0.0
c = intensities_filt.min() * 0.5
try:
c = intensities_filt.min() * 0.5
except ValueError: # raised if it is empty.
c = np.finfo(float).eps

if intensities_filt[idx_1] != 0:
b = (np.log((int_0 - c) / (int_1 - c))) / (alt_0 - alt_1)
a = (int_1 - c) / np.exp(b * alt_1)
Expand Down Expand Up @@ -152,5 +180,5 @@ def curve_fitting(altitudes: np.ndarray, intensities: np.ndarray) -> np.ndarray:
return tmp_params.x
except (ValueError, UnboundLocalError) as e:
print("ERROR: Value Error due to Overflow", a, b, c)
print("Parameters calculated are unoptimised because of Value Error", e)
print("Parameters calculated are unoptimised because of error", e)
return init_params

0 comments on commit 06a4fe6

Please sign in to comment.