Skip to content

Commit

Permalink
tracker refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelchang committed Oct 5, 2019
1 parent 489bd89 commit 96f7f79
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 78 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ add_library(omni_slam_eval_lib
src/data/frame.cc
src/data/feature.cc
src/data/landmark.cc
src/feature/tracker.cc
src/feature/lk_tracker.cc
src/feature/detector.cc
src/feature/matcher.cc
Expand Down
96 changes: 32 additions & 64 deletions src/feature/lk_tracker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,18 @@ namespace omni_slam
namespace feature
{

LKTracker::LKTracker(const int window_size, const int num_scales, const float delta_pix_err_thresh, const float err_thresh, const int template_update_rate, const int term_count, const double term_eps)
: windowSize_(window_size / pow(2, num_scales), window_size / pow(2, num_scales)),
LKTracker::LKTracker(const int window_size, const int num_scales, const float delta_pix_err_thresh, const float err_thresh, const int keyframe_interval, const int term_count, const double term_eps)
: Tracker(keyframe_interval),
windowSize_(window_size / pow(2, num_scales), window_size / pow(2, num_scales)),
numScales_(num_scales),
errThresh_(err_thresh),
deltaPixErrThresh_(delta_pix_err_thresh),
termCrit_(cv::TermCriteria::COUNT | cv::TermCriteria::EPS, term_count, term_eps),
templateUpdateRate_(template_update_rate),
prevFrame_(nullptr)
termCrit_(cv::TermCriteria::COUNT | cv::TermCriteria::EPS, term_count, term_eps)
{
}

void LKTracker::Init(data::Frame &init_frame)
int LKTracker::DoTrack(std::vector<data::Landmark> &landmarks, data::Frame &cur_frame, std::vector<double> &errors, bool stereo)
{
frameNum_ = 0;
prevId_ = init_frame.GetID();
prevFrame_ = &init_frame;
prevTemplateId_ = init_frame.GetID();
prevImg_ = init_frame.GetImage().clone();
if (init_frame.HasStereoImage())
{
prevStereoImg_ = init_frame.GetStereoImage().clone();
}
}

int LKTracker::Track(std::vector<data::Landmark> &landmarks, data::Frame &cur_frame, std::vector<double> &errors, bool stereo)
{
if (prevImg_.empty())
{
return 0;
}
bool curCompressed = cur_frame.IsCompressed();
std::vector<cv::Point2f> pointsToTrack;
std::vector<cv::KeyPoint> origKpt;
std::vector<int> origInx;
Expand All @@ -49,64 +30,67 @@ int LKTracker::Track(std::vector<data::Landmark> &landmarks, data::Frame &cur_fr
for (int i = 0; i < landmarks.size(); i++)
{
data::Landmark &landmark = landmarks[i];
const data::Feature *feat = landmark.GetObservationByFrameID(prevTemplateId_);
const data::Feature *feat = landmark.GetObservationByFrameID(keyframeId_);
const data::Feature *featPrev = landmark.GetObservationByFrameID(prevId_);
if (feat != nullptr && featPrev != nullptr)
if (feat != nullptr)
{
pointsToTrack.push_back(feat->GetKeypoint().pt);
results.push_back(featPrev->GetKeypoint().pt);
if (featPrev != nullptr)
{
results.push_back(featPrev->GetKeypoint().pt);
}
else
{
results.push_back(feat->GetKeypoint().pt);
}
origKpt.push_back(feat->GetKeypoint());
origInx.push_back(i);
}
if (cur_frame.HasStereoImage() && !prevStereoImg_.empty())
if (cur_frame.HasStereoImage() && !keyframeStereoImg_.empty())
{
const data::Feature *stereoFeat = landmark.GetStereoObservationByFrameID(prevTemplateId_);
const data::Feature *stereoFeat = landmark.GetStereoObservationByFrameID(keyframeId_);
const data::Feature *stereoFeatPrev = landmark.GetStereoObservationByFrameID(prevId_);
if (stereoFeat != nullptr && stereoFeatPrev != nullptr)
if (stereoFeat != nullptr)
{
stereoPointsToTrack.push_back(stereoFeat->GetKeypoint().pt);
stereoResults.push_back(stereoFeatPrev->GetKeypoint().pt);
if (stereoFeatPrev != nullptr)
{
stereoResults.push_back(stereoFeatPrev->GetKeypoint().pt);
}
else
{
stereoResults.push_back(stereoFeat->GetKeypoint().pt);
}
stereoOrigKpt.push_back(stereoFeat->GetKeypoint());
stereoOrigInx.push_back(i);
}
}
}
if (pointsToTrack.size() == 0)
{
prevId_ = cur_frame.GetID();
prevFrame_ = &cur_frame;
if (++frameNum_ % templateUpdateRate_ == 0)
{
prevTemplateId_ = cur_frame.GetID();
prevImg_ = cur_frame.GetImage().clone();
if (cur_frame.HasStereoImage())
{
prevStereoImg_ = cur_frame.GetStereoImage().clone();
}
}
return 0;
}
std::vector<unsigned char> status;
std::vector<float> err;
std::vector<unsigned char> stereoStatus;
std::vector<float> stereoErr;
if (prevId_ == prevTemplateId_)
if (prevId_ == keyframeId_)
{
cv::calcOpticalFlowPyrLK(prevImg_, cur_frame.GetImage(), pointsToTrack, results, status, err, windowSize_, numScales_, termCrit_, 0);
cv::calcOpticalFlowPyrLK(keyframeImg_, cur_frame.GetImage(), pointsToTrack, results, status, err, windowSize_, numScales_, termCrit_, 0);
}
else
{
cv::calcOpticalFlowPyrLK(prevImg_, cur_frame.GetImage(), pointsToTrack, results, status, err, windowSize_, numScales_, termCrit_, cv::OPTFLOW_USE_INITIAL_FLOW);
cv::calcOpticalFlowPyrLK(keyframeImg_, cur_frame.GetImage(), pointsToTrack, results, status, err, windowSize_, numScales_, termCrit_, cv::OPTFLOW_USE_INITIAL_FLOW);
}
if (stereoPointsToTrack.size() > 0)
{
if (prevId_ == prevTemplateId_)
if (prevId_ == keyframeId_)
{
cv::calcOpticalFlowPyrLK(prevStereoImg_, cur_frame.GetStereoImage(), stereoPointsToTrack, stereoResults, stereoStatus, stereoErr, windowSize_, numScales_, termCrit_, 0);
cv::calcOpticalFlowPyrLK(keyframeStereoImg_, cur_frame.GetStereoImage(), stereoPointsToTrack, stereoResults, stereoStatus, stereoErr, windowSize_, numScales_, termCrit_, 0);
}
else
{
cv::calcOpticalFlowPyrLK(prevStereoImg_, cur_frame.GetStereoImage(), stereoPointsToTrack, stereoResults, stereoStatus, stereoErr, windowSize_, numScales_, termCrit_, cv::OPTFLOW_USE_INITIAL_FLOW);
cv::calcOpticalFlowPyrLK(keyframeStereoImg_, cur_frame.GetStereoImage(), stereoPointsToTrack, stereoResults, stereoStatus, stereoErr, windowSize_, numScales_, termCrit_, cv::OPTFLOW_USE_INITIAL_FLOW);
}
}
errors.clear();
Expand Down Expand Up @@ -162,22 +146,6 @@ int LKTracker::Track(std::vector<data::Landmark> &landmarks, data::Frame &cur_fr
landmark.AddStereoObservation(feat);
}
}
if (curCompressed)
{
cur_frame.CompressImages();
}

prevId_ = cur_frame.GetID();
prevFrame_ = &cur_frame;
if (++frameNum_ % templateUpdateRate_ == 0)
{
prevTemplateId_ = cur_frame.GetID();
prevImg_ = cur_frame.GetImage().clone();
if (cur_frame.HasStereoImage())
{
prevStereoImg_ = cur_frame.GetStereoImage().clone();
}
}

return numGood;
}
Expand Down
13 changes: 3 additions & 10 deletions src/feature/lk_tracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,16 @@ namespace feature
class LKTracker : public Tracker
{
public:
LKTracker(const int window_size, const int num_scales, const float delta_pix_err_thresh = 5., const float err_thresh = 20., const int template_update_rate = 1, const int term_count = 50, const double term_eps = 0.01);
void Init(data::Frame &init_frame);
int Track(std::vector<data::Landmark> &landmarks, data::Frame &cur_frame, std::vector<double> &errors, bool stereo = true);
LKTracker(const int window_size, const int num_scales, const float delta_pix_err_thresh = 5., const float err_thresh = 20., const int keyframe_interval = 1, const int term_count = 50, const double term_eps = 0.01);

private:
int DoTrack(std::vector<data::Landmark> &landmarks, data::Frame &cur_frame, std::vector<double> &errors, bool stereo);

cv::TermCriteria termCrit_;
const cv::Size windowSize_;
const int numScales_;
const float errThresh_;
const float deltaPixErrThresh_;
const int templateUpdateRate_;
int frameNum_{0};
cv::Mat prevImg_;
cv::Mat prevStereoImg_;
int prevTemplateId_;
int prevId_;
data::Frame *prevFrame_;
};

}
Expand Down
62 changes: 62 additions & 0 deletions src/feature/tracker.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include "tracker.h"

namespace omni_slam
{
namespace feature
{

Tracker::Tracker(const int keyframe_interval)
: keyframeInterval_(keyframe_interval),
prevFrame_(nullptr)
{
}

void Tracker::Init(data::Frame &init_frame)
{
frameNum_ = 0;
prevId_ = init_frame.GetID();
prevFrame_ = &init_frame;
keyframeId_ = init_frame.GetID();
keyframeImg_ = init_frame.GetImage().clone();
if (init_frame.HasStereoImage())
{
keyframeStereoImg_ = init_frame.GetStereoImage().clone();
}
}

const data::Frame* Tracker::GetLastKeyframe()
{
return prevFrame_;
}

int Tracker::Track(std::vector<data::Landmark> &landmarks, data::Frame &cur_frame, std::vector<double> &errors, bool stereo)
{
if (keyframeImg_.empty())
{
return 0;
}
bool wasCompressed = cur_frame.IsCompressed();

int count = DoTrack(landmarks, cur_frame, errors, stereo);

prevId_ = cur_frame.GetID();
prevFrame_ = &cur_frame;
if (++frameNum_ % keyframeInterval_ == 0)
{
keyframeId_ = cur_frame.GetID();
keyframeImg_ = cur_frame.GetImage().clone();
if (cur_frame.HasStereoImage())
{
keyframeStereoImg_ = cur_frame.GetStereoImage().clone();
}
}
if (wasCompressed && !cur_frame.IsCompressed())
{
cur_frame.CompressImages();
}

return count;
}

}
}
21 changes: 19 additions & 2 deletions src/feature/tracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,25 @@ namespace feature
class Tracker
{
public:
virtual void Init(data::Frame &init_frame) = 0;
virtual int Track(std::vector<data::Landmark> &landmarks, data::Frame &cur_frame, std::vector<double> &errors, bool stereo = true) = 0;
Tracker(const int keyframe_interval = 1);

virtual void Init(data::Frame &init_frame);
int Track(std::vector<data::Landmark> &landmarks, data::Frame &cur_frame, std::vector<double> &errors, bool stereo = true);

const data::Frame* GetLastKeyframe();

protected:
cv::Mat keyframeImg_;
cv::Mat keyframeStereoImg_;
int keyframeId_;
int prevId_;
const data::Frame *prevFrame_;

private:
virtual int DoTrack(std::vector<data::Landmark> &landmarks, data::Frame &cur_frame, std::vector<double> &errors, bool stereo) = 0;

int frameNum_{0};
const int keyframeInterval_;
};

}
Expand Down
9 changes: 7 additions & 2 deletions src/module/tracking_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@ void TrackingModule::Update(std::unique_ptr<data::Frame> &frame)
return;
}

