Source code for cmp.stages.diffusion.tracking

# Copyright (C) 2009-2022, Ecole Polytechnique Federale de Lausanne (EPFL) and
# Hospital Center and University of Lausanne (UNIL-CHUV), Switzerland, and CMP3 contributors
# All rights reserved.
#
#  This software is distributed under the open-source license Modified BSD.

"""Tracking methods and workflows of the diffusion stage."""

from traits.api import *

from nipype.interfaces.base import traits
import nipype.pipeline.engine as pe
import nipype.interfaces.utility as util
from nipype import logging
# import matplotlib.pyplot as plt

from cmtklib.interfaces.mrtrix3 import Erode, StreamlineTrack, FilterTractogram
from cmtklib.interfaces.dipy import (
    DirectionGetterTractography,
    TensorInformedEudXTractography,
)
from cmtklib.interfaces.misc import ExtractHeaderVoxel2WorldMatrix
from cmtklib.diffusion import Tck2Trk
# from cmtklib.diffusion import filter_fibers


iflogger = logging.getLogger("nipype.interface")


[docs]class DipyTrackingConfig(HasTraits): """Class used to store Dipy diffusion reconstruction sub-workflow configuration parameters. Attributes ---------- imaging_model : traits.Str Diffusion imaging model (For example 'DTI') tracking_mode : traits.Str Type of local tractography algorithm (Can be "Deterministic" or "Probabilistic") SD : traits.Bool If `True`, inputs are coming from Constrained Spherical Deconvolution reconstruction number_of_seeds : traits.Int Number of seeds (Default: 1000) seed_density : traits.Float Number of seeds to place along each direction where a density of 2 is the same as [2, 2, 2] and will result in a total of 8 seeds per voxel (Default: 1.0) fa_thresh : traits.Float Fractional Anisotropy (FA) threshold (Default: 0.2) step_size : traits.traits.Float Tractography algorithm step size (Default: 0.5) max_angle : traits.Float Maximum streamline angle allowed (Default: 25.0) sh_order : traits.Int Order used for Constrained Spherical Deconvolution reconstruction (Default: 8) use_act : traits.Bool Use FAST for partial volume estimation and Anatomically-Constrained Tractography (ACT) tissue classifier (Default: False) seed_from_gmwmi : traits.Bool Seed from Grey Matter / White Matter interface (requires Anatomically-Constrained Tractography (ACT)) (Default: False) """ imaging_model = Str tracking_mode = Str SD = Bool number_of_seeds = Int(1000) seed_density = Float( 1.0, desc="Number of seeds to place along each direction. " "A density of 2 is the same as [2, 2, 2] and will result in a total of 8 seeds per voxel.", ) fa_thresh = Float(0.2) step_size = traits.Float(0.5) max_angle = Float(25.0) sh_order = Int(8) use_act = traits.Bool( False, desc="Use FAST for partial volume estimation and Anatomically-Constrained Tractography (ACT) tissue classifier", ) seed_from_gmwmi = traits.Bool( False, desc="Seed from Grey Matter / White Matter interface (requires Anatomically-Constrained Tractography (ACT))", ) # fast_number_of_classes = Int(3) def _SD_changed(self, new): """Update ``curvature`` when ``SD`` is updated. Parameters ---------- new New value of ``SD`` """ if self.tracking_mode == "Deterministic" and not new: self.curvature = 2.0 elif self.tracking_mode == "Deterministic" and new: self.curvature = 0.0 elif self.tracking_mode == "Probabilistic": self.curvature = 1.0 def _tracking_mode_changed(self, new): """Update ``curvature``, ``use_act`` and ``seed_from_gmwmi`` when ``tracking_mode`` is updated. Parameters ---------- new New value of ``tracking_mode`` """ if new == "Deterministic" and not self.SD: self.curvature = 2.0 self.use_act = False self.seed_from_gmwmi = False elif new == "Deterministic" and self.SD: self.curvature = 0.0 self.use_act = False self.seed_from_gmwmi = False elif new == "Probabilistic": self.curvature = 1.0 def _curvature_changed(self, new): """Set ``curvature`` to 0 if ``curvature`` is updated to a value <= 0.000001. Parameters ---------- new New value of ``curvature`` """ if new <= 0.000001: self.curvature = 0.0 def _use_act_changed(self, new): """Set ``seed_from_gmwmi`` if ``use_act`` has been updated to `False`. Parameters ---------- new New value of ``use_act`` """ if new is False: self.seed_from_gmwmi = False
[docs]class MRtrixTrackingConfig(HasTraits): """Class used to store Dipy diffusion reconstruction sub-workflow configuration parameters. Attributes ---------- tracking_mode : traits.Str Type of local tractography algorithm (Can be "Deterministic" or "Probabilistic") SD : traits.Bool If `True`, inputs are coming from Constrained Spherical Deconvolution reconstruction desired_number_of_tracks : traits.Int Desired number of output streamlines in the tractogram (Default: 1M) curvature = Float Maximum streamline curvature (Default: 2.0) min_length = Float Minimal streamline length (Default: 5) max_length = Float Maximal streamline length (Default: 500) angle : traits.Float Maximum streamline angle allowed (Default: 45.0) cutoff_value : traits.Float Cut-off value to terminate streamline (Default: 0.05) use_act : traits.Bool Use `5ttgen` for brain tissue types estimation and Anatomically-Constrained Tractography (ACT) tissue classifier (Default: False) seed_from_gmwmi : traits.Bool Seed from Grey Matter / White Matter interface (requires Anatomically-Constrained Tractography (ACT)) (Default: False) crop_at_gmwmi : traits.Bool Crop streamline endpoints more precisely as they cross the GM-WM interface (requires Anatomically-Constrained Tractography (ACT)) (Default: True) backtrack : traits.Bool Allow tracks to be truncated (requires Anatomically-Constrained Tractography (ACT)) (Default: True) sift : traits.Bool Filter tractogram using mrtrix3 SIFT (Default: True) """ tracking_mode = Str SD = Bool desired_number_of_tracks = Int(1000000) # max_number_of_seeds = Int(1000000000) curvature = Float(2.0) step_size = Float(0.5) min_length = Float(5) max_length = Float(500) angle = Float(45) cutoff_value = Float(0.05) use_act = traits.Bool( True, desc="Anatomically-Constrained Tractography (ACT) based on Freesurfer parcellation", ) seed_from_gmwmi = traits.Bool( False, desc="Seed from Grey Matter / White Matter interface (requires Anatomically-Constrained Tractography (ACT))", ) crop_at_gmwmi = traits.Bool( True, desc="Crop streamline endpoints more precisely as they cross the GM-WM interface " "(requires Anatomically-Constrained Tractography (ACT))", ) backtrack = traits.Bool( True, desc="Allow tracks to be truncated (requires Anatomically-Constrained Tractography (ACT))", ) sift = traits.Bool(True, desc="Filter tractogram using mrtrix3 SIFT") def _SD_changed(self, new): """Update ``curvature`` when ``SD`` is updated. Parameters ---------- new New value of ``SD`` """ if self.tracking_mode == "Deterministic" and not new: self.curvature = 2.0 elif self.tracking_mode == "Deterministic" and new: self.curvature = 0.0 elif self.tracking_mode == "Probabilistic": self.curvature = 1.0 def _use_act_changed(self, new): if new is False: self.crop_at_gmwmi = False self.seed_from_gmwmi = False self.backtrack = False def _tracking_mode_changed(self, new): """Update ``curvature`` when ``tracking_mode`` is updated. Parameters ---------- new New value of ``tracking_mode`` """ if new == "Deterministic" and not self.SD: self.curvature = 2.0 elif new == "Deterministic" and self.SD: self.curvature = 0.0 elif new == "Probabilistic": self.curvature = 1.0 def _curvature_changed(self, new): """Set ``curvature`` to 0 if ``curvature`` is updated to a value <= 0.000001. Parameters ---------- new New value of ``curvature`` """ if new <= 0.000001: self.curvature = 0.0
[docs]def create_dipy_tracking_flow(config): """Create the tractography sub-workflow of the `DiffusionStage` using Dipy. Parameters ---------- config : DipyTrackingConfig Sub-workflow configuration object Returns ------- flow : nipype.pipeline.engine.Workflow Built tractography sub-workflow """ flow = pe.Workflow(name="tracking") # inputnode inputnode = pe.Node( interface=util.IdentityInterface( fields=[ "DWI", "fod_file", "FA", "T1", "partial_volumes", "wm_mask_resampled", "gmwmi_file", "gm_registered", "bvals", "bvecs", "model", ] ), name="inputnode", ) # outputnode outputnode = pe.Node( interface=util.IdentityInterface(fields=["track_file"]), name="outputnode" ) if not config.SD and config.imaging_model != "DSI": # If tensor fitting was used dipy_tracking = pe.Node( interface=TensorInformedEudXTractography(), name="dipy_dtieudx_tracking" ) dipy_tracking.inputs.num_seeds = config.number_of_seeds dipy_tracking.inputs.fa_thresh = config.fa_thresh dipy_tracking.inputs.max_angle = config.max_angle dipy_tracking.inputs.step_size = config.step_size # fmt:off flow.connect( [ (inputnode, dipy_tracking, [("wm_mask_resampled", "seed_mask")]), (inputnode, dipy_tracking, [("DWI", "in_file")]), (inputnode, dipy_tracking, [("model", "in_model")]), (inputnode, dipy_tracking, [("FA", "in_fa")]), (inputnode, dipy_tracking, [("wm_mask_resampled", "tracking_mask")]), (dipy_tracking, outputnode, [("tracks", "track_file")]), ] ) # fmt:on else: # If CSD was used if config.tracking_mode == "Deterministic": dipy_tracking = pe.Node( interface=DirectionGetterTractography(), name="dipy_deterministic_tracking", ) dipy_tracking.inputs.algo = "deterministic" dipy_tracking.inputs.num_seeds = config.number_of_seeds dipy_tracking.inputs.fa_thresh = config.fa_thresh dipy_tracking.inputs.max_angle = config.max_angle dipy_tracking.inputs.step_size = config.step_size dipy_tracking.inputs.use_act = config.use_act dipy_tracking.inputs.use_act = config.seed_from_gmwmi dipy_tracking.inputs.seed_density = config.seed_density # dipy_tracking.inputs.fast_number_of_classes = config.fast_number_of_classes if config.imaging_model == "DSI": dipy_tracking.inputs.recon_model = "SHORE" else: dipy_tracking.inputs.recon_model = "CSD" dipy_tracking.inputs.recon_order = config.sh_order if config.imaging_model == "DSI": # fmt:off flow.connect( [ (inputnode, dipy_tracking, [("fod_file", "fod_file")]), ] ) # fmt:on # fmt:off flow.connect( [ (inputnode, dipy_tracking, [("DWI", "in_file")]), (inputnode, dipy_tracking, [("partial_volumes", "in_partial_volume_files")],), (inputnode, dipy_tracking, [("model", "in_model")]), (inputnode, dipy_tracking, [("FA", "in_fa")]), (inputnode, dipy_tracking, [("wm_mask_resampled", "seed_mask")]), (inputnode, dipy_tracking, [("gmwmi_file", "gmwmi_file")]), (inputnode, dipy_tracking, [("wm_mask_resampled", "tracking_mask")],), (dipy_tracking, outputnode, [("tracks", "track_file")]), ] ) # fmt:on elif config.tracking_mode == "Probabilistic": dipy_tracking = pe.Node( interface=DirectionGetterTractography(), name="dipy_probabilistic_tracking", ) dipy_tracking.inputs.algo = "probabilistic" dipy_tracking.inputs.num_seeds = config.number_of_seeds dipy_tracking.inputs.fa_thresh = config.fa_thresh dipy_tracking.inputs.max_angle = config.max_angle dipy_tracking.inputs.step_size = config.step_size dipy_tracking.inputs.use_act = config.use_act dipy_tracking.inputs.seed_from_gmwmi = config.seed_from_gmwmi dipy_tracking.inputs.seed_density = config.seed_density # dipy_tracking.inputs.fast_number_of_classes = config.fast_number_of_classes if config.imaging_model == "DSI": dipy_tracking.inputs.recon_model = "SHORE" else: dipy_tracking.inputs.recon_model = "CSD" dipy_tracking.inputs.recon_order = config.sh_order if config.imaging_model == "DSI": # fmt:off flow.connect( [ (inputnode, dipy_tracking, [("fod_file", "fod_file")]), ] ) # fmt:on # fmt:off flow.connect( [ (inputnode, dipy_tracking, [("DWI", "in_file")]), (inputnode, dipy_tracking, [("partial_volumes", "in_partial_volume_files")]), (inputnode, dipy_tracking, [("model", "in_model")]), (inputnode, dipy_tracking, [("FA", "in_fa")]), (inputnode, dipy_tracking, [("wm_mask_resampled", "seed_mask")]), (inputnode, dipy_tracking, [("gmwmi_file", "gmwmi_file")]), (inputnode, dipy_tracking, [("wm_mask_resampled", "tracking_mask")]), (dipy_tracking, outputnode, [("tracks", "track_file")]), ] ) # fmt:on return flow
[docs]def get_freesurfer_parcellation(roi_files): """Return the first file in the list of parcellation files Parameters ---------- roi_files : list of traits.File List of parcellation files """ print("%s" % roi_files[0]) return roi_files[0]
[docs]def create_mrtrix_tracking_flow(config): """Create the tractography sub-workflow of the `DiffusionStage` using MRtrix3. Parameters ---------- config : MRtrixTrackingConfig Sub-workflow configuration object Returns ------- flow : nipype.pipeline.engine.Workflow Built tractography sub-workflow """ flow = pe.Workflow(name="tracking") # inputnode inputnode = pe.Node( interface=util.IdentityInterface( fields=[ "DWI", "wm_mask_resampled", "gm_registered", "act_5tt_registered", "gmwmi_registered", "grad", ] ), name="inputnode", ) # outputnode outputnode = pe.Node( interface=util.IdentityInterface(fields=["track_file"]), name="outputnode" ) # Compute single fiber voxel mask wm_erode = pe.Node( interface=Erode(out_filename="wm_mask_resampled.nii.gz"), name="wm_erode" ) wm_erode.inputs.number_of_passes = 3 wm_erode.inputs.filtertype = "erode" flow.connect([(inputnode, wm_erode, [("wm_mask_resampled", "in_file")])]) if config.tracking_mode == "Deterministic": mrtrix_tracking = pe.Node( interface=StreamlineTrack(), name="mrtrix_deterministic_tracking" ) mrtrix_tracking.inputs.desired_number_of_tracks = ( config.desired_number_of_tracks ) # mrtrix_tracking.inputs.maximum_number_of_seeds = config.max_number_of_seeds mrtrix_tracking.inputs.maximum_tract_length = config.max_length mrtrix_tracking.inputs.minimum_tract_length = config.min_length mrtrix_tracking.inputs.step_size = config.step_size mrtrix_tracking.inputs.angle = config.angle mrtrix_tracking.inputs.cutoff_value = config.cutoff_value # mrtrix_tracking.inputs.args = '2>/dev/null' if config.curvature >= 0.000001: mrtrix_tracking.inputs.rk4 = True mrtrix_tracking.inputs.inputmodel = "SD_Stream" else: mrtrix_tracking.inputs.inputmodel = "SD_Stream" # fmt:off flow.connect( [(inputnode, mrtrix_tracking, [("grad", "gradient_encoding_file")])] ) # fmt:on voxel2WorldMatrixExtracter = pe.Node( interface=ExtractHeaderVoxel2WorldMatrix(), name="voxel2WorldMatrixExtracter", ) # fmt:off flow.connect( [ (inputnode, voxel2WorldMatrixExtracter, [("wm_mask_resampled", "in_file")],) ] ) # fmt:on if config.use_act: # fmt:off flow.connect( [ (inputnode, mrtrix_tracking, [("act_5tt_registered", "act_file")]), ] ) # fmt:on mrtrix_tracking.inputs.backtrack = config.backtrack mrtrix_tracking.inputs.crop_at_gmwmi = config.crop_at_gmwmi else: # fmt:off flow.connect( [ (inputnode, mrtrix_tracking, [("wm_mask_resampled", "mask_file")]), ] ) # fmt:on if config.seed_from_gmwmi: # fmt:off flow.connect( [ (inputnode, mrtrix_tracking, [("gmwmi_registered", "seed_gmwmi")]), ] ) # fmt:on else: # fmt:off flow.connect( [ (inputnode, mrtrix_tracking, [("wm_mask_resampled", "seed_file")]), ] ) # fmt:on # converter = pe.Node(interface=mrtrix.MRTrix2TrackVis(),name="trackvis") converter = pe.Node(interface=Tck2Trk(), name="trackvis") converter.inputs.out_tracks = "converted.trk" if config.sift: filter_tractogram = pe.Node(interface=FilterTractogram(), name="sift_node") filter_tractogram.inputs.out_file = "sift-filtered_tractogram.tck" # fmt:off flow.connect( [ (mrtrix_tracking, filter_tractogram, [("tracked", "in_tracks")]), (inputnode, filter_tractogram, [("DWI", "in_fod")]), ] ) # fmt:on if config.use_act: # fmt:off flow.connect( [ (inputnode, filter_tractogram, [("act_5tt_registered", "act_file")],), ] ) # fmt:on # fmt:off flow.connect( [(filter_tractogram, converter, [("out_tracks", "in_tracks")])] ) # fmt:on else: # fmt:off flow.connect( [ (mrtrix_tracking, converter, [("tracked", "in_tracks")]), ] ) # fmt:on # fmt:off flow.connect( [ (inputnode, mrtrix_tracking, [("DWI", "in_file")]), (inputnode, converter, [("wm_mask_resampled", "in_image")]), (converter, outputnode, [("out_tracks", "track_file")]), ] ) # fmt:on elif config.tracking_mode == "Probabilistic": mrtrix_tracking = pe.Node( interface=StreamlineTrack(), name="mrtrix_probabilistic_tracking" ) mrtrix_tracking.inputs.desired_number_of_tracks = ( config.desired_number_of_tracks ) # mrtrix_tracking.inputs.maximum_number_of_seeds = config.max_number_of_seeds mrtrix_tracking.inputs.maximum_tract_length = config.max_length mrtrix_tracking.inputs.minimum_tract_length = config.min_length mrtrix_tracking.inputs.step_size = config.step_size mrtrix_tracking.inputs.angle = config.angle mrtrix_tracking.inputs.cutoff_value = config.cutoff_value # mrtrix_tracking.inputs.args = '2>/dev/null' # if config.curvature >= 0.000001: # mrtrix_tracking.inputs.rk4 = True if config.SD: mrtrix_tracking.inputs.inputmodel = "iFOD2" else: mrtrix_tracking.inputs.inputmodel = "Tensor_Prob" converter = pe.Node(interface=Tck2Trk(), name="trackvis") converter.inputs.out_tracks = "converted.trk" if config.use_act: # fmt:off flow.connect( [ (inputnode, mrtrix_tracking, [("act_5tt_registered", "act_file")]), ] ) # fmt:on mrtrix_tracking.inputs.backtrack = config.backtrack mrtrix_tracking.inputs.crop_at_gmwmi = config.crop_at_gmwmi else: # fmt:off flow.connect( [ (inputnode, mrtrix_tracking, [("wm_mask_resampled", "mask_file")]), ] ) # fmt:on if config.seed_from_gmwmi: # fmt:off flow.connect( [ (inputnode, mrtrix_tracking, [("gmwmi_registered", "seed_gmwmi")]), ] ) # fmt:on else: # fmt:off flow.connect( [ (inputnode, mrtrix_tracking, [("wm_mask_resampled", "seed_file")]), ] ) # fmt:on if config.sift: filter_tractogram = pe.Node(interface=FilterTractogram(), name="sift_node") filter_tractogram.inputs.out_file = "sift-filtered_tractogram.tck" # fmt:off flow.connect( [ (mrtrix_tracking, filter_tractogram, [("tracked", "in_tracks")]), (inputnode, filter_tractogram, [("DWI", "in_fod")]), ] ) # fmt:on if config.use_act: # fmt:off flow.connect( [ (inputnode, filter_tractogram, [("act_5tt_registered", "act_file")],), ] ) # fmt:on # fmt:off flow.connect( [(filter_tractogram, converter, [("out_tracks", "in_tracks")])] ) # fmt:on else: # fmt:off flow.connect( [ (mrtrix_tracking, converter, [("tracked", "in_tracks")]), ] ) # fmt:on # fmt:off flow.connect( [ (inputnode, mrtrix_tracking, [("DWI", "in_file")]), (inputnode, converter, [("wm_mask_resampled", "in_image")]), (converter, outputnode, [("out_tracks", "track_file")]), ] ) # fmt:on return flow