Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix grid decimation filter on empty input #8

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 119 additions & 129 deletions src/filter_grid_decimation/GridDecimationFilter.cpp
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
/******************************************************************************
* Copyright (c) 2023, Antoine Lavenant ([email protected])
*
* All rights reserved.
*
****************************************************************************/
* Copyright (c) 2023, Antoine Lavenant ([email protected])
*
* All rights reserved.
*
****************************************************************************/

#include "GridDecimationFilter.hpp"

#include <pdal/PointView.hpp>
#include <pdal/StageFactory.hpp>

#include <sstream>
#include <cstdarg>
#include <sstream>

namespace pdal
{
namespace pdal {

static StaticPluginInfo const s_info
{
static StaticPluginInfo const s_info{
"filters.grid_decimation_deprecated", // better to use the pdal gridDecimation plugIN
"keep max or min points in a grid",
"",
Expand All @@ -27,153 +25,145 @@ CREATE_SHARED_STAGE(GridDecimationFilter, s_info)

std::string GridDecimationFilter::getName() const { return s_info.name; }

GridDecimationFilter::GridDecimationFilter() : m_args(new GridDecimationFilter::GridArgs)
{}
GridDecimationFilter::GridDecimationFilter() : m_args(new GridDecimationFilter::GridArgs) {}

GridDecimationFilter::~GridDecimationFilter() {}

GridDecimationFilter::~GridDecimationFilter()
{}
void GridDecimationFilter::addArgs(ProgramArgs &args) {
args.add("resolution", "Cell edge size, in units of X/Y", m_args->m_edgeLength, 1.);
args.add("output_type", "Point kept into the cells ('min', 'max')", m_args->m_methodKeep, "max");
args.add("output_dimension", "Name of the added dimension", m_args->m_nameOutDimension, "grid");
args.add("output_wkt", "Export the grid as wkt", m_args->m_nameWktgrid, "");
}

void GridDecimationFilter::initialize() {}

void GridDecimationFilter::addArgs(ProgramArgs& args)
{
args.add("resolution", "Cell edge size, in units of X/Y",m_args->m_edgeLength, 1.);
args.add("output_type", "Point keept into the cells ('min', 'max')", m_args->m_methodKeep, "max" );
args.add("output_dimension", "Name of the added dimension", m_args->m_nameOutDimension, "grid" );
args.add("output_wkt", "Export the grid as wkt", m_args->m_nameWktgrid, "" );
void GridDecimationFilter::prepared(PointTableRef table) { PointLayoutPtr layout(table.layout()); }

}
void GridDecimationFilter::ready(PointTableRef table) {
if (m_args->m_edgeLength <= 0)
throwError("resolution must be positive.");

void GridDecimationFilter::initialize()
{
}
if (m_args->m_methodKeep != "max" && m_args->m_methodKeep != "min")
throwError("The output_type must be 'max' or 'min'.");

if (m_args->m_nameOutDimension.empty())
throwError("The output_dimension must be given.");

void GridDecimationFilter::prepared(PointTableRef table)
{
PointLayoutPtr layout(table.layout());
if (!m_args->m_nameWktgrid.empty())
std::remove(m_args->m_nameWktgrid.c_str());
}

void GridDecimationFilter::ready(PointTableRef table)
{
if (m_args->m_edgeLength <=0)
throwError("resolution must be positive.");

if (m_args->m_methodKeep != "max" && m_args->m_methodKeep != "min")
throwError("The output_type must be 'max' or 'min'.");

if (m_args->m_nameOutDimension.empty())
throwError("The output_dimension must be given.");

if (!m_args->m_nameWktgrid.empty())
std::remove(m_args->m_nameWktgrid.c_str());
void GridDecimationFilter::addDimensions(PointLayoutPtr layout) {
m_args->m_dim =
layout->registerOrAssignDim(m_args->m_nameOutDimension, Dimension::Type::Unsigned8);
}

void GridDecimationFilter::addDimensions(PointLayoutPtr layout)
{
m_args->m_dim = layout->registerOrAssignDim(m_args->m_nameOutDimension, Dimension::Type::Unsigned8);
void GridDecimationFilter::processOne(BOX2D bounds, PointRef &point, PointViewPtr view) {
// get the grid cell
double x = point.getFieldAs<double>(Dimension::Id::X);
double y = point.getFieldAs<double>(Dimension::Id::Y);
int id = point.getFieldAs<double>(Dimension::Id::PointId);

// if x==(xmax of the cell), we assume the point are in the upper cell
// if y==(ymax of the cell), we assume the point are in the right cell
int width = static_cast<int>((x - bounds.minx) / m_args->m_edgeLength);
int height = static_cast<int>((y - bounds.miny) / m_args->m_edgeLength);

// to avoid numeric pb with the division (append if the point is on the grid)
if (x < bounds.minx + width * m_args->m_edgeLength)
width--;
if (y < bounds.miny + height * m_args->m_edgeLength)
height--;
if (x >= bounds.minx + (width + 1) * m_args->m_edgeLength)
width++;
if (y >= bounds.miny + (height + 1) * m_args->m_edgeLength)
height++;

auto mptRefid = this->grid.find(std::make_pair(width, height));
assert(mptRefid != this->grid.end());
auto ptRefid = mptRefid->second;

if (ptRefid == -1) {
this->grid[std::make_pair(width, height)] = point.pointId();
return;
}

PointRef ptRef = view->point(ptRefid);

double z = point.getFieldAs<double>(Dimension::Id::Z);
double zRef = ptRef.getFieldAs<double>(Dimension::Id::Z);

if (this->m_args->m_methodKeep == "max" && z > zRef)
this->grid[std::make_pair(width, height)] = point.pointId();
if (this->m_args->m_methodKeep == "min" && z < zRef)
this->grid[std::make_pair(width, height)] = point.pointId();
}

void GridDecimationFilter::processOne(BOX2D bounds, PointRef& point, PointViewPtr view)
{
//get the grid cell
double x = point.getFieldAs<double>(Dimension::Id::X);
double y = point.getFieldAs<double>(Dimension::Id::Y);
int id = point.getFieldAs<double>(Dimension::Id::PointId);

// if x==(xmax of the cell), we assume the point are in the upper cell
// if y==(ymax of the cell), we assume the point are in the right cell
int width = static_cast<int>((x - bounds.minx) / m_args->m_edgeLength);
int height = static_cast<int>((y - bounds.miny) / m_args->m_edgeLength);

// to avoid numeric pb with the division (append if the point is on the grid)
if (x < bounds.minx+width*m_args->m_edgeLength) width--;
if (y < bounds.miny+height*m_args->m_edgeLength) height--;
if (x >= bounds.minx+(width+1)*m_args->m_edgeLength) width++;
if (y >= bounds.miny+(height+1)*m_args->m_edgeLength) height++;

auto mptRefid = this->grid.find( std::make_pair(width,height) );
assert( mptRefid != this->grid.end() );
auto ptRefid = mptRefid->second;

if (ptRefid==-1)
{
this->grid[ std::make_pair(width,height) ] = point.pointId();
return;
}

PointRef ptRef = view->point(ptRefid);
void GridDecimationFilter::createGrid(BOX2D bounds) {

double z = point.getFieldAs<double>(Dimension::Id::Z);
double zRef = ptRef.getFieldAs<double>(Dimension::Id::Z);
size_t d_width = std::floor((bounds.maxx - bounds.minx) / m_args->m_edgeLength) + 1;
size_t d_height = std::floor((bounds.maxy - bounds.miny) / m_args->m_edgeLength) + 1;

if (this->m_args->m_methodKeep == "max" && z>zRef)
this->grid[ std::make_pair(width,height) ] = point.pointId();
if (this->m_args->m_methodKeep == "min" && z<zRef)
this->grid[ std::make_pair(width,height) ] = point.pointId();
}
if (d_width < 0.0 || d_width > (std::numeric_limits<int>::max)())
throwError("Grid width out of range.");
if (d_height < 0.0 || d_height > (std::numeric_limits<int>::max)())
throwError("Grid height out of range.");

void GridDecimationFilter::createGrid(BOX2D bounds)
{
size_t d_width = std::floor((bounds.maxx - bounds.minx) / m_args->m_edgeLength) + 1;
size_t d_height = std::floor((bounds.maxy - bounds.miny) / m_args->m_edgeLength) + 1;

if (d_width < 0.0 || d_width > (std::numeric_limits<int>::max)())
throwError("Grid width out of range.");
if (d_height < 0.0 || d_height > (std::numeric_limits<int>::max)())
throwError("Grid height out of range.");

int width = static_cast<int>(d_width);
int height = static_cast<int>(d_height);

std::vector<Polygon> vgrid;

for (size_t l(0); l<height; l++)
for (size_t c(0); c<width; c++)
{
BOX2D bounds_dalle (bounds.minx + c*m_args->m_edgeLength, bounds.miny + l*m_args->m_edgeLength,
bounds.minx + (c+1)*m_args->m_edgeLength, bounds.miny + (l+1)*m_args->m_edgeLength );
vgrid.push_back(Polygon(bounds_dalle));
this->grid.insert( std::make_pair( std::make_pair(c,l), -1) );
}
int width = static_cast<int>(d_width);
int height = static_cast<int>(d_height);

if (!m_args->m_nameWktgrid.empty())
{
std::ofstream oss (m_args->m_nameWktgrid);
for (auto pol : vgrid)
oss << pol.wkt() << std::endl;
std::vector<Polygon> vgrid;

for (size_t l(0); l < height; l++)
for (size_t c(0); c < width; c++) {
BOX2D bounds_dalle(bounds.minx + c * m_args->m_edgeLength,
bounds.miny + l * m_args->m_edgeLength,
bounds.minx + (c + 1) * m_args->m_edgeLength,
bounds.miny + (l + 1) * m_args->m_edgeLength);
vgrid.push_back(Polygon(bounds_dalle));
this->grid.insert(std::make_pair(std::make_pair(c, l), -1));
}


if (!m_args->m_nameWktgrid.empty()) {
std::ofstream oss(m_args->m_nameWktgrid);
for (auto pol : vgrid)
oss << pol.wkt() << std::endl;
}
}

PointViewSet GridDecimationFilter::run(PointViewPtr view)
{
PointViewSet GridDecimationFilter::run(PointViewPtr view) {
if (view->empty()) {
if (!m_args->m_nameWktgrid.empty())
std::ofstream{m_args->m_nameWktgrid};

} else {
BOX2D bounds;
view->calculateBounds(bounds);
createGrid(bounds);

for (PointId i = 0; i < view->size(); ++i)
{
PointRef point = view->point(i);
processOne(bounds,point,view);
for (PointId i = 0; i < view->size(); ++i) {
PointRef point = view->point(i);
processOne(bounds, point, view);
}

std::set<PointId> keepPoint;
for (auto it : this->grid)
if (it.second != -1)
keepPoint.insert(it.second);

for (PointId i = 0; i < view->size(); ++i)
{
PointRef point = view->point(i);
if (keepPoint.find(view->point(i).pointId()) != keepPoint.end())
point.setField(m_args->m_dim, int64_t(1));
else
point.setField(m_args->m_dim, int64_t(0));
if (it.second != -1)
keepPoint.insert(it.second);

for (PointId i = 0; i < view->size(); ++i) {
PointRef point = view->point(i);
if (keepPoint.find(view->point(i).pointId()) != keepPoint.end())
point.setField(m_args->m_dim, int64_t(1));
else
point.setField(m_args->m_dim, int64_t(0));
}

PointViewSet viewSet;
viewSet.insert(view);
return viewSet;
}

PointViewSet viewSet;
viewSet.insert(view);
return viewSet;
}

} // namespace pdal
36 changes: 29 additions & 7 deletions test/test_grid_decimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tempfile
from test import utils

import numpy as np
import pdal
import pdaltools.las_info as li
import pytest
Expand All @@ -14,14 +15,14 @@ def contains(bounds, x, y):
return bounds[0] <= x and x < bounds[1] and bounds[2] <= y and y < bounds[3]


def run_filter(type, resolution):
def run_filter(output_type, resolution):

ini_las = "test/data/4_6.las"

tmp_out_wkt = tempfile.NamedTemporaryFile(suffix=f"_{resolution}.wkt").name

filter = "filters.grid_decimation_deprecated"
utils.pdal_has_plugin(filter)
filter_name = "filters.grid_decimation_deprecated"
utils.pdal_has_plugin(filter_name)

bounds = li.las_get_xy_bounds(ini_las)

Expand All @@ -32,9 +33,9 @@ def run_filter(type, resolution):
PIPELINE = [
{"type": "readers.las", "filename": ini_las},
{
"type": filter,
"type": filter_name,
"resolution": resolution,
"output_type": type,
"output_type": output_type,
"output_dimension": "grid",
"output_wkt": tmp_out_wkt,
},
Expand Down Expand Up @@ -75,10 +76,10 @@ def run_filter(type, resolution):
continue

z = pt["Z"]
if type == "max":
if output_type == "max":
if ZRef == 0 or z > ZRef:
ZRef = z
elif type == "min":
elif output_type == "min":
if ZRef == 0 or z < ZRef:
ZRef = z

Expand Down Expand Up @@ -112,3 +113,24 @@ def test_grid_decimation_max(resolution):
)
def test_grid_decimation_min(resolution):
run_filter("min", resolution)


def test_grid_decimation_empty():
ini_las = "test/data/4_6.las"
with tempfile.NamedTemporaryFile(suffix="_empty.wkt") as tmp_out_wkt:
pipeline = pdal.Pipeline() | pdal.Reader.las(filename=ini_las)
pipeline |= pdal.Filter.grid_decimation_deprecated(
resolution=10,
output_type="min",
output_dimension="grid",
output_wkt=tmp_out_wkt.name,
where="Classification==123", # should create an empty result
)
pipeline.execute()

with open(tmp_out_wkt.name, "r") as f:
reader = csv.reader(f, delimiter="\t")
lines = [line for line in reader]
assert len(lines) == 0

assert np.all(pipeline.arrays[0]["grid"] == 0)
Loading