Skip to content

Commit

Permalink
Add residual outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
spinicist committed May 16, 2023
1 parent ab4ebc7 commit 75f13de
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 46 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# v0.11

- Fixed many bugs that crept into v0.10.
- Functionality to write out residual k-space/images
- Small but important tweaks to how the ADMM algorithm works including more sensible defaults.
- Added a through-time TV regularizer option for ADMM.
- Added a tool to calcuate a basis set from temporal images.
Expand Down
25 changes: 15 additions & 10 deletions src/cmd/admm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ int main_admm(args::Subparser &parser)

args::ValueFlag<std::string> pre(parser, "P", "Pre-conditioner (none/kspace/filename)", {"pre"}, "kspace");
args::ValueFlag<float> preBias(parser, "BIAS", "Pre-conditioner Bias (1)", {"pre-bias", 'b'}, 1.f);
args::ValueFlag<std::vector<float>, VectorReader<float>> basisScales(parser, "S", "Basis scales", {"basis-scales"});
args::ValueFlag<Index> inner_its(parser, "ITS", "Max inner iterations (2)", {"max-its"}, 2);
args::ValueFlag<float> atol(parser, "A", "Tolerance on A", {"atol"}, 1.e-6f);
args::ValueFlag<float> btol(parser, "B", "Tolerance on b", {"btol"}, 1.e-6f);
Expand Down Expand Up @@ -62,8 +61,7 @@ int main_admm(args::Subparser &parser)
Info const &info = traj.info();
auto recon = make_recon(coreOpts, sdcOpts, senseOpts, traj, reader);
auto M = make_kspace_pre(pre.Get(), recon->oshape, traj, ReadBasis(coreOpts.basisFile.Get()), preBias.Get());
auto N = make_scales_pre(basisScales.Get(), recon->ishape);
auto A = std::make_shared<LinOps::Multiply<Cx>>(recon, N);
std::shared_ptr<LinOps::Op<Cx>> A = recon;
auto const sz = recon->ishape;

std::function<void(Index const iter, ADMM::Vector const &x)> debug_x = [sz](Index const ii, ADMM::Vector const &x) {
Expand All @@ -76,7 +74,10 @@ int main_admm(args::Subparser &parser)
float const scale = Scaling(coreOpts.scaling, recon, M->adjoint(CChipMap(allData, 0)));
allData.device(Threads::GlobalDevice()) = allData * allData.constant(scale);
Index const volumes = allData.dimension(4);
Cx5 out(sz[0], outSz[0], outSz[1], outSz[2], volumes);
Cx5 out(sz[0], outSz[0], outSz[1], outSz[2], volumes), resid;
if (coreOpts.residImage) {
resid.resize(sz[0], outSz[0], outSz[1], outSz[2], volumes);
}

std::vector<std::shared_ptr<LinOps::Op<Cx>>> reg_ops;
std::vector<std::shared_ptr<Prox<Cx>>> prox;
Expand Down Expand Up @@ -144,14 +145,18 @@ int main_admm(args::Subparser &parser)
debug_x};
auto const &all_start = Log::Now();
for (Index iv = 0; iv < volumes; iv++) {
auto x = admm.run(&allData(0, 0, 0, 0, iv), ρ.Get());
if (ext_x) {
x = ext_x->forward(x);
auto x = ext_x->forward(admm.run(&allData(0, 0, 0, 0, iv), ρ.Get()));
auto xm = Tensorfy(x, sz);
out.chip<4>(iv) = out_cropper.crop4(xm) / out.chip<4>(iv).constant(scale);
if (coreOpts.residImage || coreOpts.residKSpace) {
allData.chip<4>(iv) -= recon->forward(xm);
}
if (coreOpts.residImage) {
xm = recon->adjoint(allData.chip<4>(iv));
resid.chip<4>(iv) = out_cropper.crop4(xm) / resid.chip<4>(iv).constant(scale);
}
x = N->forward(x);
out.chip<4>(iv) = out_cropper.crop4(Tensorfy(x, sz)) / out.chip<4>(iv).constant(scale);
}
Log::Print("All Volumes: {}", Log::ToNow(all_start));
WriteOutput(coreOpts, out, parser.GetCommand().Name(), traj, Log::Saved());
WriteOutput(coreOpts, out, parser.GetCommand().Name(), traj, Log::Saved(), resid, allData);
return EXIT_SUCCESS;
}
21 changes: 16 additions & 5 deletions src/cmd/cg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,26 @@ int main_cg(args::Subparser &parser)
float const scale = Scaling(coreOpts.scaling, recon, CChipMap(allData, 0));
allData.device(Threads::GlobalDevice()) = allData * allData.constant(scale);
Index const volumes = allData.dimension(4);
Cx5 out(sz[0], outSz[0], outSz[1], outSz[2], volumes);
Cx5 out(sz[0], outSz[0], outSz[1], outSz[2], volumes), resid;
if (coreOpts.residImage) {
resid.resize(sz[0], outSz[0], outSz[1], outSz[2], volumes);
}

