Skip to content

Commit

Permalink
add sliding window BA doesn't work well
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelchang committed Sep 19, 2019
1 parent 6ee9c3b commit 2c340c5
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 3 deletions.
1 change: 1 addition & 0 deletions launch/slam_eval.launch
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
bundle_adjustment_loss_coefficient: 0.05
bundle_adjustment_logging: true
bundle_adjustment_num_threads: 20
local_bundle_adjustment_window: 0
stereo_matcher_window_size: 256
stereo_matcher_num_scales: 5
stereo_matcher_error_threshold: 20
Expand Down
10 changes: 8 additions & 2 deletions src/module/reconstruction_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ void ReconstructionModule::Update(std::vector<data::Landmark> &landmarks)
lastLandmarksSize_ = landmarks.size();
}

void ReconstructionModule::BundleAdjust(std::vector<data::Landmark> &landmarks)
void ReconstructionModule::BundleAdjust(std::vector<data::Landmark> &landmarks, const std::vector<int> &frame_ids)
{
bundleAdjuster_->Optimize(landmarks);
bundleAdjuster_->Optimize(landmarks, frame_ids);

for (int i = 0; i < lastLandmarksSize_; i++)
{
Expand All @@ -56,6 +56,12 @@ void ReconstructionModule::BundleAdjust(std::vector<data::Landmark> &landmarks)
}
}

void ReconstructionModule::BundleAdjust(std::vector<data::Landmark> &landmarks)
{
std::vector<int> temp;
BundleAdjust(landmarks, temp);
}

