Source code for cmp.stages.diffusion.reconstruction

# Copyright (C) 2009-2022, Ecole Polytechnique Federale de Lausanne (EPFL) and
# Hospital Center and University of Lausanne (UNIL-CHUV), Switzerland, and CMP3 contributors
# All rights reserved.
#
#  This software is distributed under the open-source license Modified BSD.

"""Reconstruction methods and workflows."""

# General imports

from traits.api import *

import nipype.pipeline.engine as pe
import nipype.interfaces.utility as util

from nipype.interfaces.base import traits
from nipype.interfaces.mrtrix3.utils import TensorMetrics

from nipype import logging

from cmtklib.interfaces.mrtrix3 import (
    Erode,
    MRtrix_mul,
    MRThreshold,
    MRConvert,
    EstimateResponseForSH,
    ConstrainedSphericalDeconvolution,
    DWI2Tensor,
    Tensor2Vector,
)

# from nipype.interfaces.mrtrix3.preprocess import ResponseSD
from cmtklib.diffusion import FlipTable, FlipBvec
from cmtklib.interfaces.dipy import DTIEstimateResponseSH, CSD, SHORE, MAPMRI

# from nipype.interfaces.dipy import CSD


iflogger = logging.getLogger("nipype.interface")


[docs]class DipyReconConfig(HasTraits): """Class used to store Dipy diffusion reconstruction sub-workflow configuration parameters. Attributes ---------- imaging_model : Str Diffusion imaging model (For instance 'DTI') flip_table_axis : traits.List(['x', 'y', 'z']) Axis to be flipped in the gradient table. local_model_editor : {False: '1:Tensor', True: '2:Constrained Spherical Deconvolution'} List of reconstruction models local_model : traits.Bool Reconstruction model selected (See `local_model_editor`) (Default: True, meaning Tensor is performed) lmax_order : traits.Enum([2, 4, 6, 8, 10, 12, 14, 16]) Choices of maximal order to use for Constrained Spherical Deconvolution single_fib_thr : traits.Float(0.7, min=0, max=1) FA threshold recon_mode : traits.Str Can be "Probabilistic" or "Deterministic" mapmri : traits.Bool(False) tracking_processing_tool : traits.Enum('MRtrix', 'Dipy') laplacian_regularization : traits.Bool Apply laplacian regularization in MAP-MRI if `True` (Default: True) laplacian_weighting : traits.Float Laplacian regularization weight in MAP-MRI (Default: 0.05) positivity_constraint : traits.Bool Apply positivity constraint in MAP-MRI if `True` (Default: True) radial_order : traits.Int MAP-MRI radial order (Default: 8) small_delta : traits.Float Small data for gradient table (pulse duration) used by MAP-MRI (Default: 0.02) big_delta : traits.Float Big data for gradient table (time interval) used by MAP-MRI (Default: 0.5) radial_order_values : traits.List([2, 4, 6, 8, 10, 12]) Choices of radial order values used by SHORE shore_radial_order : traits.Str Even number that represents the order of the basis (Default: 6) shore_zeta : traits.Int Scale factor in SHORE (Default: 700) shore_lambda_n : traits.Float Radial regularisation constant in SHORE (Default: 1e-8) shore_lambda_l : traits.Float Angular regularisation constant in SHORE (Default: 1e-8) shore_tau : traits.Float Diffusion time used by SHORE. By default the value that makes *q* equal to the square root of the b-value (Default: 0.025330295910584444) shore_constrain_e0 : traits.Bool Constrain SHORE optimization such that E(0) = 1 (Default: False) shore_positive_constraint : traits.Bool Constrain the SHORE propagator to be positive (Default: False) """ imaging_model = Str # flip_table_axis = List(editor=CheckListEditor(values=['x','y','z'],cols=3)) flip_table_axis = List(["x", "y", "z"]) # gradient_table = File local_model_editor = Dict( {False: "1:Tensor", True: "2:Constrained Spherical Deconvolution"} ) local_model = Bool(True) lmax_order = Enum([2, 4, 6, 8, 10, 12, 14, 16]) # normalize_to_B0 = Bool(False) single_fib_thr = Float(0.7, min=0, max=1) recon_mode = Str mapmri = Bool(False) tracking_processing_tool = Enum("MRtrix", "Dipy") laplacian_regularization = traits.Bool( True, usedefault=True, desc="Apply laplacian regularization" ) laplacian_weighting = traits.Float( 0.05, usedefault=True, desc="Regularization weight" ) positivity_constraint = traits.Bool( True, usedefault=True, desc="Apply positivity constraint" ) radial_order = traits.Int(8, usedefault=True, desc="radial order") small_delta = traits.Float( 0.02, mandatory=True, desc="Small data for gradient table (pulse duration)" ) big_delta = traits.Float( 0.5, mandatory=True, desc="Small data for gradient table (time interval)" ) radial_order_values = traits.List([2, 4, 6, 8, 10, 12]) shore_radial_order = Enum( 6, values="radial_order_values", usedefault=True, desc="Even number that represents the order of the basis", ) shore_zeta = traits.Int(700, usedefault=True, desc="Scale factor") shore_lambda_n = traits.Float( 1e-8, usedefault=True, desc="radial regularisation constant" ) shore_lambda_l = traits.Float( 1e-8, usedefault=True, desc="angular regularisation constant" ) shore_tau = traits.Float( 0.025330295910584444, desc=( "Diffusion time. By default the value that makes q equal to the square root of the b-value." ), ) shore_constrain_e0 = traits.Bool( False, usedefault=True, desc=("Constrain the optimization such that E(0) = 1.") ) shore_positive_constraint = traits.Bool( False, usedefault=True, desc="Constrain the propagator to be positive." ) def _imaging_model_changed(self, new): """Update ``local_model_editor`` and ``self.local_model`` when ``imaging_model`` is updated. Parameters ---------- new : string New value of ``imaging_model`` """ if new == "DSI": pass elif new == "DTI": self.local_model_editor = { False: "1:Tensor", True: "2:Constrained Spherical Deconvolution", } elif new == "multishell" or new == "HARDI": self.local_model_editor = {True: "Constrained Spherical Deconvolution"} self.local_model = True def _recon_mode_changed(self, new): """Update ``local_model_editor`` and ``self.local_model`` when ``recon_mode`` is updated. Parameters ---------- new : string New value of ``recon_mode`` """ if new == "Probabilistic" and self.imaging_model != "DSI": self.local_model_editor = {True: "Constrained Spherical Deconvolution"} self.local_model = True elif new == "Probabilistic" and self.imaging_model == "DSI": pass else: self.local_model_editor = { False: "1:Tensor", True: "2:Constrained Spherical Deconvolution", }
[docs]class MRtrixReconConfig(HasTraits): """Class used to store Dipy diffusion reconstruction sub-workflow configuration parameters. Attributes ---------- flip_table_axis : traits.List(['x', 'y', 'z']) Axis to be flipped in the gradient table. local_model_editor : {False: '1:Tensor', True: '2:Constrained Spherical Deconvolution'} List of reconstruction models local_model : traits.Bool Reconstruction model selected (See `local_model_editor`) (Default: True, meaning Tensor is performed) lmax_order : traits.Enum([2, 4, 6, 8, 10, 12, 14, 16]) Choices of maximal order to use for Constrained Spherical Deconvolution single_fib_thr : traits.Float(0.7, min=0, max=1) FA threshold recon_mode : traits.Str Can be "Probabilistic" or "Deterministic" """ # gradient_table = File flip_table_axis = List(["x", "y", "z"]) local_model_editor = Dict( {False: "1:Tensor", True: "2:Constrained Spherical Deconvolution"} ) local_model = Bool(True) lmax_order = Enum([2, 4, 6, 8, 10, 12, 14, 16]) normalize_to_B0 = Bool(False) single_fib_thr = Float(0.7, min=0, max=1) recon_mode = Str def _imaging_model_changed(self, new): """Update ``local_model_editor`` and ``self.local_model`` when ``imaging_model`` is updated. Parameters ---------- new : string New value of ``imaging_model`` """ if new == "DTI": self.local_model_editor = { False: "1:Tensor", True: "2:Constrained Spherical Deconvolution", } elif new == "multishell" or new == "HARDI": self.local_model_editor = {True: "Constrained Spherical Deconvolution"} self.local_model = True def _recon_mode_changed(self, new): """Update ``local_model_editor`` and ``self.local_model`` when ``recon_mode`` is updated. Parameters ---------- new : string New value of ``recon_mode`` """ if new == "Probabilistic": self.local_model_editor = {True: "Constrained Spherical Deconvolution"} self.local_model = True else: self.local_model_editor = { False: "1:Tensor", True: "2:Constrained Spherical Deconvolution", }
[docs]def create_dipy_recon_flow(config): """Create the reconstruction sub-workflow of the `DiffusionStage` using Dipy. Parameters ---------- config : DipyReconConfig Workflow configuration Returns ------- flow : nipype.pipeline.engine.Workflow Built reconstruction sub-workflow """ flow = pe.Workflow(name="reconstruction") inputnode = pe.Node( interface=util.IdentityInterface( fields=[ "diffusion", "diffusion_resampled", "brain_mask_resampled", "wm_mask_resampled", "bvals", "bvecs", ] ), name="inputnode", ) outputnode = pe.Node( interface=util.IdentityInterface( fields=[ "DWI", "FA", "AD", "MD", "RD", "fod", "model", "eigVec", "RF", "grad", "bvecs", "shore_maps", "mapmri_maps", ], mandatory_inputs=True, ), name="outputnode", ) # Flip gradient table flip_bvecs = pe.Node(interface=FlipBvec(), name="flip_bvecs") flip_bvecs.inputs.flipping_axis = config.flip_table_axis flip_bvecs.inputs.delimiter = " " flip_bvecs.inputs.header_lines = 0 flip_bvecs.inputs.orientation = "h" # fmt:off flow.connect( [ (inputnode, flip_bvecs, [("bvecs", "bvecs")]), (flip_bvecs, outputnode, [("bvecs_flipped", "bvecs")]), ] ) # fmt:on # Compute single fiber voxel mask dipy_erode = pe.Node( interface=Erode(out_filename="wm_mask_resampled.nii.gz"), name="dipy_erode" ) dipy_erode.inputs.number_of_passes = 1 dipy_erode.inputs.filtertype = "erode" flow.connect([(inputnode, dipy_erode, [("wm_mask_resampled", "in_file")])]) if config.imaging_model != "DSI": # Tensor -> EigenVectors / FA, AD, MD, RD maps dipy_tensor = pe.Node(interface=DTIEstimateResponseSH(), name="dipy_tensor") dipy_tensor.inputs.auto = True dipy_tensor.inputs.roi_radius = 10 dipy_tensor.inputs.fa_thresh = config.single_fib_thr # fmt:off flow.connect( [ (inputnode, dipy_tensor, [("diffusion_resampled", "in_file")]), (inputnode, dipy_tensor, [("bvals", "in_bval")]), (flip_bvecs, dipy_tensor, [("bvecs_flipped", "in_bvec")]), (dipy_erode, dipy_tensor, [("out_file", "in_mask")]), (dipy_tensor, outputnode, [("response", "RF")]), (dipy_tensor, outputnode, [("fa_file", "FA")]), (dipy_tensor, outputnode, [("ad_file", "AD")]), (dipy_tensor, outputnode, [("md_file", "MD")]), (dipy_tensor, outputnode, [("rd_file", "RD")]), ] ) # fmt:on if not config.local_model: # fmt:off flow.connect( [ (inputnode, outputnode, [("diffusion_resampled", "DWI")]), (dipy_tensor, outputnode, [("dti_model", "model")]), ] ) # fmt:on # Constrained Spherical Deconvolution else: # Perform spherical deconvolution dipy_CSD = pe.Node(interface=CSD(), name="dipy_CSD") dipy_CSD.inputs.save_shm_coeff = True dipy_CSD.inputs.out_shm_coeff = "diffusion_shm_coeff.nii.gz" if config.tracking_processing_tool == "MRtrix": dipy_CSD.inputs.tracking_processing_tool = "mrtrix" elif config.tracking_processing_tool == "Dipy": dipy_CSD.inputs.tracking_processing_tool = "dipy" if config.lmax_order != "Auto": dipy_CSD.inputs.sh_order = config.lmax_order dipy_CSD.inputs.fa_thresh = config.single_fib_thr # fmt:off flow.connect( [ (inputnode, dipy_CSD, [("diffusion_resampled", "in_file")]), (inputnode, dipy_CSD, [("bvals", "in_bval")]), (flip_bvecs, dipy_CSD, [("bvecs_flipped", "in_bvec")]), (inputnode, dipy_CSD, [("brain_mask_resampled", "in_mask")]), (dipy_CSD, outputnode, [("model", "model")]), ] ) # fmt:on if config.tracking_processing_tool != "Dipy": # fmt:off flow.connect( [(dipy_CSD, outputnode, [("out_shm_coeff", "DWI")])] ) # fmt:on else: # fmt:off flow.connect( [(inputnode, outputnode, [("diffusion_resampled", "DWI")])] ) # fmt:on else: # Perform SHORE reconstruction (DSI) dipy_SHORE = pe.Node(interface=SHORE(), name="dipy_SHORE") if config.tracking_processing_tool == "MRtrix": dipy_SHORE.inputs.tracking_processing_tool = "mrtrix" elif config.tracking_processing_tool == "Dipy": dipy_SHORE.inputs.tracking_processing_tool = "dipy" dipy_SHORE.inputs.radial_order = int(config.shore_radial_order) dipy_SHORE.inputs.zeta = config.shore_zeta dipy_SHORE.inputs.lambda_n = config.shore_lambda_n dipy_SHORE.inputs.lambda_l = config.shore_lambda_l dipy_SHORE.inputs.tau = config.shore_tau dipy_SHORE.inputs.constrain_e0 = config.shore_constrain_e0 dipy_SHORE.inputs.positive_constraint = config.shore_positive_constraint shore_maps_merge = pe.Node(interface=util.Merge(3), name="merge_shore_maps") # fmt:off flow.connect( [ (inputnode, dipy_SHORE, [("diffusion_resampled", "in_file")]), (inputnode, dipy_SHORE, [("bvals", "in_bval")]), (flip_bvecs, dipy_SHORE, [("bvecs_flipped", "in_bvec")]), (inputnode, dipy_SHORE, [("brain_mask_resampled", "in_mask")]), (dipy_SHORE, outputnode, [("model", "model")]), (dipy_SHORE, outputnode, [("fodf", "fod")]), (dipy_SHORE, outputnode, [("GFA", "FA")]), (dipy_SHORE, shore_maps_merge, [("GFA", "in1"), ("MSD", "in2"), ("RTOP", "in3")],), (shore_maps_merge, outputnode, [("out", "shore_maps")]), (inputnode, outputnode, [("diffusion_resampled", "DWI")]) ] ) # fmt:on if config.mapmri: dipy_MAPMRI = pe.Node(interface=MAPMRI(), name="dipy_mapmri") dipy_MAPMRI.inputs.laplacian_regularization = config.laplacian_regularization dipy_MAPMRI.inputs.laplacian_weighting = config.laplacian_weighting dipy_MAPMRI.inputs.positivity_constraint = config.positivity_constraint dipy_MAPMRI.inputs.radial_order = config.radial_order dipy_MAPMRI.inputs.small_delta = config.small_delta dipy_MAPMRI.inputs.big_delta = config.big_delta mapmri_maps_merge = pe.Node(interface=util.Merge(8), name="merge_mapmri_maps") # fmt:off flow.connect( [ (inputnode, dipy_MAPMRI, [("diffusion_resampled", "in_file")]), (inputnode, dipy_MAPMRI, [("bvals", "in_bval")]), (flip_bvecs, dipy_MAPMRI, [("bvecs_flipped", "in_bvec")]), (dipy_MAPMRI, mapmri_maps_merge, [("rtop_file", "in1"), ("rtap_file", "in2"), ("rtpp_file", "in3"), ("msd_file", "in4"), ("qiv_file", "in5"), ("ng_file", "in6"), ("ng_perp_file", "in7"), ("ng_para_file", "in8")]), (mapmri_maps_merge, outputnode, [("out", "mapmri_maps")]), ] ) # fmt:on return flow
[docs]def create_mrtrix_recon_flow(config): """Create the reconstruction sub-workflow of the `DiffusionStage` using MRtrix3. Parameters ---------- config : DipyReconConfig Workflow configuration Returns ------- flow : nipype.pipeline.engine.Workflow Built reconstruction sub-workflow """ # TODO: Add AD and RD maps flow = pe.Workflow(name="reconstruction") inputnode = pe.Node( interface=util.IdentityInterface( fields=["diffusion", "diffusion_resampled", "wm_mask_resampled", "grad"] ), name="inputnode", ) outputnode = pe.Node( interface=util.IdentityInterface( fields=["DWI", "FA", "ADC", "tensor", "eigVec", "RF", "grad"], mandatory_inputs=True, ), name="outputnode", ) # Flip gradient table flip_table = pe.Node(interface=FlipTable(), name="flip_table") flip_table.inputs.flipping_axis = config.flip_table_axis flip_table.inputs.delimiter = " " flip_table.inputs.header_lines = 0 flip_table.inputs.orientation = "v" # fmt:off flow.connect( [ (inputnode, flip_table, [("grad", "table")]), (flip_table, outputnode, [("table", "grad")]), ] ) # fmt:on # Tensor mrtrix_tensor = pe.Node(interface=DWI2Tensor(), name="mrtrix_make_tensor") # fmt:off flow.connect( [ (inputnode, mrtrix_tensor, [("diffusion_resampled", "in_file")]), (flip_table, mrtrix_tensor, [("table", "encoding_file")]), ] ) # fmt:on # Tensor -> FA map mrtrix_tensor_metrics = pe.Node( interface=TensorMetrics(out_fa="FA.mif", out_adc="ADC.mif"), name="mrtrix_tensor_metrics", ) convert_Tensor = pe.Node( interface=MRConvert(out_filename="dwi_tensor.nii.gz"), name="convert_tensor" ) convert_FA = pe.Node( interface=MRConvert(out_filename="FA.nii.gz"), name="convert_FA" ) convert_ADC = pe.Node( interface=MRConvert(out_filename="ADC.nii.gz"), name="convert_ADC" ) # fmt:off flow.connect( [ (mrtrix_tensor, convert_Tensor, [("tensor", "in_file")]), (mrtrix_tensor, mrtrix_tensor_metrics, [("tensor", "in_file")]), (mrtrix_tensor_metrics, convert_FA, [("out_fa", "in_file")]), (mrtrix_tensor_metrics, convert_ADC, [("out_adc", "in_file")]), (convert_Tensor, outputnode, [("converted", "tensor")]), (convert_FA, outputnode, [("converted", "FA")]), (convert_ADC, outputnode, [("converted", "ADC")]), ] ) # fmt:on # Tensor -> Eigenvectors mrtrix_eigVectors = pe.Node(interface=Tensor2Vector(), name="mrtrix_eigenvectors") # fmt:off flow.connect( [ (mrtrix_tensor, mrtrix_eigVectors, [("tensor", "in_file")]), (mrtrix_eigVectors, outputnode, [("vector", "eigVec")]), ] ) # fmt:on # Constrained Spherical Deconvolution if config.local_model: print("CSD true") # Compute single fiber voxel mask mrtrix_erode = pe.Node( interface=Erode(out_filename="wm_mask_res_eroded.nii.gz"), name="mrtrix_erode", ) mrtrix_erode.inputs.number_of_passes = 1 mrtrix_erode.inputs.filtertype = "erode" mrtrix_mul_eroded_FA = pe.Node( interface=MRtrix_mul(), name="mrtrix_mul_eroded_FA" ) mrtrix_mul_eroded_FA.inputs.out_filename = "diffusion_resampled_tensor_FA_masked.mif" mrtrix_thr_FA = pe.Node( interface=MRThreshold(out_file="FA_th.mif"), name="mrtrix_thr" ) mrtrix_thr_FA.inputs.abs_value = config.single_fib_thr # fmt:off flow.connect( [ (inputnode, mrtrix_erode, [("wm_mask_resampled", "in_file")]), (mrtrix_erode, mrtrix_mul_eroded_FA, [("out_file", "input2")]), (mrtrix_tensor_metrics, mrtrix_mul_eroded_FA, [("out_fa", "input1")]), (mrtrix_mul_eroded_FA, mrtrix_thr_FA, [("out_file", "in_file")]), ] ) # fmt:on # Compute single fiber response function mrtrix_rf = pe.Node(interface=EstimateResponseForSH(), name="mrtrix_rf") mrtrix_rf.inputs.maximum_harmonic_order = int(config.lmax_order) mrtrix_rf.inputs.algorithm = "tournier" # mrtrix_rf.inputs.normalise = config.normalize_to_B0 # fmt:off flow.connect( [ (inputnode, mrtrix_rf, [("diffusion_resampled", "in_file")]), (mrtrix_thr_FA, mrtrix_rf, [("thresholded", "mask_image")]), (flip_table, mrtrix_rf, [("table", "encoding_file")]), ] ) # fmt:on # Perform spherical deconvolution mrtrix_CSD = pe.Node( interface=ConstrainedSphericalDeconvolution(), name="mrtrix_CSD" ) mrtrix_CSD.inputs.algorithm = "csd" mrtrix_CSD.inputs.maximum_harmonic_order = int(config.lmax_order) # mrtrix_CSD.inputs.normalise = config.normalize_to_B0 convert_CSD = pe.Node( interface=MRConvert(out_filename="spherical_harmonics_image.nii.gz"), name="convert_CSD", ) # fmt:off flow.connect( [ (inputnode, mrtrix_CSD, [("diffusion_resampled", "in_file")]), (mrtrix_rf, mrtrix_CSD, [("response", "response_file")]), (mrtrix_rf, outputnode, [("response", "RF")]), (inputnode, mrtrix_CSD, [("wm_mask_resampled", "mask_image")]), (flip_table, mrtrix_CSD, [("table", "encoding_file")]), (mrtrix_CSD, convert_CSD, [("spherical_harmonics_image", "in_file")]), (convert_CSD, outputnode, [("converted", "DWI")]) # (mrtrix_CSD,outputnode,[('spherical_harmonics_image','DWI')]) ] ) # fmt:on else: # fmt:off flow.connect( [(inputnode, outputnode, [("diffusion_resampled", "DWI")])] ) # fmt:on return flow