Skip to content

Commit

Permalink
Add TV along time
Browse files Browse the repository at this point in the history
  • Loading branch information
spinicist committed May 12, 2023
1 parent a14eb4d commit 23263f2
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 33 deletions.
49 changes: 30 additions & 19 deletions src/cmd/admm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ int main_admm(args::Subparser &parser)

// Default is TV on spatial dimensions, i.e. classic compressed sensing
args::Flag tv(parser, "TV", "Total Variation", {"tv"});
args::ValueFlag<float> tvt(parser, "TVT", "Total Variation along time/frames/basis", {"tvt"});
args::Flag tgv(parser, "TGV", "Total Generalized Variation", {"tgv"});
args::Flag l1(parser, "L1", "Simple L1 regularization", {"l1"});
args::Flag nmrent(parser, "E", "NMR Entropy", {"nmrent"});
Expand Down Expand Up @@ -79,24 +80,10 @@ int main_admm(args::Subparser &parser)

std::vector<std::shared_ptr<LinOps::Op<Cx>>> reg_ops;
std::vector<std::shared_ptr<Prox<Cx>>> prox;
std::shared_ptr<LinOps::Op<Cx>> ext_x; // Need for TGV, sigh
if (wavelets) {
prox.push_back(std::make_shared<ThresholdWavelets>(λ.Get(), sz, width.Get(), wavelets.Get()));
reg_ops.push_back(std::make_shared<TensorIdentity<Cx, 4>>(sz));
} else if (patchSize) {
prox.push_back(std::make_shared<LLR>(λ.Get(), patchSize.Get(), winSize.Get(), sz));
reg_ops.push_back(std::make_shared<TensorIdentity<Cx, 4>>(sz));
} else if (nmrent) {
prox.push_back(std::make_shared<NMREntropy>(λ.Get()));
reg_ops.push_back(std::make_shared<TensorIdentity<Cx, 4>>(sz));
} else if (l1) {
prox.push_back(std::make_shared<SoftThreshold>(λ.Get()));
reg_ops.push_back(std::make_shared<TensorIdentity<Cx, 4>>(sz));
} else if (tv) {
prox.push_back(std::make_shared<SoftThreshold>(λ.Get()));
reg_ops.push_back(std::make_shared<GradOp>(sz));
} else if (tgv) {
auto grad_x = std::make_shared<GradOp>(sz);
std::shared_ptr<LinOps::Op<Cx>> ext_x = std::make_shared<TensorIdentity<Cx, 4>>(sz); // Need for TGV, sigh

if (tgv) {
auto grad_x = std::make_shared<GradOp>(sz, std::vector<Index>{1, 2, 3});
ext_x = std::make_shared<LinOps::Extract<Cx>>(recon->cols() + grad_x->rows(), 0, recon->cols());
auto ext_v = std::make_shared<LinOps::Extract<Cx>>(recon->cols() + grad_x->rows(), recon->cols(), grad_x->rows());
auto op1 = std::make_shared<LinOps::Subtract<Cx>>(std::make_shared<LinOps::Multiply<Cx>>(grad_x, ext_x), ext_v);
Expand All @@ -111,10 +98,34 @@ int main_admm(args::Subparser &parser)
Log::Tensor(fmt::format("admm-x-{:02d}", ii), sz, xv.data());
Log::Tensor(fmt::format("admm-v-{:02d}", ii), grad_x->oshape, xv.data() + Product(sz));
};
} else {
} else if (wavelets) {
prox.push_back(std::make_shared<ThresholdWavelets>(λ.Get(), sz, width.Get(), wavelets.Get()));
reg_ops.push_back(std::make_shared<TensorIdentity<Cx, 4>>(sz));
} else if (patchSize) {
prox.push_back(std::make_shared<LLR>(λ.Get(), patchSize.Get(), winSize.Get(), sz));
reg_ops.push_back(std::make_shared<TensorIdentity<Cx, 4>>(sz));
} else if (nmrent) {
prox.push_back(std::make_shared<NMREntropy>(λ.Get()));
reg_ops.push_back(std::make_shared<TensorIdentity<Cx, 4>>(sz));
} else if (l1) {
prox.push_back(std::make_shared<SoftThreshold>(λ.Get()));
reg_ops.push_back(std::make_shared<TensorIdentity<Cx, 4>>(sz));
} else if (tv) {
prox.push_back(std::make_shared<SoftThreshold>(λ.Get()));
reg_ops.push_back(std::make_shared<GradOp>(sz, std::vector<Index>{1, 2, 3}));
}

