Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for pathlib.Path objects as input #460

Closed
wants to merge 11 commits into from
4 changes: 1 addition & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
name: Test Python package

on:
- push

on: [push, pull_request]
jobs:
test:
runs-on: ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dependencies:
- geopandas>=0.7.0
- matplotlib>=3.1.2
- numpy>=1.17.3
- opencv-python>=4.1
- opencv>=4.1
- pandas>=0.25.3
- pyproj>=2.1
- PyYAML>=5.4
Expand Down
15 changes: 8 additions & 7 deletions solaris/data/coco.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
from pathlib import Path

import geopandas as gpd
import numpy as np
Expand Down Expand Up @@ -47,7 +48,7 @@ def geojson2coco(

Arguments
---------
image_src : :class:`str` or :class:`list` or :class:`dict`
image_src : :class:`str` or :class:`pathlib.Path` or :class:`list` or :class:`dict`
Source image(s) to use in the dataset. This can be::

1. a string path to an image,
Expand Down Expand Up @@ -149,8 +150,8 @@ def geojson2coco(
logger.setLevel(_get_logging_level(int(verbose)))
logger.debug("Preparing image filename: image ID dict.")
# pdb.set_trace()
if isinstance(image_src, str):
if image_src.endswith("json"):
if isinstance(image_src, (str, Path)):
if str(image_src).endswith("json"):
logger.debug("COCO json provided. Extracting fname:id dict.")
with open(image_src, "r") as f:
image_ref = json.load(f)
Expand Down Expand Up @@ -599,13 +600,13 @@ def _get_fname_list(p, recursive=False, extension=".tif"):
"""Get a list of filenames from p, which can be a dir, fname, or list."""
if isinstance(p, list):
return p
elif isinstance(p, str):
if os.path.isdir(p):
elif isinstance(p, (str, Path)):
if Path(p).is_dir():
return get_files_recursively(
p, traverse_subdirs=recursive, extension=extension
)
elif os.path.isfile(p):
return [p]
elif Path(p).is_file():
return [str(p)]
else:
raise ValueError("If a string is provided, it must be a valid" " path.")
else:
Expand Down
34 changes: 16 additions & 18 deletions solaris/eval/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
from pathlib import Path

import geopandas as gpd
import pandas as pd
import shapely.wkt
from fiona._err import CPLE_OpenFailedError
from fiona.errors import DriverError
from solaris.utils.core import _check_gdf_load
from tqdm.auto import tqdm

from . import iou
Expand All @@ -29,28 +31,29 @@ class Evaluator:

Arguments
---------
ground_truth_vector_file : str
ground_truth_vector_file : `str` or :class:`pathlib.Path`
Path to .geojson file for ground truth.

"""

def __init__(self, ground_truth_vector_file):
# Load Ground Truth : Ground Truth should be in geojson or shape file
try:
if ground_truth_vector_file.lower().endswith("json"):
self.load_truth(ground_truth_vector_file)
elif ground_truth_vector_file.lower().endswith("csv"):
self.load_truth(ground_truth_vector_file, truthCSV=True)
self.ground_truth_fname = ground_truth_vector_file
except AttributeError: # handles passing gdf instead of path to file
self.ground_truth_GDF = ground_truth_vector_file
if isinstance(ground_truth_vector_file, (str, Path)):
self.ground_truth_fname = str(ground_truth_vector_file)
else:
self.ground_truth_fname = "GeoDataFrame variable"

if isinstance(ground_truth_vector_file, (str, Path)) and ground_truth_vector_file.lower().endswith("csv"):
self.load_truth(ground_truth_vector_file, truthCSV=True)
else:
self.load_truth(ground_truth_vector_file)
self.ground_truth_sindex = self.ground_truth_GDF.sindex # get sindex
# create deep copy of ground truth file for calculations
self.ground_truth_GDF_Edit = self.ground_truth_GDF.copy(deep=True)
self.proposal_GDF = gpd.GeoDataFrame([]) # initialize proposal GDF

def __repr__(self):

return "Evaluator {}".format(os.path.split(self.ground_truth_fname)[-1])

def get_iou_by_building(self):
Expand Down Expand Up @@ -509,7 +512,7 @@ def load_proposal(

Arguments
---------
proposal_vector_file : str
proposal_vector_file : `str` or :class:`pathlib.Path`
Path to the file containing proposal vector objects. This can be
a .geojson or a .csv.
conf_field_list : list, optional
Expand Down Expand Up @@ -540,7 +543,7 @@ def load_proposal(
"""

# Load Proposal if proposal_vector_file is a path to a file
if os.path.isfile(proposal_vector_file):
if Path(proposal_vector_file).is_file():
# if it's a CSV format, first read into a pd df and then convert
# to gpd gdf by loading in geometries using shapely
if proposalCSV:
Expand Down Expand Up @@ -588,7 +591,7 @@ def load_truth(

Arguments
---------
ground_truth_vector_file : str
ground_truth_vector_file : `str` or :class:`pathlib.Path`
Path to the ground truth vector file. Must be either .geojson or
.csv format.
truthCSV : bool, optional
Expand Down Expand Up @@ -617,12 +620,7 @@ def load_truth(
],
)
else:
try:
self.ground_truth_GDF = gpd.read_file(ground_truth_vector_file)
except (CPLE_OpenFailedError, DriverError): # empty geojson
self.ground_truth_GDF = gpd.GeoDataFrame(
{"sindex": [], "condition": [], "geometry": []}
)
self.ground_truth_GDF = _check_gdf_load(ground_truth_vector_file)
# force calculation of spatialindex
self.ground_truth_sindex = self.ground_truth_GDF.sindex
# create deep copy of ground truth file for calculations
Expand Down
8 changes: 4 additions & 4 deletions solaris/eval/pixel.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ def f1(
``1``, values < `prop_threshold` will be set to ``0``.
show_plot : bool, optional
Switch to plot the outputs. Defaults to ``False``.
im_file : str, optional
im_file : `str` or :class:`pathlib.Path`, optional
Image file corresponding to the masks. Ignored if
``show_plot == False``. Defaults to ``''``.
show_colorbar : bool, optional
Switch to show colorbar. Ignored if ``show_plot == False``.
Defaults to ``False``.
plot_file : str, optional
plot_file : `str` or :class:`pathlib.Path`, optional
Output file if plotting. Ignored if ``show_plot == False``.
Defaults to ``''``.
dpi : int, optional
Expand Down Expand Up @@ -167,7 +167,7 @@ def f1(
plt.suptitle(title, fontsize=fontsize)

# ground truth
if len(im_file) > 0:
if len(str(im_file)) > 0:
# raw image
ax1.imshow(cv2.imread(im_file, 1))
# ground truth
Expand Down Expand Up @@ -211,7 +211,7 @@ def f1(
# fig.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.subplots_adjust(top=0.8)

if len(plot_file) > 0:
if len(str(plot_file)) > 0:
plt.savefig(plot_file, dpi=dpi)
print("Time to create and save F1 plots:", time.time() - t0, "seconds")

Expand Down
37 changes: 19 additions & 18 deletions solaris/eval/vector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import glob
import os
from pathlib import Path

import geopandas as gpd
import numpy as np
Expand Down Expand Up @@ -49,9 +50,9 @@ def get_all_objects(
unique classes present in each
Arguments
---------
proposal_polygons_dir : str
proposal_polygons_dir : `str` or :class:`pathlib.Path`
The path that contains any model proposal polygons
gt_polygons_dir : str
gt_polygons_dir : `str` or :class:`pathlib.Path`
The path that contains the ground truth polygons
prediction_cat_attrib : str
The column or attribute within the predictions that specifies
Expand All @@ -71,23 +72,23 @@ def get_all_objects(
A union of the prop_objs and gt_objs lists
"""
objs = []
os.chdir(proposal_polygons_dir)
os.chdir(str(proposal_polygons_dir))
search = "*" + file_format
proposal_geojsons = glob.glob(search)
for geojson in tqdm(proposal_geojsons):
ground_truth_poly = os.path.join(gt_polygons_dir, geojson)
ground_truth_poly = Path(gt_polygons_dir) / geojson
if os.path.exists(ground_truth_poly):
ground_truth_gdf = gpd.read_file(ground_truth_poly)
proposal_gdf = gpd.read_file(geojson)
for index, row in proposal_gdf.iterrows():
objs.append(row[prediction_cat_attrib])
prop_objs = list(set(objs))
os.chdir(gt_polygons_dir)
os.chdir(str(gt_polygons_dir))
search = "*" + file_format
objs = []
gt_geojsons = glob.glob(search)
for geojson in tqdm(gt_geojsons):
proposal_poly = os.path.join(proposal_polygons_dir, geojson)
proposal_poly = Path(proposal_polygons_dir) / geojson
if os.path.exists(proposal_poly):
proposal_gdf = gpd.read_file(proposal_poly)
ground_truth_gdf = gpd.read_file(geojson)
Expand All @@ -114,9 +115,9 @@ def precision_calc(
calculate metric for classes that exist in the ground truth.
Arguments
---------
proposal_polygons_dir : str
proposal_polygons_dir : `str` or :class:`pathlib.Path`
The path that contains any model proposal polygons
gt_polygons_dir : str
gt_polygons_dir : `str` or :class:`pathlib.Path`
The path that contains the ground truth polygons
prediction_cat_attrib : str
The column or attribute within the predictions that specifies
Expand Down Expand Up @@ -148,7 +149,7 @@ def precision_calc(
All confidences for each object for each class
"""
ious = []
os.chdir(proposal_polygons_dir)
os.chdir(str(proposal_polygons_dir))
search = "*" + file_format
proposal_geojsons = glob.glob(search)
iou_holder = []
Expand All @@ -166,7 +167,7 @@ def precision_calc(
confidences.append([])

for geojson in tqdm(proposal_geojsons):
ground_truth_poly = os.path.join(gt_polygons_dir, geojson)
ground_truth_poly = Path(gt_polygons_dir) / geojson
if os.path.exists(ground_truth_poly):
ground_truth_gdf = gpd.read_file(ground_truth_poly)
proposal_gdf = gpd.read_file(geojson)
Expand Down Expand Up @@ -241,9 +242,9 @@ def recall_calc(
calculate metric for classes that exist in the ground truth.
Arguments
---------
proposal_polygons_dir : str
proposal_polygons_dir : `str` or :class:`pathlib.Path`
The path that contains any model proposal polygons
gt_polygons_dir : str
gt_polygons_dir : `str` or :class:`pathlib.Path`
The path that contains the ground truth polygons
prediction_cat_attrib : str
The column or attribute within the predictions that specifies
Expand All @@ -270,7 +271,7 @@ def recall_calc(
The mean recall score of recall_by_class
"""
ious = []
os.chdir(gt_polygons_dir)
os.chdir(str(gt_polygons_dir))
search = "*" + file_format
gt_geojsons = glob.glob(search)
iou_holder = []
Expand All @@ -285,7 +286,7 @@ def recall_calc(
for i in range(len(object_subset)):
iou_holder.append([])
for geojson in tqdm(gt_geojsons):
proposal_poly = os.path.join(proposal_polygons_dir, geojson)
proposal_poly = Path(proposal_polygons_dir) / geojson
if os.path.exists(proposal_poly):
proposal_gdf = gpd.read_file(proposal_poly)
ground_truth_gdf = gpd.read_file(geojson)
Expand Down Expand Up @@ -353,9 +354,9 @@ def mF1(
only calculate metric for classes that exist in the ground truth.
Arguments
---------
proposal_polygons_dir : str
proposal_polygons_dir : `str` or :class:`pathlib.Path`
The path that contains any model proposal polygons
gt_polygons_dir : str
gt_polygons_dir : `str` or :class:`pathlib.Path`
The path that contains the ground truth polygons
prediction_cat_attrib : str
The column or attribute within the predictions that specifies
Expand Down Expand Up @@ -480,9 +481,9 @@ def mAP_score(

Arguments
---------
proposal_polygons_dir : str
proposal_polygons_dir : `str` or :class:`pathlib.Path`
The path that contains any model proposal polygons
gt_polygons_dir : str
gt_polygons_dir : `str` or :class:`pathlib.Path`
The path that contains the ground truth polygons
prediction_cat_attrib : str
The column or attribute within the predictions that specifies
Expand Down
10 changes: 6 additions & 4 deletions solaris/raster/image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import numpy as np
import rasterio

Expand All @@ -9,7 +11,7 @@ def get_geo_transform(raster_src):

Arguments
---------
raster_src : str, :class:`rasterio.DatasetReader`, or `osgeo.gdal.Dataset`
raster_src : str, :class:`pathlib.Path`, :class:`rasterio.DatasetReader`, or `osgeo.gdal.Dataset`
Path to a raster image with georeferencing data to apply to `geom`.
Alternatively, an opened :class:`rasterio.Band` object or
:class:`osgeo.gdal.Dataset` object can be provided. Required if not
Expand All @@ -21,7 +23,7 @@ def get_geo_transform(raster_src):
An affine transformation object to the image's location in its CRS.
"""

if isinstance(raster_src, str):
if isinstance(raster_src, (str, Path)):
affine_obj = rasterio.open(raster_src).transform
elif isinstance(raster_src, rasterio.DatasetReader):
affine_obj = raster_src.transform
Expand Down Expand Up @@ -175,7 +177,7 @@ def stitch_images(
# ---------
# array : :class:`numpy.ndarray`
# A numpy array with a the shape: [Channels, X, Y] or [X, Y]
# out_name : str
# out_name : str or :class:`pathlib.Path`
# The output name and path for your image
# proj : :class:`gdal.projection`
# A projection, can be extracted from an image opened with gdal with
Expand All @@ -200,7 +202,7 @@ def stitch_images(
# driver = gdal.GetDriverByName("GTiff")
# if len(array.shape) == 2:
# array = array[np.newaxis, ...]
# os.makedirs(os.path.dirname(os.path.abspath(out_name)), exist_ok=True)
# Path(out_name).resolve().parent.mkdir(exist_ok=True)
# dataset = driver.Create(out_name, array.shape[2], array.shape[1], array.shape[0], out_format)
# if verbose is True:
# print("Array Shape, should be [Channels, X, Y] or [X,Y]:", array.shape)
Expand Down
Loading