- Requirements
- Installation
- Introduction
- Running the Code
- How Does It Work?
- Code Structure
- Publications
- git
- python3
For the installation I recommend a virtualenv.
git clone [email protected]:titus-leistner/dsac-students.git
cd dsac-students/
python3 -m venv .
source bin/activate
pip install -r requirements.txt
This repository contains the the code for line fitting with DSAC as well as all boilerplate code for circle fitting written by Eric Brachmann. All following information aims at the already implemented line-fitting part. You find more detailed information about the exercise in the included PDF.
Left: Input image. Right: Ground truth line.
We solve this task by training a CNN which predicts a set of 2D points within the image. We fit our desired line to these points using RANSAC.
Left: Input image. Center: Points predicted by a CNN. Right: Line (blue) fitted to the predictions.
Ideally, the CNN would place all its point predictions on the image line segment. But because RANSAC is robust to outlier points, the CNN may choose to allow some erroneous predictions in favor of overall accuracy.
We train the CNN end-to-end from scratch to minimize the deviation between the (robustly) fitted line and the ground truth line.
Just execute python main.py --cpu
to start a training run with the standard settings.
Running python main.py -h
will list all parameter options for playing around.
The code generates training data on the fly, and trains two CNNs in parallel. The first CNN predicts a set of 2D points to which the output line is fitted using DSAC. The second CNN is a baseline where the line parameters are predicted directly, i.e. without DSAC.
In a specified interval during training, both CNNs are tested on a fixed validation set, and the visualization of the predictions is stored in an image such as the following:
Left: Validation inputs. Center: DSAC estimates with dots marking the CNN prediction, blue the fitted line and green the ground truth line. Green borders mark accurate line predictions, red boxes mark inaccurate line predictions (there is a threshold parameter). Right: Predictions of the baseline CNN (direct prediction of line parameters).
Depending on your system specification, one training iteration (with the standard batch size of 32) can take more than one second. This might seem excessive for a simple toy problem. Note that this code is designed for educative clarity rather than speed. The whole DSAC portion of training runs in native Python on a single CPU core, and backpropagation relies soley on standard PyTorch autograd. In any production setting, one would write a C++/CUDA extension encapsulating DSAC for a huge runtime boost. See for example our camera localization pipelines which utilize DSAC here and here.
Vanilla RANSAC works by creating a set of model hypotheses (line hypotheses in our case), scoring them e.g. by inlier counting, and selecting the best one.
DSAC is based on the idea of making hypothesis selection a probabilistic action. The probability of selecting a hypothesis increases with its score (e.g. inlier count). Training the CNN aims at minimizing the expected loss of the selected hypothesis.
More details and a formal description can be found in the papers referenced at the end of this document.
In a nutshell, the training process works like this:
- CNN predicts 2D points
- sample line hypotheses by choosing random pairs of points
- score hypotheses by soft inlier counting, and calculate selection probabilities
- refine hypotheses by re-fitting them to their soft inliers
- calculate expected loss of refined hypotheses w.r.t. selection probabilities
- backprob, update CNN, repeat
For this toy problem, we are interested in observing visually nicely aligned lines rather then the nominal error in line parameters. We thus measure the maximum distance between the predicted line and ground truth within the image, and aim at minimizing this distance as our loss function.
Red arrows mark the error between ground truth line (green) and estimated line (blue) that we try to minimize.
main.py
Main script that handles the training loop.
dsac.py
Encapsulates robust, differentiable line fitting with DSAC (sampling hypotheses, scoring, refinement, expected loss).
line_dataset.py
Generates random, noisy input images with associated ground truth parameters. Also includes functions for visualizing predictions.
line_loss.py
Loss function used to compare predicted and ground truth lines.
line_nn.py
Definition of the CNN architecture which supports prediction of 2D points or direct regression of line parameters.
The following paper introduced DSAC for camera localization (paper link).
@inproceedings{brachmann2017dsac,
title={{DSAC}-{Differentiable RANSAC} for Camera Localization},
author={Brachmann, Eric and Krull, Alexander and Nowozin, Sebastian and Shotton, Jamie and Michel, Frank and Gumhold, Stefan and Rother, Carsten},
booktitle={CVPR},
year={2017}
}
This code uses a soft inlier count instead of a learned scoring function, as suggested in the following paper (paper link).
@inproceedings{brachmann2018lessmore,
title={Learning Less is More-{6D} Camera Localization via {3D} Surface Regression},
author={Brachmann, Eric and Rother, Carsten},
booktitle={CVPR},
year={2018}
}
Please cite one of these papers if you use DSAC or parts of this code in your own work.