auto const &all_start = Log::Now();
for (Index iv = 0; iv < volumes; iv++) {
auto const &vol_start = Log::Now();
auto b = recon->adjoint(CChipMap(allData, iv));
out.chip<4>(iv) = out_cropper.crop4(Tensorfy(cg.run(b.data()), sz)) / out.chip<4>(iv).constant(scale);
Log::Print("Volume {}: {}", iv, Log::ToNow(vol_start));
auto x = cg.run(b.data());
auto xm = Tensorfy(x, sz);
out.chip<4>(iv) = out_cropper.crop4(xm) / out.chip<4>(iv).constant(scale);
if (coreOpts.residImage || coreOpts.residKSpace) {
allData.chip<4>(iv) -= recon->forward(xm);
}
if (coreOpts.residImage) {
xm = recon->adjoint(allData.chip<4>(iv));
resid.chip<4>(iv) = out_cropper.crop4(xm) / resid.chip<4>(iv).constant(scale);
}
}
Log::Print("All Volumes: {}", Log::ToNow(all_start));
WriteOutput(coreOpts, out, parser.GetCommand().Name(), traj, Log::Saved());
WriteOutput(coreOpts, out, parser.GetCommand().Name(), traj, Log::Saved(), resid, allData);
return EXIT_SUCCESS;
}
38 changes: 22 additions & 16 deletions src/cmd/lsmr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ int main_lsmr(args::Subparser &parser)
args::ValueFlag<Index> its(parser, "N", "Max iterations (8)", {'i', "max-its"}, 8);
args::ValueFlag<std::string> pre(parser, "P", "Pre-conditioner (none/kspace/filename)", {"pre"}, "kspace");
args::ValueFlag<float> preBias(parser, "BIAS", "Pre-conditioner Bias (1)", {"pre-bias", 'b'}, 1.f);
args::ValueFlag<std::vector<float>, VectorReader<float>> basisScales(parser, "S", "Basis scales", {"basis-scales"});
args::ValueFlag<float> atol(parser, "A", "Tolerance on A (1e-6)", {"atol"}, 1.e-6f);
args::ValueFlag<float> btol(parser, "B", "Tolerance on b (1e-6)", {"btol"}, 1.e-6f);
args::ValueFlag<float> ctol(parser, "C", "Tolerance on cond(A) (1e-6)", {"ctol"}, 1.e-6f);
Expand All @@ -34,31 +33,38 @@ int main_lsmr(args::Subparser &parser)
HD5::Reader reader(coreOpts.iname.Get());
Trajectory traj(reader);
Info const &info = traj.info();
auto recon = make_recon(coreOpts, sdcOpts, senseOpts, traj, reader);
auto M = make_kspace_pre(pre.Get(), recon->oshape, traj, ReadBasis(coreOpts.basisFile.Get()), preBias.Get());
auto N = make_scales_pre(basisScales.Get(), recon->ishape);
auto A = std::make_shared<LinOps::Multiply<Cx>>(recon, N);
auto debug = [&recon](Index const i, LSMR::Vector const &x) {
Log::Tensor(fmt::format("lsmr-x-{:02d}", i), recon->ishape, x.data());
auto A = make_recon(coreOpts, sdcOpts, senseOpts, traj, reader);
auto M = make_kspace_pre(pre.Get(), A->oshape, traj, ReadBasis(coreOpts.basisFile.Get()), preBias.Get());
auto debug = [&A](Index const i, LSMR::Vector const &x) {
Log::Tensor(fmt::format("lsmr-x-{:02d}", i), A->ishape, x.data());
};
LSMR lsmr{A, M, its.Get(), atol.Get(), btol.Get(), ctol.Get(), debug};
auto sz = recon->ishape;
auto sz = A->ishape;
Cropper out_cropper(info.matrix, LastN<3>(sz), info.voxel_size, coreOpts.fov.Get());
Sz3 const outSz = out_cropper.size();
Cx5 allData = reader.readTensor<Cx5>(HD5::Keys::Noncartesian);
float const scale = Scaling(coreOpts.scaling, recon, M->adjoint(CChipMap(allData, 0)));
float const scale = Scaling(coreOpts.scaling, A, M->adjoint(CChipMap(allData, 0)));
allData.device(Threads::GlobalDevice()) = allData * allData.constant(scale);
Index const volumes = allData.dimension(4);
Cx5 out(sz[0], outSz[0], outSz[1], outSz[2], volumes);
Cx5 out(sz[0], outSz[0], outSz[1], outSz[2], volumes), resid;
if (coreOpts.residImage) {
resid.resize(sz[0], outSz[0], outSz[1], outSz[2], volumes);
}

auto const &all_start = Log::Now();
for (Index iv = 0; iv < volumes; iv++) {
auto const &vol_start = Log::Now();
out.chip<4>(iv).device(Threads::GlobalDevice()) =
out_cropper.crop4(Tensorfy(N->forward(lsmr.run(&allData(0, 0, 0, 0, iv), λ.Get())), recon->ishape)) /
out.chip<4>(iv).constant(scale);
Log::Print("Volume {}: {}", iv, Log::ToNow(vol_start));
auto x = lsmr.run(&allData(0, 0, 0, 0, iv), λ.Get());
auto xm = Tensorfy(x, sz);
out.chip<4>(iv) = out_cropper.crop4(xm) / out.chip<4>(iv).constant(scale);
if (coreOpts.residImage || coreOpts.residKSpace) {
allData.chip<4>(iv) -= A->forward(xm);
}
if (coreOpts.residImage) {
xm = A->adjoint(allData.chip<4>(iv));
resid.chip<4>(iv) = out_cropper.crop4(xm) / resid.chip<4>(iv).constant(scale);
}
}
Log::Print("All Volumes: {}", Log::ToNow(all_start));
WriteOutput(coreOpts, out, parser.GetCommand().Name(), traj, Log::Saved());
WriteOutput(coreOpts, out, parser.GetCommand().Name(), traj, Log::Saved(), resid, allData);
return EXIT_SUCCESS;
}
34 changes: 20 additions & 14 deletions src/cmd/lsqr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ int main_lsqr(args::Subparser &parser)
args::ValueFlag<Index> its(parser, "N", "Max iterations (8)", {'i', "max-its"}, 8);
args::ValueFlag<std::string> pre(parser, "P", "Pre-conditioner (none/kspace/filename)", {"pre"}, "kspace");
args::ValueFlag<float> preBias(parser, "BIAS", "Pre-conditioner Bias (1)", {"pre-bias", 'b'}, 1.f);
args::ValueFlag<std::vector<float>, VectorReader<float>> basisScales(parser, "S", "Basis scales", {"basis-scales"});
args::ValueFlag<float> atol(parser, "A", "Tolerance on A (1e-6)", {"atol"}, 1.e-6f);
args::ValueFlag<float> btol(parser, "B", "Tolerance on b (1e-6)", {"btol"}, 1.e-6f);
args::ValueFlag<float> ctol(parser, "C", "Tolerance on cond(A) (1e-6)", {"ctol"}, 1.e-6f);
Expand All @@ -33,28 +32,35 @@ int main_lsqr(args::Subparser &parser)
HD5::Reader reader(coreOpts.iname.Get());
Trajectory traj(reader);
Info const &info = traj.info();
auto recon = make_recon(coreOpts, sdcOpts, senseOpts, traj, reader);
auto M = make_kspace_pre(pre.Get(), recon->oshape, traj, ReadBasis(coreOpts.basisFile.Get()), preBias.Get());
auto N = make_scales_pre(basisScales.Get(), recon->ishape);
auto A = std::make_shared<LinOps::Multiply<Cx>>(recon, N);
auto A = make_recon(coreOpts, sdcOpts, senseOpts, traj, reader);
auto M = make_kspace_pre(pre.Get(), A->oshape, traj, ReadBasis(coreOpts.basisFile.Get()), preBias.Get());
LSQR lsqr{A, M, its.Get(), atol.Get(), btol.Get(), ctol.Get(), true};
auto sz = recon->ishape;
auto sz = A->ishape;
Cropper out_cropper(info.matrix, LastN<3>(sz), info.voxel_size, coreOpts.fov.Get());
Sz3 const outSz = out_cropper.size();
Cx5 allData = reader.readTensor<Cx5>(HD5::Keys::Noncartesian);
float const scale = Scaling(coreOpts.scaling, recon, M->adjoint(CChipMap(allData, 0)));
float const scale = Scaling(coreOpts.scaling, A, M->adjoint(CChipMap(allData, 0)));
allData.device(Threads::GlobalDevice()) = allData * allData.constant(scale);
Index const volumes = allData.dimension(4);
Cx5 out(sz[0], outSz[0], outSz[1], outSz[2], volumes);
Cx5 out(sz[0], outSz[0], outSz[1], outSz[2], volumes), resid;
if (coreOpts.residImage) {
resid.resize(sz[0], outSz[0], outSz[1], outSz[2], volumes);
}