ReconstructionModule::Stats& ReconstructionModule::GetStats()
{
return stats_;
Expand Down
1 change: 1 addition & 0 deletions src/module/reconstruction_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ReconstructionModule
ReconstructionModule(std::unique_ptr<reconstruction::Triangulator> &&triangulator, std::unique_ptr<optimization::BundleAdjuster> &&bundle_adjuster);

void Update(std::vector<data::Landmark> &landmarks);
void BundleAdjust(std::vector<data::Landmark> &landmarks, const std::vector<int> &frame_ids);
void BundleAdjust(std::vector<data::Landmark> &landmarks);

Stats& GetStats();
Expand Down
62 changes: 61 additions & 1 deletion src/optimization/bundle_adjuster.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ BundleAdjuster::BundleAdjuster(int max_iterations, double loss_coeff, int num_th
solverOptions_.logging_type = log ? ceres::PER_MINIMIZER_ITERATION : ceres::SILENT;
}

bool BundleAdjuster::Optimize(std::vector<data::Landmark> &landmarks)
bool BundleAdjuster::Optimize(std::vector<data::Landmark> &landmarks, const std::vector<int> &frame_ids)
{
std::vector<double> landmarkEstimates;
landmarkEstimates.reserve(3 * landmarks.size());
Expand All @@ -32,9 +32,32 @@ bool BundleAdjuster::Optimize(std::vector<data::Landmark> &landmarks)
ceres::LossFunction *loss_function = new ceres::HuberLoss(lossCoeff_);
for (const data::Landmark &landmark : landmarks)
{
if (frame_ids.size() > 0)
{
bool observed = false;
for (int id : frame_ids)
{
if (landmark.IsObservedInFrame(id))
{
observed = true;
break;
}
}
if (!observed)
{
continue;
}
}
bool hasEstCameraPoses = false;
for (const data::Feature &feature : landmark.GetObservations())
{
if (frame_ids.size() > 0)
{
if (std::find(frame_ids.begin(), frame_ids.end(), feature.GetFrame().GetID()) == frame_ids.end())
{
continue;
}
}
if (feature.GetFrame().HasEstimatedPose() && feature.GetFrame().IsEstimatedByLandmark(landmark.GetID()))
{
hasEstCameraPoses = true;
Expand Down Expand Up @@ -63,6 +86,13 @@ bool BundleAdjuster::Optimize(std::vector<data::Landmark> &landmarks)
}
for (const data::Feature &feature : landmark.GetObservations())
{
if (frame_ids.size() > 0)
{
if (std::find(frame_ids.begin(), frame_ids.end(), feature.GetFrame().GetID()) == frame_ids.end())
{
continue;
}
}
if (!feature.GetFrame().HasEstimatedPose() && feature.GetFrame().HasPose())
{
if (!landmark.HasEstimatedPosition())
Expand Down Expand Up @@ -154,9 +184,32 @@ bool BundleAdjuster::Optimize(std::vector<data::Landmark> &landmarks)
int inx = 0;
for (data::Landmark &landmark : landmarks)
{
if (frame_ids.size() > 0)
{
bool observed = false;
for (int id : frame_ids)
{
if (landmark.IsObservedInFrame(id))
{
observed = true;
break;
}
}
if (!observed)
{
continue;
}
}
bool hasEstCameraPoses = false;
for (const data::Feature &feature : landmark.GetObservations())
{
if (frame_ids.size() > 0)
{
if (std::find(frame_ids.begin(), frame_ids.end(), feature.GetFrame().GetID()) == frame_ids.end())
{
continue;
}
}
if (feature.GetFrame().HasEstimatedPose() && feature.GetFrame().IsEstimatedByLandmark(landmark.GetID()))
{
hasEstCameraPoses = true;
Expand All @@ -181,8 +234,15 @@ bool BundleAdjuster::Optimize(std::vector<data::Landmark> &landmarks)
const Matrix<double, 3, 4> pose = util::TFUtil::QuaternionTranslationToPoseMatrix(quat, t);
frame.second->SetEstimatedInversePose(pose);
}
problem_.reset(new ceres::Problem());
return true;
}

bool BundleAdjuster::Optimize(std::vector<data::Landmark> &landmarks)
{
std::vector<int> tmp;
Optimize(landmarks, tmp);
}

}
}
1 change: 1 addition & 0 deletions src/optimization/bundle_adjuster.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class BundleAdjuster
public:
BundleAdjuster(int max_iterations = 500, double loss_coeff = 0.1, int num_threads = 1, bool log = false);

bool Optimize(std::vector<data::Landmark> &landmarks, const std::vector<int> &frame_ids);
bool Optimize(std::vector<data::Landmark> &landmarks);

private:
Expand Down
16 changes: 16 additions & 0 deletions src/ros/slam_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace ros
SLAMEval::SLAMEval(const ::ros::NodeHandle &nh, const ::ros::NodeHandle &nh_private)
: OdometryEval<true>(nh, nh_private), ReconstructionEval<true>(nh, nh_private), StereoEval(nh, nh_private)
{
this->nhp_.param("local_bundle_adjustment_window", baSlidingWindow_, 0);
}

void SLAMEval::InitPublishers()
Expand All @@ -31,8 +32,23 @@ void SLAMEval::ProcessFrame(unique_ptr<data::Frame> &&frame)
trackingModule_->Update(frame);
odometryModule_->Update(trackingModule_->GetLandmarks(), *trackingModule_->GetFrames().back());
reconstructionModule_->Update(trackingModule_->GetLandmarks());
if (baSlidingWindow_ > 0 && (frameNum_ + 1) % baSlidingWindow_ == 0)
{
std::vector<int> frameIds;
frameIds.reserve(baSlidingWindow_);
for (auto it = trackingModule_->GetFrames().rbegin(); it != trackingModule_->GetFrames().rend(); ++it)
{
frameIds.push_back((*it)->GetID());
if (frameIds.size() >= baSlidingWindow_)
{
break;
}
}
reconstructionModule_->BundleAdjust(trackingModule_->GetLandmarks(), frameIds);
}
trackingModule_->Redetect();
stereoModule_->Update(*trackingModule_->GetFrames().back(), trackingModule_->GetLandmarks());
frameNum_++;
}

void SLAMEval::Finish()
Expand Down
3 changes: 3 additions & 0 deletions src/ros/slam_eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class SLAMEval : public OdometryEval<true>, ReconstructionEval<true>, StereoEval
void GetResultsData(std::map<std::string, std::vector<std::vector<double>>> &data);
void Visualize(cv_bridge::CvImagePtr &base_img);
void Visualize(cv_bridge::CvImagePtr &base_img, cv_bridge::CvImagePtr &base_stereo_img);

int baSlidingWindow_;
int frameNum_{0};
};

}
Expand Down

0 comments on commit 2c340c5

Please sign in to comment.