Skip to content

Commit

Permalink
♻️ Rewire guardrail for anat_mni_ants_register
Browse files Browse the repository at this point in the history
  • Loading branch information
shnizzedy committed Oct 22, 2022
1 parent 88b5589 commit c83ed62
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 57 deletions.
32 changes: 31 additions & 1 deletion CPAC/registration/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from copy import deepcopy
from nipype.interfaces.ants import Registration
from nipype.interfaces.fsl import FLIRT
from nipype.interfaces.utility import Function
from nipype.interfaces.utility import Function, Merge, Select
from CPAC.pipeline.nipype_pipeline_engine import Node, Workflow
# from CPAC.pipeline.nipype_pipeline_engine.utils import connect_from_spec
from CPAC.qc import qc_masks, registration_guardrail_thresholds
Expand Down Expand Up @@ -56,6 +56,36 @@ def __init__(self, *args, metric=None, value=None, threshold=None,
super().__init__(msg, *args, **kwargs)


def guardrail_selection(wf: 'Workflow', node1: 'Node', node2: 'Node',
) -> Node:
"""Generate requisite Nodes for choosing a path through the graph
with retries
Parameters
----------
wf : Workflow
node1, node2 : Node
try guardrail, retry guardrail
Returns
-------
select : Node
"""
# pylint: disable=redefined-outer-name,reimported,unused-import
from CPAC.pipeline.nipype_pipeline_engine import Node, Workflow
name = node1.name
choices = Node(Merge(2), run_without_submitting=True,
name=f'{name}_choices')
select = Node(Select(), run_without_submitting=True,
name=f'choose_{name}')
wf.connect([(node1, choices, [('registered', 'in1')]),
(node2, choices, [('registered', 'in2')]),
(choices, select, [('out', 'inlist')]),
(node1, select, [('failed_qc', 'index')])])
return select


def registration_guardrail(registered: str, reference: str,
retry: bool = False, retry_num: int = 0
) -> Tuple[str, int]:
Expand Down
115 changes: 59 additions & 56 deletions CPAC/registration/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from CPAC.pipeline import nipype_pipeline_engine as pe
from nipype.interfaces import afni, ants, c3, fsl, utility as util
from nipype.interfaces.afni import utils as afni_utils
from nipype.interfaces.utility import Merge, Select
from CPAC.anat_preproc.lesion_preproc import create_lesion_preproc
from CPAC.func_preproc.utils import chunk_ts, split_ts_chunks
from CPAC.registration.guardrails import registration_guardrail_node
from CPAC.registration.guardrails import guardrail_selection, \
registration_guardrail_node
from CPAC.registration.utils import seperate_warps_list, \
check_transforms, \
generate_inverse_transform_flags, \
Expand Down Expand Up @@ -1174,20 +1174,14 @@ def create_wf_calculate_ants_warp(name='create_wf_calculate_ants_warp',
.. image::
:width: 500
'''

from CPAC.registration.guardrails import retry_hardcoded_reg
calc_ants_warp_wf = pe.Workflow(name=name)

inputspec = pe.Node(util.IdentityInterface(
fields=['moving_brain',
'reference_brain',
'moving_skull',
'reference_skull',
'reference_mask',
'moving_mask',
'fixed_image_mask',
'ants_para',
'interp']),
name='inputspec')
warp_inputs = ['moving_brain', 'reference_brain', 'moving_skull',
'reference_skull', 'ants_para', 'moving_mask',
'reference_mask', 'fixed_image_mask', 'interp']
inputspec = pe.Node(util.IdentityInterface(fields=warp_inputs),
name='inputspec')

outputspec = pe.Node(util.IdentityInterface(
fields=['ants_initial_xfm',
Expand All @@ -1208,27 +1202,30 @@ def create_wf_calculate_ants_warp(name='create_wf_calculate_ants_warp',
calculate_ants_warp.inputs.initial_moving_transform_com = 0
'''
reg_imports = ['import os', 'import subprocess']
calculate_ants_warp = \
pe.Node(interface=util.Function(input_names=['moving_brain',
'reference_brain',
'moving_skull',
'reference_skull',
'ants_para',
'moving_mask',
'reference_mask',
'fixed_image_mask',
'interp',
'reg_with_skull'],
output_names=['warp_list',
'warped_image'],
function=hardcoded_reg,
imports=reg_imports),
name='calc_ants_warp',
mem_gb=2.8,
mem_x=(2e-7, 'moving_brain', 'xyz'))
warp_inputs += ['reg_with_skull']
warp_outputs = ['warp_list', 'warped_image']
calculate_ants_warp = pe.Node(
interface=util.Function(input_names=warp_inputs,
output_names=warp_outputs,
function=hardcoded_reg,
imports=reg_imports),
name='calc_ants_warp', mem_gb=2.8,
mem_x=(2e-7, 'moving_brain', 'xyz'))
retry_calculate_ants_warp = pe.Node(
interface=util.Function(input_names=[*warp_inputs, 'previous_failure'],
output_names=warp_outputs,
function=retry_hardcoded_reg,
imports=['from CPAC.registration.utils '
'import hardcoded_reg',
'from CPAC.utils.docs import '
'retry_docstring']),
name='retry_calc_ants_warp', mem_gb=2.8,
mem_x=(2e-7, 'moving_brain', 'xyz'))
guardrails = tuple(registration_guardrail_node(
f'{_try}{name}_guardrail', i) for i, _try in enumerate(('', 'retry_')))

calculate_ants_warp.interface.num_threads = num_threads

retry_calculate_ants_warp.interface.num_threads = num_threads
select_forward_initial = pe.Node(util.Function(
input_names=['warp_list', 'selection'],
output_names=['selected_warp'],
Expand Down Expand Up @@ -1264,13 +1261,10 @@ def create_wf_calculate_ants_warp(name='create_wf_calculate_ants_warp',

select_inverse_warp.inputs.selection = "Inverse"

guardrail = registration_guardrail_node(f'{name}_guardrail')
calc_ants_warp_wf.connect(inputspec, 'moving_brain',
calculate_ants_warp, 'moving_brain')
calc_ants_warp_wf.connect(inputspec, 'reference_brain',
calculate_ants_warp, 'reference_brain')
calc_ants_warp_wf.connect(inputspec, 'reference_brain',
guardrail, 'reference')

if reg_ants_skull == 1:
calculate_ants_warp.inputs.reg_with_skull = 1
Expand All @@ -1279,11 +1273,17 @@ def create_wf_calculate_ants_warp(name='create_wf_calculate_ants_warp',
calculate_ants_warp, 'moving_skull')
calc_ants_warp_wf.connect(inputspec, 'reference_skull',
calculate_ants_warp, 'reference_skull')
for guardrail in guardrails:
calc_ants_warp_wf.connect(inputspec, 'reference_skull',
guardrail, 'reference')
else:
calc_ants_warp_wf.connect(inputspec, 'moving_brain',
calculate_ants_warp, 'moving_skull')
calc_ants_warp_wf.connect(inputspec, 'reference_brain',
calculate_ants_warp, 'reference_skull')
for guardrail in guardrails:
calc_ants_warp_wf.connect(inputspec, 'reference_brain',
guardrail, 'reference')

calc_ants_warp_wf.connect(inputspec, 'fixed_image_mask',
calculate_ants_warp, 'fixed_image_mask')
Expand Down Expand Up @@ -1317,9 +1317,11 @@ def create_wf_calculate_ants_warp(name='create_wf_calculate_ants_warp',
outputspec, 'warp_field')
calc_ants_warp_wf.connect(select_inverse_warp, 'selected_warp',
outputspec, 'inverse_warp_field')
calc_ants_warp_wf.connect(calculate_ants_warp, 'warped_image',
guardrail, 'registered')
calc_ants_warp_wf.connect(guardrail, 'registered',
for guardrail in guardrails:
calc_ants_warp_wf.connect(calculate_ants_warp, 'warped_image',
guardrail, 'registered')
select = guardrail_selection(calc_ants_warp_wf, *guardrails)
calc_ants_warp_wf.connect(select, 'out',
outputspec, 'normalized_output_brain')

return calc_ants_warp_wf
Expand Down Expand Up @@ -2928,38 +2930,39 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None):
wf.connect(func_to_anat_bbreg, 'outputspec.anat_func',
retry_guardrail, 'registered')

mean_bolds = pe.Node(Merge(2), run_without_submitting=True,
mean_bolds = pe.Node(util.Merge(2), run_without_submitting=True,
name=f'bbreg_mean_bold_choices_{pipe_num}')
xfms = pe.Node(Merge(2), run_without_submitting=True,
xfms = pe.Node(util.Merge(2), run_without_submitting=True,
name=f'bbreg_xfm_choices_{pipe_num}')
fallback_mean_bolds = pe.Node(Select, run_without_submitting=True,
fallback_mean_bolds = pe.Node(util.Select(),
run_without_submitting=True,
name=f'bbreg_choose_mean_bold_{pipe_num}'
)
fallback_xfms = pe.Node(Select, run_without_submitting=True,
fallback_xfms = pe.Node(util.Select(), run_without_submitting=True,
name=f'bbreg_choose_xfm_{pipe_num}')
if opt is True:
wf.connect([
(bbreg_guardrail, mean_bolds, ['registered', 'in1']),
(retry_guardrail, mean_bolds, ['registered', 'in1']),
(bbreg_guardrail, mean_bolds, [('registered', 'in1')]),
(retry_guardrail, mean_bolds, [('registered', 'in2')]),
(func_to_anat_bbreg, xfms, [
'outputspec.func_to_anat_linear_xfm', 'in2']),
('outputspec.func_to_anat_linear_xfm', 'in1')]),
(retry_node, xfms, [
'outputspec.func_to_anat_linear_xfm_nobbreg', 'in2'])])
('outputspec.func_to_anat_linear_xfm', 'in2')])])
else:
# Fall back to no-BBReg
wf.connect([
(bbreg_guardrail, mean_bolds, ['registered', 'in1']),
(func_to_anat, mean_bolds, ['outputspec.anat_func_nobbreg',
'in1']),
(bbreg_guardrail, mean_bolds, [('registered', 'in1')]),
(func_to_anat, mean_bolds, [('outputspec.anat_func_nobbreg',
'in2')]),
(func_to_anat_bbreg, xfms, [
'outputspec.func_to_anat_linear_xfm', 'in2']),
('outputspec.func_to_anat_linear_xfm', 'in1')]),
(func_to_anat, xfms, [
'outputspec.func_to_anat_linear_xfm_nobbreg', 'in2'])])
('outputspec.func_to_anat_linear_xfm_nobbreg', 'in2')])])
wf.connect([
(mean_bolds, fallback_mean_bolds, ['out', 'inlist']),
(xfms, fallback_xfms, ['out', 'inlist']),
(bbreg_guardrail, fallback_mean_bolds, ['failed_qc', 'index']),
(bbreg_guardrail, fallback_xfms, ['failed_qc', 'index'])])
(mean_bolds, fallback_mean_bolds, [('out', 'inlist')]),
(xfms, fallback_xfms, [('out', 'inlist')]),
(bbreg_guardrail, fallback_mean_bolds, [('failed_qc', 'index')]),
(bbreg_guardrail, fallback_xfms, [('failed_qc', 'index')])])
outputs = {
'space-T1w_desc-mean_bold': (fallback_mean_bolds, 'out'),
'from-bold_to-T1w_mode-image_desc-linear_xfm': (fallback_xfms,
Expand Down

0 comments on commit c83ed62

Please sign in to comment.