-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
65 lines (55 loc) · 1.8 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import multiprocessing
import torch
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import utils
import argparse
import sfm_loss
import networks.architectures
import cv2
import time
import random
from pathlib import Path
import os
import options
import sys
from sfm_trainer import SfMTrainer
from fundamental_trainer import FundamentalTrainer
from debugger import Debugger as SfMDebugger
from sfm_tester import SfMTester
from point_trainer import PointTrainer
from debugger_point import DebuggerPoint
from debugger_fcons import DebuggerFcons
def parse_args(extra=[], overwrite={}):
always = ["net", "workers", "device", "load"]
return options.get_args(
description="Train, debug or test a network",
options=always + extra,
overwrite=overwrite)
def main():
choice = sys.argv.pop(1)
# Depth
if choice == "sfm-train":
action = SfMTrainer(parse_args(["name", "batch", "train", "loss", "dataset"]))
elif choice == "sfm-debug":
action = SfMDebugger(parse_args(["loss", "dataset"]))
elif choice == "sfm-test":
action = SfMTester(parse_args(["dataset", "loss"], overwrite={"batch": 1}))
# Unsuperpoint
elif choice == "point-train":
action = PointTrainer(parse_args(["name", "batch", "train"]))
elif choice == "point-debug":
action = DebuggerPoint(parse_args(["loss", "batch"]))
# Fundamental consensus
elif choice == "fcons-train":
action = FundamentalTrainer(parse_args(["name", "batch", "train"]))
elif choice == "fcons-debug":
action = DebuggerFcons(parse_args(["loss", "batch"]))
# Ooops..
else:
print("No such action to perform: %s" % choice)
exit()
action.run()
if __name__ == "__main__":
multiprocessing.set_start_method('spawn', True)
main()