Skip to content

Commit

Permalink
Merge pull request #79 from SamueleBumbaca/SamueleBumbaca-patch-1
Browse files Browse the repository at this point in the history
LoftrRomaToMultiview mask
  • Loading branch information
lcmrl authored Oct 8, 2024
2 parents 5c7bab7 + 36f5cbb commit 629449a
Showing 1 changed file with 25 additions and 41 deletions.
66 changes: 25 additions & 41 deletions src/deep_image_matching/utils/loftr_roma_to_multiview.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from copy import deepcopy
from collections import defaultdict
from deep_image_matching.io.h5_to_db import COLMAPDatabase, image_ids_to_pair_id

from typing import Optional
import os, h5py, warnings
import numpy as np
from tqdm import tqdm
Expand All @@ -17,7 +17,6 @@
def get_focal(image_path, err_on_default=False):
image = Image.open(image_path)
max_size = max(image.size)

exif = image.getexif()
focal = None
if exif is not None:
Expand All @@ -28,26 +27,20 @@ def get_focal(image_path, err_on_default=False):
if ExifTags.TAGS.get(tag, None) == 'FocalLengthIn35mmFilm':
focal_35mm = float(value)
break

if focal_35mm is not None:
focal = focal_35mm / 35. * max_size

if focal is None:
if err_on_default:
raise RuntimeError("Failed to find focal length")

# failed to find it in exif, use prior
FOCAL_PRIOR = 1.2
focal = FOCAL_PRIOR * max_size

return focal

def create_camera(db, image_path, camera_model):
image = Image.open(image_path)
width, height = image.size

focal = get_focal(image_path)

if camera_model == 'simple-pinhole':
model = 0 # simple pinhole
param_arr = np.array([focal, width / 2, height / 2])
Expand All @@ -60,15 +53,13 @@ def create_camera(db, image_path, camera_model):
elif camera_model == 'opencv':
model = 4 # opencv
param_arr = np.array([focal, focal, width / 2, height / 2, 0., 0., 0., 0.])

return db.add_camera(model, width, height, param_arr)


def add_keypoints(db, h5_path, image_path, camera_model, single_camera = True):
keypoint_f = h5py.File(os.path.join(h5_path, 'keypoints.h5'), 'r')
fname_to_id = {}
db.clean_keypoints()

for filename in tqdm(list(keypoint_f.keys())):
keypoints = keypoint_f[filename][()]
fname_with_ext = filename
Expand All @@ -79,36 +70,29 @@ def add_keypoints(db, h5_path, image_path, camera_model, single_camera = True):
image_id, camera_id = images[filename]
fname_to_id[filename] = image_id
db.add_keypoints(image_id, keypoints)

return fname_to_id

def add_matches(db, h5_path, fname_to_id):
db.clean_matches()
db.clean_two_view_geometries()
match_file = h5py.File(os.path.join(h5_path, 'matches.h5'), 'r')

added = set()
n_keys = len(match_file.keys())
n_total = (n_keys * (n_keys - 1)) // 2

with tqdm(total=n_total) as pbar:
for key_1 in match_file.keys():
group = match_file[key_1]
for key_2 in group.keys():
id_1 = fname_to_id[key_1]
id_2 = fname_to_id[key_2]

pair_id = image_ids_to_pair_id(id_1, id_2)
if pair_id in added:
warnings.warn(f'Pair {pair_id} ({id_1}, {id_2}) already added!')
continue

matches = group[key_2][()]
#db.add_matches(id_1, id_2, matches)
db.add_two_view_geometry(id_1, id_2, matches)

added.add(pair_id)

pbar.update(1)

def get_unique_idxs(A, dim=1):
Expand All @@ -122,8 +106,9 @@ def get_unique_idxs(A, dim=1):

def import_into_colmap(img_dir,
feature_dir ='.featureout',
database_path = 'colmap.db'
database_path = 'database.db'
):

db = COLMAPDatabase.connect(database_path)
#db.create_tables()
single_camera = False
Expand All @@ -133,7 +118,6 @@ def import_into_colmap(img_dir,
feature_dir,
fname_to_id,
)

db.commit()
return

