Source code for cmp.stages.diffusion.tracking

# Copyright (C) 2009-2020, Ecole Polytechnique Federale de Lausanne (EPFL) and
# Hospital Center and University of Lausanne (UNIL-CHUV), Switzerland
# All rights reserved.
#
#  This software is distributed under the open-source license Modified BSD.

"""Tracking methods and workflows of the diffusion stage."""

from traits.api import *

from nipype.interfaces.base import traits
import nipype.pipeline.engine as pe
import nipype.interfaces.utility as util
from nipype import logging

# import matplotlib.pyplot as plt

from cmtklib.interfaces.mrtrix3 import Erode, StreamlineTrack, FilterTractogram
from cmtklib.interfaces.dipy import DirectionGetterTractography, TensorInformedEudXTractography
from cmtklib.interfaces.misc import ExtractHeaderVoxel2WorldMatrix
from cmtklib.diffusion import Tck2Trk, Make_Mrtrix_Seeds

# from cmtklib.diffusion import filter_fibers

iflogger = logging.getLogger('nipype.interface')


[docs]class Dipy_tracking_config(HasTraits): """Class used to store Dipy diffusion reconstruction sub-workflow configuration parameters. Attributes ---------- imaging_model : traits.Str Diffusion imaging model (For example 'DTI') tracking_mode : traits.Str Type of local tractography algorithm (Can be "Deterministic" or "Probabilistic") SD : traits.Bool If `True`, inputs are coming from Constrained Spherical Deconvolution reconstruction number_of_seeds : traits.Int Number of seeds (Default: 1000) seed_density : traits.Float Number of seeds to place along each direction where a density of 2 is the same as [2, 2, 2] and will result in a total of 8 seeds per voxel (Default: 1.0) fa_thresh : traits.Float Fractional Anisotropy (FA) threshold (Default: 0.2) step_size : traits.traits.Float Tractography algorithm step size (Default: 0.5) max_angle : traits.Float Maximum streamline angle allowed (Default: 25.0) sh_order : traits.Int Order used for Constrained Spherical Deconvolution reconstruction (Default: 8) use_act : traits.Bool Use FAST for partial volume estimation and Anatomically-Constrained Tractography (ACT) tissue classifier (Default: False) seed_from_gmwmi : traits.Bool Seed from Grey Matter / White Matter interface (requires Anatomically-Constrained Tractography (ACT)) (Default: False) """ imaging_model = Str tracking_mode = Str SD = Bool number_of_seeds = Int(1000) seed_density = Float(1.0, desc='Number of seeds to place along each direction. ' 'A density of 2 is the same as [2, 2, 2] and will result in a total of 8 seeds per voxel.') fa_thresh = Float(0.2) step_size = traits.Float(0.5) max_angle = Float(25.0) sh_order = Int(8) use_act = traits.Bool(False, desc='Use FAST for partial volume estimation and Anatomically-Constrained Tractography (ACT) tissue classifier') seed_from_gmwmi = traits.Bool(False, desc="Seed from Grey Matter / White Matter interface (requires Anatomically-Constrained Tractography (ACT))") # fast_number_of_classes = Int(3) def _SD_changed(self, new): """Update ``curvature`` when ``SD`` is updated. Parameters ---------- new New value of ``SD`` """ if self.tracking_mode == "Deterministic" and not new: self.curvature = 2.0 elif self.tracking_mode == "Deterministic" and new: self.curvature = 0.0 elif self.tracking_mode == "Probabilistic": self.curvature = 1.0 def _tracking_mode_changed(self, new): """Update ``curvature``, ``use_act`` and ``seed_from_gmwmi`` when ``tracking_mode`` is updated. Parameters ---------- new New value of ``tracking_mode`` """ if new == "Deterministic" and not self.SD: self.curvature = 2.0 self.use_act = False self.seed_from_gmwmi = False elif new == "Deterministic" and self.SD: self.curvature = 0.0 self.use_act = False self.seed_from_gmwmi = False elif new == "Probabilistic": self.curvature = 1.0 def _curvature_changed(self, new): """Set ``curvature`` to 0 if ``curvature`` is updated to a value <= 0.000001. Parameters ---------- new New value of ``curvature`` """ if new <= 0.000001: self.curvature = 0.0 def _use_act_changed(self, new): """Set ``seed_from_gmwmi`` if ``use_act`` has been updated to `False`. Parameters ---------- new New value of ``use_act`` """ if new is False: self.seed_from_gmwmi = False
[docs]class MRtrix_tracking_config(HasTraits): """Class used to store Dipy diffusion reconstruction sub-workflow configuration parameters. Attributes ---------- tracking_mode : traits.Str Type of local tractography algorithm (Can be "Deterministic" or "Probabilistic") SD : traits.Bool If `True`, inputs are coming from Constrained Spherical Deconvolution reconstruction desired_number_of_tracks : traits.Int Desired number of output streamlines in the tractogram (Default: 1M) curvature = Float Maximum streamline curvature (Default: 2.0) min_length = Float Minimal streamline length (Default: 5) max_length = Float Maximal streamline length (Default: 500) angle : traits.Float Maximum streamline angle allowed (Default: 45.0) cutoff_value : traits.Float Cut-off value to terminate streamline (Default: 0.05) use_act : traits.Bool Use `5ttgen` for brain tissue types estimation and Anatomically-Constrained Tractography (ACT) tissue classifier (Default: False) seed_from_gmwmi : traits.Bool Seed from Grey Matter / White Matter interface (requires Anatomically-Constrained Tractography (ACT)) (Default: False) crop_at_gmwmi : traits.Bool Crop streamline endpoints more precisely as they cross the GM-WM interface (requires Anatomically-Constrained Tractography (ACT)) (Default: True) backtrack : traits.Bool Allow tracks to be truncated (requires Anatomically-Constrained Tractography (ACT)) (Default: True) sift : traits.Bool Filter tractogram using mrtrix3 SIFT (Default: True) """ tracking_mode = Str SD = Bool desired_number_of_tracks = Int(1000000) # max_number_of_seeds = Int(1000000000) curvature = Float(2.0) step_size = Float(0.5) min_length = Float(5) max_length = Float(500) angle = Float(45) cutoff_value = Float(0.05) use_act = traits.Bool( True, desc="Anatomically-Constrained Tractography (ACT) based on Freesurfer parcellation") seed_from_gmwmi = traits.Bool(False, desc="Seed from Grey Matter / White Matter interface (requires Anatomically-Constrained Tractography (ACT))") crop_at_gmwmi = traits.Bool(True, desc='Crop streamline endpoints more precisely as they cross the GM-WM interface ' '(requires Anatomically-Constrained Tractography (ACT))') backtrack = traits.Bool(True, desc="Allow tracks to be truncated (requires Anatomically-Constrained Tractography (ACT))") sift = traits.Bool(True, desc="Filter tractogram using mrtrix3 SIFT") def _SD_changed(self, new): """Update ``curvature`` when ``SD`` is updated. Parameters ---------- new New value of ``SD`` """ if self.tracking_mode == "Deterministic" and not new: self.curvature = 2.0 elif self.tracking_mode == "Deterministic" and new: self.curvature = 0.0 elif self.tracking_mode == "Probabilistic": self.curvature = 1.0 def _use_act_changed(self, new): if new is False: self.crop_at_gmwmi = False self.seed_from_gmwmi = False self.backtrack = False def _tracking_mode_changed(self, new): """Update ``curvature`` when ``tracking_mode`` is updated. Parameters ---------- new New value of ``tracking_mode`` """ if new == "Deterministic" and not self.SD: self.curvature = 2.0 elif new == "Deterministic" and self.SD: self.curvature = 0.0 elif new == "Probabilistic": self.curvature = 1.0 def _curvature_changed(self, new): """Set ``curvature`` to 0 if ``curvature`` is updated to a value <= 0.000001. Parameters ---------- new New value of ``curvature`` """ if new <= 0.000001: self.curvature = 0.0
[docs]def create_dipy_tracking_flow(config): """Create the tractography sub-workflow of the `DiffusionStage` using Dipy. Parameters ---------- config : Dipy_tracking_config Sub-workflow configuration object Returns ------- flow : nipype.pipeline.engine.Workflow Built tractography sub-workflow """ flow = pe.Workflow(name="tracking") # inputnode inputnode = pe.Node(interface=util.IdentityInterface( fields=['DWI', 'fod_file', 'FA', 'T1', 'partial_volumes', 'wm_mask_resampled', 'gmwmi_file', 'gm_registered', 'bvals', 'bvecs', 'model']), name='inputnode') # outputnode outputnode = pe.Node(interface=util.IdentityInterface( fields=["track_file"]), name='outputnode') if not config.SD and config.imaging_model != 'DSI': # If tensor fitting was used dipy_tracking = pe.Node( interface=TensorInformedEudXTractography(), name='dipy_dtieudx_tracking') dipy_tracking.inputs.num_seeds = config.number_of_seeds dipy_tracking.inputs.fa_thresh = config.fa_thresh dipy_tracking.inputs.max_angle = config.max_angle dipy_tracking.inputs.step_size = config.step_size flow.connect([ # (dipy_seeds,dipy_tracking,[('seed_files','seed_file')]), (inputnode, dipy_tracking, [('wm_mask_resampled', 'seed_mask')]), (inputnode, dipy_tracking, [('DWI', 'in_file')]), (inputnode, dipy_tracking, [('model', 'in_model')]), (inputnode, dipy_tracking, [('FA', 'in_fa')]), (inputnode, dipy_tracking, [ ('wm_mask_resampled', 'tracking_mask')]), (dipy_tracking, outputnode, [('tracks', 'track_file')]) ]) else: # If CSD was used if config.tracking_mode == 'Deterministic': dipy_tracking = pe.Node( interface=DirectionGetterTractography(), name='dipy_deterministic_tracking') dipy_tracking.inputs.algo = 'deterministic' dipy_tracking.inputs.num_seeds = config.number_of_seeds dipy_tracking.inputs.fa_thresh = config.fa_thresh dipy_tracking.inputs.max_angle = config.max_angle dipy_tracking.inputs.step_size = config.step_size dipy_tracking.inputs.use_act = config.use_act dipy_tracking.inputs.use_act = config.seed_from_gmwmi dipy_tracking.inputs.seed_density = config.seed_density # dipy_tracking.inputs.fast_number_of_classes = config.fast_number_of_classes if config.imaging_model == 'DSI': dipy_tracking.inputs.recon_model = 'SHORE' else: dipy_tracking.inputs.recon_model = 'CSD' dipy_tracking.inputs.recon_order = config.sh_order if config.imaging_model == 'DSI': flow.connect([ (inputnode, dipy_tracking, [('fod_file', 'fod_file')]), ]) flow.connect([ # (dipy_seeds,dipy_tracking,[('seed_files','seed_file')]), (inputnode, dipy_tracking, [('DWI', 'in_file')]), (inputnode, dipy_tracking, [ ('partial_volumes', 'in_partial_volume_files')]), (inputnode, dipy_tracking, [('model', 'in_model')]), (inputnode, dipy_tracking, [('FA', 'in_fa')]), (inputnode, dipy_tracking, [ ('wm_mask_resampled', 'seed_mask')]), (inputnode, dipy_tracking, [('gmwmi_file', 'gmwmi_file')]), (inputnode, dipy_tracking, [ ('wm_mask_resampled', 'tracking_mask')]), (dipy_tracking, outputnode, [('tracks', 'track_file')]) ]) elif config.tracking_mode == 'Probabilistic': dipy_tracking = pe.Node( interface=DirectionGetterTractography(), name='dipy_probabilistic_tracking') dipy_tracking.inputs.algo = 'probabilistic' dipy_tracking.inputs.num_seeds = config.number_of_seeds dipy_tracking.inputs.fa_thresh = config.fa_thresh dipy_tracking.inputs.max_angle = config.max_angle dipy_tracking.inputs.step_size = config.step_size dipy_tracking.inputs.use_act = config.use_act dipy_tracking.inputs.seed_from_gmwmi = config.seed_from_gmwmi dipy_tracking.inputs.seed_density = config.seed_density # dipy_tracking.inputs.fast_number_of_classes = config.fast_number_of_classes if config.imaging_model == 'DSI': dipy_tracking.inputs.recon_model = 'SHORE' else: dipy_tracking.inputs.recon_model = 'CSD' dipy_tracking.inputs.recon_order = config.sh_order # flow.connect([ # (inputnode,dipy_tracking,[("bvals","bvals")]), # (inputnode,dipy_tracking,[("bvecs","bvecs")]) # ]) if config.imaging_model == 'DSI': flow.connect([ (inputnode, dipy_tracking, [('fod_file', 'fod_file')]), ]) flow.connect([ # (dipy_seeds,dipy_tracking,[('seed_files','seed_file')]), (inputnode, dipy_tracking, [('DWI', 'in_file')]), (inputnode, dipy_tracking, [ ('partial_volumes', 'in_partial_volume_files')]), (inputnode, dipy_tracking, [('model', 'in_model')]), (inputnode, dipy_tracking, [('FA', 'in_fa')]), (inputnode, dipy_tracking, [ ('wm_mask_resampled', 'seed_mask')]), (inputnode, dipy_tracking, [('gmwmi_file', 'gmwmi_file')]), (inputnode, dipy_tracking, [ ('wm_mask_resampled', 'tracking_mask')]), (dipy_tracking, outputnode, [('tracks', 'track_file')]) ]) return flow
[docs]def get_freesurfer_parcellation(roi_files): """Return the first file in the list of parcellation files Parameters ---------- roi_files : list of traits.File List of parcellation files """ print("%s" % roi_files[0]) return roi_files[0]
[docs]def create_mrtrix_tracking_flow(config): """Create the tractography sub-workflow of the `DiffusionStage` using MRtrix3. Parameters ---------- config : MRtrix_tracking_config Sub-workflow configuration object Returns ------- flow : nipype.pipeline.engine.Workflow Built tractography sub-workflow """ flow = pe.Workflow(name="tracking") # inputnode inputnode = pe.Node(interface=util.IdentityInterface( fields=['DWI', 'wm_mask_resampled', 'gm_registered', 'act_5tt_registered', 'gmwmi_registered', 'grad']), name='inputnode') # outputnode outputnode = pe.Node(interface=util.IdentityInterface( fields=["track_file"]), name='outputnode') # Compute single fiber voxel mask wm_erode = pe.Node(interface=Erode(out_filename="wm_mask_resampled.nii.gz"), name='wm_erode') wm_erode.inputs.number_of_passes = 3 wm_erode.inputs.filtertype = 'erode' flow.connect([ (inputnode, wm_erode, [("wm_mask_resampled", 'in_file')]) ]) if config.tracking_mode == 'Deterministic': mrtrix_seeds = pe.Node( interface=Make_Mrtrix_Seeds(), name='mrtrix_seeds') mrtrix_tracking = pe.Node( interface=StreamlineTrack(), name='mrtrix_deterministic_tracking') mrtrix_tracking.inputs.desired_number_of_tracks = config.desired_number_of_tracks # mrtrix_tracking.inputs.maximum_number_of_seeds = config.max_number_of_seeds mrtrix_tracking.inputs.maximum_tract_length = config.max_length mrtrix_tracking.inputs.minimum_tract_length = config.min_length mrtrix_tracking.inputs.step_size = config.step_size mrtrix_tracking.inputs.angle = config.angle mrtrix_tracking.inputs.cutoff_value = config.cutoff_value # mrtrix_tracking.inputs.args = '2>/dev/null' if config.curvature >= 0.000001: mrtrix_tracking.inputs.rk4 = True mrtrix_tracking.inputs.inputmodel = 'SD_Stream' else: mrtrix_tracking.inputs.inputmodel = 'SD_Stream' flow.connect([ (inputnode, mrtrix_tracking, [("grad", "gradient_encoding_file")]) ]) voxel2WorldMatrixExtracter = pe.Node(interface=ExtractHeaderVoxel2WorldMatrix(), name='voxel2WorldMatrixExtracter') flow.connect([ (inputnode, voxel2WorldMatrixExtracter, [("wm_mask_resampled", "in_file")]) ]) flow.connect([ (inputnode, mrtrix_seeds, [('wm_mask_resampled', 'WM_file')]), (inputnode, mrtrix_seeds, [('gm_registered', 'ROI_files')]), ]) if config.use_act: flow.connect([ (inputnode, mrtrix_tracking, [ ('act_5tt_registered', 'act_file')]), ]) mrtrix_tracking.inputs.backtrack = config.backtrack mrtrix_tracking.inputs.crop_at_gmwmi = config.crop_at_gmwmi else: flow.connect([ (inputnode, mrtrix_tracking, [ ('wm_mask_resampled', 'mask_file')]), ]) if config.seed_from_gmwmi: flow.connect([ (inputnode, mrtrix_tracking, [ ('gmwmi_registered', 'seed_gmwmi')]), ]) else: flow.connect([ (inputnode, mrtrix_tracking, [ ('wm_mask_resampled', 'seed_file')]), ]) # converter = pe.Node(interface=mrtrix.MRTrix2TrackVis(),name="trackvis") converter = pe.Node(interface=Tck2Trk(), name='trackvis') converter.inputs.out_tracks = 'converted.trk' if config.sift: filter_tractogram = pe.Node(interface=FilterTractogram(), name='sift_node') filter_tractogram.inputs.out_file = 'sift-filtered_tractogram.tck' flow.connect([ (mrtrix_tracking, filter_tractogram, [('tracked', 'in_tracks')]), (inputnode, filter_tractogram, [('DWI', 'in_fod')]) ]) if config.use_act: flow.connect([ (inputnode, filter_tractogram, [ ('act_5tt_registered', 'act_file')]), ]) flow.connect([ (filter_tractogram, converter, [('out_tracks', 'in_tracks')]) ]) else: flow.connect([ (mrtrix_tracking, converter, [('tracked', 'in_tracks')]), ]) flow.connect([ (inputnode, mrtrix_tracking, [('DWI', 'in_file')]), (inputnode, converter, [('wm_mask_resampled', 'in_image')]), (converter, outputnode, [('out_tracks', 'track_file')]) ]) elif config.tracking_mode == 'Probabilistic': mrtrix_seeds = pe.Node( interface=Make_Mrtrix_Seeds(), name='mrtrix_seeds') mrtrix_tracking = pe.Node( interface=StreamlineTrack(), name='mrtrix_probabilistic_tracking') mrtrix_tracking.inputs.desired_number_of_tracks = config.desired_number_of_tracks # mrtrix_tracking.inputs.maximum_number_of_seeds = config.max_number_of_seeds mrtrix_tracking.inputs.maximum_tract_length = config.max_length mrtrix_tracking.inputs.minimum_tract_length = config.min_length mrtrix_tracking.inputs.step_size = config.step_size mrtrix_tracking.inputs.angle = config.angle mrtrix_tracking.inputs.cutoff_value = config.cutoff_value # mrtrix_tracking.inputs.args = '2>/dev/null' # if config.curvature >= 0.000001: # mrtrix_tracking.inputs.rk4 = True if config.SD: mrtrix_tracking.inputs.inputmodel = 'iFOD2' else: mrtrix_tracking.inputs.inputmodel = 'Tensor_Prob' # converter = pe.MapNode(interface=mrtrix.MRTrix2TrackVis(),iterfield=['in_file'],name='trackvis') converter = pe.Node(interface=Tck2Trk(), name='trackvis') converter.inputs.out_tracks = 'converted.trk' # orientation_matcher = pe.Node(interface=match_orientation(), name="orient_matcher") flow.connect([ (inputnode, mrtrix_seeds, [('wm_mask_resampled', 'WM_file')]), (inputnode, mrtrix_seeds, [('gm_registered', 'ROI_files')]), ]) if config.use_act: flow.connect([ (inputnode, mrtrix_tracking, [ ('act_5tt_registered', 'act_file')]), ]) mrtrix_tracking.inputs.backtrack = config.backtrack mrtrix_tracking.inputs.crop_at_gmwmi = config.crop_at_gmwmi else: flow.connect([ (inputnode, mrtrix_tracking, [ ('wm_mask_resampled', 'mask_file')]), ]) if config.seed_from_gmwmi: flow.connect([ (inputnode, mrtrix_tracking, [ ('gmwmi_registered', 'seed_gmwmi')]), ]) else: flow.connect([ (inputnode, mrtrix_tracking, [ ('wm_mask_resampled', 'seed_file')]), ]) if config.sift: filter_tractogram = pe.Node(interface=FilterTractogram(), name='sift_node') filter_tractogram.inputs.out_file = 'sift-filtered_tractogram.tck' flow.connect([ (mrtrix_tracking, filter_tractogram, [('tracked', 'in_tracks')]), (inputnode, filter_tractogram, [('DWI', 'in_fod')]) ]) if config.use_act: flow.connect([ (inputnode, filter_tractogram, [('act_5tt_registered', 'act_file')]), ]) flow.connect([ (filter_tractogram, converter, [('out_tracks', 'in_tracks')]) ]) else: flow.connect([ (mrtrix_tracking, converter, [('tracked', 'in_tracks')]), ]) flow.connect([ (inputnode, mrtrix_tracking, [('DWI', 'in_file')]), (inputnode, converter, [('wm_mask_resampled', 'in_image')]), (converter, outputnode, [('out_tracks', 'track_file')]) ]) return flow