From d006da427bb6d8970a8d463a7385f4d2f315783c Mon Sep 17 00:00:00 2001 From: Jure Zbontar Date: Fri, 30 Oct 2015 11:40:10 +0100 Subject: [PATCH] predict_kitti script --- README.md | 6 +++++ predict_kitti.lua | 64 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100755 predict_kitti.lua diff --git a/README.md b/README.md index 204f099..80a3d9e 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,12 @@ The resulting disparity maps should look like this: - `right.png` +Note that `-disp_max 70` is used only as an example. To reproduce our +results on the KITTI data sets use `-disp_max 228`. + +See the [predict_kitti.lua](predict_kitti.lua) script for how you might +call `main.lua` in a loop, for multiple image pairs. + ### Load the Output Binary Files ### You can load the binary files (if, for example, you want to apply diff --git a/predict_kitti.lua b/predict_kitti.lua new file mode 100755 index 0000000..62b7411 --- /dev/null +++ b/predict_kitti.lua @@ -0,0 +1,64 @@ +#! /usr/bin/env luajit + +--[[ + +This script computes the 3 pixel error on all KITTI 2012 training examples with +the fast architecture. + +Don't use this script to fit hyperparameters; the error is computed on the +training examples. + +This is not the fastest way to use the neural network---a new process is +spawned and the network is loaded from disk for each image pair---but +is probably the safest. + +Usage +----- + + $ ./predict_kitti.lua + 0 0.0028267929719645 + 1 0.026568045683624 + 2 0.039333925127797 + ... + 191 0.078452818068974 + 192 0.012351983422143 + 193 0.066736774940625 + 0.03222369495401 + +]]-- + +require 'image' +require 'torch' +require 'libadcensus' + +path = 'data.kitti/unzip' +cmd = './main.lua kitti fast -a predict' .. + ' -net_fname net/net_kitti_fast_-a_train_all.t7' .. + ' -left %s -right %s -disp_max 228' + +err_sum = 0 +n_te = 194 +for i = 0, n_te - 1 do + -- call mc-cnn + local im0 = ('%s/training/image_0/%06d_10.png'):format(path, i) + local im1 = ('%s/training/image_1/%06d_10.png'):format(path, i) + local im = image.loadPNG(im0) + local img_height = im:size(2) + local img_width = im:size(3) + os.execute(cmd:format(im0, im1) .. ' > /dev/null') + local disp = torch.FloatTensor(torch.FloatStorage('disp.bin')):view(1, 1, img_height, img_width) + + -- ground truth + local ground_truth = torch.FloatTensor(1, img_height, img_width) + adcensus.readPNG16(ground_truth, ('%s/training/disp_noc/%06d_10.png'):format(path, i)) + + -- compute the error + local mask = torch.ne(ground_truth, 0):float() + local bad = torch.add(disp, -1, ground_truth):abs():gt(3):float():cmul(mask) + local err = bad:sum() / mask:sum() + err_sum = err_sum + err + print(i, err) + + collectgarbage() +end +print(err_sum / n_te)