if (tvt) {
prox.push_back(std::make_shared<SoftThreshold>(tvt.Get()));
reg_ops.push_back(std::make_shared<LinOps::Multiply<Cx>>(std::make_shared<GradOp>(sz, std::vector<Index>{0}), ext_x));
}

if (prox.size() == 0) {
Log::Fail("Must specify at least one regularizer");
}



ADMM admm{
A,
M,
Expand Down
4 changes: 2 additions & 2 deletions src/cmd/transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ int main_transform(args::Subparser &parser)
auto input = reader.readTensor<Cx5>(HD5::Keys::Image);
Sz4 dims = FirstN<4>(input.dimensions());
Cx6 output(AddBack(dims, 3, input.dimension(4)));
GradOp g(dims);
GradOp g(dims, std::vector<Index>{1, 2, 3});
for (Index iv = 0; iv < input.dimension(4); iv++) {
output.chip<5>(iv) = g.forward(CChipMap(input, iv));
}
Expand All @@ -57,7 +57,7 @@ int main_transform(args::Subparser &parser)
auto input = reader.readTensor<Cx6>("grad");
Sz4 dims = FirstN<4>(input.dimensions());
Cx5 output(AddBack(dims, input.dimension(5)));
GradOp g(dims);
GradOp g(dims, std::vector<Index>{1, 2, 3});
for (Index iv = 0; iv < input.dimension(5); iv++) {
output.chip<4>(iv) = g.adjoint(CChipMap(input, iv));
}
Expand Down
18 changes: 7 additions & 11 deletions src/op/grad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,33 +25,29 @@ inline auto BackwardDiff(T1 const &a, T2 &&b, Sz4 const dims, Index const dim)
}
} // namespace

GradOp::GradOp(InDims const dims)
: Parent("GradOp", dims, AddBack(dims, 3))
GradOp::GradOp(InDims const ishape, std::vector<Index> const &d)
: Parent("GradOp", ishape, AddBack(ishape, (Index)d.size())),
dims_{d}
{
}

void GradOp::forward(InCMap const &x, OutMap &y) const
{
auto const time = this->startForward(x);
y.setZero();
for (Index ii = 0; ii < 3; ii++) {
ForwardDiff(x, y.chip<4>(ii), x.dimensions(), ii + 1);
for (Index ii = 0; ii < (Index)dims_.size(); ii++) {
ForwardDiff(x, y.chip<4>(ii), x.dimensions(), dims_[ii]);
}
// Log::Tensor("grad-fwd-x", x.dimensions(), x.data());
// Log::Tensor("grad-fwd-y", y.dimensions(), y.data());
this->finishForward(y, time);
}

void GradOp::adjoint(OutCMap const &y, InMap &x) const
{
auto const time = this->startAdjoint(y);
x.setZero();
for (Index ii = 0; ii < 3; ii++) {
BackwardDiff(y.chip<4>(ii), x, x.dimensions(), ii + 1);
// Log::Tensor(fmt::format("grad-adj-temp-{}", ii), x.dimensions(), x.data());
for (Index ii = 0; ii < (Index)dims_.size(); ii++) {
BackwardDiff(y.chip<4>(ii), x, x.dimensions(), dims_[ii]);
}
// Log::Tensor("grad-adj-y", y.dimensions(), y.data());
// Log::Tensor("grad-adj-x", x.dimensions(), x.data());
this->finishAdjoint(x, time);
}

Expand Down
5 changes: 4 additions & 1 deletion src/op/grad.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@ namespace rl {
struct GradOp final : TensorOperator<Cx, 4, 5>
{
OP_INHERIT(Cx, 4, 5)
GradOp(InDims const dims);
GradOp(InDims const ishape, std::vector<Index> const &gradDims);
OP_DECLARE()

private:
std::vector<Index> dims_;
};

struct GradVecOp final : TensorOperator<Cx, 5, 5>
Expand Down

0 comments on commit 23263f2

Please sign in to comment.