const data::Frame *keyframe = tracker_->GetLastKeyframe();
vector<double> trackErrors;
int tracks = tracker_->Track(landmarks_, *frames_.back(), trackErrors);
if (fivePointChecker_ && tracks > 0)
{
Matrix3d E;
std::vector<int> inlierIndices;
fivePointChecker_->ComputeE(landmarks_, **next(frames_.rbegin()), *frames_.back(), E, inlierIndices);
fivePointChecker_->ComputeE(landmarks_, *keyframe, *frames_.back(), E, inlierIndices);
std::unordered_set<int> inlierSet(inlierIndices.begin(), inlierIndices.end());
for (int i = 0; i < landmarks_.size(); i++)
{
Expand All @@ -59,7 +60,7 @@ void TrackingModule::Update(std::unique_ptr<data::Frame> &frame)
{
Matrix3d E;
std::vector<int> inlierIndices;
fivePointChecker_->ComputeE(landmarks_, **next(frames_.rbegin()), *frames_.back(), E, inlierIndices, true);
fivePointChecker_->ComputeE(landmarks_, *keyframe, *frames_.back(), E, inlierIndices, true);
std::unordered_set<int> inlierSet(inlierIndices.begin(), inlierIndices.end());
for (int i = 0; i < landmarks_.size(); i++)
{
Expand Down Expand Up @@ -149,6 +150,10 @@ void TrackingModule::Update(std::unique_ptr<data::Frame> &frame)

void TrackingModule::Redetect()
{
if (tracker_->GetLastKeyframe() != frames_.back().get())
{
return;
}
int imsize = max(frames_.back()->GetImage().rows, frames_.back()->GetImage().cols);
#pragma omp parallel for collapse(2)
for (int i = 0; i < rs_.size() - 1; i++)
Expand Down

0 comments on commit 96f7f79

Please sign in to comment.