Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds DoG-HardNet model #103

Merged
merged 14 commits into from
Jan 23, 2024
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ We provide a [demo notebook](demo.ipynb) which shows how to perform feature extr
Here is a minimal script to match two images:

```python
from lightglue import LightGlue, SuperPoint, DISK, SIFT, ALIKED
from lightglue import LightGlue, SuperPoint, DISK, SIFT, ALIKED, DoGHardNet
from lightglue.utils import load_image, rbd

# SuperPoint+LightGlue
Expand Down
1 change: 1 addition & 0 deletions lightglue/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .aliked import ALIKED # noqa
from .disk import DISK # noqa
from .dog_hardnet import DoGHardNet # noqa
from .lightglue import LightGlue # noqa
from .sift import SIFT # noqa
from .superpoint import SuperPoint # noqa
Expand Down
41 changes: 41 additions & 0 deletions lightglue/dog_hardnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
from kornia.color import rgb_to_grayscale
from kornia.feature import HardNet, LAFDescriptor, laf_from_center_scale_ori

from .sift import SIFT


class DoGHardNet(SIFT):
required_data_keys = ["image"]

def __init__(self, **conf):
super().__init__(**conf)
self.laf_desc = LAFDescriptor(HardNet(True)).eval()

def forward(self, data: dict) -> dict:
image = data["image"]
if image.shape[1] == 3:
image = rgb_to_grayscale(image)
device = image.device
self.laf_desc = self.laf_desc.to(device)
self.laf_desc.descriptor = self.laf_desc.descriptor.eval()
pred = []
if "image_size" in data.keys():
im_size = data.get("image_size").long()
else:
im_size = None
for k in range(len(image)):
img = image[k]
if im_size is not None:
w, h = data["image_size"][k]
img = img[:, : h.to(torch.int32), : w.to(torch.int32)]
p = self.extract_single_image(img)
lafs = laf_from_center_scale_ori(
p["keypoints"].reshape(1, -1, 2),
6.0 * p["scales"].reshape(1, -1, 1, 1),
torch.rad2deg(p["oris"]).reshape(1, -1, 1),
).to(device)
p["descriptors"] = self.laf_desc(img[None], lafs).reshape(-1, 128)
pred.append(p)
pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
return pred
5 changes: 5 additions & 0 deletions lightglue/lightglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,11 @@ class LightGlue(nn.Module):
"input_dim": 128,
"add_scale_ori": True,
},
"doghardnet": {
"weights": "doghardnet_lightglue",
"input_dim": 128,
"add_scale_ori": True,
},
}

def __init__(self, features="superpoint", **conf) -> None:
Expand Down