diff --git a/clic/src/tier4/threshold_otsu.cpp b/clic/src/tier4/threshold_otsu.cpp index 1553cca0..b248dee8 100644 --- a/clic/src/tier4/threshold_otsu.cpp +++ b/clic/src/tier4/threshold_otsu.cpp @@ -11,86 +11,22 @@ namespace cle::tier4 { -// -// auto -// threshold_otsu_func(const Device::Pointer & device, const Array::Pointer & src, Array::Pointer dst) -> Array::Pointer -// { -// constexpr int bin = 256; -// const float min_intensity = tier2::minimum_of_all_pixels_func(device, src); -// const float max_intensity = tier2::maximum_of_all_pixels_func(device, src); -// auto hist_array = Array::create(bin, 1, 1, 1, dType::FLOAT, mType::BUFFER, src->device()); -// tier3::histogram_func(device, src, hist_array, bin, min_intensity, max_intensity); -// std::vector histogram_array(hist_array->size()); -// hist_array->readTo(histogram_array.data()); -// double threshold = -1; -// double max_variance = -1; -// double variance = 0; -// double sum_1 = 0; -// double sum_2 = 0; -// double weight_1 = 0; -// double weight_2 = 0; -// double mean_1 = 0; -// double mean_2 = 0; -// const double nb_pixels = src->size(); -// const double intensity_factor = static_cast(max_intensity - min_intensity); -// // const double intensity_factor = static_cast(max_intensity - min_intensity) / static_cast(bin - -// 1); -// // old implementation - to be removed -// // std::vector range(histogram_array.size()); -// // std::iota(range.begin(), range.end(), 0.0); -// // std::transform(range.begin(), range.end(), range.begin(), [intensity_factor, min_intensity](float intensity) { -// // return intensity * intensity_factor + static_cast(min_intensity); -// // }); -// // -// std::vector range(bin); -// std::iota(range.begin(), range.end(), 0.0f); -// std::transform(range.begin(), range.end(), range.begin(), [intensity_factor, min_intensity](double value) { -// return (value * intensity_factor) / (bin - 1) + static_cast(min_intensity); -// }); -// sum_1 = std::transform_reduce( -// range.begin(), range.end(), histogram_array.begin(), 0.0f, std::plus<>(), [](double intensity, float hist_value) -// { -// return intensity * static_cast(hist_value); -// }); -// for (size_t index = 0; index < range.size(); ++index) -// { -// if (histogram_array[index] == 0) -// { -// continue; -// } -// weight_1 += histogram_array[index]; -// weight_2 = nb_pixels - weight_1; -// sum_2 += range[index] * static_cast(histogram_array[index]); -// mean_1 = sum_2 / weight_1; -// mean_2 = (sum_1 - sum_2) / weight_2; -// variance = weight_1 * weight_2 * ((mean_1 - mean_2) * (mean_1 - mean_2)); -// if (variance > max_variance) -// { -// threshold = range[index]; -// max_variance = variance; -// } -// } -// std::cout << "Threshold: " << threshold << std::endl; -// tier0::create_like(src, dst, dType::BINARY); -// return tier1::greater_constant_func(device, src, dst, static_cast(threshold)); -// } - - auto threshold_otsu_func(const Device::Pointer & device, const Array::Pointer & src, Array::Pointer dst) -> Array::Pointer { + // Initialize histogram constexpr int bin = 256; const float min_intensity = tier2::minimum_of_all_pixels_func(device, src); const float max_intensity = tier2::maximum_of_all_pixels_func(device, src); - + double range = max_intensity - min_intensity; + + // Compute histogram auto hist_array = Array::create(bin, 1, 1, 1, dType::FLOAT, mType::BUFFER, src->device()); tier3::histogram_func(device, src, hist_array, bin, min_intensity, max_intensity); std::vector counts(hist_array->size()); hist_array->readTo(counts.data()); - - double range = max_intensity - min_intensity; - + // Compute bin centers std::vector bin_centers(bin); std::iota(bin_centers.begin(), bin_centers.end(), 0.0); std::transform( @@ -98,15 +34,14 @@ threshold_otsu_func(const Device::Pointer & device, const Array::Pointer & src, return (value * range) / (bin - 1) + static_cast(min_intensity); }); - std::vector weight1(bin), weight2(bin), mean1(bin), mean2(bin), variance12(bin - 1); - // Compute weight1 + std::vector weight1(bin), weight2(bin), mean1(bin), mean2(bin), variance12(bin - 1); std::partial_sum(counts.begin(), counts.end(), weight1.begin()); // Compute weight2 std::vector reversed_counts(counts.rbegin(), counts.rend()); std::partial_sum(reversed_counts.begin(), reversed_counts.end(), weight2.rbegin()); - + // Compute mean1 std::vector counts_bin_centers(bin); std::transform(counts.begin(), counts.end(), bin_centers.begin(), counts_bin_centers.begin(), std::multiplies<>()); @@ -123,13 +58,12 @@ threshold_otsu_func(const Device::Pointer & device, const Array::Pointer & src, variance12[i] = weight1[i] * weight2[i + 1] * (mean1[i] - mean2[i + 1]) * (mean1[i] - mean2[i + 1]); } + // Find the maximum variance and threshold value associated with it auto max_it = std::max_element(variance12.begin(), variance12.end()); size_t idx = std::distance(variance12.begin(), max_it); - double threshold = bin_centers[idx]; - - std::cout << "Threshold: " << threshold << std::endl; - + + // Create binary image with threshold tier0::create_like(src, dst, dType::BINARY); return tier1::greater_constant_func(device, src, dst, static_cast(threshold)); }