Expand All @@ -142,12 +126,19 @@ def LoftrRomaToMultiview(
output_dir: Path,
image_dir: Path,
img_ext: Path,
mask_dir: Optional[Path] = None,
) -> None:

with h5py.File(fr'{input_dir}\features.h5', mode='r') as h5_feats, \
h5py.File(fr'{input_dir}\matches.h5', mode='r') as h5_matches, \
h5py.File(fr'{input_dir}\matches_loftr.h5', mode='w') as h5_out:

if mask_dir:
image_names = [p.name for p in image_dir.glob(f"*{img_ext}")]
mask_paths = {p: os.path.join(mask_dir, p.replace(img_ext,'.png')) for p in image_names}
mask_arr = {}
for p in image_names:
if not os.path.exists(mask_paths[p]):
FileNotFoundError(f"Mask for {p} not found, expected path: {mask_paths[p]}")
mask_arr[p] = np.array(Image.open(mask_paths[p])) > 0 # ensure binary
with h5py.File(fr'{input_dir}/features.h5', mode='r') as h5_feats, \
h5py.File(fr'{input_dir}/matches.h5', mode='r') as h5_matches, \
h5py.File(fr'{input_dir}/matches_loftr.h5', mode='w') as h5_out:
for img1 in h5_matches.keys():
print(img1)
kpts1 = h5_feats[img1]['keypoints'][...]
Expand All @@ -158,17 +149,20 @@ def LoftrRomaToMultiview(
kpts2 = h5_feats[img2]['keypoints'][...]
matches = group_match[img2][...]
h5_out[img1][img2] = np.hstack((kpts1[matches[:,0],:], kpts2[matches[:,1],:]))

kpts = defaultdict(list)
match_indexes = defaultdict(dict)
total_kpts=defaultdict(int)

with h5py.File(fr'{input_dir}\matches_loftr.h5', mode='r') as f_match:
with h5py.File(fr'{input_dir}/matches_loftr.h5', mode='r') as f_match:
for k1 in f_match.keys():
group = f_match[k1]
for k2 in group.keys():
matches = group[k2][...]
total_kpts[k1]
if mask_dir:
matches = matches[
mask_arr[k1][matches[:, 1].astype(int), matches[:, 0].astype(int)] &
mask_arr[k2][matches[:, 3].astype(int), matches[:, 2].astype(int)]
]
total_kpts[k1]#???
kpts[k1].append(matches[:, :2])
kpts[k2].append(matches[:, 2:])
current_match = torch.arange(len(matches)).reshape(-1, 1).repeat(1, 2)
Expand All @@ -177,10 +171,8 @@ def LoftrRomaToMultiview(
total_kpts[k1]+=len(matches)
total_kpts[k2]+=len(matches)
match_indexes[k1][k2]=current_match

for k in kpts.keys():
kpts[k] = np.round(np.concatenate(kpts[k], axis=0))

unique_kpts = {}
unique_match_idxs = {}
out_match = defaultdict(dict)
Expand All @@ -204,22 +196,14 @@ def LoftrRomaToMultiview(
unique_idxs_current2 = get_unique_idxs(m2_semiclean[:, 1], dim=0)
m2_semiclean2 = m2_semiclean[unique_idxs_current2]
out_match[k1][k2] = m2_semiclean2.numpy()

with h5py.File(fr'{output_dir}\keypoints.h5', mode='w') as f_kp:
with h5py.File(fr'{output_dir}/keypoints.h5', mode='w') as f_kp:
for k, kpts1 in unique_kpts.items():
f_kp[k] = kpts1

with h5py.File(fr'{output_dir}\matches.h5', mode='w') as f_match:
with h5py.File(fr'{output_dir}/matches.h5', mode='w') as f_match:
for k1, gr in out_match.items():
group = f_match.require_group(k1)
for k2, match in gr.items():
group[k2] = match

try:
os.remove(f"{output_dir}/database.db")
except:
pass

import_into_colmap(
image_dir,
feature_dir=f"{output_dir}",
Expand All @@ -245,4 +229,4 @@ def LoftrRomaToMultiview(
output_dir,
image_dir,
img_ext,
)
)

0 comments on commit 629449a

Please sign in to comment.