Skip to content

Commit

Permalink
Auto formatting (#34)
Browse files Browse the repository at this point in the history
- Use black: `python -m black .`
- [Update flake8
rules](https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#flake8)
  • Loading branch information
sarlinpe authored Sep 27, 2023
1 parent 86d8f67 commit ca40df7
Show file tree
Hide file tree
Showing 11 changed files with 529 additions and 406 deletions.
3 changes: 3 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[flake8]
max-line-length = 88
extend-ignore = E203
16 changes: 16 additions & 0 deletions .github/workflows/black.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: Black Formatting Check
on:
push:
branches:
- main
pull_request:
types: [ assigned, opened, synchronize, reopened ]
jobs:
formatting-check:
name: Formatting Check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: psf/black@stable
with:
jupyter: true
210 changes: 125 additions & 85 deletions benchmark.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Benchmark script for LightGlue on real images
from pathlib import Path
import argparse
Expand All @@ -15,9 +14,9 @@
torch.set_grad_enabled(False)


def measure(matcher, data, device='cuda', r=100):
def measure(matcher, data, device="cuda", r=100):
timings = np.zeros((r, 1))
if device.type == 'cuda':
if device.type == "cuda":
starter = torch.cuda.Event(enable_timing=True)
ender = torch.cuda.Event(enable_timing=True)
# warmup
Expand All @@ -26,7 +25,7 @@ def measure(matcher, data, device='cuda', r=100):
# measurements
with torch.no_grad():
for rep in range(r):
if device.type == 'cuda':
if device.type == "cuda":
starter.record()
_ = matcher(data)
ender.record()
Expand All @@ -40,77 +39,99 @@ def measure(matcher, data, device='cuda', r=100):
timings[rep] = curr_time
mean_syn = np.sum(timings) / r
std_syn = np.std(timings)
return {'mean': mean_syn, 'std': std_syn}
return {"mean": mean_syn, "std": std_syn}


def print_as_table(d, title, cnames):
print()
header = f'{title:30} '+' '.join([f'{x:>7}' for x in cnames])
header = f"{title:30} " + " ".join([f"{x:>7}" for x in cnames])
print(header)
print('-'*len(header))
print("-" * len(header))
for k, l in d.items():
print(f'{k:30}', ' '.join([f'{x:>7.1f}' for x in l]))


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Benchmark script for LightGlue')
parser.add_argument('--device', choices=['auto', 'cuda', 'cpu', 'mps'],
default='auto', help='device to benchmark on')
parser.add_argument('--compile', action='store_true',
help='Compile LightGlue runs')
parser.add_argument('--no_flash', action='store_true',
help='disable FlashAttention')
parser.add_argument('--no_prune_thresholds', action='store_true',
help='disable pruning thresholds (i.e. always do pruning)')
parser.add_argument('--add_superglue', action='store_true',
help='add SuperGlue to the benchmark (requires hloc)')
parser.add_argument('--measure', default='time',
choices=['time', 'log-time', 'throughput'])
parser.add_argument('--repeat', '--r', type=int, default=100,
help='repetitions of measurements')
parser.add_argument('--num_keypoints', nargs="+", type=int,
default=[256, 512, 1024, 2048, 4096],
help='number of keypoints (list separated by spaces)')
parser.add_argument('--matmul_precision', default='highest',
choices=['highest', 'high', 'medium'])
parser.add_argument('--save', default=None, type=str,
help='path where figure should be saved')
print(f"{k:30}", " ".join([f"{x:>7.1f}" for x in l]))


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark script for LightGlue")
parser.add_argument(
"--device",
choices=["auto", "cuda", "cpu", "mps"],
default="auto",
help="device to benchmark on",
)
parser.add_argument("--compile", action="store_true", help="Compile LightGlue runs")
parser.add_argument(
"--no_flash", action="store_true", help="disable FlashAttention"
)
parser.add_argument(
"--no_prune_thresholds",
action="store_true",
help="disable pruning thresholds (i.e. always do pruning)",
)
parser.add_argument(
"--add_superglue",
action="store_true",
help="add SuperGlue to the benchmark (requires hloc)",
)
parser.add_argument(
"--measure", default="time", choices=["time", "log-time", "throughput"]
)
parser.add_argument(
"--repeat", "--r", type=int, default=100, help="repetitions of measurements"
)
parser.add_argument(
"--num_keypoints",
nargs="+",
type=int,
default=[256, 512, 1024, 2048, 4096],
help="number of keypoints (list separated by spaces)",
)
parser.add_argument(
"--matmul_precision", default="highest", choices=["highest", "high", "medium"]
)
parser.add_argument(
"--save", default=None, type=str, help="path where figure should be saved"
)
args = parser.parse_intermixed_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if args.device != 'auto':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.device != "auto":
device = torch.device(args.device)

print('Running benchmark on device:', device)
print("Running benchmark on device:", device)

images = Path('assets')
images = Path("assets")
inputs = {
'easy': (load_image(images / 'DSC_0411.JPG'),
load_image(images / 'DSC_0410.JPG')),
'difficult': (load_image(images / 'sacre_coeur1.jpg'),
load_image(images / 'sacre_coeur2.jpg')),
"easy": (
load_image(images / "DSC_0411.JPG"),
load_image(images / "DSC_0410.JPG"),
),
"difficult": (
load_image(images / "sacre_coeur1.jpg"),
load_image(images / "sacre_coeur2.jpg"),
),
}

configs = {
'LightGlue-full': {
'depth_confidence': -1,
'width_confidence': -1,
"LightGlue-full": {
"depth_confidence": -1,
"width_confidence": -1,
},
# 'LG-prune': {
# 'width_confidence': -1,
# },
# 'LG-depth': {
# 'depth_confidence': -1,
# },
'LightGlue-adaptive': {}
"LightGlue-adaptive": {},
}

if args.compile:
configs = {**configs, **{k+'-compile': v for k, v in configs.items()}}
configs = {**configs, **{k + "-compile": v for k, v in configs.items()}}

sg_configs = {
# 'SuperGlue': {},
'SuperGlue-fast': {'sinkhorn_iterations': 5}
"SuperGlue-fast": {"sinkhorn_iterations": 5}
}

torch.set_float32_matmul_precision(args.matmul_precision)
Expand All @@ -119,89 +140,108 @@ def print_as_table(d, title, cnames):

extractor = SuperPoint(max_num_keypoints=None, detection_threshold=-1)
extractor = extractor.eval().to(device)
figsize = (len(inputs)*4.5, 4.5)
figsize = (len(inputs) * 4.5, 4.5)
fig, axes = plt.subplots(1, len(inputs), sharey=True, figsize=figsize)
axes = axes if len(inputs) > 1 else [axes]
fig.canvas.manager.set_window_title(f'LightGlue benchmark ({device.type})')
fig.canvas.manager.set_window_title(f"LightGlue benchmark ({device.type})")

for title, ax in zip(inputs.keys(), axes):
ax.set_xscale('log', base=2)
ax.set_xscale("log", base=2)
bases = [2**x for x in range(7, 16)]
ax.set_xticks(bases, bases)
ax.grid(which='major')
if args.measure == 'log-time':
ax.set_yscale('log')
ax.grid(which="major")
if args.measure == "log-time":
ax.set_yscale("log")
yticks = [10**x for x in range(6)]
ax.set_yticks(yticks, yticks)
mpos = [10**x * i for x in range(6) for i in range(2, 10)]
mlabel = [10**x * i if i in [2, 5] else None for x in range(6) for i in range(2, 10)]
mlabel = [
10**x * i if i in [2, 5] else None
for x in range(6)
for i in range(2, 10)
]
ax.set_yticks(mpos, mlabel, minor=True)
ax.grid(which='minor', linewidth=0.2)
ax.grid(which="minor", linewidth=0.2)
ax.set_title(title)

ax.set_xlabel("# keypoints")
if args.measure == 'throughput':
ax.set_ylabel("Throughput [pairs/s]")
if args.measure == "throughput":
ax.set_ylabel("Throughput [pairs/s]")
else:
ax.set_ylabel("Latency [ms]")

for name, conf in configs.items():
print('Run benchmark for:', name)
print("Run benchmark for:", name)
torch.cuda.empty_cache()
matcher = LightGlue(
features='superpoint', flash=not args.no_flash, **conf)
matcher = LightGlue(features="superpoint", flash=not args.no_flash, **conf)
if args.no_prune_thresholds:
matcher.pruning_keypoint_thresholds = {
k: -1 for k in matcher.pruning_keypoint_thresholds}
k: -1 for k in matcher.pruning_keypoint_thresholds
}
matcher = matcher.eval().to(device)
if name.endswith('compile'):
if name.endswith("compile"):
import torch._dynamo

torch._dynamo.reset() # avoid buffer overflow
matcher.compile()
for (pair_name, ax) in zip(inputs.keys(), axes):
for pair_name, ax in zip(inputs.keys(), axes):
image0, image1 = [x.to(device) for x in inputs[pair_name]]
runtimes = []
for num_kpts in args.num_keypoints:
extractor.conf['max_num_keypoints'] = num_kpts
extractor.conf["max_num_keypoints"] = num_kpts
feats0 = extractor.extract(image0)
feats1 = extractor.extract(image1)
runtime = measure(matcher,
{'image0': feats0, 'image1': feats1},
device=device, r=args.repeat)['mean']
runtime = measure(
matcher,
{"image0": feats0, "image1": feats1},
device=device,
r=args.repeat,
)["mean"]
results[pair_name][name].append(
1000/runtime if args.measure == 'throughput' else runtime)
ax.plot(args.num_keypoints, results[pair_name][name], label=name,
marker='o')
1000 / runtime if args.measure == "throughput" else runtime
)
ax.plot(
args.num_keypoints, results[pair_name][name], label=name, marker="o"
)
del matcher, feats0, feats1

if args.add_superglue:
from hloc.matchers.superglue import SuperGlue

for name, conf in sg_configs.items():
print('Run benchmark for:', name)
print("Run benchmark for:", name)
matcher = SuperGlue(conf)
matcher = matcher.eval().to(device)
for (pair_name, ax) in zip(inputs.keys(), axes):
for pair_name, ax in zip(inputs.keys(), axes):
image0, image1 = [x.to(device) for x in inputs[pair_name]]
runtimes = []
for num_kpts in args.num_keypoints:
extractor.conf['max_num_keypoints'] = num_kpts
extractor.conf["max_num_keypoints"] = num_kpts
feats0 = extractor.extract(image0)
feats1 = extractor.extract(image1)
data = {
'image0': image0[None],
'image1': image1[None],
**{k+'0': v for k, v in feats0.items()},
**{k+'1': v for k, v in feats1.items()}
"image0": image0[None],
"image1": image1[None],
**{k + "0": v for k, v in feats0.items()},
**{k + "1": v for k, v in feats1.items()},
}
data['scores0'] = data['keypoint_scores0']
data['scores1'] = data['keypoint_scores1']
data['descriptors0'] = data['descriptors0'].transpose(-1, -2).contiguous()
data['descriptors1'] = data['descriptors1'].transpose(-1, -2).contiguous()
runtime = measure(matcher, data, device=device, r=args.repeat)['mean']
data["scores0"] = data["keypoint_scores0"]
data["scores1"] = data["keypoint_scores1"]
data["descriptors0"] = (
data["descriptors0"].transpose(-1, -2).contiguous()
)
data["descriptors1"] = (
data["descriptors1"].transpose(-1, -2).contiguous()
)
runtime = measure(matcher, data, device=device, r=args.repeat)[
"mean"
]
results[pair_name][name].append(
1000/runtime if args.measure == 'throughput' else runtime)
ax.plot(args.num_keypoints, results[pair_name][name], label=name,
marker='o')
1000 / runtime if args.measure == "throughput" else runtime
)
ax.plot(
args.num_keypoints, results[pair_name][name], label=name, marker="o"
)
del matcher, data, image0, image1, feats0, feats1

for name, runtimes in results.items():
Expand Down
Loading

0 comments on commit ca40df7

Please sign in to comment.