# Copyright (C) 2009-2022, Ecole Polytechnique Federale de Lausanne (EPFL) and
# Hospital Center and University of Lausanne (UNIL-CHUV), Switzerland, and CMP3 contributors
# All rights reserved.
#
# This software is distributed under the open-source license Modified BSD.
"""The PyCartool module provides Nipype interfaces with Cartool using pycartool."""
# General imports
import os
import pickle
import nibabel
import numpy as np
import scipy.io as sio
# Nipype imports
from nipype.interfaces.base import (
BaseInterface, BaseInterfaceInputSpec,
TraitedSpec, traits
)
# EEG package imports
import mne
import pycartool as cart
class CartoolInverseSolutionROIExtractionInputSpec(BaseInterfaceInputSpec):
epochs_file = traits.File(
exists=True, desc='eeg * epochs in .set format', mandatory=True)
invsol_file = traits.File(
exists=True, desc='Inverse solution (.is file loaded with pycartool)', mandatory=True)
mapping_spi_rois_file = traits.File(
exists=True,
desc='Cartool-reconstructed sources / parcellation ROI mapping file, loaded with pickle',
mandatory=True
)
lamb = traits.Int(6, desc='Regularization weight')
svd_toi_begin = traits.Float(0, desc='Start TOI for SVD projection')
svd_toi_end = traits.Float(0.25, desc='End TOI for SVD projection')
out_roi_ts_fname_prefix = traits.Str(
exists=False,
desc="Output name prefix (no extension) for rois * time series files",
mandatory=True
)
class CartoolInverseSolutionROIExtractionOutputSpec(TraitedSpec):
roi_ts_npy_file = traits.File(desc="Path to output ROI time series file in .npy format")
roi_ts_mat_file = traits.File(desc="Path to output ROI time series file in .mat format")
class CartoolInverseSolutionROIExtraction(BaseInterface):
"""Use Pycartool to load inverse solutions estimated by Cartool.
Examples
--------
>>> from cmtklib.interfaces.pycartool import CartoolInverseSolutionROIExtraction
>>> cartool_inv_sol = CartoolInverseSolutionROIExtraction()
>>> cartool_inv_sol.inputs.epochs_file = 'sub-01_task-faces_desc-preproc_eeg.set'
>>> cartool_inv_sol.inputs.invsol_file = 'sub-01_eeg.LORETA.is'
>>> cartool_inv_sol.inputs.mapping_spi_rois_file = 'sub-01_atlas-L2018_res-scale1.pickle.rois'
>>> cartool_inv_sol.inputs.lamd = 6
>>> cartool_inv_sol.inputs.svd_toi_begin = 0
>>> cartool_inv_sol.inputs.svd_toi_end = 0.25
>>> cartool_inv_sol.inputs.out_roi_ts_fname_prefix = 'sub-01_task-faces_atlas-L2008_res-scale1_rec-LORETA_timeseries'
>>> cartool_inv_sol.run() # doctest: +SKIP
References
----------
- https://pycartool.readthedocs.io/en/latest/pycartool.io.html#module-pycartool.io.inverse_solution
"""
input_spec = CartoolInverseSolutionROIExtractionInputSpec
output_spec = CartoolInverseSolutionROIExtractionOutputSpec
def _run_interface(self, runtime):
svd_params = {
"toi_begin": self.inputs.svd_toi_begin,
"toi_end": self.inputs.svd_toi_end
}
roi_tcs = self.apply_inverse_epochs_cartool(
self.inputs.epochs_file,
self.inputs.invsol_file,
self.inputs.lamb,
self.inputs.mapping_spi_rois_file,
svd_params
)
np.save(self._gen_output_filename_roi_ts(extension=".npy"), roi_tcs)
sio.savemat(self._gen_output_filename_roi_ts(extension=".mat"), {"ts": roi_tcs})
return runtime
def _list_outputs(self):
outputs = self._outputs().get()
outputs["roi_ts_npy_file"] = self._gen_output_filename_roi_ts(extension=".npy")
outputs["roi_ts_mat_file"] = self._gen_output_filename_roi_ts(extension=".mat")
return outputs
def _gen_output_filename_roi_ts(self, extension):
# Return the absolute path of the output ROI time series file
return os.path.abspath(self.inputs.out_roi_ts_fname_prefix + extension)
class CreateSpiRoisMappingInputSpec(BaseInterfaceInputSpec):
roi_volume_file = traits.File(exits=True, desc="Parcellation file in nifti format", mandatory=True)
spi_file = traits.File(exits=True, desc="Cartool reconstructed sources file in spi format", mandatory=True)
out_mapping_spi_rois_fname = traits.Str(
desc="Name of output sources / parcellation ROI mapping file in .pickle.rois format",
mandatory=True
)
class CreateSpiRoisMappingOutputSpec(TraitedSpec):
mapping_spi_rois_file = traits.File(
desc="Path to output Cartool-reconstructed sources / parcellation ROI mapping file "
"in .pickle.rois format"
)
class CreateSpiRoisMapping(BaseInterface):
"""Create Cartool-reconstructed sources / parcellation ROI mapping file.
Examples
--------
>>> from cmtklib.interfaces.pycartool import CreateSpiRoisMapping
>>> createrois = CreateSpiRoisMapping()
>>> createrois.inputs.roi_volume_file = '/path/to/sub-01_atlas-L2018_res-scale1_dseg.nii.gz'
>>> createrois.inputs.spi_file = '/path/to/sub-01_eeg.spi'
>>> createrois.inputs.out_mapping_spi_rois_fname = 'sub-01_atlas-L2018_res-scale1_eeg.pickle.rois'
>>> createrois.run() # doctest: +SKIP
"""
input_spec = CreateSpiRoisMappingInputSpec
output_spec = CreateSpiRoisMappingOutputSpec
def _run_interface(self, runtime):
mapping_spi_roi = self._create_mapping_spi_rois(
self.inputs.roi_volume_file,
self.inputs.spi_file
)
with open(self._gen_output_filename_mapping_spi_rois(), "wb") as f:
pickle.dump(mapping_spi_roi, f)
return runtime
@staticmethod
def _create_mapping_spi_rois(roi_volume_file, spi_file):
# Load input parcellation and spi files
source = cart.source_space.read_spi(spi_file)
imdata = nibabel.load(roi_volume_file).get_fdata()
x, y, z = np.where(imdata)
center_brain = [np.mean(x), np.mean(y), np.mean(z)]
source.coordinates[:, 0] = -source.coordinates[:, 0]
source.coordinates = source.coordinates - source.coordinates.mean(0) + center_brain
xyz = source.get_coordinates()
xyz = np.round(xyz).astype(int)
num_spi = len(xyz)
# label positions
rois_file = np.zeros(num_spi)
x_roi, y_roi, z_roi = np.where((imdata > 0) & (imdata < np.unique(imdata)[-1]))
# For each coordinate
for spi_id, spi in enumerate(xyz):
distances = ((spi.reshape(-1, 1) - [x_roi, y_roi, z_roi]) ** 2).sum(0)
roi_id = np.argmin(distances)
rois_file[spi_id] = imdata[x_roi[roi_id], y_roi[roi_id], z_roi[roi_id]]
groups_of_indexes = [np.where(rois_file == roi)[0].tolist() for roi in np.unique(rois_file)]
names = [str(int(i)) for i in np.unique(rois_file) if i != 0]
mapping_spi_roi = cart.regions_of_interest.RegionsOfInterest(
names=names, groups_of_indexes=groups_of_indexes, source_space=source
)
return mapping_spi_roi
def _list_outputs(self):
outputs = self._outputs().get()
outputs["mapping_spi_rois_file"] = self._gen_output_filename_mapping_spi_rois()
return outputs
def _gen_output_filename_mapping_spi_rois(self):
# Return the absolute path of the output inverse operator file
return os.path.abspath(self.inputs.out_mapping_spi_rois_fname)