Skip to content

Commit

Permalink
♻️ Refactor bbreg guardrails
Browse files Browse the repository at this point in the history
  • Loading branch information
shnizzedy committed Oct 25, 2022
1 parent e28b1a7 commit be69c25
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 127 deletions.
30 changes: 30 additions & 0 deletions CPAC/registration/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,36 @@
Registration: {'reference': 'reference', 'registered': 'out_file'}}


def connect_retries(wf, nodes, connections):
"""Function to generalize making the same connections to try and
retry nodes.
For each 3-tuple (``conn``) in ``connections``, will do
.. code-block:: Python
wf.connect(conn[0], node, conn[1], conn[2])
for each node in nodes
Parameters
----------
wf : Workflow
nodes : iterable of Nodes
connections : iterable of 3-tuples of (Node, str or tuple, str)
Returns
-------
Workflow
"""
for node in nodes:
for conn in connections:
wf.connect(conn[0], node, conn[1], conn[2])
return wf


def guardrail_selection(wf: 'Workflow', node1: 'Node', node2: 'Node',
output_key: str = 'registered',
guardrail_node: 'Node' = None) -> Node:
Expand Down
213 changes: 86 additions & 127 deletions CPAC/registration/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@
"""Registration functions"""
# pylint: disable=too-many-lines,ungrouped-imports,wrong-import-order
# TODO: replace Tuple with tuple, Union with |, once Python >= 3.9, 3.10
from sqlite3 import connect
from typing import Optional, Tuple, Union
from CPAC.pipeline import nipype_pipeline_engine as pe
from nipype.interfaces import afni, ants, c3, fsl, utility as util
from nipype.interfaces.afni import utils as afni_utils
from CPAC.anat_preproc.lesion_preproc import create_lesion_preproc
from CPAC.func_preproc.utils import chunk_ts, split_ts_chunks
from CPAC.pipeline.random_state.seed import increment_seed
from CPAC.registration.guardrails import guardrail_selection, \
from CPAC.registration.guardrails import connect_retries, \
guardrail_selection, \
registration_guardrail_node
from CPAC.registration.utils import seperate_warps_list, \
check_transforms, \
Expand Down Expand Up @@ -974,7 +976,6 @@ def create_bbregister_func_to_anat(phase_diff_distcor=False,
outputspec.anat_func : string (nifti file)
Functional data in anatomical space
"""
from CPAC.pipeline.random_state.seed import seed_plus_1
register_bbregister_func_to_anat = pe.Workflow(name=name)

inputspec = pe.Node(util.IdentityInterface(fields=['func',
Expand All @@ -999,14 +1000,10 @@ def create_bbregister_func_to_anat(phase_diff_distcor=False,

wm_bb_mask = pe.Node(interface=fsl.ImageMaths(),
name='wm_bb_mask')
if retry:
seed = seed_plus_1()
wm_bb_mask.seed = seed

register_bbregister_func_to_anat.connect(
inputspec, 'bbr_wm_mask_args',
wm_bb_mask, 'op_string')

register_bbregister_func_to_anat.connect(inputspec,
'anat_wm_segmentation',
wm_bb_mask, 'in_file')
Expand All @@ -1017,50 +1014,52 @@ def bbreg_args(bbreg_target):
bbreg_func_to_anat = pe.Node(interface=fsl.FLIRT(),
name='bbreg_func_to_anat')
bbreg_func_to_anat.inputs.dof = 6
guardrail_bbreg_func_to_anat = registration_guardrail_node(
f'{bbreg_func_to_anat.name}_guardrail')
nodes = [bbreg_func_to_anat]
guardrails = [guardrail_bbreg_func_to_anat]
if retry:
bbreg_func_to_anat.seed = seed

register_bbregister_func_to_anat.connect(
inputspec, 'bbr_schedule',
bbreg_func_to_anat, 'schedule')
register_bbregister_func_to_anat.connect(
wm_bb_mask, ('out_file', bbreg_args),
bbreg_func_to_anat, 'args')
register_bbregister_func_to_anat.connect(
inputspec, 'func',
bbreg_func_to_anat, 'in_file')
register_bbregister_func_to_anat.connect(
inputspec, 'anat',
bbreg_func_to_anat, 'reference')
register_bbregister_func_to_anat.connect(
inputspec, 'linear_reg_matrix',
bbreg_func_to_anat, 'in_matrix_file')

retry_bbreg_func_to_anat = increment_seed(bbreg_func_to_anat.clone(
f'retry_{bbreg_func_to_anat.name}'))
guardrail_retry_bbreg_func_to_anat = registration_guardrail_node(
f'{retry_bbreg_func_to_anat.name}_guardrail')
nodes += [retry_bbreg_func_to_anat]
guardrails += [guardrail_retry_bbreg_func_to_anat]
register_bbregister_func_to_anat = connect_retries(
register_bbregister_func_to_anat, nodes, [
(inputspec, 'bbr_schedule', 'schedule'),
(wm_bb_mask, ('out_file', bbreg_args), 'args'),
(inputspec, 'func', 'in_file'),
(inputspec, 'anat', 'reference'),
(inputspec, 'linear_reg_matrix', 'in_matrix_file')])
if phase_diff_distcor:
register_bbregister_func_to_anat = connect_retries(
register_bbregister_func_to_anat, nodes, [
(inputNode_pedir, ('pedir', convert_pedir), 'pedir'),
(inputspec, 'fieldmap', 'fieldmap'),
(inputspec, 'fieldmapmask', 'fieldmapmask'),
(inputNode_echospacing, 'echospacing', 'echospacing')])
for i, node in enumerate(nodes):
register_bbregister_func_to_anat.connect(inputspec, 'anat',
guardrails[i], 'reference')
register_bbregister_func_to_anat.connect(node, 'out_file',
guardrails[i], 'registered')
if retry:
# pylint: disable=no-value-for-parameter
outfile = guardrail_selection(register_bbregister_func_to_anat,
*guardrails)
matrix = guardrail_selection(register_bbregister_func_to_anat, *nodes,
'out_matrix_file', guardrails[0])
register_bbregister_func_to_anat.connect(
inputNode_pedir, ('pedir', convert_pedir),
bbreg_func_to_anat, 'pedir')
register_bbregister_func_to_anat.connect(
inputspec, 'fieldmap',
bbreg_func_to_anat, 'fieldmap')
register_bbregister_func_to_anat.connect(
inputspec, 'fieldmapmask',
bbreg_func_to_anat, 'fieldmapmask')
matrix, 'out', outputspec, 'func_to_anat_linear_xfm')
register_bbregister_func_to_anat.connect(outfile, 'out',
outputspec, 'anat_func')
else:
register_bbregister_func_to_anat.connect(
inputNode_echospacing, 'echospacing',
bbreg_func_to_anat, 'echospacing')

guardrail = registration_guardrail_node(name=f'{name}_guardrail')
register_bbregister_func_to_anat.connect(inputspec, 'anat',
guardrail, 'reference')
register_bbregister_func_to_anat.connect(
bbreg_func_to_anat, 'out_matrix_file',
outputspec, 'func_to_anat_linear_xfm')
register_bbregister_func_to_anat.connect(bbreg_func_to_anat, 'out_file',
guardrail, 'registered')
register_bbregister_func_to_anat.connect(guardrail, 'registered',
outputspec, 'anat_func')

bbreg_func_to_anat, 'out_matrix_file',
outputspec, 'func_to_anat_linear_xfm')
register_bbregister_func_to_anat.connect(guardrails[0], 'registered',
outputspec, 'anat_func')
return register_bbregister_func_to_anat


Expand Down Expand Up @@ -2871,9 +2870,11 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None):
(func_to_anat, 'outputspec.func_to_anat_linear_xfm_nobbreg')
}

if opt in [True, "fallback"]:
if opt in [True, 'fallback']:
fallback = opt == 'fallback'
func_to_anat_bbreg = create_bbregister_func_to_anat(
diff_complete, f'func_to_anat_bbreg{bbreg_status}_{pipe_num}')
diff_complete, f'func_to_anat_bbreg{bbreg_status}_{pipe_num}',
opt is True)
func_to_anat_bbreg.inputs.inputspec.bbr_schedule = \
cfg.registration_workflows['functional_registration'][
'coregistration']['boundary_based_registration'][
Expand All @@ -2882,56 +2883,31 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None):
cfg.registration_workflows['functional_registration'][
'coregistration']['boundary_based_registration'][
'bbr_wm_mask_args']
bbreg_guardrail = registration_guardrail_node(
f'bbreg{bbreg_status}_guardrail_{pipe_num}')
if opt is True:
# Retry once on failure
retry_node = create_bbregister_func_to_anat(diff_complete,
f'retry_func_to_anat_'
f'bbreg_{pipe_num}',
retry=True)
retry_node.inputs.inputspec.bbr_schedule = cfg[
'registration_workflows', 'functional_registration',
'coregistration', 'boundary_based_registration',
'bbr_schedule']
retry_node.inputs.inputspec.bbr_wm_mask_args = cfg[
'registration_workflows', 'functional_registration',
'coregistration', 'boundary_based_registration',
'bbr_wm_mask_args']
retry_guardrail = registration_guardrail_node(
f'retry_bbreg_guardrail_{pipe_num}')
if fallback:
bbreg_guardrail = registration_guardrail_node(
f'bbreg{bbreg_status}_guardrail_{pipe_num}')

node, out = strat_pool.get_data('desc-reginput_bold')
wf.connect(node, out, func_to_anat_bbreg, 'inputspec.func')
if opt is True:
wf.connect(node, out, retry_node, 'inputspec.func')

if cfg.registration_workflows['functional_registration'][
'coregistration']['boundary_based_registration'][
'reference'] == 'whole-head':
node, out = strat_pool.get_data('T1w')
wf.connect(node, out, func_to_anat_bbreg, 'inputspec.anat')
wf.connect(node, out, bbreg_guardrail, 'reference')
if opt is True:
wf.connect(node, out, retry_node, 'inputspec.anat')
wf.connect(node, out, retry_guardrail, 'reference')
if fallback:
wf.connect(node, out, bbreg_guardrail, 'reference')

elif cfg.registration_workflows['functional_registration'][
'coregistration']['boundary_based_registration'][
'reference'] == 'brain':
node, out = strat_pool.get_data('desc-brain_T1w')
wf.connect(node, out, func_to_anat_bbreg, 'inputspec.anat')
wf.connect(node, out, bbreg_guardrail, 'reference')
if opt is True:
wf.connect(node, out, retry_node, 'inputspec.anat')
wf.connect(node, out, retry_guardrail, 'reference')
if fallback:
wf.connect(node, out, bbreg_guardrail, 'reference')

wf.connect(func_to_anat, 'outputspec.func_to_anat_linear_xfm_nobbreg',
func_to_anat_bbreg, 'inputspec.linear_reg_matrix')
if opt is True:
wf.connect(func_to_anat,
'outputspec.func_to_anat_linear_xfm_nobbreg',
retry_node, 'inputspec.linear_reg_matrix')

if strat_pool.check_rpool('space-bold_label-WM_mask'):
node, out = strat_pool.get_data(["space-bold_label-WM_mask"])
Expand All @@ -2948,76 +2924,59 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None):
"label-WM_mask"])
wf.connect(node, out,
func_to_anat_bbreg, 'inputspec.anat_wm_segmentation')
if opt is True:
wf.connect(node, out, retry_node, 'inputspec.anat_wm_segmentation')

if diff_complete:
node, out = strat_pool.get_data('effectiveEchoSpacing')
wf.connect(node, out,
func_to_anat_bbreg, 'echospacing_input.echospacing')
if opt is True:
wf.connect(node, out,
retry_node, 'echospacing_input.echospacing')

node, out = strat_pool.get_data('diffphase-pedir')
wf.connect(node, out, func_to_anat_bbreg, 'pedir_input.pedir')
if opt is True:
wf.connect(node, out, retry_node, 'pedir_input.pedir')

node, out = strat_pool.get_data("despiked-fieldmap")
wf.connect(node, out, func_to_anat_bbreg, 'inputspec.fieldmap')
if opt is True:
wf.connect(node, out, retry_node, 'inputspec.fieldmap')

node, out = strat_pool.get_data("fieldmap-mask")
wf.connect(node, out,
func_to_anat_bbreg, 'inputspec.fieldmapmask')
if opt is True:
wf.connect(node, out, retry_node, 'inputspec.fieldmapmask')

wf.connect(func_to_anat_bbreg, 'outputspec.anat_func',
bbreg_guardrail, 'registered')
if opt is True:
wf.connect(func_to_anat_bbreg, 'outputspec.anat_func',
retry_guardrail, 'registered')

mean_bolds = pe.Node(util.Merge(2), run_without_submitting=True,
name=f'bbreg_mean_bold_choices_{pipe_num}')
xfms = pe.Node(util.Merge(2), run_without_submitting=True,
name=f'bbreg_xfm_choices_{pipe_num}')
fallback_mean_bolds = pe.Node(util.Select(),
run_without_submitting=True,
name=f'bbreg_choose_mean_bold_{pipe_num}'
)
fallback_xfms = pe.Node(util.Select(), run_without_submitting=True,
name=f'bbreg_choose_xfm_{pipe_num}')
if opt is True:
wf.connect([
(bbreg_guardrail, mean_bolds, [('registered', 'in1')]),
(retry_guardrail, mean_bolds, [('registered', 'in2')]),
(func_to_anat_bbreg, xfms, [
('outputspec.func_to_anat_linear_xfm', 'in1')]),
(retry_node, xfms, [
('outputspec.func_to_anat_linear_xfm', 'in2')])])
else:
if fallback:
# Fall back to no-BBReg
mean_bolds = pe.Node(util.Merge(2), run_without_submitting=True,
name=f'bbreg_mean_bold_choices_{pipe_num}')
xfms = pe.Node(util.Merge(2), run_without_submitting=True,
name=f'bbreg_xfm_choices_{pipe_num}')
fallback_mean_bolds = pe.Node(util.Select(),
run_without_submitting=True,
name='bbreg_choose_mean_bold_'
f'{pipe_num}')
fallback_xfms = pe.Node(util.Select(), run_without_submitting=True,
name=f'bbreg_choose_xfm_{pipe_num}')
wf.connect([
(func_to_anat_bbreg, bbreg_guardrail, [
('outputspec.anat_func', 'registered')]),
(bbreg_guardrail, mean_bolds, [('registered', 'in1')]),
(func_to_anat, mean_bolds, [('outputspec.anat_func_nobbreg',
'in2')]),
(func_to_anat_bbreg, xfms, [
('outputspec.func_to_anat_linear_xfm', 'in1')]),
(func_to_anat, xfms, [
('outputspec.func_to_anat_linear_xfm_nobbreg', 'in2')])])
wf.connect([
(mean_bolds, fallback_mean_bolds, [('out', 'inlist')]),
(xfms, fallback_xfms, [('out', 'inlist')]),
(bbreg_guardrail, fallback_mean_bolds, [('failed_qc', 'index')]),
(bbreg_guardrail, fallback_xfms, [('failed_qc', 'index')])])
outputs = {
'space-T1w_desc-mean_bold': (fallback_mean_bolds, 'out'),
'from-bold_to-T1w_mode-image_desc-linear_xfm': (fallback_xfms,
'out')}
('outputspec.func_to_anat_linear_xfm_nobbreg', 'in2')]),
(mean_bolds, fallback_mean_bolds, [('out', 'inlist')]),
(xfms, fallback_xfms, [('out', 'inlist')]),
(bbreg_guardrail, fallback_mean_bolds, [
('failed_qc', 'index')]),
(bbreg_guardrail, fallback_xfms, [('failed_qc', 'index')])])
outputs = {
'space-T1w_desc-mean_bold': (fallback_mean_bolds, 'out'),
'from-bold_to-T1w_mode-image_desc-linear_xfm': (fallback_xfms,
'out')}
else:
outputs = {
'space-T1w_desc-mean_bold': (func_to_anat_bbreg,
'outputspec.anat_func'),
'from-bold_to-T1w_mode-image_desc-linear_xfm': (
func_to_anat_bbreg,
'outputspec.func_to_anat_linear_xfm')}
return wf, outputs


Expand Down

0 comments on commit be69c25

Please sign in to comment.