diff --git a/.gitignore b/.gitignore index 9d2f3c0..460b3f6 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ CMakeLists.txt.user.* .8.un~ docs/_build *.asv +build/ +.vscode/ \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 2a994cf..20c0e1d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -67,7 +67,17 @@ FIND_PACKAGE(ITK 5.2.0 REQUIRED) INCLUDE(${ITK_USE_FILE}) # VTK - required for quality mesh transformations -FIND_PACKAGE(VTK 9.1.0 REQUIRED COMPONENTS CommonCore IOCore IOLegacy IOPLY IOGeometry FiltersModeling) +FIND_PACKAGE(VTK 9.1.0 REQUIRED COMPONENTS + CommonCore + IOCore + IOLegacy + IOPLY + IOGeometry + IOImage + IOXML + FiltersCore + FiltersGeneral + FiltersModeling) SET(GREEDY_VTK_LIBRARIES ${VTK_LIBRARIES}) # Deal with FFTW - only used by experimental LDDMM code @@ -240,8 +250,15 @@ SET(GREEDY_API_LIBS greedyapi ${FFTWF_LIB} ${FFTWF_THREADS_LIB} ${SPARSE_LIBRARY}) +# propagation api +add_subdirectory(src/propagation propagation) + +SET(PROPAGATION_SRC + src/propagation/greedy_propagation.cxx +) + # List of installable targets -SET(GREEDY_INSTALL_BIN_TARGETS greedy greedy_template_average multi_chunk_greedy) +SET(GREEDY_INSTALL_BIN_TARGETS greedy greedy_template_average multi_chunk_greedy greedy_propagation) SET(GREEDY_INSTALL_LIB_TARGETS greedyapi) @@ -300,8 +317,13 @@ IF(BUILD_CLI) ADD_EXECUTABLE(test_greedy testing/src/GreedyTestDriver.cxx) TARGET_LINK_LIBRARIES(test_greedy ${GREEDY_API_LIBS}) + ADD_EXECUTABLE(test_propagation testing/src/propagation/propagation_test.cxx) + TARGET_LINK_LIBRARIES(test_propagation propagationapi) + TARGET_INCLUDE_DIRECTORIES(test_propagation PUBLIC ${GREEDY_SOURCE_DIR}/src/propagation ) + ADD_EXECUTABLE(multi_chunk_greedy ${CHUNK_GREEDY_SRC}) TARGET_LINK_LIBRARIES(multi_chunk_greedy ${GREEDY_API_LIBS}) + ENDIF(BUILD_CLI) # Install command-line executables @@ -421,6 +443,9 @@ IF(NOT GREEDY_BUILD_AS_SUBPROJECT) ADD_TEST(NAME "Phantom_NCC_Sim_NoMask" COMMAND test_greedy phantom 1 3 NCC 7 0 WORKING_DIRECTORY ${TESTING_DATADIR}) ADD_TEST(NAME "Phantom_WNCC_Sim_NoMask" COMMAND test_greedy phantom 1 3 WNCC 7 0 WORKING_DIRECTORY ${TESTING_DATADIR}) + # Tests for greedy_propagation + ADD_TEST(NAME "propagation_basic" COMMAND test_propagation basic WORKING_DIRECTORY ${TESTING_DATADIR}) + ADD_TEST(NAME "propagation_extra_mesh" COMMAND test_propagation extra_mesh WORKING_DIRECTORY ${TESTING_DATADIR}) # Add tests for lmshoot IF(GREEDY_BUILD_LMSHOOT) ADD_TEST(NAME "lmshoot_regression" COMMAND lmshoot_test shoot_regression.mat WORKING_DIRECTORY ${TESTING_DATADIR}/lmshoot) diff --git a/src/CommandLineHelper.h b/src/CommandLineHelper.h index 9f0dfb3..089f6e2 100644 --- a/src/CommandLineHelper.h +++ b/src/CommandLineHelper.h @@ -256,6 +256,21 @@ class CommandLineHelper return val; } + /** + * Read output directory path + */ + std::string read_output_dir() + { + std::string dir = read_arg(); + if(this->data_root.length()) + dir = itksys::SystemTools::CollapseFullPath(dir, data_root); + + if(!itksys::SystemTools::PathExists(dir.c_str())) + throw GreedyException("Folder '%s' does not exist", dir.c_str()); + + return dir; + } + /** * Check if a string ends with another string and return the * substring without the suffix @@ -360,27 +375,26 @@ class CommandLineHelper return vector; } - std::vector read_int_vector() + std::vector read_int_vector(char delimiter = 'x') { std::string arg = read_arg(); std::istringstream f(arg); std::string s; std::vector vector; - while (getline(f, s, 'x')) + while (getline(f, s, delimiter)) { errno = 0; char *pend; long val = std::strtol(s.c_str(), &pend, 10); if(errno || *pend) - throw GreedyException("Expected an integer vector as parameter to '%s', instead got '%s'", - current_command.c_str(), arg.c_str()); + throw GreedyException("Expected an integer vector delimited by '%c' as parameter to '%s', instead got '%s'", + delimiter, current_command.c_str(), arg.c_str()); vector.push_back((int) val); } if(!vector.size()) - throw GreedyException("Expected an integer vector as parameter to '%s', instead got '%s'", - current_command.c_str(), arg.c_str()); - + throw GreedyException("Expected an integer vector delimited by '%c' as parameter to '%s', instead got '%s'", + delimiter, current_command.c_str(), arg.c_str()); return vector; } diff --git a/src/GreedyAPI.cxx b/src/GreedyAPI.cxx index 275ac3f..5ad7415 100644 --- a/src/GreedyAPI.cxx +++ b/src/GreedyAPI.cxx @@ -416,6 +416,28 @@ ::ReadImageViaCache(const std::string &filename, return pointer; } +template +typename GreedyApproach::MeshPointer +GreedyApproach +::ReadMeshViaCache(const std::string &filename) +{ + typename MeshCache::const_iterator it = m_MeshCache.find(filename); + if(it != m_MeshCache.cend()) + { + vtkObject *cached_object = it->second.target; + MeshType *mesh = dynamic_cast(cached_object); + if (!mesh) + throw GreedyException("Cached mesh %s cannot be cast to type %s", + filename.c_str(), typeid(MeshType).name()); + MeshPointer pMesh = DeepCopyMesh(mesh); // important to avoid in-place mutation + return pMesh; + } + + // Read the mesh using mesh io reader + return ReadMesh(filename.c_str()); +} + + template template TObject * @@ -542,6 +564,27 @@ ::WriteImageViaCache(TImage *img, const std::string &filename, itk::IOComponentE } } +template +void +GreedyApproach +::WriteMeshViaCache(MeshType *mesh, const std::string &filename) +{ + typename MeshCache::const_iterator it = m_MeshCache.find(filename); + if (it != m_MeshCache.end()) + { + auto *cached = dynamic_cast(it->second.target); + if (!cached) + throw GreedyException("Cached mesh %s cannot be cast to %s", + filename.c_str(), typeid(MeshType*).name()); + cached->DeepCopy(mesh); + } + + if (it == m_MeshCache.end() || it->second.force_write) + { + WriteMesh(mesh, filename.c_str()); + } +} + #include @@ -1719,7 +1762,7 @@ ::RunDeformable(GreedyParameters ¶m) if(param.tjr_param.weight > 0.0) { // Read the mesh - vtkSmartPointer point_set = ReadMesh(param.tjr_param.tetra_mesh.c_str()); + vtkSmartPointer point_set = ReadMeshViaCache(param.tjr_param.tetra_mesh.c_str()); vtkSmartPointer tetra = dynamic_cast(point_set.GetPointer()); if(!tetra) throw GreedyException("Mesh %s is not an UnstructuredGrid!", param.tjr_param.tetra_mesh.c_str()); @@ -1905,8 +1948,9 @@ ::RunDeformable(GreedyParameters ¶m) tm_Gradient.Stop(); // Print a report for this iteration - std::cout << this->PrintIter(level, iter, metric_report, reg_report) << std::endl; - fflush(stdout); + std::string iter_line = this->PrintIter(level, iter, metric_report, reg_report); + gout.printf("%s\n", iter_line.c_str()); + gout.flush(); // Record the metric value in the log this->RecordMetricValue(metric_report); @@ -2279,7 +2323,7 @@ ::RunDeformableOptimization(GreedyParameters ¶m) if(param.tjr_param.weight > 0.0) { // Read the mesh - vtkSmartPointer point_set = ReadMesh(param.tjr_param.tetra_mesh.c_str()); + vtkSmartPointer point_set = ReadMeshViaCache(param.tjr_param.tetra_mesh.c_str()); vtkSmartPointer tetra = dynamic_cast(point_set.GetPointer()); if(!tetra) throw GreedyException("Mesh %s is not an UnstructuredGrid!", param.tjr_param.tetra_mesh.c_str()); @@ -2451,8 +2495,9 @@ ::RunDeformableOptimization(GreedyParameters ¶m) } // Print a report for this iteration - std::cout << this->PrintIter(level, iter, metric_report, reg_report) << std::endl; - fflush(stdout); + std::string iter_line = this->PrintIter(level, iter, metric_report, reg_report); + gout.printf("%s\n", iter_line.c_str()); + gout.flush(); // Record the metric value in the log this->RecordMetricValue(metric_report); @@ -3251,7 +3296,7 @@ ::RunReslice(GreedyParameters ¶m) std::vector meshes, original_meshes; for(unsigned int i = 0; i < r_param.meshes.size(); i++) { - vtkSmartPointer mesh = ReadMesh(r_param.meshes[i].fixed.c_str()); + vtkSmartPointer mesh = ReadMeshViaCache(r_param.meshes[i].fixed.c_str()); meshes.push_back(mesh); if(r_param.meshes[i].jacobian_mode) @@ -3425,7 +3470,7 @@ ::RunReslice(GreedyParameters ¶m) if(r_param.meshes[i].jacobian_mode) WriteJacobianMesh(original_meshes[i], meshes[i], r_param.meshes[i].output.c_str()); else - WriteMesh(meshes[i], r_param.meshes[i].output.c_str()); + WriteMeshViaCache(meshes[i], r_param.meshes[i].output.c_str()); } @@ -3739,6 +3784,15 @@ ::AddCachedInputObject(std::string key, itk::Object *object) m_ImageCache[key].force_write = false; } +template +void GreedyApproach +::AddCachedInputObject(std::string key, vtkObject *object) +{ + m_MeshCache[key].target = object; + m_MeshCache[key].force_write = false; +} + + template void GreedyApproach ::AddCachedOutputObject(std::string key, itk::Object *object, bool force_write) @@ -3768,6 +3822,14 @@ ::GetCachedObjectNames() const return keys; } +template +void GreedyApproach +::AddCachedOutputObject(std::string key, vtkObject *object, bool force_write) +{ + m_MeshCache[key].target = object; + m_MeshCache[key].force_write = force_write; +} + template const typename GreedyApproach::MetricLogType & GreedyApproach diff --git a/src/GreedyAPI.h b/src/GreedyAPI.h index 58b97c7..5fcf7dc 100644 --- a/src/GreedyAPI.h +++ b/src/GreedyAPI.h @@ -37,6 +37,7 @@ #include #include "itkCommand.h" #include +#include template class MultiImageOpticalFlowHelper; @@ -85,7 +86,8 @@ class GreedyApproach }; // Mesh data structures - typedef vtkSmartPointer MeshPointer; + typedef vtkPointSet MeshType; + typedef vtkSmartPointer MeshPointer; typedef std::vector MeshArray; static void ConfigThreads(const GreedyParameters ¶m); @@ -141,6 +143,7 @@ class GreedyApproach * */ void AddCachedInputObject(std::string key, itk::Object *object); + void AddCachedInputObject(std::string key, vtkObject *object); /** * Add an image/matrix to the output cache. This has the same behavior as @@ -153,6 +156,7 @@ class GreedyApproach * will be allocated. It can then me accessed using GetCachedObject() */ void AddCachedOutputObject(std::string key, itk::Object *object, bool force_write = false); + void AddCachedOutputObject(std::string key, vtkObject *object, bool force_write = false); /** * Get a cached object by name @@ -291,9 +295,17 @@ class GreedyApproach bool force_write; }; + struct VTKCacheEntry { + vtkObject *target; + bool force_write; + }; + typedef std::map ImageCache; ImageCache m_ImageCache; + typedef std::map MeshCache; + MeshCache m_MeshCache; + // A log of metric values used during registration - so metric can be looked up // in the callbacks to RunAffine, etc. MetricLogType m_MetricLog; @@ -311,6 +323,8 @@ class GreedyApproach itk::SmartPointer ReadImageViaCache(const std::string &filename, itk::IOComponentEnum *comp_type = NULL); + MeshPointer ReadMeshViaCache(const std::string &filename); + template TObject *CheckCache(const std::string &filename) const; // Get a filename for dumping intermediate outputs @@ -330,6 +344,8 @@ class GreedyApproach void WriteCompressedWarpInPhysicalSpaceViaCache( ImageBaseType *moving_ref_space, VectorImageType *warp, const char *filename, double precision); + void WriteMeshViaCache(MeshType *mesh, const std::string &filename); + // Compute the moments of a composite image (mean and covariance matrix of coordinate weighted by intensity) void ComputeImageMoments(CompositeImageType *image, const vnl_vector &weights, VecFx &m1, MatFx &m2); diff --git a/src/GreedyMeshIO.cxx b/src/GreedyMeshIO.cxx index 9db9596..515ac05 100644 --- a/src/GreedyMeshIO.cxx +++ b/src/GreedyMeshIO.cxx @@ -12,6 +12,8 @@ #include #include #include +#include +#include #include #include #include @@ -73,6 +75,8 @@ vtkSmartPointer ReadMeshByExtension(const char *fname) else throw GreedyException("No mesh reader for file %s", fname); } + else if(fn_str.rfind(".vtp") == fn_str.length() - 4) + return ReadMesh(fname); else throw GreedyException("No mesh reader for file %s", fname); } @@ -96,6 +100,11 @@ void WriteMeshByExtension(TMesh *mesh, const char *fname) else if (usg) WriteMesh(usg, fname); } + else if (fn_str.rfind(".vtp") == fn_str.length() - 4) + { + vtkPolyData *pd = dynamic_cast(mesh); + WriteMesh(pd, fname); + } else throw GreedyException("No mesh writer for file %s", fname); } diff --git a/src/GreedyParameters.cxx b/src/GreedyParameters.cxx index 127a6ba..198f526 100644 --- a/src/GreedyParameters.cxx +++ b/src/GreedyParameters.cxx @@ -28,6 +28,9 @@ #include "CommandLineHelper.h" #include +const SmoothingParameters GreedyParameters::default_sigma_pre = { 1.7320508076, false }; +const SmoothingParameters GreedyParameters::default_sigma_post = { 0.7071067812, false }; + GreedyParameters::GreedyParameters() { input_groups.push_back(GreedyInputGroup()); @@ -440,7 +443,7 @@ bool GreedyParameters::ParseCommandLine(const std::string &cmd, CommandLineHelpe int level = cl.read_integer(); if(level < 0 || level >= VERB_INVALID) throw GreedyException("Invalid verbosity level %d", level); - + this->verbosity = (Verbosity)(level); } else if(cmd == "-lbfgs-ftol") @@ -494,6 +497,65 @@ operator << (std::ostream &oss, const SmoothingParameters &sp) return oss; } +// Copy affine registration settings +void +GreedyParameters +::CopyAffineSettings(const GreedyParameters &other) +{ + this->affine_dof = other.affine_dof; + this->affine_init_mode = other.affine_init_mode; + this->rigid_search = other.rigid_search; + this->affine_jitter = other.affine_jitter; + this->metric = other.metric; + this->metric_radius = other.metric_radius; + this->iter_per_level = other.iter_per_level; +} + +// Copy deformable registration settings +void +GreedyParameters +::CopyDeformableSettings(const GreedyParameters &other) +{ + this->metric = other.metric; + this->metric_radius = other.metric_radius; + this->iter_per_level = other.iter_per_level; + this->epsilon_per_level = other.epsilon_per_level; + this->time_step_mode = other.time_step_mode; + this->warp_precision = other.warp_precision; +} + +// Copy reslicing settings +void +GreedyParameters +::CopyReslicingSettings(const GreedyParameters &other) +{ + this->current_interp = other.current_interp; +} + +// Copy general settings +void +GreedyParameters +::CopyGeneralSettings(const GreedyParameters &other) +{ + // Common Debug Settings + this->flag_debug_deriv = other.flag_debug_deriv; + this->deriv_epsilon = other.deriv_epsilon; + this->flag_debug_aff_obj = other.flag_debug_aff_obj; + this->flag_dump_pyramid = other.flag_dump_pyramid; + this->flag_dump_moving = other.flag_dump_moving; + this->dump_frequency = other.dump_frequency; + this->dump_prefix = other.dump_prefix; + this->flag_powell = other.flag_powell; + this->verbosity = other.verbosity; + + // General Settings + this->flag_float_math = other.flag_float_math; + this->dim = other.dim; + this->threads = other.threads; + this->sigma_pre = other.sigma_pre; + this->sigma_post = other.sigma_post; +} + std::string GreedyParameters::GenerateCommandLine() { // Generate default parameters diff --git a/src/GreedyParameters.h b/src/GreedyParameters.h index cc2eede..f3cf941 100644 --- a/src/GreedyParameters.h +++ b/src/GreedyParameters.h @@ -59,6 +59,14 @@ struct SmoothingParameters bool operator != (const SmoothingParameters &other) { return sigma != other.sigma || physical_units != other.physical_units; } + + bool operator == (const SmoothingParameters &other) { + return sigma == other.sigma && physical_units == other.physical_units; + } + + bool operator == (const SmoothingParameters &other) const { + return sigma == other.sigma && physical_units == other.physical_units; + } }; enum RigidSearchRotationMode @@ -110,6 +118,9 @@ struct ResliceSpec struct ResliceMeshSpec { + ResliceMeshSpec() {} + ResliceMeshSpec(const std::string &_fixed, const std::string &_output) + :fixed(_fixed), output(_output) {} std::string fixed; std::string output; bool jacobian_mode = false; @@ -128,6 +139,7 @@ struct TransformSpec : filename(in_filename), exponent(in_exponent) {} }; + enum AffineInitMode { VOX_IDENTITY = 0, // Identity mapping in voxel space @@ -232,7 +244,7 @@ class PerLevelSpec else return false; } - void Print(std::ostream &oss) const + void Print(std::ostream &oss) const { if(m_UseCommon) oss << m_CommonValue; @@ -353,8 +365,12 @@ struct GreedyParameters double background = 0.0; // Smoothing parameters - SmoothingParameters sigma_pre = { 1.7320508076, false }; - SmoothingParameters sigma_post = { 0.7071067812, false }; + // -- the static default value is used to detect whether the user has set the smoothing parameters + // -- this is used by propagation to override the greedy defaults with the propagation defaults + static const SmoothingParameters default_sigma_pre; + static const SmoothingParameters default_sigma_post; + SmoothingParameters sigma_pre = default_sigma_pre; + SmoothingParameters sigma_post = default_sigma_post; // Which metric to use MetricType metric = SSD; @@ -449,7 +465,7 @@ struct GreedyParameters // Save format for new reslice image pairs itk::IOComponentEnum current_reslice_format = itk::IOComponentEnum::UNKNOWNCOMPONENTTYPE; - + // Verbosity flag Verbosity verbosity = VERB_DEFAULT; @@ -476,6 +492,13 @@ struct GreedyParameters // Generate a command line for current parameters std::string GenerateCommandLine(); + + // Methods for copy settings + // -- This is useful for using same settings for different GreedyAPI runs + void CopyAffineSettings(const GreedyParameters &other); + void CopyDeformableSettings(const GreedyParameters &other); + void CopyReslicingSettings(const GreedyParameters &other); + void CopyGeneralSettings(const GreedyParameters &other); }; diff --git a/src/propagation/CMakeLists.txt b/src/propagation/CMakeLists.txt new file mode 100644 index 0000000..cad1277 --- /dev/null +++ b/src/propagation/CMakeLists.txt @@ -0,0 +1,22 @@ +message("Configurating Greedy Propagation...") + +set(PROPAGATION_LIB_SRC + PropagationAPI.cxx + PropagationIO.cxx + PropagationInputBuilder.cxx + PropagationTools.txx +) + +set(PROPAGATION_INCLUDE_DIR + ${CMAKE_CURRENT_SOURCE_DIR} + ${GREEDY_SOURCE_DIR}/src +) + +add_library(propagationapi ${PROPAGATION_LIB_SRC}) +target_link_libraries(propagationapi PUBLIC ${GREEDY_API_LIBS}) +target_include_directories(propagationapi PUBLIC ${PROPAGATION_INCLUDE_DIR}) + +if (BUILD_CLI) + add_executable(greedy_propagation main.cxx) + target_link_libraries(greedy_propagation PRIVATE propagationapi) +endif (BUILD_CLI) diff --git a/src/propagation/PropagationAPI.cxx b/src/propagation/PropagationAPI.cxx new file mode 100644 index 0000000..80904f4 --- /dev/null +++ b/src/propagation/PropagationAPI.cxx @@ -0,0 +1,932 @@ +#include "PropagationAPI.h" +#include "PropagationTools.h" +#include "GreedyAPI.h" +#include "PropagationData.hxx" +#include "PropagationIO.h" +#include "GreedyMeshIO.h" + +#include +#include + +#include +#include +#include +#include +#include +#include + +using namespace propagation; + +template +PropagationAPI +::PropagationAPI(const std::shared_ptr> input) +{ + m_PParam = input->m_PropagationParam; + m_GParam = input->m_GreedyParam; + m_Data = input->m_Data; + m_StdOut = std::make_shared(m_PParam.verbosity); + ValidateInputData(); +} + +template +void +PropagationAPI +::ValidateInputData() +{ + if (!m_Data->img4d) + throw GreedyException("Reference 4D Image Input not found!"); + + if (!m_Data->seg_ref) + throw GreedyException("Reference segmentation Image not found!"); + + + // Validate inputs + // -- ref tp has to be within the range of the tp + uint16_t nt = m_Data->img4d->GetLargestPossibleRegion().GetSize()[3]; + if (m_PParam.refTP <= 0 || m_PParam.refTP > nt) + throw GreedyException("Reference tp %d is out of the tp range of the 4d image (1 to %d)", m_PParam.refTP, nt); + + + for (size_t tp : m_PParam.targetTPs) + if (tp <=0 || tp > nt) throw GreedyException("Target tp %d is out of the tp range of the 4d image (1 to %d)",tp, nt); +} + +template +PropagationAPI +::~PropagationAPI() +{ +} + +template +int +PropagationAPI +::Run() +{ + PropagationData pData; + + if (m_PParam.debug) + { + m_StdOut->printf("-- [Propagation] Debug Mode is ON \n"); + m_StdOut->printf("-- [Propagation] Debug Output Dir: %s \n", m_PParam.debug_dir.c_str()); + } + + PrepareTimePointData(); // Prepare initial timepoint data for propagation + CreateTimePointLists(); + + m_StdOut->printf("-- [Propagation] forward list: "); + for (auto tp : m_ForwardTPs) + m_StdOut->printf(" %d", tp); + m_StdOut->printf("\n"); + + m_StdOut->printf("-- [Propagation] backward list: "); + for (auto tp : m_BackwardTPs) + m_StdOut->printf(" %d", tp); + m_StdOut->printf("\n"); + + // Run forward propagation + if (m_ForwardTPs.size() > 1) + RunUnidirectionalPropagation(m_ForwardTPs); + + // Run backward propagation + if (m_BackwardTPs.size() > 1) + RunUnidirectionalPropagation(m_BackwardTPs); + + // Write out a 4D Segmentation + Generate4DSegmentation(); + + m_StdOut->printf("Run Completed! \n"); + return EXIT_SUCCESS; +} + +template +void +PropagationAPI +::ValidateInputOrientation() +{ + typename TImage3D::DirectionType img_direction, seg_direction; + img_direction = m_Data->tp_data[m_PParam.refTP].img->GetDirection(); + seg_direction = m_Data->seg_ref->GetDirection(); + + if (img_direction != seg_direction) + { + std::cerr << "Image Direction: " << std::endl << img_direction << std::endl; + std::cerr << "Segmentation Direction: " << std::endl << seg_direction << std::endl; + std::string fn_seg = m_PParam.use4DSegInput ? m_PParam.fn_seg4d : m_PParam.fn_seg3d; + throw GreedyException("Image and Segmentation orientations do not match. Segmentation file %s\n", + fn_seg.c_str()); + } +} + +template +void +PropagationAPI +::CreateReferenceMask() +{ + // Threshold, Dilate and Resample + auto thr_tail = PTools::template ThresholdImage(m_Data->seg_ref, 1, SHRT_MAX, 1, 0); + auto dlt_tail = PTools::template DilateImage(thr_tail, 10, 1); + m_Data->tp_data[m_PParam.refTP].seg_srs = PTools:: + template Resample3DImage(dlt_tail, 0.5, ResampleInterpolationMode::NearestNeighbor); + + // Create object name + m_Data->tp_data[m_PParam.refTP].seg_srs-> + SetObjectName(GenerateUnaryTPObjectName("mask_", m_PParam.refTP, nullptr, "_srs")); +} + +template +void +PropagationAPI +::PrepareTimePointData() +{ + m_StdOut->printf("-- [Propagation] Preparing Time Point Data \n"); + + std::vector tps(m_PParam.targetTPs); + if (std::find(tps.begin(), tps.end(), m_PParam.refTP) == tps.end()) + tps.push_back(m_PParam.refTP); // Add refTP to the tps to process + + for (size_t tp : tps) + { + TimePointDatatpData; + + // Extract full res image + tpData.img = PTools::template ExtractTimePointImage(m_Data->img4d, tp); + tpData.img->SetObjectName(GenerateUnaryTPObjectName("img_", tp)); + + // Generate resampled image + tpData.img_srs = PTools::template Resample3DImage(tpData.img, 0.5, ResampleInterpolationMode::Linear, 1); + m_Data->tp_data[tp] = tpData; + tpData.img_srs->SetObjectName(GenerateUnaryTPObjectName("img_", tp, nullptr, "_srs")); + } + + // Reference TP Segmentation + m_Data->tp_data[m_PParam.refTP].seg = m_Data->seg_ref; + m_Data->tp_data[m_PParam.refTP].seg->SetObjectName(GenerateUnaryTPObjectName("seg_", m_PParam.refTP)); + m_Data->tp_data[m_PParam.refTP].seg_mesh = PTools::GetMeshFromLabelImage(m_Data->seg_ref); + + // Write out the reference mesh + if (m_PParam.writeOutputToDisk) + { + WriteMesh(m_Data->tp_data[m_PParam.refTP].seg_mesh, + GenerateUnaryTPFileName(m_PParam.fnmeshout_pattern.c_str(), m_PParam.refTP, + m_PParam.outdir.c_str(), ".vtk").c_str()); + } + + ValidateInputOrientation(); + CreateReferenceMask(); + + // Debug: write out extracted tp images + if (m_PParam.debug) + { + for (auto &kv : m_Data->tp_data) + { + PTools::template WriteImage(kv.second.img, + GenerateUnaryTPObjectName("img_", kv.first, m_PParam.debug_dir.c_str(), nullptr, ".nii.gz")); + + PTools::template WriteImage(kv.second.img_srs, + GenerateUnaryTPObjectName("img_", kv.first, m_PParam.debug_dir.c_str(), "_srs", ".nii.gz")); + } + + PTools::template WriteImage(m_Data->tp_data[m_PParam.refTP].seg_srs, + GenerateUnaryTPObjectName("mask_", m_PParam.refTP, m_PParam.debug_dir.c_str(), "_srs", ".nii.gz")); + } +} + +template +void +PropagationAPI +::CreateTimePointLists() +{ + m_ForwardTPs.push_back(m_PParam.refTP); + m_BackwardTPs.push_back(m_PParam.refTP); + + for (unsigned int tp : m_PParam.targetTPs) + { + if (tp < m_PParam.refTP) + m_BackwardTPs.push_back(tp); + else if (tp == m_PParam.refTP) + continue; // ignore reference tp in the target list + else + m_ForwardTPs.push_back(tp); + } + + std::sort(m_ForwardTPs.begin(), m_ForwardTPs.end()); + std::sort(m_BackwardTPs.rbegin(), m_BackwardTPs.rend()); // sort backward reversely +} + +template +void +PropagationAPI +::RunUnidirectionalPropagation(const std::vector &tp_list) +{ + + m_StdOut->printf("-- [Propagation] Unidirectional Propagation for tp_list: "); + for (auto tp : tp_list) + m_StdOut->printf(" %d", tp); + m_StdOut->printf("\n"); + + RunDownSampledPropagation(tp_list); // Generate affine matrices and masks + GenerateFullResolutionMasks(tp_list); // Reslice downsampled masks to full-res + GenerateReferenceSpace(tp_list); // Generate reference space for faster run + + // Run reg between ref and target tp and warp reference segmentation to target segmentation + for (size_t crnt = 1; crnt < tp_list.size(); ++crnt) // this can be parallelized + { + RunFullResolutionPropagation(tp_list[crnt]); + } +} + +template +void +PropagationAPI +::RunFullResolutionPropagation(const unsigned int target_tp) +{ + m_StdOut->printf("-- [Propagation] Running Full Resolution Propagation from %02d to %02d\n", + m_PParam.refTP, target_tp); + + // Run Deformable Reg from target to ref + RunPropagationDeformable(target_tp, m_PParam.refTP, true); + + // Warp ref segmentation to target + RunPropagationReslice(m_PParam.refTP, target_tp, true); + + // Warp ref segmentation mesh to target + RunPropagationMeshReslice(m_PParam.refTP, target_tp); + + // Copy extra meshes to reference time point data + auto &ref_data = m_Data->tp_data[m_PParam.refTP]; + for (auto kv : m_Data->extra_mesh_cache) + { + ref_data.AddExtraMesh(kv.first, kv.second); + } +} + +template +void +PropagationAPI +::GenerateReferenceSpace(const std::vector &tp_list) +{ + // Add full-resolution masks together for trimming + using TAddFilter = itk::AddImageFilter; + auto fltAdd = TAddFilter::New(); + fltAdd->SetInput1(m_Data->tp_data[tp_list[0]].full_res_mask); + fltAdd->SetInput2(m_Data->tp_data[tp_list[1]].full_res_mask); + fltAdd->Update(); + auto img_tail = fltAdd->GetOutput(); + + for (size_t i = 2; i < tp_list.size(); ++i) + { + fltAdd->SetInput1(img_tail); + fltAdd->SetInput2(m_Data->tp_data[tp_list[i]].full_res_mask); + fltAdd->Update(); + img_tail = fltAdd->GetOutput(); + } + + typename TLabelImage3D::RegionType roi; + auto trimmed = PTools::TrimLabelImage(img_tail, 5, roi); + + // Move trimmed image to the roi region + trimmed->SetRegions(roi); + typename TLabelImage3D::PointType origin; + for (int i = 0; i < 3; ++i) + { + origin.SetElement(i, roi.GetIndex().GetElement(i)); + } + trimmed->SetOrigin(origin); + + auto ref_space = PTools::CastLabelToRealImage(trimmed); + + m_Data->full_res_ref_space = ref_space; + + if (m_PParam.debug) + { + std::ostringstream fnrs; + fnrs << m_PParam.debug_dir << PTools::GetPathSeparator() + << "full_res_reference_space.nii.gz"; + PTools::template WriteImage(ref_space, fnrs.str()); + } +} + +template +void +PropagationAPI +::GenerateFullResolutionMasks(const std::vector &tp_list) +{ + + m_StdOut->printf("-- [Propagation] Generating Full Resolution Masks \n"); + + for (size_t i = 0; i < tp_list.size(); ++i) + { + const unsigned int tp = tp_list[i]; + TimePointData &tp_data = m_Data->tp_data[tp]; + tp_data.full_res_mask = PropagationTools::ResliceLabelImageWithIdentityMatrix(tp_data.img, tp_data.seg_srs); + std::string fnmask = GenerateUnaryTPObjectName("mask_", tp); + if (m_PParam.debug) + { + fnmask = GenerateUnaryTPObjectName("mask_", tp, m_PParam.debug_dir.c_str(), nullptr, ".nii.gz"); + PTools::template WriteImage(tp_data.full_res_mask, fnmask); + } + tp_data.full_res_mask->SetObjectName(fnmask); + } +} + +template +void +PropagationAPI +::RunDownSampledPropagation(const std::vector &tp_list) +{ + m_StdOut->printf("-- [Propagation] Down Sampled Propagation started \n"); + + for (size_t i = 1; i < tp_list.size(); ++i) + { + unsigned int c = tp_list[i], p = tp_list[i - 1]; // current tp and previous tp + RunPropagationAffine(p, c); // affine reg current to prev + RunPropagationDeformable(p, c, false); // deformable reg current to prev + BuildTransformChainForReslice(p, c); // build transformation chain for current tp + RunPropagationReslice(m_PParam.refTP, c, false); // warp ref mask to current + } +} + +template +void +PropagationAPI +::RunPropagationAffine(unsigned int tp_fix, unsigned int tp_mov) +{ + m_StdOut->printf("-- [Propagation] Running Affine %02d to %02d \n", tp_mov, tp_fix); + TimePointData &df = m_Data->tp_data[tp_fix], &dm = m_Data->tp_data[tp_mov]; + + // Create a new GreedyAPI for affine run and configure + std::shared_ptr> GreedyAPI = std::make_shared>(); + + GreedyInputGroup ig; + ImagePairSpec ip; + ip.weight = 1.0; + + auto img_fix = df.img_srs; + auto img_mov = dm.img_srs; + + ip.fixed = img_fix->GetObjectName(); + ip.moving = img_mov->GetObjectName(); + ig.inputs.push_back(ip); + + typename TCompositeImage3D::Pointer casted_fix = PTools::CastImageToCompositeImage(img_fix); + typename TCompositeImage3D::Pointer casted_mov = PTools::CastImageToCompositeImage(img_mov); + + GreedyAPI->AddCachedInputObject(ip.fixed, casted_fix); + GreedyAPI->AddCachedInputObject(ip.moving, casted_mov); + + // Set dilated fix seg as mask + auto mask_fix = df.seg_srs; + ig.fixed_mask = mask_fix->GetObjectName(); + auto casted_mask = PTools::CastLabelToRealImage(mask_fix); + GreedyAPI->AddCachedInputObject(ig.fixed_mask, casted_mask); + + // Configure greedy parameters + GreedyParameters param; + param.mode = GreedyParameters::AFFINE; + param.CopyGeneralSettings(m_GParam); // copy general settings from user input + param.CopyAffineSettings(m_GParam); // copy affine settings from user input + // Override global default settings with propagation specific setting + param.affine_init_mode = AffineInitMode::RAS_IDENTITY; + param.affine_dof = GreedyParameters::DOF_RIGID; + + // Check smoothing parameters. If greedy default detected, change to propagation default. + // -- This is to ensure if user has not set the smoothing parameters, the propagation defaults are used instead of greedy defaults + // -- and at the same time user can still override the defaults if needed. + const SmoothingParameters prop_default_pre = { 3.0, true }, prop_default_post = { 1.5, true }; + param.sigma_pre = (m_GParam.sigma_pre == GreedyParameters::default_sigma_pre) ? prop_default_pre : m_GParam.sigma_pre; + param.sigma_post = (m_GParam.sigma_post == GreedyParameters::default_sigma_post) ? prop_default_post : m_GParam.sigma_post; + + // Add the input group to the parameters + param.input_groups.clear(); + param.input_groups.push_back(ig); + + // Configure output + bool force_write = false; + param.output = GenerateBinaryTPObjectName("affine_", tp_mov, tp_fix); + + if (m_PParam.debug) + { + force_write = true; + param.output = GenerateBinaryTPObjectName("affine_", tp_mov, tp_fix, + m_PParam.debug_dir.c_str(), ".mat"); + } + + m_StdOut->printf("-- [Propagation] Affine Command: %s \n", param.GenerateCommandLine().c_str()); + + dm.affine_to_prev->SetObjectName(param.output); + GreedyAPI->AddCachedOutputObject(param.output, dm.affine_to_prev, force_write); + + int ret = GreedyAPI->RunAffine(param); + + if (ret != 0) + throw GreedyException("GreedyAPI execution failed in Proapgation Affine Run: tp_fix = %d, tp_mov = %d", + tp_fix, tp_mov); +} + +template +void +PropagationAPI +::RunPropagationDeformable(unsigned int tp_fix, unsigned int tp_mov, bool isFullRes) +{ + m_StdOut->printf("-- [Propagation] Running %s Deformable %02d to %02d \n", + isFullRes ? "Full-resolution" : "Down-sampled", tp_mov, tp_fix); + + // Get relevant tp data + TimePointData &tpdata_fix = m_Data->tp_data[tp_fix]; + TimePointData &tpdata_mov = m_Data->tp_data[tp_mov]; + + // Set greedy parameters + std::shared_ptr> GreedyAPI = std::make_shared>(); + GreedyParameters param; + param.mode = GreedyParameters::GREEDY; + param.CopyDeformableSettings(m_GParam); + param.CopyGeneralSettings(m_GParam); + + // Set input images + GreedyInputGroup ig; + ImagePairSpec ip; + ip.weight = 1.0; + auto img_fix = isFullRes ? tpdata_fix.img : tpdata_fix.img_srs; + auto img_mov = isFullRes ? tpdata_mov.img : tpdata_mov.img_srs; + ip.fixed = img_fix->GetObjectName(); + ip.moving = img_mov->GetObjectName(); + + typename TCompositeImage3D::Pointer casted_fix = PTools::CastImageToCompositeImage(img_fix); + typename TCompositeImage3D::Pointer casted_mov = PTools::CastImageToCompositeImage(img_mov); + GreedyAPI->AddCachedInputObject(ip.fixed, casted_fix); + GreedyAPI->AddCachedInputObject(ip.moving, casted_mov); + ig.inputs.push_back(ip); + + // Set mask images + auto mask_fix = isFullRes ? tpdata_fix.full_res_mask : tpdata_fix.seg_srs; + ig.fixed_mask = mask_fix->GetObjectName(); + auto casted_mask = PTools::CastLabelToRealImage(mask_fix); + GreedyAPI->AddCachedInputObject(ig.fixed_mask, casted_mask); + + // Check smoothing parameters. If greedy default detected, change to propagation default. + const SmoothingParameters prop_default_pre = { 3.0, true }, prop_default_post = { 1.5, true }; + param.sigma_pre = (m_GParam.sigma_pre == GreedyParameters::default_sigma_pre) ? + prop_default_pre : m_GParam.sigma_pre; + param.sigma_post = (m_GParam.sigma_post == GreedyParameters::default_sigma_post) ? + prop_default_post : m_GParam.sigma_post; + + // Configure output + bool force_write = false; // Write out images for debugging + const char *suffix = isFullRes ? "" : "_srs"; + param.output = GenerateBinaryTPObjectName("warp_", tp_mov, tp_fix, nullptr, suffix); + param.inverse_warp = GenerateBinaryTPObjectName("warp_", tp_fix, tp_mov, nullptr, suffix); + + if (m_PParam.debug) + { + force_write = true; + param.output = GenerateBinaryTPObjectName("warp_", tp_mov, tp_fix, + m_PParam.debug_dir.c_str(), suffix, ".nii.gz"); + param.inverse_warp = GenerateBinaryTPObjectName("warp_", tp_fix, tp_mov, + m_PParam.debug_dir.c_str(), suffix, ".nii.gz"); + } + + using LDDMM3DType = LDDMMData; + + if (isFullRes) + { + // Set the transformation chain + for (size_t i = 0; i < tpdata_fix.transform_specs.size(); ++i) + { + auto &trans_spec = tpdata_fix.transform_specs[i]; + std::string affine_id = trans_spec.affine->GetObjectName(); + ig.moving_pre_transforms.push_back(TransformSpec(affine_id, -1.0)); + GreedyAPI->AddCachedInputObject(affine_id, trans_spec.affine.GetPointer()); + } + + // Set output objects + tpdata_fix.deform_from_ref = LDDMM3DType::new_vimg(tpdata_fix.img); + tpdata_fix.deform_from_ref->SetObjectName(param.output); + GreedyAPI->AddCachedOutputObject(param.output, tpdata_fix.deform_from_ref, force_write); + + tpdata_fix.deform_to_ref = LDDMM3DType::new_vimg(tpdata_fix.img); + tpdata_fix.deform_to_ref->SetObjectName(param.inverse_warp); + GreedyAPI->AddCachedOutputObject(param.inverse_warp, tpdata_fix.deform_to_ref, force_write); + } + else + { + // Set Initial affine transform + std::string it_name = tpdata_mov.affine_to_prev->GetObjectName(); + ig.moving_pre_transforms.push_back(TransformSpec(it_name, 1.0)); + GreedyAPI->AddCachedInputObject(it_name, tpdata_mov.affine_to_prev); + + // Set output objects + tpdata_mov.deform_to_prev = LDDMM3DType::new_vimg(tpdata_mov.img_srs); + tpdata_mov.deform_to_prev->SetObjectName(param.output); + GreedyAPI->AddCachedOutputObject(param.output, tpdata_mov.deform_to_prev, force_write); + + tpdata_mov.deform_from_prev = LDDMM3DType::new_vimg(tpdata_mov.img_srs); + tpdata_mov.deform_from_prev->SetObjectName(param.inverse_warp); + GreedyAPI->AddCachedOutputObject(param.inverse_warp, tpdata_mov.deform_from_prev, force_write); + } + + // Add the input group to the parameters + param.input_groups.clear(); + param.input_groups.push_back(ig); + + m_StdOut->printf("-- [Propagation] Deformable Command: %s \n", param.GenerateCommandLine().c_str()); + + int ret = GreedyAPI->RunDeformable(param); + + if (ret != 0) + throw GreedyException("GreedyAPI execution failed in Proapgation Deformable Run:" + " tp_fix = %d, tp_mov = %d, isFulRes = %d", + tp_fix, tp_mov, isFullRes); +} + +template +void +PropagationAPI +::RunPropagationReslice(unsigned int tp_in, unsigned int tp_out, bool isFullRes) +{ + m_StdOut->printf("-- [Propagation] Running %s Reslice %02d to %02d \n", + isFullRes ? "Full-resolution" : "Down-sampled", tp_in, tp_out); + + TimePointData &tpdata_in = m_Data->tp_data[tp_in]; + TimePointData &tpdata_out = m_Data->tp_data[tp_out]; + + // API and parameter configuration + using GreedyAPIType = GreedyApproach<3u, TReal>; + std::shared_ptr GreedyAPI = std::make_shared(); + GreedyParameters param; + param.mode = GreedyParameters::RESLICE; + param.CopyGeneralSettings(m_GParam); + param.CopyReslicingSettings(m_GParam); + + // Set reference image + auto img_ref = isFullRes ? tpdata_out.img : tpdata_out.img_srs; + param.reslice_param.ref_image = img_ref->GetObjectName(); + auto casted_ref = PTools::CastImageToCompositeImage(img_ref); + GreedyAPI->AddCachedInputObject(param.reslice_param.ref_image, casted_ref.GetPointer()); + + // Set input image + auto img_in = isFullRes ? tpdata_in.seg : tpdata_in.seg_srs; + std::string imgin_name = img_in->GetObjectName(); + auto casted_mov = PropagationTools + ::template CastToCompositeImage>(img_in); + GreedyAPI->AddCachedInputObject(imgin_name, casted_mov.GetPointer()); + + // Set output image + std::string imgout_name; + bool force_write = false; + if (isFullRes) + { + force_write = m_PParam.writeOutputToDisk; + imgout_name = GenerateUnaryTPFileName(m_PParam.fnsegout_pattern.c_str(), + tp_out, m_Data->outdir.c_str(), ".nii.gz"); + } + else if (m_PParam.debug) + { + force_write = true; + imgout_name = GenerateUnaryTPObjectName("mask_", tp_out, m_PParam.debug_dir.c_str(), "_srs", ".nii.gz"); + } + else // non debug, non full-res + { + imgout_name = GenerateUnaryTPObjectName("mask_", tp_out, nullptr, "_srs"); + } + + auto img_out = TLabelImage3D::New(); // create a new empty image + img_out->SetObjectName(imgout_name); + if (isFullRes) + tpdata_out.seg = img_out.GetPointer(); + else + tpdata_out.seg_srs = img_out.GetPointer(); + GreedyAPI->AddCachedOutputObject(imgout_name, img_out.GetPointer(), force_write); + + // Make a reslice spec with input-output pair and push to the parameter + ResliceSpec rspec(imgin_name, imgout_name, m_PParam.reslice_spec); + param.reslice_param.images.push_back(rspec); + + // Build transformation chain + if (isFullRes) + { + // Prepend deformation field before all affine matrices for full-res reslice + std::string deform_id = tpdata_out.deform_from_ref->GetObjectName(); + param.reslice_param.transforms.push_back(TransformSpec(deform_id)); + GreedyAPI->AddCachedInputObject(deform_id, tpdata_out.deform_from_ref.GetPointer()); + } + + for (size_t i = 0; i < tpdata_out.transform_specs.size(); ++i) + { + auto &trans_spec = tpdata_out.transform_specs[i]; + std::string affine_id = trans_spec.affine->GetObjectName(); + param.reslice_param.transforms.push_back(TransformSpec(affine_id, -1.0)); + GreedyAPI->AddCachedInputObject(affine_id, trans_spec.affine.GetPointer()); + + // Append deformation field to each affine matrix for downsampled propagation + if (!isFullRes) + { + std::string deform_id = trans_spec.deform->GetObjectName(); + param.reslice_param.transforms.push_back(TransformSpec(deform_id)); + GreedyAPI->AddCachedInputObject(deform_id, trans_spec.deform.GetPointer()); + } + } + + m_StdOut->printf("-- [Propagation] Reslice Command: %s \n", param.GenerateCommandLine().c_str()); + + int ret = GreedyAPI->RunReslice(param); + + if (ret != 0) + throw GreedyException("GreedyAPI execution failed in Proapgation Reslice Run: " + "tp_in = %d, tp_out = %d, isFulRes = %d", + tp_in, tp_out, isFullRes); +} + +template +void +PropagationAPI +::RunPropagationMeshReslice(unsigned int tp_in, unsigned int tp_out) +{ + m_StdOut->printf("-- [Propagation] Running Mesh Reslice %02d to %02d \n", tp_in, tp_out); + + TimePointData &tpdata_in = m_Data->tp_data[tp_in]; + TimePointData &tpdata_out = m_Data->tp_data[tp_out]; + + // API and parameter configuration + using GreedyAPIType = GreedyApproach<3u, TReal>; + std::shared_ptr GreedyAPI = std::make_shared(); + GreedyParameters param; + param.mode = GreedyParameters::RESLICE; + param.CopyGeneralSettings(m_GParam); + param.CopyReslicingSettings(m_GParam); + + // Set reference image + auto img_ref = tpdata_out.img; + param.reslice_param.ref_image = img_ref->GetObjectName(); + auto casted_ref = PTools::CastImageToCompositeImage(img_ref); + GreedyAPI->AddCachedInputObject(param.reslice_param.ref_image, casted_ref.GetPointer()); + + // Set input mesh + auto mesh_in = tpdata_in.seg_mesh; + std::string mesh_in_name = GenerateUnaryTPObjectName("mesh_", tp_in, nullptr, nullptr, ".vtk"); + GreedyAPI->AddCachedInputObject(mesh_in_name, mesh_in); + + // Set output image + std::string mesh_out_name = + GenerateUnaryTPFileName(m_PParam.fnmeshout_pattern.c_str(), tp_out, m_PParam.outdir.c_str(), ".vtk"); + + // Make a reslice spec with input-output pair and push to the parameter + ResliceMeshSpec rmspec(mesh_in_name, mesh_out_name); + param.reslice_param.meshes.push_back(rmspec); + tpdata_out.seg_mesh = TPropagationMesh::New(); + GreedyAPI->AddCachedOutputObject(mesh_out_name, tpdata_out.seg_mesh, m_PParam.writeOutputToDisk); + + // Add extra meshes to warp + for (auto &mesh_spec : m_PParam.extra_mesh_list) + { + ResliceMeshSpec rms; + + if (mesh_spec.cached) + { + auto tag = mesh_spec.fnout_pattern; + rms.fixed = tag; + + // add input to cache + auto mesh_in = m_Data->extra_mesh_cache[tag]; + GreedyAPI->AddCachedInputObject(tag, mesh_in); + + // configure output + tpdata_out.AddExtraMesh(tag, TPropagationMesh::New()); + std::string out_name = GenerateBinaryTPObjectName(tag.c_str(), tp_in, tp_out, + nullptr, nullptr, nullptr); + GreedyAPI->AddCachedOutputObject(out_name, tpdata_out.GetExtraMesh(tag), false); + rms.output = out_name; + } + else + { + auto pattern = mesh_spec.fnout_pattern; + rms.fixed = mesh_spec.fn_mesh; + std::string fn_mesh_ref = GenerateUnaryTPFileName(pattern.c_str(), m_PParam.refTP, + m_PParam.outdir.c_str(), ".vtk"); + itksys::SystemTools::CopyAFile(mesh_spec.fn_mesh, fn_mesh_ref); + rms.output = GenerateUnaryTPFileName(pattern.c_str(), tp_out, m_PParam.outdir.c_str(), ".vtk"); + } + + param.reslice_param.meshes.push_back(rms); + } + + // Build transformation chain + + // -- compose affine chain + for (int i = tpdata_out.transform_specs.size() - 1; i >= 0; --i) + { + auto &trans_spec = tpdata_out.transform_specs[i]; + std::string affine_id = trans_spec.affine->GetObjectName(); + param.reslice_param.transforms.push_back(TransformSpec(affine_id)); + GreedyAPI->AddCachedInputObject(affine_id, trans_spec.affine.GetPointer()); + } + + // -- append deformation field to the affine chain + std::string deform_id = tpdata_out.deform_to_ref->GetObjectName(); + param.reslice_param.transforms.push_back(TransformSpec(deform_id)); + GreedyAPI->AddCachedInputObject(deform_id, tpdata_out.deform_to_ref.GetPointer()); + + m_StdOut->printf("-- [Propagation] Mesh Reslice Command: %s \n", param.GenerateCommandLine().c_str()); + + int ret = GreedyAPI->RunReslice(param); + + if (ret != 0) + throw GreedyException("GreedyAPI execution failed in Proapgation Mesh Reslice Run: tp_in = %d, tp_out = %d", + tp_in, tp_out); +} + +template +void +PropagationAPI +::BuildTransformChainForReslice(unsigned int tp_prev, unsigned int tp_crnt) +{ + m_StdOut->printf("-- [Propagation] Building reslicing transformation chain for tp: %02d\n", tp_crnt); + + TimePointData &tpdata_crnt = m_Data->tp_data[tp_crnt]; + TimePointData &tpdata_prev = m_Data->tp_data[tp_prev]; + + // Copy previous transform specs as a starting point + for (auto &spec : tpdata_prev.transform_specs) + tpdata_crnt.transform_specs.push_back(spec); + + // Get current transformations + auto affine = tpdata_crnt.affine_to_prev; + auto deform = tpdata_crnt.deform_from_prev; + + // Build spec and append to existing list + TimePointTransformSpec spec(affine, deform, tp_crnt); + tpdata_crnt.transform_specs.push_back(spec); +} + +template +void +PropagationAPI +::Generate4DSegmentation() +{ + auto fltJoin = itk::JoinSeriesImageFilter::New(); + + for (size_t i = 1; i <= m_Data->GetNumberOfTimePoints(); ++i) + { + if (m_Data->tp_data.count(i)) + { + // Append slice to the list + fltJoin->PushBackInput(m_Data->tp_data[i].seg); + } + else + { + // Append an empty image + auto refseg = m_Data->tp_data[m_PParam.refTP].seg; + auto emptyImage = PTools::template CreateEmptyImage(refseg); + fltJoin->PushBackInput(emptyImage); + } + } + fltJoin->Update(); + m_Data->seg4d_out = fltJoin->GetOutput(); + + if (m_PParam.writeOutputToDisk) + { + std::ostringstream fnseg4d; + fnseg4d << m_PParam.outdir<< PTools::GetPathSeparator() + << "seg4d.nii.gz"; + PTools::template WriteImage(m_Data->seg4d_out, fnseg4d.str()); + } +} + +template +inline std::string +PropagationAPI +::GenerateUnaryTPObjectName(const char *base, unsigned int tp, + const char *debug_dir, const char *suffix, const char *file_ext) +{ + std::ostringstream oss; + if (debug_dir) + oss << debug_dir << PTools::GetPathSeparator(); + if (base) + oss << base; + + oss << setfill('0') << setw(2) << tp; + + if (suffix) + oss << suffix; + if (file_ext) + oss << file_ext; + + return oss.str(); +} + +template +inline std::string +PropagationAPI +::GenerateBinaryTPObjectName(const char *base, unsigned int tp1, unsigned int tp2, + const char *debug_dir, const char *suffix, const char *file_ext) +{ + std::ostringstream oss; + if (debug_dir) + oss << debug_dir << PTools::GetPathSeparator(); + if (base) + oss << base; + + oss << setfill('0') << setw(2) << tp1 << "_to_" << setfill('0') << setw(2) << tp2; + + if (suffix) + oss << suffix; + if (file_ext) + oss << file_ext; + + return oss.str(); +} + +template +inline std::string +PropagationAPI +::GenerateUnaryTPFileName(const char *pattern, unsigned int tp, + const char *output_dir, const char *file_ext) +{ + std::ostringstream oss; + if (output_dir) + oss << output_dir << PTools::GetPathSeparator(); + + if (strchr(pattern, '%') == NULL) + { + // % pattern not found, append tp after the pattern + oss << pattern << "_" << setfill('0') << setw(2) << tp; + if (file_ext) + oss << file_ext; + } + else + { + // use pattern specified by user + oss << PTools::ssprintf(pattern, tp); + } + + return oss.str(); +} + +template +std::shared_ptr> +PropagationAPI +::GetOutput() +{ + std::shared_ptr> ret = + std::make_shared>(); + + ret->Initialize(m_Data); + return ret; +} + +//================================================= +// PropagationStdOut Definition +//================================================= + +PropagationStdOut +::PropagationStdOut(PropagationParameters::Verbosity verbosity, FILE *f_out) +: m_Verbosity(verbosity), m_Output(f_out ? f_out : stdout) +{ + +} + +PropagationStdOut +::~PropagationStdOut() +{ + +} + +void +PropagationStdOut +::printf(const char *format, ...) +{ + if(m_Verbosity > PropagationParameters::VERB_NONE) + { + char buffer[4096]; + va_list args; + va_start (args, format); + vsnprintf (buffer, 4096, format, args); + va_end (args); + + fprintf(m_Output, "%s", buffer); + } +} + +void +PropagationStdOut +::print_verbose(const char *format, ...) +{ + if (m_Verbosity == PropagationParameters::VERB_VERBOSE) + { + va_list args; + va_start (args, format); + this->printf(format, args); + } +} + +void +PropagationStdOut +::flush() +{ + fflush(m_Output); +} + +namespace propagation +{ + template class PropagationAPI; + template class PropagationAPI; +} diff --git a/src/propagation/PropagationAPI.h b/src/propagation/PropagationAPI.h new file mode 100644 index 0000000..0511dd7 --- /dev/null +++ b/src/propagation/PropagationAPI.h @@ -0,0 +1,121 @@ +#ifndef PROPAGATIONAPI_H +#define PROPAGATIONAPI_H + +#include "lddmm_data.h" +#include "GreedyParameters.h" +#include "PropagationParameters.hxx" +#include "PropagationCommon.hxx" + +#include +#include +#include +#include +#include +#include + +namespace propagation +{ +template +class PropagationData; + +template +class PropagationInput; + +template +class PropagationOutput; + +template +class PropagationTools; + +class PropagationStdOut +{ +public: + + PropagationStdOut(PropagationParameters::Verbosity verbosity, FILE *f_out = NULL); + ~PropagationStdOut(); + + void printf(const char *format, ...); + void print_verbose(const char *format, ...); + void flush(); + +private: + PropagationParameters::Verbosity m_Verbosity; + FILE *m_Output; +}; + +template +class PropagationAPI +{ +public: + using TImage4D = itk::Image; + using TImage3D = itk::Image; + using TLabelImage4D = itk::Image; + using TLabelImage3D = itk::Image; + using TLDDMM3D = LDDMMData; + using TVectorImage3D = typename TLDDMM3D::VectorImageType; + using TCompositeImage3D = typename TLDDMM3D::CompositeImageType; + using TTransform = itk::MatrixOffsetTransformBase; + using TPropagationMesh = vtkPolyData; + using TPropagationMeshPointer = vtkSmartPointer; + using PTools = PropagationTools; + + enum ResampleInterpolationMode { Linear=0, NearestNeighbor }; + + PropagationAPI() = delete; + + /** Specialized constructor for api run */ + PropagationAPI(const std::shared_ptr> input); + + ~PropagationAPI(); + PropagationAPI(const PropagationAPI &other) = delete; + PropagationAPI &operator=(const PropagationAPI &other) = delete; + /** Start the execution of the propagation pipeline */ + int Run(); + + /** Build an output obejct for API run */ + std::shared_ptr> GetOutput(); + +private: + void ValidateInputData(); + void PrepareTimePointData(); + void ValidateInputOrientation(); + void CreateReferenceMask(); + void CreateTimePointLists(); + void Generate4DSegmentation(); + + void RunUnidirectionalPropagation(const std::vector &tp_list); + void RunDownSampledPropagation(const std::vector &tp_list); + void GenerateFullResolutionMasks(const std::vector &tp_list); + void GenerateReferenceSpace(const std::vector &tp_list); + void RunFullResolutionPropagation(const unsigned int target_tp); + + void RunPropagationAffine(unsigned int tp_fix, unsigned int tp_mov); + void RunPropagationDeformable(unsigned int tp_fix, unsigned int tp_mov, bool isFullRes); + void RunPropagationReslice(unsigned int tp_in, unsigned int tp_out, bool isFullRes); + void RunPropagationMeshReslice(unsigned int tp_in, unsigned int tp_out); + void BuildTransformChainForReslice(unsigned int tp_prev, unsigned int tp_crnt); + + static inline std::string + GenerateUnaryTPObjectName(const char *base, unsigned int tp, + const char *debug_dir = nullptr, const char *suffix = nullptr, + const char *file_ext = nullptr); + + static inline std::string + GenerateBinaryTPObjectName(const char *base, unsigned int tp1, unsigned int tp2, + const char *debug_dir = nullptr, const char *suffix = nullptr, + const char *file_ext = nullptr); + + static inline std::string + GenerateUnaryTPFileName(const char *pattern, unsigned int tp, + const char *output_dir = nullptr, const char *file_ext = nullptr); + + std::shared_ptr> m_Data; + GreedyParameters m_GParam; + PropagationParameters m_PParam; + std::vector m_ForwardTPs; + std::vector m_BackwardTPs; + std::shared_ptr m_StdOut; +}; + +} +#endif // PROPAGATIONAPI_H diff --git a/src/propagation/PropagationCommon.hxx b/src/propagation/PropagationCommon.hxx new file mode 100644 index 0000000..2723888 --- /dev/null +++ b/src/propagation/PropagationCommon.hxx @@ -0,0 +1,22 @@ +#ifndef PROPAGATIONCOMMON_HXX +#define PROPAGATIONCOMMON_HXX + +namespace propagation +{ + +#define PROPAGATION_DATA_TYPEDEFS \ +using TImage4D = typename PropagationAPI::TImage4D; \ +using TImage3D = typename PropagationAPI::TImage3D; \ +using TLabelImage4D = typename PropagationAPI::TLabelImage4D; \ +using TLabelImage3D = typename PropagationAPI::TLabelImage3D; \ +using TLDDMM3D = typename PropagationAPI::TLDDMM3D; \ +using TVectorImage3D = typename PropagationAPI::TVectorImage3D; \ +using TCompositeImage3D = typename PropagationAPI::TCompositeImage3D; \ +using TTransform = typename PropagationAPI::TTransform; \ +using TPropagationMesh = typename PropagationAPI::TPropagationMesh; \ +using TPropagationMeshPointer = typename PropagationAPI::TPropagationMeshPointer; \ +using ResampleInterpolationMode = typename PropagationAPI::ResampleInterpolationMode; + +} + +#endif // PROPAGATIONCOMMON_HXX diff --git a/src/propagation/PropagationData.hxx b/src/propagation/PropagationData.hxx new file mode 100644 index 0000000..d4a2e2b --- /dev/null +++ b/src/propagation/PropagationData.hxx @@ -0,0 +1,148 @@ +#ifndef PROPAGATIONDATA_HXX +#define PROPAGATIONDATA_HXX + +#include +#include +#include +#include +#include +#include +#include "lddmm_data.h" +#include "PropagationAPI.h" +#include "PropagationCommon.hxx" + +namespace propagation +{ + +template +class PropagationMeshGroup +{ + PROPAGATION_DATA_TYPEDEFS +}; + +template +class TimePointTransformSpec +{ +public: + PROPAGATION_DATA_TYPEDEFS + TimePointTransformSpec(typename TTransform::Pointer _affine, + typename TVectorImage3D::Pointer _deform, unsigned int _crntTP) + : affine(_affine), deform (_deform), currentTP(_crntTP) {} + + TimePointTransformSpec(const TimePointTransformSpec &other) = default; + TimePointTransformSpec & operator=(const TimePointTransformSpec &other) = default; + + unsigned int currentTP; + typename TTransform::Pointer affine; + typename TVectorImage3D::Pointer deform; +}; + +template +class TimePointData +{ +public: + PROPAGATION_DATA_TYPEDEFS + TimePointData(); + ~TimePointData(); + TimePointData(const TimePointData &other) = default; + TimePointData &operator=(const TimePointData &other) = default; + + typename TImage3D::Pointer img; + typename TImage3D::Pointer img_srs; + typename TLabelImage3D::Pointer seg; + typename TLabelImage3D::Pointer seg_srs; + typename TLabelImage3D::Pointer full_res_mask; + typename TTransform::Pointer affine_to_prev; + typename TVectorImage3D::Pointer deform_to_prev; + typename TVectorImage3D::Pointer deform_to_ref; + typename TVectorImage3D::Pointer deform_from_prev; + typename TVectorImage3D::Pointer deform_from_ref; + TPropagationMeshPointer seg_mesh; // mesh warped from reference tp + + std::vector> transform_specs; + std::vector> full_res_label_trans_specs; + + void AddExtraMesh(std::string tag, TPropagationMeshPointer mesh) + { + m_ExtraMeshes.insert({tag, mesh}); + } + + TPropagationMeshPointer GetExtraMesh(std::string &tag) + { + TPropagationMeshPointer ret = nullptr; + if (m_ExtraMeshes.count(tag)) + ret = m_ExtraMeshes[tag]; + + return ret; + } + + std::vector GetExtraMeshTags() + { + std::vector ret; + + for (auto kv : m_ExtraMeshes) + ret.push_back(kv.first); + + return ret; + } + + size_t GetExtraMeshSize() { return m_ExtraMeshes.size(); } + +protected: + // warped extra meshes, indexed by tags + std::map m_ExtraMeshes; +}; + +template +class PropagationData +{ +public: + PROPAGATION_DATA_TYPEDEFS + PropagationData(); + + size_t GetNumberOfTimePoints(); + + std::map> tp_data; + typename TImage4D::Pointer img4d; + typename TLabelImage3D::Pointer seg_ref; + typename TLabelImage4D::Pointer seg4d_in; + typename TLabelImage4D::Pointer seg4d_out; + std::string outdir; + typename TImage3D::Pointer full_res_ref_space; + + // extra meshes for warping, indexed by tags + std::map extra_mesh_cache; +}; + +template +PropagationData +::PropagationData() +{ + +} + +template +size_t +PropagationData +::GetNumberOfTimePoints() +{ + return this->img4d->GetBufferedRegion().GetSize()[3]; +} + +template +TimePointData +::TimePointData() +{ + affine_to_prev = TTransform::New(); +} + +template +TimePointData +::~TimePointData() +{ + +} + +} // end of namespace propagation + +#endif // PROPAGATIONDATA_HXX diff --git a/src/propagation/PropagationIO.cxx b/src/propagation/PropagationIO.cxx new file mode 100644 index 0000000..2b52be9 --- /dev/null +++ b/src/propagation/PropagationIO.cxx @@ -0,0 +1,266 @@ +#include "PropagationIO.h" +#include "PropagationTools.h" +#include "PropagationCommon.hxx" + + +using namespace propagation; + +//================================================== +// PropagationInput Definitions +//================================================== + +template +PropagationInput +::PropagationInput() +{ + m_Data = std::make_shared>(); +} + +template +PropagationInput +::~PropagationInput() +{ + +} + +template +void +PropagationInput +::SetDefaultGreedyParameters() +{ + +} + +template +void +PropagationInput +::SetDefaultPropagationParameters() +{ + if (m_PropagationParam.fnsegout_pattern.size() == 0) + { + std::cout << "-- [Propagation] segmentation output filename pattern (-sps-op) has not been set. " + "Setting to default value \"segmentation_%02d_resliced.nii.gz\"" << std::endl; + m_PropagationParam.fnsegout_pattern = "segmentation_%02d_resliced.nii.gz"; + } + + if (m_PropagationParam.fnmeshout_pattern.size() == 0) + { + std::cout << "-- [Propagation] segmentation mesh output filename pattern (-sps-mop) has not been set. " + "Setting to default value \"segmentation_mesh_%02d_resliced.vtk\"" << std::endl; + m_PropagationParam.fnmeshout_pattern = "segmentation_mesh_%02d_resliced.vtk"; + } +} + +template +void +PropagationInput +::SetGreedyParameters(const GreedyParameters &gParam) +{ + this->m_GreedyParam = gParam; + + // dim = 3 is mandatory and is not configured by users + if (m_GreedyParam.dim != 3) + m_GreedyParam.dim = 3; + + SetDefaultGreedyParameters(); + ValidateGreedyParameters(); +} + +template +void +PropagationInput +::SetPropagationParameters(const PropagationParameters &pParam) +{ + this->m_PropagationParam = pParam; + SetDefaultPropagationParameters(); + ValidatePropagationParameters(); +} + +template +void +PropagationInput +::ValidateGreedyParameters() const +{ + +} + +template +void +PropagationInput +::ValidatePropagationParameters() const +{ + if (m_PropagationParam.writeOutputToDisk && m_PropagationParam.outdir.size() == 0) + throw GreedyException("Output directory (-spo) not provided!"); +} + +template +void +PropagationInput +::SetDataForAPIRun(std::shared_ptr> data) +{ + m_Data = data; +} + + +//================================================== +// PropagationOutput Definitions +//================================================== + +template +PropagationOutput +::PropagationOutput() +{ + m_Data = nullptr; +} + +template +PropagationOutput +::~PropagationOutput() +{ + +} + +template +void +PropagationOutput +::Initialize(std::shared_ptr> data) +{ + m_Data = data; +} + +template +typename PropagationOutput::TImage3D::Pointer +PropagationOutput +::GetImage3D(unsigned int tp) +{ + return m_Data->tp_data.at(tp).img; +} + +template +typename PropagationOutput::TLabelImage3D::Pointer +PropagationOutput +::GetSegmentation3D(unsigned int tp) +{ + return m_Data->tp_data.at(tp).seg; +} + +template +typename PropagationOutput +::TSegmentation3DSeries +PropagationOutput +::GetSegmentation3DSeries() +{ + TSegmentation3DSeries ret; + for (auto kv : m_Data->tp_data) + { + ret[kv.first] = kv.second.seg; + } + + return ret; +} + +template +typename PropagationOutput::TLabelImage4D::Pointer +PropagationOutput +::GetSegmentation4D() +{ + return m_Data->seg4d_out; +} + +template +typename PropagationOutput +::TMeshSeries +PropagationOutput +::GetMeshSeries() +{ + TMeshSeries ret; + for (auto &kv : m_Data->tp_data) + { + ret[kv.first] = kv.second.seg_mesh; + } + + return ret; +} + +template +typename PropagationOutput +::TPropagationMeshPointer +PropagationOutput +::GetMesh(unsigned int tp) +{ + return m_Data->tp_data.at(tp).seg_mesh; +} + +template +typename PropagationOutput +::TMeshSeries +PropagationOutput +::GetExtraMeshSeries(std::string tag) +{ + TMeshSeries ret; + for (auto &kv : m_Data->tp_data) + { + ret[kv.first] = kv.second.GetExtraMesh(tag); + } + + return ret; +} + +template +typename PropagationOutput +::TPropagationMeshPointer +PropagationOutput +::GetExtraMesh(std::string tag, unsigned int tp) +{ + return m_Data->tp_data.at(tp).GetExtraMesh(tag); +} + +template +typename PropagationOutput +::TMeshSeriesMap +PropagationOutput +::GetAllExtraMeshSeries() +{ + TMeshSeriesMap ret; + auto firstTP = m_Data->tp_data.begin()->second; + + // populate return map + for (auto tag : firstTP.GetExtraMeshTags()) + { + ret[tag] = this->GetExtraMeshSeries(tag); + } + + return ret; +} + +template +std::vector +PropagationOutput +::GetTimePointList() +{ + std::vector ret; + for (auto kv : m_Data->tp_data) + { + ret.push_back(kv.first); + } + + return ret; +} + +template +size_t +PropagationOutput +::GetNumberOfTimePoints() +{ + return m_Data->tp_data.size(); +} + + + +namespace propagation +{ + template class PropagationInput; + template class PropagationInput; + template class PropagationOutput; + template class PropagationOutput; +} diff --git a/src/propagation/PropagationIO.h b/src/propagation/PropagationIO.h new file mode 100644 index 0000000..b2fcf8a --- /dev/null +++ b/src/propagation/PropagationIO.h @@ -0,0 +1,107 @@ +#ifndef PROPAGATIONIO_H +#define PROPAGATIONIO_H + +#include +#include +#include "PropagationCommon.hxx" +#include "PropagationAPI.h" +#include "PropagationData.hxx" +#include "PropagationParameters.hxx" +#include "GreedyParameters.h" + + +namespace propagation +{ + +template +class PropagationInputBuilder; + +template +class PropagationInput +{ +public: + PropagationInput(); + ~PropagationInput(); + PropagationInput(const PropagationInput &other) = default; + PropagationInput &operator=(const PropagationInput &other) = default; + + /** Set Greedy Parameters for the input. The setter will also validate input + * and populate default values to fields where value are not provided by the user */ + void SetGreedyParameters(const GreedyParameters &gParam); + const GreedyParameters &GetConstGreedyParameters() const + { return m_GreedyParam; } + + /** Set Propagation Parameters for the input. The setter will also validate input + * and populate default values to fields where value are not provided by the user */ + void SetPropagationParameters(const PropagationParameters &pParam); + const PropagationParameters &GetConstPropagationParameters() const + { return m_PropagationParam; } + + void SetOutputDirectory(std::string outdir) { m_Data->outdir = outdir; } + + friend class PropagationInputBuilder; + friend class PropagationAPI; + +private: + void SetDefaultGreedyParameters(); + void SetDefaultPropagationParameters(); + + void ValidateGreedyParameters() const; + void ValidatePropagationParameters() const; + + /** For API Runs, image data is directly passed into the Input using the + * PropagationInputBuilder*/ + void SetDataForAPIRun(std::shared_ptr> data); + + std::shared_ptr> m_Data; + GreedyParameters m_GreedyParam; + PropagationParameters m_PropagationParam; +}; + +template +class PropagationOutput +{ +public: + PROPAGATION_DATA_TYPEDEFS + + PropagationOutput(); + ~PropagationOutput(); + PropagationOutput(const PropagationOutput &other) = default; + PropagationOutput &operator=(const PropagationOutput &other) = default; + + void Initialize(std::shared_ptr> data); + bool IsInitialized() const + { return m_Data != nullptr; } + + typedef std::map TMeshSeries; + typedef std::map TMeshSeriesMap; + typedef std::map TSegmentation3DSeries; + + /** Image Getter */ + typename TImage3D::Pointer GetImage3D(unsigned int tp); + + /** Segmentation Getters */ + typename TLabelImage4D::Pointer GetSegmentation4D(); + typename TLabelImage3D::Pointer GetSegmentation3D(unsigned int tp); + TSegmentation3DSeries GetSegmentation3DSeries(); + + /** Segmentation Mesh Getters*/ + TMeshSeries GetMeshSeries(); + TPropagationMeshPointer GetMesh(unsigned int tp); + + /** Extra Mesh Getters */ + TMeshSeries GetExtraMeshSeries(std::string tag); + TPropagationMeshPointer GetExtraMesh(std::string tag, unsigned int tp); + TMeshSeriesMap GetAllExtraMeshSeries(); + + size_t GetNumberOfTimePoints(); + std::vector GetTimePointList(); + +private: + std::shared_ptr> m_Data; +}; + +} + + +#endif // PROPAGATIONIO_H diff --git a/src/propagation/PropagationInputBuilder.cxx b/src/propagation/PropagationInputBuilder.cxx new file mode 100644 index 0000000..55835c6 --- /dev/null +++ b/src/propagation/PropagationInputBuilder.cxx @@ -0,0 +1,308 @@ +#include "PropagationInputBuilder.h" +#include "GreedyException.h" +#include "GreedyMeshIO.h" +#include "PropagationTools.h" + +using namespace propagation; + +template +PropagationInputBuilder +::PropagationInputBuilder() +{ + m_Data = std::make_shared>(); + m_PParam.writeOutputToDisk = false; +} + +template +PropagationInputBuilder +::~PropagationInputBuilder() +{ + +} + +template +void +PropagationInputBuilder +::Reset() +{ + m_Data = nullptr; + m_Data = std::make_shared>(); + m_GParam = GreedyParameters(); + m_PParam = PropagationParameters(); +} + + +template +void +PropagationInputBuilder +::SetImage4D(TImage4D *img4d) +{ + m_Data->img4d = img4d; +} + +template +void +PropagationInputBuilder +::SetReferenceSegmentationIn3D(TLabelImage3D *seg3d) +{ + m_Data->seg_ref = seg3d; +} + +template +void +PropagationInputBuilder +::SetReferenceSegmentationIn4D(TLabelImage4D *seg4d) +{ + m_Data->seg4d_in = seg4d; +} + +template +void +PropagationInputBuilder +::SetReferenceTimePoint(unsigned int refTP) +{ + m_PParam.refTP = refTP; +} + +template +void +PropagationInputBuilder +::SetTargetTimePoints(const std::vector &targetTPs) +{ + m_PParam.targetTPs = targetTPs; +} + +template +void +PropagationInputBuilder +::SetResliceMetric(const InterpSpec metric) +{ + m_PParam.reslice_spec = metric; +} + +template +void +PropagationInputBuilder +::SetResliceMetricToNearestNeighbor() +{ + m_PParam.reslice_spec = InterpSpec(InterpSpec::NEAREST); +} + +template +void +PropagationInputBuilder +::SetResliceMetricToLinear() +{ + m_PParam.reslice_spec = InterpSpec(InterpSpec::LINEAR); +} + +template +void +PropagationInputBuilder +::SetResliceMetricToLabel(double sigma, bool is_physical_unit) +{ + m_PParam.reslice_spec = InterpSpec(InterpSpec::LABELWISE, sigma, is_physical_unit); +} + +template +InterpSpec +PropagationInputBuilder +::GetResliceMetric() const +{ + return m_PParam.reslice_spec; +} + +template +void +PropagationInputBuilder +::SetDebugOn(std::string &debug_dir) +{ + m_PParam.debug = true; + m_PParam.debug_dir = debug_dir; +} + +template +void +PropagationInputBuilder +::SetRegistrationMetric(GreedyParameters::MetricType metric, std::vector metric_radius) +{ + m_GParam.metric = metric; + if (metric == GreedyParameters::NCC || metric == GreedyParameters::WNCC) + { + if (metric_radius.size() == 0) + throw GreedyException("PropagationInputBuilder::SetRegistrationMetric: metric_radius is required but not givien"); + + m_GParam.metric_radius = metric_radius; + } +} + +template +GreedyParameters::MetricType +PropagationInputBuilder +::GetRegistrationMetric() const +{ + return m_GParam.metric; +} + +template +std::vector +PropagationInputBuilder +::GetRegistrationMetricRadius() const +{ + return m_GParam.metric_radius; +} + +template +void +PropagationInputBuilder +::SetMultiResolutionSchedule(std::vector iter_per_level) +{ + m_GParam.iter_per_level = iter_per_level; +} + +template +std::vector +PropagationInputBuilder +::GetMultiResolutionSchedule() const +{ + return m_GParam.iter_per_level; +} + +template +void +PropagationInputBuilder +::SetAffineDOF(GreedyParameters::AffineDOF dof) +{ + m_GParam.affine_dof = dof; +} + +template +GreedyParameters::AffineDOF +PropagationInputBuilder +::GetAffineDOF() const +{ + return m_GParam.affine_dof; +} + +template +void +PropagationInputBuilder +::AddExtraMeshToWarp(std::string fnmesh, std::string outpattern) +{ + MeshSpec meshspec; + meshspec.fn_mesh = fnmesh; + meshspec.fnout_pattern = outpattern; + m_PParam.extra_mesh_list.push_back(meshspec); +} + +template +void +PropagationInputBuilder +::AddExtraMeshToWarp(TPropagationMeshPointer mesh, std::string tag) +{ + MeshSpec meshspec; + meshspec.fn_mesh = tag; // useless + meshspec.fnout_pattern = tag; + meshspec.cached = true; + m_PParam.extra_mesh_list.push_back(meshspec); + m_Data->extra_mesh_cache[tag] = mesh; +} + +template +std::vector +PropagationInputBuilder +::GetExtraMeshesToWarp() const +{ + return m_PParam.extra_mesh_list; +} + +template +void +PropagationInputBuilder +::SetGreedyVerbosity(GreedyParameters::Verbosity v) +{ + m_GParam.verbosity = v; +} + +template +GreedyParameters::Verbosity +PropagationInputBuilder +::GetGreedyVerbosity() const +{ + return m_GParam.verbosity; +} + +template +void +PropagationInputBuilder +::SetPropagationVerbosity(PropagationParameters::Verbosity v) +{ + m_PParam.verbosity = v; +} + + +template +PropagationParameters::Verbosity +PropagationInputBuilder +::GetPropagationVerbosity() const +{ + return m_PParam.verbosity; +} + +template +void +PropagationInputBuilder +::ConfigForCLI(const PropagationParameters &pParam, const GreedyParameters &gParam) +{ + // Copy the parameters, this will cover all settings not related to data reading + this->SetPropagationParameters(pParam); + this->SetGreedyParameters(gParam); + + // Read the data + this->SetImage4D(PropagationTools::template ReadImage(pParam.fn_img4d)); + + // Read Segmentation Image from the parameter + if (pParam.use4DSegInput) + { + this->SetReferenceSegmentationIn4D(PropagationTools + ::template ReadImage(pParam.fn_seg4d)); + } + else + { + this->SetReferenceSegmentationIn3D(PropagationTools + ::template ReadImage(pParam.fn_seg3d)); + } + + // Read extra meshes + for (auto &meshspec : pParam.extra_mesh_list) + { + auto meshData = ReadMesh(meshspec.fn_mesh.c_str()); + auto *pd = dynamic_cast(meshData.GetPointer()); + if (!pd) + throw GreedyException("Propagation: Extra mesh %s is not a vtkPolyData", meshspec.fn_mesh.c_str()); + + this->AddExtraMeshToWarp(pd, meshspec.fnout_pattern); + } +} + +template +std::shared_ptr> +PropagationInputBuilder +::BuildPropagationInput() +{ + std::shared_ptr> pInput = std::make_shared>(); + pInput->SetGreedyParameters(m_GParam); + pInput->SetPropagationParameters(m_PParam); + pInput->SetDataForAPIRun(m_Data); + + if (m_PParam.outdir.size() > 0) + pInput->SetOutputDirectory(m_PParam.outdir); + + return pInput; +} + +namespace propagation +{ + template class PropagationInputBuilder; + template class PropagationInputBuilder; +} \ No newline at end of file diff --git a/src/propagation/PropagationInputBuilder.h b/src/propagation/PropagationInputBuilder.h new file mode 100644 index 0000000..0aae8f8 --- /dev/null +++ b/src/propagation/PropagationInputBuilder.h @@ -0,0 +1,110 @@ + +#include +#include +#include "PropagationCommon.hxx" +#include "PropagationIO.h" + +namespace propagation +{ + +template +class PropagationInputBuilder +{ +public: + PROPAGATION_DATA_TYPEDEFS + + PropagationInputBuilder(); + ~PropagationInputBuilder(); + PropagationInputBuilder(const PropagationInputBuilder &other) = delete; + PropagationInputBuilder &operator=(const PropagationInputBuilder &other) = delete; + + void Reset(); + + //-------------------------------------------------- + // Propagation Parameters Configuration + + /** Set Reference 4D Image */ + void SetImage4D(TImage4D *img4d); + + /** Set 3D Segmentation Image for the reference time point.*/ + void SetReferenceSegmentationIn3D(TLabelImage3D *seg3d); + + /** Set 4D Segmentation Image with reference time segmentation. This will override + * 3D segmentation image input */ + void SetReferenceSegmentationIn4D(TLabelImage4D *seg4d); + + /** Set reference time point */ + void SetReferenceTimePoint(unsigned int refTP); + unsigned int GetReferenceTimePoint() const { return m_PParam.refTP; }; + + /** Set target time point list. Reference time point is ignored */ + void SetTargetTimePoints(const std::vector &targetTPs); + + /** Set Reslice Metric */ + void SetResliceMetric(const InterpSpec metric); + void SetResliceMetricToLinear(); + void SetResliceMetricToNearestNeighbor(); + void SetResliceMetricToLabel(double sigma, bool is_physical_unit); + InterpSpec GetResliceMetric() const; + + /** Add Extra Mesh to Warp */ + void AddExtraMeshToWarp(std::string fnmesh, std::string outpattern); + void AddExtraMeshToWarp(TPropagationMeshPointer mesh, std::string tag); + std::vector GetExtraMeshesToWarp() const; + + /** Set Propagation Verbosity */ + void SetPropagationVerbosity(PropagationParameters::Verbosity v); + PropagationParameters::Verbosity GetPropagationVerbosity() const; + + /** Turn on the debug mode (-sp-debug) for propagation. A debugging output directory is needed for + * dumping out intermediary files */ + void SetDebugOn(std::string &debug_dir); + + /** Turn off the debug mode for propagation */ + void SetDebugOff(); + + //-------------------------------------------------- + // Greedy Parameters Configuration + + /** Set the metric (-m) for greedy run */ + void SetRegistrationMetric(GreedyParameters::MetricType metric, std::vector metric_radius = std::vector()); + GreedyParameters::MetricType GetRegistrationMetric() const; + std::vector GetRegistrationMetricRadius() const; + + /** Set multi-resolution schedule (-n) default is 100x100 */ + void SetMultiResolutionSchedule(std::vector iter_per_level); + std::vector GetMultiResolutionSchedule() const; + + /** Set Affine degree of freedom */ + void SetAffineDOF(GreedyParameters::AffineDOF dof); + GreedyParameters::AffineDOF GetAffineDOF() const; + + /** Set Verbose Level */ + void SetGreedyVerbosity(GreedyParameters::Verbosity v); + GreedyParameters::Verbosity GetGreedyVerbosity() const; + + + void SetGreedyParameters(const GreedyParameters &gParam) { m_GParam = gParam; } + const GreedyParameters &GetConstGreedyParameters() const + { return m_GParam; } + + void SetPropagationParameters(const PropagationParameters &pParam) { m_PParam = pParam; } + const PropagationParameters &GetConstPropagationParameters() const + { return m_PParam; } + + // Parse the parameters, read data, configure this builder and build an input + void ConfigForCLI(const PropagationParameters &pParam, const GreedyParameters &gParam); + + // Actual method to build the input + std::shared_ptr> BuildPropagationInput(); + + +private: + //std::shared_ptr> m_Input; + GreedyParameters m_GParam; + PropagationParameters m_PParam; + std::shared_ptr> m_Data; +}; + +} // namespace propagation + diff --git a/src/propagation/PropagationParameters.hxx b/src/propagation/PropagationParameters.hxx new file mode 100644 index 0000000..38f7424 --- /dev/null +++ b/src/propagation/PropagationParameters.hxx @@ -0,0 +1,47 @@ +#ifndef PROPAGATIONPARAMETERS_H +#define PROPAGATIONPARAMETERS_H + +#include +#include +#include +#include "GreedyParameters.h" + +namespace propagation +{ + + +struct MeshSpec +{ + std::string fn_mesh; + std::string fnout_pattern; + bool cached = false; // whether input mesh is in the cache +}; + +// Parameters for the segmentation propagation +struct PropagationParameters +{ + enum Verbosity { VERB_NONE=0, VERB_DEFAULT, VERB_VERBOSE, VERB_INVALID }; + + std::string fn_img4d; + std::string fn_seg3d; + std::string fn_seg4d; + std::vector extra_mesh_list; + std::string fnsegout_pattern; + std::string fnmeshout_pattern; + std::string outdir; + unsigned int refTP; + std::vector targetTPs; + + // always use labelwise interpolation for segmentation + InterpSpec reslice_spec = InterpSpec(InterpSpec::LABELWISE, 0.2, false, 0.0); + + bool debug = false; + std::string debug_dir; + bool writeOutputToDisk = true; // whether to write final output data to disk + bool use4DSegInput = false; // whether to use 4d segmentation input as reference + Verbosity verbosity; +}; + +} // End of namespace propagation + +#endif // PROPAGATIONPARAMETERS_H diff --git a/src/propagation/PropagationTools.h b/src/propagation/PropagationTools.h new file mode 100644 index 0000000..c9414af --- /dev/null +++ b/src/propagation/PropagationTools.h @@ -0,0 +1,208 @@ +#ifndef PROPAGATIONTOOLS_H +#define PROPAGATIONTOOLS_H + +#include "PropagationCommon.hxx" +#include "PropagationAPI.h" +#include +#include +#include +#include +#include +#include +#include + +namespace propagation +{ + +template +class PropagationTools +{ +public: + PROPAGATION_DATA_TYPEDEFS + PropagationTools(); + ~PropagationTools(); + + static typename TImage3D::Pointer CastLabelToRealImage(TLabelImage3D *input); + + static typename TLabelImage3D::Pointer + ResliceLabelImageWithIdentityMatrix(TImage3D *ref, TLabelImage3D *src); + + static TPropagationMeshPointer GetMeshFromLabelImage(TLabelImage3D *img); + + static typename TLabelImage3D::Pointer + TrimLabelImage(TLabelImage3D *input, double vox, typename TLabelImage3D::RegionType &roi); + + static void ExpandRegion(itk::ImageRegion<3> ®ion, const itk::Index<3> &idx); + + static void ConnectITKToVTK(itk::VTKImageExport *fltExport,vtkImageImport *fltImport); + + /** + * This static function constructs a NIFTI matrix from the ITK direction + * cosines matrix and Spacing and Origin vectors + */ + static vnl_matrix_fixed + ConstructNiftiSform(vnl_matrix m_dir, vnl_vector v_origin, + vnl_vector v_spacing); + + static vnl_matrix_fixed + ConstructVTKtoNiftiTransform(vnl_matrix m_dir, vnl_vector v_origin, + vnl_vector v_spacing); + + template + static itk::SmartPointer ReadImage(const std::string &filename); + + template + static void WriteImage(TImage *img, const std::string &filename, + itk::IOComponentEnum comp = itk::IOComponentEnum::UNKNOWNCOMPONENTTYPE); + + template + static typename TTimepointImage::Pointer ExtractTimePointImage(TFullImage *full_img, unsigned int tp); + + template + static itk::SmartPointer + Resample3DImage(TImage *input, double factor, ResampleInterpolationMode intpMode, double smooth_sigma = 0); + + template + static typename TOutputImage::Pointer + ThresholdImage(TInputImage *img, typename TInputImage::PixelType lower, typename TInputImage::PixelType upper, + typename TOutputImage::PixelType value_in, typename TOutputImage::PixelType value_out); + + template + static typename TOutputImage::Pointer + DilateImage(TInputImage *img, size_t radius, typename TInputImage::PixelType value); + + template + static typename TCompositeImage3D::Pointer + CastToCompositeImage(TInputImage *img); + + static typename TCompositeImage3D::Pointer CastImageToCompositeImage(TImage3D *img); + + template + static typename TImage::Pointer CreateEmptyImage(TImage *sample); + + inline static char GetPathSeparator() + { + #ifdef _WIN32 + return '\\'; + #else + return '/'; + #endif + } + + inline static std::string ssprintf(const char *format, ...) + { + if(format && strlen(format)) + { + char buffer[4096]; + va_list args; + va_start (args, format); + vsnprintf(buffer, 4096, format, args); + va_end (args); + return std::string(buffer); + } + else + return std::string(); + } + +}; + +template +class UnaryFunctorImageToSingleComponentVectorImageFilter + : public itk::ImageToImageFilter +{ +public: + typedef UnaryFunctorImageToSingleComponentVectorImageFilter Self; + typedef itk::ImageToImageFilter Superclass; + typedef itk::SmartPointer Pointer; + typedef itk::SmartPointer< const Self > ConstPointer; + + typedef TInputImage InputImageType; + typedef TOutputImage OutputImageType; + typedef TFunctor FunctorType; + + typedef typename Superclass::OutputImageRegionType OutputImageRegionType; + + /** Run-time type information (and related methods). */ + itkTypeMacro(UnaryFunctorImageToSingleComponentVectorImageFilter, ImageToImageFilter) + itkNewMacro(Self) + + /** ImageDimension constants */ + itkStaticConstMacro(InputImageDimension, unsigned int, TInputImage::ImageDimension); + itkStaticConstMacro(OutputImageDimension, unsigned int, TOutputImage::ImageDimension); + + void SetFunctor(const FunctorType &functor) + { + if(m_Functor != functor) + { + m_Functor = functor; + this->Modified(); + } + } + + itkGetConstReferenceMacro(Functor, FunctorType) + + void DynamicThreadedGenerateData(const OutputImageRegionType & outputRegionForThread) override; + + +protected: + + UnaryFunctorImageToSingleComponentVectorImageFilter() {} + virtual ~UnaryFunctorImageToSingleComponentVectorImageFilter() {} + + FunctorType m_Functor; + +}; + +template +class LinearIntensityMapping +{ +public: + typedef LinearIntensityMapping Self; + + double operator() (TReal g) const { return MapInternalToNative(g); } + + double MapInternalToNative(TReal internal) const + { return internal * scale + shift; } + + double MapNativeToInternal(TReal native) const + { return (native - shift) / scale; } + + LinearIntensityMapping() : scale(1.0), shift(0.0) {} + LinearIntensityMapping(TReal a, TReal b) : scale(a), shift(b) {} + + bool operator != (const Self &other) const + { return scale != other.scale || shift != other.shift; } + +protected: + TReal scale; + TReal shift; +}; + +template +class IdentityIntensityMapping +{ +public: + TReal operator()(TReal g) const { return g; } + + TReal MapGradientMagnitudeToNative(TReal internalGM) const + { return internalGM; } + + TReal MapInternalToNative(TReal internal) const + { return internal; } + + TReal MapNativeToInternal(TReal native) const { return native; } + + virtual TReal GetScale() const { return 1; } + virtual TReal GetShift() const { return 0; } + + bool IsIdentity() const { return true; } + + bool operator!=(const IdentityIntensityMapping &) const + { return false; } +}; + +} // end of namespace propagation + +#include "PropagationTools.txx" + +#endif // PROPAGATIONTOOLS_H diff --git a/src/propagation/PropagationTools.txx b/src/propagation/PropagationTools.txx new file mode 100644 index 0000000..a138f5c --- /dev/null +++ b/src/propagation/PropagationTools.txx @@ -0,0 +1,617 @@ +#include "PropagationTools.h" +#include "GreedyException.h" +#include "ImageRegionConstIteratorWithIndexOverride.h" +#include "GreedyMeshIO.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace propagation +{ + +template +PropagationTools +::PropagationTools() +{ + +} + +template +PropagationTools +::~PropagationTools() +{ + +} + +template +typename PropagationTools::TImage3D::Pointer +PropagationTools +::CastLabelToRealImage(TLabelImage3D *input) +{ + auto output = TImage3D::New(); + output->SetRegions(input->GetLargestPossibleRegion()); + output->SetDirection(input->GetDirection()); + output->SetOrigin(input->GetOrigin()); + output->SetSpacing(input->GetSpacing()); + output->Allocate(); + + itk::ImageRegionIterator it_input( + input, input->GetLargestPossibleRegion()); + itk::ImageRegionIterator it_output( + output, output->GetLargestPossibleRegion()); + + // Deep copy pixels + while (!it_input.IsAtEnd()) + { + it_output.Set(it_input.Get()); + ++it_output; + ++it_input; + } + + return output; +} + +template +typename PropagationTools::TLabelImage3D::Pointer +PropagationTools +::ResliceLabelImageWithIdentityMatrix(TImage3D *ref, TLabelImage3D *src) +{ + // Code adapted from c3d command -reslice-identity + typedef itk::AffineTransform TranType; + typename TranType::Pointer atran = TranType::New(); + atran->SetIdentity(); + + // Build the resampling filter + typedef itk::ResampleImageFilter ResampleFilterType; + typename ResampleFilterType::Pointer fltSample = ResampleFilterType::New(); + + fltSample->SetTransform(atran); + + // Initialize the resampling filter with an identity transform + fltSample->SetInput(src); + + // Set the unknown intensity to positive value + fltSample->SetDefaultPixelValue(0); + + // Set the interpolator + typedef itk::NearestNeighborInterpolateImageFunction NNInterpolator; + fltSample->SetInterpolator(NNInterpolator::New()); + + // Calculate where the transform is taking things + itk::ContinuousIndex idx[3]; + for(size_t i = 0; i < 3; i++) + { + idx[0][i] = 0.0; + idx[1][i] = ref->GetBufferedRegion().GetSize(i) / 2.0; + idx[2][i] = ref->GetBufferedRegion().GetSize(i) - 1.0; + } + for(size_t j = 0; j < 3; j++) + { + itk::ContinuousIndex idxmov; + itk::Point pref, pmov; + ref->TransformContinuousIndexToPhysicalPoint(idx[j], pref); + pmov = atran->TransformPoint(pref); + src->TransformPhysicalPointToContinuousIndex(pmov, idxmov); + } + + // Set the spacing, origin, direction of the output + fltSample->UseReferenceImageOn(); + fltSample->SetReferenceImage(ref); + fltSample->Update(); + + return fltSample->GetOutput(); +} + +template +void +PropagationTools +::ConnectITKToVTK(itk::VTKImageExport *fltExport,vtkImageImport *fltImport) +{ + fltImport->SetUpdateInformationCallback( fltExport->GetUpdateInformationCallback()); + fltImport->SetPipelineModifiedCallback( fltExport->GetPipelineModifiedCallback()); + fltImport->SetWholeExtentCallback( fltExport->GetWholeExtentCallback()); + fltImport->SetSpacingCallback( fltExport->GetSpacingCallback()); + fltImport->SetOriginCallback( fltExport->GetOriginCallback()); + fltImport->SetScalarTypeCallback( fltExport->GetScalarTypeCallback()); + fltImport->SetNumberOfComponentsCallback( fltExport->GetNumberOfComponentsCallback()); + fltImport->SetPropagateUpdateExtentCallback( fltExport->GetPropagateUpdateExtentCallback()); + fltImport->SetUpdateDataCallback( fltExport->GetUpdateDataCallback()); + fltImport->SetDataExtentCallback( fltExport->GetDataExtentCallback()); + fltImport->SetBufferPointerCallback( fltExport->GetBufferPointerCallback()); + fltImport->SetCallbackUserData( fltExport->GetCallbackUserData()); +} + +template +typename PropagationTools::TPropagationMeshPointer +PropagationTools +::GetMeshFromLabelImage(TLabelImage3D *img) +{ + short imax = img->GetBufferPointer()[0]; + short imin = imax; + for(size_t i = 0; i < img->GetBufferedRegion().GetNumberOfPixels(); i++) + { + short x = img->GetBufferPointer()[i]; + imax = std::max(imax, x); + imin = std::min(imin, x); + } + + typedef itk::VTKImageExport ExporterType; + typename ExporterType::Pointer fltExport = ExporterType::New(); + fltExport->SetInput(img); + vtkImageImport *fltImport = vtkImageImport::New(); + ConnectITKToVTK(fltExport.GetPointer(), fltImport); + + // Append filter for assembling labels + vtkAppendPolyData *fltAppend = vtkAppendPolyData::New(); + + // Extracting one label at a time and assigning label value + for (short i = 1; i <= imax; i += 1.0) + { + // Extract one label + vtkDiscreteMarchingCubes *fltDMC = vtkDiscreteMarchingCubes::New(); + fltDMC->SetInputConnection(fltImport->GetOutputPort()); + fltDMC->ComputeGradientsOff(); + fltDMC->ComputeScalarsOff(); + fltDMC->SetNumberOfContours(1); + fltDMC->ComputeNormalsOn(); + fltDMC->SetValue(0, i); + fltDMC->Update(); + + vtkPolyData *labelMesh = fltDMC->GetOutput(); + + // Set scalar values for the label + vtkShortArray *scalar = vtkShortArray::New(); + scalar->SetNumberOfComponents(1); + scalar->SetNumberOfTuples(labelMesh->GetNumberOfPoints()); + scalar->Fill(i); + scalar->SetName("Label"); + labelMesh->GetPointData()->SetScalars(scalar); + fltAppend->AddInputData(labelMesh); + } + + fltAppend->Update(); + + // Compute the transform from VTK coordinates to NIFTI/RAS coordinates + // Create the transform filter + vtkTransformPolyDataFilter *fltTransform = vtkTransformPolyDataFilter::New(); + fltTransform->SetInputData(fltAppend->GetOutput()); + + typedef vnl_matrix_fixed Mat44; + Mat44 vtk2out; + Mat44 vtk2nii = ConstructVTKtoNiftiTransform( + img->GetDirection().GetVnlMatrix().as_ref(), + img->GetOrigin().GetVnlVector(), + img->GetSpacing().GetVnlVector()); + + vtk2out = vtk2nii; + + // Update the VTK transform to match + vtkTransform *transform = vtkTransform::New(); + transform->SetMatrix(vtk2out.data_block()); + fltTransform->SetTransform(transform); + fltTransform->Update(); + + // Get final output + return fltTransform->GetOutput(); +} + +template +typename PropagationTools::TLabelImage3D::Pointer +PropagationTools +::TrimLabelImage(TLabelImage3D *input, double vox, typename TLabelImage3D::RegionType &roi) +{ + typedef typename TLabelImage3D::RegionType RegionType; + typedef itk::ImageRegionIteratorWithIndex Iterator; + + // Initialize the bounding box + RegionType bbox; + + // Find the extent of the non-background region of the image + Iterator it(input, input->GetBufferedRegion()); + for( ; !it.IsAtEnd(); ++it) + if(it.Value() != 0) + ExpandRegion(bbox, it.GetIndex()); + + typename TLabelImage3D::SizeType radius; + for(size_t i = 0; i < 3; i++) + radius[i] = (int) ceil(vox); + bbox.PadByRadius(radius); + + // Make sure the bounding box is within the contents of the image + bbox.Crop(input->GetBufferedRegion()); + + // Chop off the region + typedef itk::RegionOfInterestImageFilter TrimFilter; + typename TrimFilter::Pointer fltTrim = TrimFilter::New(); + fltTrim->SetInput(input); + fltTrim->SetRegionOfInterest(bbox); + fltTrim->Update(); + + // Copy bounding box to output roi + roi = bbox; + + return fltTrim->GetOutput(); +} + +template +void +PropagationTools +::ExpandRegion(itk::ImageRegion<3> ®ion, const itk::Index<3> &idx) +{ + if(region.GetNumberOfPixels() == 0) + { + region.SetIndex(idx); + for(size_t i = 0; i < 3; i++) + region.SetSize(i, 1); + } + else + { + for(size_t i = 0; i < 3; i++) + { + if(region.GetIndex(i) > idx[i]) + { + region.SetSize(i, region.GetSize(i) + (region.GetIndex(i) - idx[i])); + region.SetIndex(i, idx[i]); + } + else if(region.GetIndex(i) + (long) region.GetSize(i) <= idx[i]) { + region.SetSize(i, 1 + idx[i] - region.GetIndex(i)); + } + } + } +} + + +template +vnl_matrix_fixed +PropagationTools +::ConstructNiftiSform(vnl_matrix m_dir, vnl_vector v_origin, + vnl_vector v_spacing) +{ + // Set the NIFTI/RAS transform + vnl_matrix m_ras_matrix; + vnl_diag_matrix m_scale, m_lps_to_ras; + vnl_vector v_ras_offset; + + // Compute the matrix + m_scale.set(v_spacing); + m_lps_to_ras.set(vnl_vector(3, 1.0)); + m_lps_to_ras[0] = -1; + m_lps_to_ras[1] = -1; + m_ras_matrix = m_lps_to_ras * m_dir * m_scale; + + // Compute the vector + v_ras_offset = m_lps_to_ras * v_origin; + + // Create the larger matrix + vnl_vector vcol(4, 1.0); + vcol.update(v_ras_offset); + + vnl_matrix_fixed m_sform; + m_sform.set_identity(); + m_sform.update(m_ras_matrix); + m_sform.set_column(3, vcol); + return m_sform; +} + +template +vnl_matrix_fixed +PropagationTools +::ConstructVTKtoNiftiTransform(vnl_matrix m_dir, vnl_vector v_origin, + vnl_vector v_spacing) +{ + vnl_matrix_fixed vox2nii = ConstructNiftiSform(m_dir, v_origin, v_spacing); + vnl_matrix_fixed vtk2vox; + vtk2vox.set_identity(); + for(size_t i = 0; i < 3; i++) + { + vtk2vox(i,i) = 1.0 / v_spacing[i]; + vtk2vox(i,3) = - v_origin[i] / v_spacing[i]; + } + return vox2nii * vtk2vox; +} + +template +template +itk::SmartPointer +PropagationTools +::ReadImage(const std::string &filename) +{ + using TReader = itk::ImageFileReader; + typename TReader::Pointer reader = TReader::New(); + reader->SetFileName(filename.c_str()); + reader->Update(); + return reader->GetOutput(); +} + +template +template +void +PropagationTools +::WriteImage(TImage *img, const std::string &filename, itk::IOComponentEnum comp) +{ + if(dynamic_cast(img)) + TLDDMM3D::vimg_write(dynamic_cast(img), filename.c_str(), comp); + else if(dynamic_cast(img)) + TLDDMM3D::img_write(dynamic_cast(img), filename.c_str(), comp); + else if(dynamic_cast(img)) + TLDDMM3D::cimg_write(dynamic_cast(img), filename.c_str(), comp); + else + { + // Some other type (e.g., LabelImage). We use the image writer and ignore the comp + typedef itk::ImageFileWriter WriterType; + typename WriterType::Pointer writer = WriterType::New(); + writer->SetFileName(filename.c_str()); + writer->SetUseCompression(true); + writer->SetInput(img); + writer->Update(); + } +} + +template +template +typename TTimePointImage::Pointer +PropagationTools +::ExtractTimePointImage(TFullImage *full_img, unsigned int tp) +{ + // Logic adapated from SNAP ImageWrapper method: + // ConfigureTimePointImageFromImage4D() + // Always use 1-based index for time point + assert(tp > 0); + + unsigned int nt = full_img->GetBufferedRegion().GetSize()[3u]; + unsigned int bytes_per_volume = full_img->GetPixelContainer()->Size() / nt; + + typename TImage3D::Pointer tp_img = TImage3D::New(); + + typename TImage3D::RegionType region; + typename TImage3D::SpacingType spacing; + typename TImage3D::PointType origin; + typename TImage3D::DirectionType dir; + for(unsigned int j = 0; j < 3; j++) + { + region.SetSize(j, full_img->GetBufferedRegion().GetSize()[j]); + region.SetIndex(j, full_img->GetBufferedRegion().GetIndex()[j]); + spacing[j] = full_img->GetSpacing()[j]; + origin[j] = full_img->GetOrigin()[j]; + for(unsigned int k = 0; k < 3; k++) + dir(j,k) = full_img->GetDirection()(j,k); + } + + // All of the information from the 4D image is propagaged to the 3D timepoints + tp_img->SetRegions(region); + tp_img->SetSpacing(spacing); + tp_img->SetOrigin(origin); + tp_img->SetDirection(dir); + tp_img->SetNumberOfComponentsPerPixel(full_img->GetNumberOfComponentsPerPixel()); + tp_img->Allocate(); + + // Set the buffer pointer + tp_img->GetPixelContainer()->SetImportPointer( + full_img->GetBufferPointer() + bytes_per_volume * (tp - 1), + bytes_per_volume); + + return tp_img; +} + +template +template +itk::SmartPointer +PropagationTools +::Resample3DImage(TImage* input, double factor, + ResampleInterpolationMode intpMode, double smooth_sigma) +{ + typedef itk::DiscreteGaussianImageFilter SmoothFilter; + typename TImage::Pointer imageToResample = input; + + // Smooth image if needed + if (smooth_sigma > 0) + { + typename SmoothFilter::Pointer fltDSSmooth = SmoothFilter::New(); + typename SmoothFilter::ArrayType variance; + for (int i = 0; i < 3; ++i) + variance[i] = smooth_sigma * smooth_sigma; + + fltDSSmooth->SetInput(input); + fltDSSmooth->SetVariance(variance); + fltDSSmooth->UseImageSpacingOn(); + fltDSSmooth->Update(); + imageToResample = fltDSSmooth->GetOutput(); + } + + // Create resampled images + typedef itk::ResampleImageFilter ResampleFilter; + typedef itk::LinearInterpolateImageFunction LinearInterpolator; + typedef itk::NearestNeighborInterpolateImageFunction NNInterpolator; + + typename ResampleFilter::Pointer fltResample = ResampleFilter::New(); + fltResample->SetInput(imageToResample); + fltResample->SetTransform(itk::IdentityTransform::New()); + + switch (intpMode) + { + case ResampleInterpolationMode::Linear: + fltResample->SetInterpolator(LinearInterpolator::New()); + break; + case ResampleInterpolationMode::NearestNeighbor: + fltResample->SetInterpolator(NNInterpolator::New()); + break; + default: + throw GreedyException("Unkown Interpolation Mode"); + } + + typename TImage::SizeType sz; + for(size_t i = 0; i < 3; i++) + sz[i] = (unsigned long)(imageToResample->GetBufferedRegion().GetSize(i) * factor + 0.5); + + // Compute the spacing of the new image + typename TImage::SpacingType spc_pre = imageToResample->GetSpacing(); + typename TImage::SpacingType spc_post = spc_pre; + for(size_t i = 0; i < 3; i++) + spc_post[i] *= imageToResample->GetBufferedRegion().GetSize()[i] * 1.0 / sz[i]; + + // Get the bounding box of the input image + typename TImage::PointType origin_pre = imageToResample->GetOrigin(); + + // Recalculate the origin. The origin describes the center of voxel 0,0,0 + // so that as the voxel size changes, the origin will change as well. + typename TImage::SpacingType off_pre = (imageToResample->GetDirection() * spc_pre) * 0.5; + typename TImage::SpacingType off_post = (imageToResample->GetDirection() * spc_post) * 0.5; + typename TImage::PointType origin_post = origin_pre - off_pre + off_post; + + // Set the image sizes and spacing. + fltResample->SetSize(sz); + fltResample->SetOutputSpacing(spc_post); + fltResample->SetOutputOrigin(origin_post); + fltResample->SetOutputDirection(imageToResample->GetDirection()); + + // Set the unknown intensity to positive value + fltResample->SetDefaultPixelValue(0); + + // Perform resampling + fltResample->UpdateLargestPossibleRegion(); + + return fltResample->GetOutput(); +} + +template +template +typename TOutputImage::Pointer +PropagationTools +::ThresholdImage(TInputImage *img, typename TInputImage::PixelType lower, typename TInputImage::PixelType upper, + typename TOutputImage::PixelType value_in, typename TOutputImage::PixelType value_out) +{ + using ThresholdFilter = itk::BinaryThresholdImageFilter; + typename ThresholdFilter::Pointer fltThreshold = ThresholdFilter::New(); + fltThreshold->SetInput(img); + fltThreshold->SetLowerThreshold(lower); + fltThreshold->SetUpperThreshold(upper); + fltThreshold->SetInsideValue(value_in); + fltThreshold->SetOutsideValue(value_out); + fltThreshold->Update(); + + return fltThreshold->GetOutput(); +} + +template +template +typename TOutputImage::Pointer +PropagationTools +::DilateImage(TInputImage *img, size_t radius, typename TInputImage::PixelType value) +{ + // Label dilation + using Element = itk::BinaryBallStructuringElement ; + typename Element::SizeType sz = { radius, radius, radius }; + Element elt; + elt.SetRadius(sz); + elt.CreateStructuringElement(); + + typedef itk::BinaryDilateImageFilter DilateFilter; + typename DilateFilter::Pointer fltDilation = DilateFilter::New(); + fltDilation->SetInput(img); + fltDilation->SetDilateValue(value); + fltDilation->SetKernel(elt); + fltDilation->Update(); + + return fltDilation->GetOutput(); +} + +template +template +typename PropagationTools::TCompositeImage3D::Pointer +PropagationTools +::CastToCompositeImage(TInputImage *img) +{ + typedef UnaryFunctorImageToSingleComponentVectorImageFilter< + TInputImage, TCompositeImage3D, TIntensityMapping> FilterType; + typedef itk::ImageSource VectorImageSource; + + TIntensityMapping intensityMapping; + itk::SmartPointer filter = FilterType::New(); + filter->SetInput(img); + filter->SetFunctor(intensityMapping); + itk::SmartPointer imgSource = filter.GetPointer(); + imgSource->UpdateOutputInformation(); + imgSource->Update(); + return imgSource->GetOutput(); +} + +template +typename PropagationTools::TCompositeImage3D::Pointer +PropagationTools +::CastImageToCompositeImage(TImage3D *img) +{ + auto flt = itk::ComposeImageFilter::New(); + flt->SetInput(0, img); + flt->Update(); + return flt->GetOutput(); +} + +template +template +typename TImage::Pointer +PropagationTools +::CreateEmptyImage(TImage *sample) +{ + auto duplicator = itk::ImageDuplicator::New(); + duplicator->SetInputImage(sample); + duplicator->Update(); + auto imgout = duplicator->GetOutput(); + imgout->FillBuffer(itk::NumericTraits::Zero); + return imgout; +} + + +template +void +UnaryFunctorImageToSingleComponentVectorImageFilter +::DynamicThreadedGenerateData(const OutputImageRegionType &outputRegionForThread) +{ + // Use our fast iterators for vector images + typedef itk::ImageLinearIteratorWithIndex IterBase; + typedef IteratorExtender IterType; + + typedef typename OutputImageType::InternalPixelType OutputComponentType; + typedef typename InputImageType::InternalPixelType InputComponentType; + + // Define the iterators + IterType outputIt(this->GetOutput(), outputRegionForThread); + int line_len = outputRegionForThread.GetSize(0); + + // Using a generic ITK iterator for the input because it supports RLE images and adaptors + itk::ImageScanlineConstIterator< InputImageType > inputIt(this->GetInput(), outputRegionForThread); + + while ( !inputIt.IsAtEnd() ) + { + // Get the pointer to the input and output pixel lines + OutputComponentType *out = outputIt.GetPixelPointer(this->GetOutput()); + + for(int i = 0; i < line_len; i++, ++inputIt) + { + out[i] = m_Functor(inputIt.Get()); + } + + outputIt.NextLine(); + inputIt.NextLine(); + } +} + + + +} // end of namespace propagation + diff --git a/src/propagation/greedy_propagation.md b/src/propagation/greedy_propagation.md new file mode 100644 index 0000000..28ac262 --- /dev/null +++ b/src/propagation/greedy_propagation.md @@ -0,0 +1,75 @@ +# Greedy Segmentation Propagation Tool + +Greedy Segmentation Propagation Tool applys greedy registration to warp a 3D segmentation from a reference time point to target timepoints in a 4D image. + +## Example Usage + +``` +greedy_propagation \ +-i img4d.nii.gz \ +-sr3 seg05.nii.gz \ +-tpr 5 \ +-tpt 1,2,3,4,6,7 \ +-o /your/output/directory \ +-debug /your/debug/output/directory \ +-n 100x100 -m SSD -s 3mm 1.5mm -threads 10 -V 1 +``` + +The above command warps the segmentation image `seg05.nii.gz` from reference time point `5` to time points `1,2,3,4,6,7` by running greedy registrations between time points with `img4d.nii.gz` as the reference 4d image. + + +## Propagation Parameters +### 4D Image Input `-i ` +Specifies the 4D image that is the base of the segmentation. Propagation algorithm will extract 3D time point images from the 4D image and use them as fix/moving images in the registration runs. + +### 3D Reference Segmentation Input `-sr3 ` +Specifies the 3D segmentation image of the reference time point. This option will override all previously specified reference segmentation inputs by `-sps` or `sps-4d`. + +### 4D Reference Segmentation Input `-sr4 ` +Specify the 4D segmentation image containing the segmentation slice for the reference time point. Only the segmentation image from the reference time point will be used for the propagation run. This option will override all previously specified reference segmentation inputs by `-sr3` or `sr4`. + +### Reference Time Point `-tpr