auto const &all_start = Log::Now();
for (Index iv = 0; iv < volumes; iv++) {
auto const &vol_start = Log::Now();
out.chip<4>(iv).device(Threads::GlobalDevice()) =
out_cropper.crop4(Tensorfy(N->inverse(lsqr.run(&allData(0, 0, 0, 0, iv), λ.Get())), recon->ishape)) /
out.chip<4>(iv).constant(scale);
Log::Print("Volume {}: {}", iv, Log::ToNow(vol_start));
auto x = lsqr.run(&allData(0, 0, 0, 0, iv), λ.Get());
auto xm = Tensorfy(x, sz);
out.chip<4>(iv) = out_cropper.crop4(xm) / out.chip<4>(iv).constant(scale);
if (coreOpts.residImage || coreOpts.residKSpace) {
allData.chip<4>(iv) -= A->forward(xm);
}
if (coreOpts.residImage) {
xm = A->adjoint(allData.chip<4>(iv));
resid.chip<4>(iv) = out_cropper.crop4(xm) / resid.chip<4>(iv).constant(scale);
}
}
Log::Print("All Volumes: {}", Log::ToNow(all_start));
WriteOutput(coreOpts, out, parser.GetCommand().Name(), traj, Log::Saved());
WriteOutput(coreOpts, out, parser.GetCommand().Name(), traj, Log::Saved(), resid, allData);
return EXIT_SUCCESS;
}
2 changes: 2 additions & 0 deletions src/io/hd5-keys.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ std::string const Norm = "norm";
std::string const Parameters = "parameters";
std::string const Precond = "precond";
std::string const ProtonDensity = "pd";
std::string const ResidualImage = "resid-image";
std::string const ResidualKSpace = "resid-noncartesian";
std::string const SDC = "sdc";
std::string const SENSE = "sense";
std::string const Trajectory = "trajectory";
Expand Down
10 changes: 10 additions & 0 deletions src/parse_args.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ CoreOpts::CoreOpts(args::Subparser &parser)
, osamp(parser, "O", "Grid oversampling factor (2)", {'s', "osamp"}, 2.f)
, fov(parser, "FOV", "Final FoV in mm (default header value)", {"fov"}, -1)
, bucketSize(parser, "B", "Gridding bucket size (32)", {"bucket-size"}, 32)
, residImage(parser, "R", "Write residuals in image space", {"resid-image"})
, residKSpace(parser, "R", "Write residuals in k-space", {"resid-kspace"})
, keepTrajectory(parser, "", "Keep the trajectory in the output file", {"keep"})
{
}
Expand Down Expand Up @@ -173,6 +175,8 @@ void WriteOutput(
std::string const &suffix,
rl::Trajectory const &traj,
std::string const &log,
rl::Cx5 const &residImage,
rl::Cx5 const &residKSpace,
std::map<std::string, float> const &meta)
{
auto const fname = OutName(opts.iname.Get(), opts.oname.Get(), suffix, "h5");
Expand All @@ -185,4 +189,10 @@ void WriteOutput(
writer.writeInfo(traj.info());
}
writer.writeString("log", log);
if (opts.residImage) {
writer.writeTensor(HD5::Keys::ResidualImage, residImage.dimensions(), residImage.data());
}
if (opts.residKSpace) {
writer.writeTensor(HD5::Keys::ResidualKSpace, residKSpace.dimensions(), residKSpace.data());
}
}
4 changes: 3 additions & 1 deletion src/parse_args.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct CoreOpts
args::ValueFlag<std::string> oname, basisFile, ktype, scaling;
args::ValueFlag<float> osamp, fov;
args::ValueFlag<Index> bucketSize;
args::Flag keepTrajectory;
args::Flag residImage, residKSpace, keepTrajectory;
};

void WriteOutput(
Expand All @@ -57,5 +57,7 @@ void WriteOutput(
std::string const &suffix,
rl::Trajectory const &traj,
std::string const &log,
rl::Cx5 const &residImage = rl::Cx5(),
rl::Cx5 const &residKSpace = rl::Cx5(),
std::map<std::string, float> const &meta = std::map<std::string, float>());

0 comments on commit 75f13de

Please sign in to comment.