diff --git a/CMakeLists.txt b/CMakeLists.txt index b9bc989..f302062 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,6 +63,7 @@ add_library(omni_slam_eval_lib src/data/frame.cc src/data/feature.cc src/data/landmark.cc + src/feature/tracker.cc src/feature/lk_tracker.cc src/feature/detector.cc src/feature/matcher.cc diff --git a/src/feature/lk_tracker.cc b/src/feature/lk_tracker.cc index ffda2f2..2b79ea8 100644 --- a/src/feature/lk_tracker.cc +++ b/src/feature/lk_tracker.cc @@ -7,37 +7,18 @@ namespace omni_slam namespace feature { -LKTracker::LKTracker(const int window_size, const int num_scales, const float delta_pix_err_thresh, const float err_thresh, const int template_update_rate, const int term_count, const double term_eps) - : windowSize_(window_size / pow(2, num_scales), window_size / pow(2, num_scales)), +LKTracker::LKTracker(const int window_size, const int num_scales, const float delta_pix_err_thresh, const float err_thresh, const int keyframe_interval, const int term_count, const double term_eps) + : Tracker(keyframe_interval), + windowSize_(window_size / pow(2, num_scales), window_size / pow(2, num_scales)), numScales_(num_scales), errThresh_(err_thresh), deltaPixErrThresh_(delta_pix_err_thresh), - termCrit_(cv::TermCriteria::COUNT | cv::TermCriteria::EPS, term_count, term_eps), - templateUpdateRate_(template_update_rate), - prevFrame_(nullptr) + termCrit_(cv::TermCriteria::COUNT | cv::TermCriteria::EPS, term_count, term_eps) { } -void LKTracker::Init(data::Frame &init_frame) +int LKTracker::DoTrack(std::vector &landmarks, data::Frame &cur_frame, std::vector &errors, bool stereo) { - frameNum_ = 0; - prevId_ = init_frame.GetID(); - prevFrame_ = &init_frame; - prevTemplateId_ = init_frame.GetID(); - prevImg_ = init_frame.GetImage().clone(); - if (init_frame.HasStereoImage()) - { - prevStereoImg_ = init_frame.GetStereoImage().clone(); - } -} - -int LKTracker::Track(std::vector &landmarks, data::Frame &cur_frame, std::vector &errors, bool stereo) -{ - if (prevImg_.empty()) - { - return 0; - } - bool curCompressed = cur_frame.IsCompressed(); std::vector pointsToTrack; std::vector origKpt; std::vector origInx; @@ -49,23 +30,37 @@ int LKTracker::Track(std::vector &landmarks, data::Frame &cur_fr for (int i = 0; i < landmarks.size(); i++) { data::Landmark &landmark = landmarks[i]; - const data::Feature *feat = landmark.GetObservationByFrameID(prevTemplateId_); + const data::Feature *feat = landmark.GetObservationByFrameID(keyframeId_); const data::Feature *featPrev = landmark.GetObservationByFrameID(prevId_); - if (feat != nullptr && featPrev != nullptr) + if (feat != nullptr) { pointsToTrack.push_back(feat->GetKeypoint().pt); - results.push_back(featPrev->GetKeypoint().pt); + if (featPrev != nullptr) + { + results.push_back(featPrev->GetKeypoint().pt); + } + else + { + results.push_back(feat->GetKeypoint().pt); + } origKpt.push_back(feat->GetKeypoint()); origInx.push_back(i); } - if (cur_frame.HasStereoImage() && !prevStereoImg_.empty()) + if (cur_frame.HasStereoImage() && !keyframeStereoImg_.empty()) { - const data::Feature *stereoFeat = landmark.GetStereoObservationByFrameID(prevTemplateId_); + const data::Feature *stereoFeat = landmark.GetStereoObservationByFrameID(keyframeId_); const data::Feature *stereoFeatPrev = landmark.GetStereoObservationByFrameID(prevId_); - if (stereoFeat != nullptr && stereoFeatPrev != nullptr) + if (stereoFeat != nullptr) { stereoPointsToTrack.push_back(stereoFeat->GetKeypoint().pt); - stereoResults.push_back(stereoFeatPrev->GetKeypoint().pt); + if (stereoFeatPrev != nullptr) + { + stereoResults.push_back(stereoFeatPrev->GetKeypoint().pt); + } + else + { + stereoResults.push_back(stereoFeat->GetKeypoint().pt); + } stereoOrigKpt.push_back(stereoFeat->GetKeypoint()); stereoOrigInx.push_back(i); } @@ -73,40 +68,29 @@ int LKTracker::Track(std::vector &landmarks, data::Frame &cur_fr } if (pointsToTrack.size() == 0) { - prevId_ = cur_frame.GetID(); - prevFrame_ = &cur_frame; - if (++frameNum_ % templateUpdateRate_ == 0) - { - prevTemplateId_ = cur_frame.GetID(); - prevImg_ = cur_frame.GetImage().clone(); - if (cur_frame.HasStereoImage()) - { - prevStereoImg_ = cur_frame.GetStereoImage().clone(); - } - } return 0; } std::vector status; std::vector err; std::vector stereoStatus; std::vector stereoErr; - if (prevId_ == prevTemplateId_) + if (prevId_ == keyframeId_) { - cv::calcOpticalFlowPyrLK(prevImg_, cur_frame.GetImage(), pointsToTrack, results, status, err, windowSize_, numScales_, termCrit_, 0); + cv::calcOpticalFlowPyrLK(keyframeImg_, cur_frame.GetImage(), pointsToTrack, results, status, err, windowSize_, numScales_, termCrit_, 0); } else { - cv::calcOpticalFlowPyrLK(prevImg_, cur_frame.GetImage(), pointsToTrack, results, status, err, windowSize_, numScales_, termCrit_, cv::OPTFLOW_USE_INITIAL_FLOW); + cv::calcOpticalFlowPyrLK(keyframeImg_, cur_frame.GetImage(), pointsToTrack, results, status, err, windowSize_, numScales_, termCrit_, cv::OPTFLOW_USE_INITIAL_FLOW); } if (stereoPointsToTrack.size() > 0) { - if (prevId_ == prevTemplateId_) + if (prevId_ == keyframeId_) { - cv::calcOpticalFlowPyrLK(prevStereoImg_, cur_frame.GetStereoImage(), stereoPointsToTrack, stereoResults, stereoStatus, stereoErr, windowSize_, numScales_, termCrit_, 0); + cv::calcOpticalFlowPyrLK(keyframeStereoImg_, cur_frame.GetStereoImage(), stereoPointsToTrack, stereoResults, stereoStatus, stereoErr, windowSize_, numScales_, termCrit_, 0); } else { - cv::calcOpticalFlowPyrLK(prevStereoImg_, cur_frame.GetStereoImage(), stereoPointsToTrack, stereoResults, stereoStatus, stereoErr, windowSize_, numScales_, termCrit_, cv::OPTFLOW_USE_INITIAL_FLOW); + cv::calcOpticalFlowPyrLK(keyframeStereoImg_, cur_frame.GetStereoImage(), stereoPointsToTrack, stereoResults, stereoStatus, stereoErr, windowSize_, numScales_, termCrit_, cv::OPTFLOW_USE_INITIAL_FLOW); } } errors.clear(); @@ -162,22 +146,6 @@ int LKTracker::Track(std::vector &landmarks, data::Frame &cur_fr landmark.AddStereoObservation(feat); } } - if (curCompressed) - { - cur_frame.CompressImages(); - } - - prevId_ = cur_frame.GetID(); - prevFrame_ = &cur_frame; - if (++frameNum_ % templateUpdateRate_ == 0) - { - prevTemplateId_ = cur_frame.GetID(); - prevImg_ = cur_frame.GetImage().clone(); - if (cur_frame.HasStereoImage()) - { - prevStereoImg_ = cur_frame.GetStereoImage().clone(); - } - } return numGood; } diff --git a/src/feature/lk_tracker.h b/src/feature/lk_tracker.h index 2f7634a..2169e68 100644 --- a/src/feature/lk_tracker.h +++ b/src/feature/lk_tracker.h @@ -16,23 +16,16 @@ namespace feature class LKTracker : public Tracker { public: - LKTracker(const int window_size, const int num_scales, const float delta_pix_err_thresh = 5., const float err_thresh = 20., const int template_update_rate = 1, const int term_count = 50, const double term_eps = 0.01); - void Init(data::Frame &init_frame); - int Track(std::vector &landmarks, data::Frame &cur_frame, std::vector &errors, bool stereo = true); + LKTracker(const int window_size, const int num_scales, const float delta_pix_err_thresh = 5., const float err_thresh = 20., const int keyframe_interval = 1, const int term_count = 50, const double term_eps = 0.01); private: + int DoTrack(std::vector &landmarks, data::Frame &cur_frame, std::vector &errors, bool stereo); + cv::TermCriteria termCrit_; const cv::Size windowSize_; const int numScales_; const float errThresh_; const float deltaPixErrThresh_; - const int templateUpdateRate_; - int frameNum_{0}; - cv::Mat prevImg_; - cv::Mat prevStereoImg_; - int prevTemplateId_; - int prevId_; - data::Frame *prevFrame_; }; } diff --git a/src/feature/tracker.cc b/src/feature/tracker.cc new file mode 100644 index 0000000..d52623f --- /dev/null +++ b/src/feature/tracker.cc @@ -0,0 +1,62 @@ +#include "tracker.h" + +namespace omni_slam +{ +namespace feature +{ + +Tracker::Tracker(const int keyframe_interval) + : keyframeInterval_(keyframe_interval), + prevFrame_(nullptr) +{ +} + +void Tracker::Init(data::Frame &init_frame) +{ + frameNum_ = 0; + prevId_ = init_frame.GetID(); + prevFrame_ = &init_frame; + keyframeId_ = init_frame.GetID(); + keyframeImg_ = init_frame.GetImage().clone(); + if (init_frame.HasStereoImage()) + { + keyframeStereoImg_ = init_frame.GetStereoImage().clone(); + } +} + +const data::Frame* Tracker::GetLastKeyframe() +{ + return prevFrame_; +} + +int Tracker::Track(std::vector &landmarks, data::Frame &cur_frame, std::vector &errors, bool stereo) +{ + if (keyframeImg_.empty()) + { + return 0; + } + bool wasCompressed = cur_frame.IsCompressed(); + + int count = DoTrack(landmarks, cur_frame, errors, stereo); + + prevId_ = cur_frame.GetID(); + prevFrame_ = &cur_frame; + if (++frameNum_ % keyframeInterval_ == 0) + { + keyframeId_ = cur_frame.GetID(); + keyframeImg_ = cur_frame.GetImage().clone(); + if (cur_frame.HasStereoImage()) + { + keyframeStereoImg_ = cur_frame.GetStereoImage().clone(); + } + } + if (wasCompressed && !cur_frame.IsCompressed()) + { + cur_frame.CompressImages(); + } + + return count; +} + +} +} diff --git a/src/feature/tracker.h b/src/feature/tracker.h index 2a5c6b1..bbbe78a 100644 --- a/src/feature/tracker.h +++ b/src/feature/tracker.h @@ -12,8 +12,25 @@ namespace feature class Tracker { public: - virtual void Init(data::Frame &init_frame) = 0; - virtual int Track(std::vector &landmarks, data::Frame &cur_frame, std::vector &errors, bool stereo = true) = 0; + Tracker(const int keyframe_interval = 1); + + virtual void Init(data::Frame &init_frame); + int Track(std::vector &landmarks, data::Frame &cur_frame, std::vector &errors, bool stereo = true); + + const data::Frame* GetLastKeyframe(); + +protected: + cv::Mat keyframeImg_; + cv::Mat keyframeStereoImg_; + int keyframeId_; + int prevId_; + const data::Frame *prevFrame_; + +private: + virtual int DoTrack(std::vector &landmarks, data::Frame &cur_frame, std::vector &errors, bool stereo) = 0; + + int frameNum_{0}; + const int keyframeInterval_; }; } diff --git a/src/module/tracking_module.cc b/src/module/tracking_module.cc index 59d8ca2..2975d9d 100644 --- a/src/module/tracking_module.cc +++ b/src/module/tracking_module.cc @@ -40,13 +40,14 @@ void TrackingModule::Update(std::unique_ptr &frame) return; } + const data::Frame *keyframe = tracker_->GetLastKeyframe(); vector trackErrors; int tracks = tracker_->Track(landmarks_, *frames_.back(), trackErrors); if (fivePointChecker_ && tracks > 0) { Matrix3d E; std::vector inlierIndices; - fivePointChecker_->ComputeE(landmarks_, **next(frames_.rbegin()), *frames_.back(), E, inlierIndices); + fivePointChecker_->ComputeE(landmarks_, *keyframe, *frames_.back(), E, inlierIndices); std::unordered_set inlierSet(inlierIndices.begin(), inlierIndices.end()); for (int i = 0; i < landmarks_.size(); i++) { @@ -59,7 +60,7 @@ void TrackingModule::Update(std::unique_ptr &frame) { Matrix3d E; std::vector inlierIndices; - fivePointChecker_->ComputeE(landmarks_, **next(frames_.rbegin()), *frames_.back(), E, inlierIndices, true); + fivePointChecker_->ComputeE(landmarks_, *keyframe, *frames_.back(), E, inlierIndices, true); std::unordered_set inlierSet(inlierIndices.begin(), inlierIndices.end()); for (int i = 0; i < landmarks_.size(); i++) { @@ -149,6 +150,10 @@ void TrackingModule::Update(std::unique_ptr &frame) void TrackingModule::Redetect() { + if (tracker_->GetLastKeyframe() != frames_.back().get()) + { + return; + } int imsize = max(frames_.back()->GetImage().rows, frames_.back()->GetImage().cols); #pragma omp parallel for collapse(2) for (int i = 0; i < rs_.size() - 1; i++)