From 36f5cbbd3f00ab3db54068484c95d340b5825ac9 Mon Sep 17 00:00:00 2001 From: Samuele Bumbaca <46762244+SamueleBumbaca@users.noreply.github.com> Date: Tue, 8 Oct 2024 11:50:46 +0200 Subject: [PATCH] LoftrRomaToMultiview masks Masks are required in setups such as the Rolling Table. This is a way to remove outside mask keypoints detected by detector-free matchers like RoMa and LOFTR. It should have masks named as the image counterpart and in .png format in a masks folder. --- .../utils/loftr_roma_to_multiview.py | 66 +++++++------------ 1 file changed, 25 insertions(+), 41 deletions(-) diff --git a/src/deep_image_matching/utils/loftr_roma_to_multiview.py b/src/deep_image_matching/utils/loftr_roma_to_multiview.py index 52353bf..d3e38ff 100644 --- a/src/deep_image_matching/utils/loftr_roma_to_multiview.py +++ b/src/deep_image_matching/utils/loftr_roma_to_multiview.py @@ -5,7 +5,7 @@ from copy import deepcopy from collections import defaultdict from deep_image_matching.io.h5_to_db import COLMAPDatabase, image_ids_to_pair_id - +from typing import Optional import os, h5py, warnings import numpy as np from tqdm import tqdm @@ -17,7 +17,6 @@ def get_focal(image_path, err_on_default=False): image = Image.open(image_path) max_size = max(image.size) - exif = image.getexif() focal = None if exif is not None: @@ -28,26 +27,20 @@ def get_focal(image_path, err_on_default=False): if ExifTags.TAGS.get(tag, None) == 'FocalLengthIn35mmFilm': focal_35mm = float(value) break - if focal_35mm is not None: focal = focal_35mm / 35. * max_size - if focal is None: if err_on_default: raise RuntimeError("Failed to find focal length") - # failed to find it in exif, use prior FOCAL_PRIOR = 1.2 focal = FOCAL_PRIOR * max_size - return focal def create_camera(db, image_path, camera_model): image = Image.open(image_path) width, height = image.size - focal = get_focal(image_path) - if camera_model == 'simple-pinhole': model = 0 # simple pinhole param_arr = np.array([focal, width / 2, height / 2]) @@ -60,7 +53,6 @@ def create_camera(db, image_path, camera_model): elif camera_model == 'opencv': model = 4 # opencv param_arr = np.array([focal, focal, width / 2, height / 2, 0., 0., 0., 0.]) - return db.add_camera(model, width, height, param_arr) @@ -68,7 +60,6 @@ def add_keypoints(db, h5_path, image_path, camera_model, single_camera = True): keypoint_f = h5py.File(os.path.join(h5_path, 'keypoints.h5'), 'r') fname_to_id = {} db.clean_keypoints() - for filename in tqdm(list(keypoint_f.keys())): keypoints = keypoint_f[filename][()] fname_with_ext = filename @@ -79,36 +70,29 @@ def add_keypoints(db, h5_path, image_path, camera_model, single_camera = True): image_id, camera_id = images[filename] fname_to_id[filename] = image_id db.add_keypoints(image_id, keypoints) - return fname_to_id def add_matches(db, h5_path, fname_to_id): db.clean_matches() db.clean_two_view_geometries() match_file = h5py.File(os.path.join(h5_path, 'matches.h5'), 'r') - added = set() n_keys = len(match_file.keys()) n_total = (n_keys * (n_keys - 1)) // 2 - with tqdm(total=n_total) as pbar: for key_1 in match_file.keys(): group = match_file[key_1] for key_2 in group.keys(): id_1 = fname_to_id[key_1] id_2 = fname_to_id[key_2] - pair_id = image_ids_to_pair_id(id_1, id_2) if pair_id in added: warnings.warn(f'Pair {pair_id} ({id_1}, {id_2}) already added!') continue - matches = group[key_2][()] #db.add_matches(id_1, id_2, matches) db.add_two_view_geometry(id_1, id_2, matches) - added.add(pair_id) - pbar.update(1) def get_unique_idxs(A, dim=1): @@ -122,8 +106,9 @@ def get_unique_idxs(A, dim=1): def import_into_colmap(img_dir, feature_dir ='.featureout', - database_path = 'colmap.db' + database_path = 'database.db' ): + db = COLMAPDatabase.connect(database_path) #db.create_tables() single_camera = False @@ -133,7 +118,6 @@ def import_into_colmap(img_dir, feature_dir, fname_to_id, ) - db.commit() return @@ -142,12 +126,19 @@ def LoftrRomaToMultiview( output_dir: Path, image_dir: Path, img_ext: Path, + mask_dir: Optional[Path] = None, ) -> None: - - with h5py.File(fr'{input_dir}\features.h5', mode='r') as h5_feats, \ - h5py.File(fr'{input_dir}\matches.h5', mode='r') as h5_matches, \ - h5py.File(fr'{input_dir}\matches_loftr.h5', mode='w') as h5_out: - + if mask_dir: + image_names = [p.name for p in image_dir.glob(f"*{img_ext}")] + mask_paths = {p: os.path.join(mask_dir, p.replace(img_ext,'.png')) for p in image_names} + mask_arr = {} + for p in image_names: + if not os.path.exists(mask_paths[p]): + FileNotFoundError(f"Mask for {p} not found, expected path: {mask_paths[p]}") + mask_arr[p] = np.array(Image.open(mask_paths[p])) > 0 # ensure binary + with h5py.File(fr'{input_dir}/features.h5', mode='r') as h5_feats, \ + h5py.File(fr'{input_dir}/matches.h5', mode='r') as h5_matches, \ + h5py.File(fr'{input_dir}/matches_loftr.h5', mode='w') as h5_out: for img1 in h5_matches.keys(): print(img1) kpts1 = h5_feats[img1]['keypoints'][...] @@ -158,17 +149,20 @@ def LoftrRomaToMultiview( kpts2 = h5_feats[img2]['keypoints'][...] matches = group_match[img2][...] h5_out[img1][img2] = np.hstack((kpts1[matches[:,0],:], kpts2[matches[:,1],:])) - kpts = defaultdict(list) match_indexes = defaultdict(dict) total_kpts=defaultdict(int) - - with h5py.File(fr'{input_dir}\matches_loftr.h5', mode='r') as f_match: + with h5py.File(fr'{input_dir}/matches_loftr.h5', mode='r') as f_match: for k1 in f_match.keys(): group = f_match[k1] for k2 in group.keys(): matches = group[k2][...] - total_kpts[k1] + if mask_dir: + matches = matches[ + mask_arr[k1][matches[:, 1].astype(int), matches[:, 0].astype(int)] & + mask_arr[k2][matches[:, 3].astype(int), matches[:, 2].astype(int)] + ] + total_kpts[k1]#??? kpts[k1].append(matches[:, :2]) kpts[k2].append(matches[:, 2:]) current_match = torch.arange(len(matches)).reshape(-1, 1).repeat(1, 2) @@ -177,10 +171,8 @@ def LoftrRomaToMultiview( total_kpts[k1]+=len(matches) total_kpts[k2]+=len(matches) match_indexes[k1][k2]=current_match - for k in kpts.keys(): kpts[k] = np.round(np.concatenate(kpts[k], axis=0)) - unique_kpts = {} unique_match_idxs = {} out_match = defaultdict(dict) @@ -204,22 +196,14 @@ def LoftrRomaToMultiview( unique_idxs_current2 = get_unique_idxs(m2_semiclean[:, 1], dim=0) m2_semiclean2 = m2_semiclean[unique_idxs_current2] out_match[k1][k2] = m2_semiclean2.numpy() - - with h5py.File(fr'{output_dir}\keypoints.h5', mode='w') as f_kp: + with h5py.File(fr'{output_dir}/keypoints.h5', mode='w') as f_kp: for k, kpts1 in unique_kpts.items(): f_kp[k] = kpts1 - - with h5py.File(fr'{output_dir}\matches.h5', mode='w') as f_match: + with h5py.File(fr'{output_dir}/matches.h5', mode='w') as f_match: for k1, gr in out_match.items(): group = f_match.require_group(k1) for k2, match in gr.items(): group[k2] = match - - try: - os.remove(f"{output_dir}/database.db") - except: - pass - import_into_colmap( image_dir, feature_dir=f"{output_dir}", @@ -245,4 +229,4 @@ def LoftrRomaToMultiview( output_dir, image_dir, img_ext, - ) \ No newline at end of file + )