Skip to content

Commit

Permalink
♻️ Retry BBR on failure
Browse files Browse the repository at this point in the history
  • Loading branch information
shnizzedy committed Oct 10, 2022
1 parent e9359c2 commit 3f530ff
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 54 deletions.
17 changes: 16 additions & 1 deletion CPAC/pipeline/nipype_pipeline_engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,22 @@


def connect_from_spec(wf, spec, original_spec, exclude=None):
"""Function to connect all original inputs to a new spec"""
"""Function to connect all original inputs to a new spec
Parameters
----------
wf : Workflow
spec : dict
original_spec : dict
exclude : list, tuple, or dict, optional
Returns
-------
Workflow
"""
for _item, _value in original_spec.items():
if isinstance(exclude, (list, tuple)):
if _item not in exclude:
Expand Down
26 changes: 18 additions & 8 deletions CPAC/registration/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,36 +120,46 @@ def registration_guardrail_node(name=None):
function=registration_guardrail), name=name)


def registration_guardrail_workflow(registration_node, retry=True):
def registration_guardrail_workflow(registration_node, retry=True, spec=None):
"""A workflow to handle hitting a registration guardrail
Parameters
----------
name : str
registration_node : Node
registration_node : Node or Workflow
retry : bool, optional
spec : dict, required for guardrailing function nodes
Resource pool keys for reference and registered resources, in
the format ``{'reference': str, 'registered': str}``
Returns
-------
Workflow
See Also
--------
spec_key
"""
name = f'{registration_node.name}_guardrail'
wf = Workflow(name=f'{name}_wf')
outputspec = deepcopy(registration_node.outputs)
guardrail = registration_guardrail_node(name)
outkey = spec_key(registration_node, 'registered')
if spec is None:
spec = {key: spec_key(registration_node, key) for
key in ['reference', 'registered']}
wf.connect([
(registration_node, guardrail, [
(spec_key(registration_node, 'reference'), 'reference')]),
(registration_node, guardrail, [(outkey, 'registered')])])
(registration_node, guardrail, [(spec['reference'], 'reference')]),
(registration_node, guardrail, [(spec['registered'], 'registered')])])
if retry:
wf = retry_registration(wf, registration_node,
guardrail.outputs.registered)
else:
wf.connect(guardrail, 'registered', outputspec, outkey)
wf = connect_from_spec(wf, outputspec, registration_node, outkey)
wf.connect(guardrail, 'registered', outputspec, spec['registered'])
wf = connect_from_spec(wf, outputspec,
registration_node, spec['registered'])
return wf


Expand Down
98 changes: 53 additions & 45 deletions CPAC/registration/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

from CPAC.anat_preproc.lesion_preproc import create_lesion_preproc
from CPAC.func_preproc.utils import chunk_ts, split_ts_chunks
from CPAC.registration.guardrails import registration_guardrail_node
from CPAC.registration.guardrails import registration_guardrail_node, \
registration_guardrail_workflow
from CPAC.registration.utils import seperate_warps_list, \
check_transforms, \
generate_inverse_transform_flags, \
Expand Down Expand Up @@ -3686,14 +3687,23 @@ def single_step_resample_timeseries_to_T1template(wf, cfg, strat_pool,
"space-template_desc-brain_bold",
"space-template_desc-bold_mask"]}
""" # noqa: 501
subwf = pe.Workflow('single_step_resample_timeseries_to_T1template_'
f'{pipe_num}')
guardrail_preproc = registration_guardrail_workflow(
subwf,
spec={'reference': f'convert_bbr2itk_{pipe_num}.reference_file',
'registered': f'merge_func_to_standard_{pipe_num}.merged_file'})
guardrail_brain = registration_guardrail_workflow(
subwf,
spec={'reference': f'applyxfm_func_to_standard_{pipe_num}.'
'reference_image',
'registered': f'get_func_brain_to_standard_{pipe_num}.out_file'})
bbr2itk = pe.Node(util.Function(input_names=['reference_file',
'source_file',
'transform_file'],
output_names=['itk_transform'],
function=run_c3d),
name=f'convert_bbr2itk_{pipe_num}')
guardrail_preproc = registration_guardrail_node(
'single-step-resampling-preproc_guardrail')
if cfg.registration_workflows['functional_registration'][
'coregistration']['boundary_based_registration'][
'reference'] == 'whole-head':
Expand All @@ -3702,23 +3712,22 @@ def single_step_resample_timeseries_to_T1template(wf, cfg, strat_pool,
'coregistration']['boundary_based_registration'][
'reference'] == 'brain':
node, out = strat_pool.get_data('desc-brain_T1w')
wf.connect(node, out, bbr2itk, 'reference_file')
wf.connect(node, out, guardrail_preproc, 'reference')
subwf.connect(node, out, bbr2itk, 'reference_file')

node, out = strat_pool.get_data(['desc-reginput_bold', 'desc-mean_bold'])
wf.connect(node, out, bbr2itk, 'source_file')
subwf.connect(node, out, bbr2itk, 'source_file')

node, out = strat_pool.get_data('from-bold_to-T1w_mode-image_desc-linear_'
'xfm')
wf.connect(node, out, bbr2itk, 'transform_file')
subwf.connect(node, out, bbr2itk, 'transform_file')

split_func = pe.Node(interface=fsl.Split(),
name=f'split_func_{pipe_num}')

split_func.inputs.dimension = 't'

node, out = strat_pool.get_data('desc-stc_bold')
wf.connect(node, out, split_func, 'in_file')
subwf.connect(node, out, split_func, 'in_file')

### Loop starts! ###
motionxfm2itk = pe.MapNode(util.Function(
Expand All @@ -3731,54 +3740,53 @@ def single_step_resample_timeseries_to_T1template(wf, cfg, strat_pool,
iterfield=['transform_file'])

node, out = strat_pool.get_data('motion-basefile')
wf.connect(node, out, motionxfm2itk, 'reference_file')
wf.connect(node, out, motionxfm2itk, 'source_file')
subwf.connect(node, out, motionxfm2itk, 'reference_file')
subwf.connect(node, out, motionxfm2itk, 'source_file')

node, out = strat_pool.get_data('coordinate-transformation')
motion_correct_tool = check_prov_for_motion_tool(
strat_pool.get_cpac_provenance('coordinate-transformation'))
if motion_correct_tool == 'mcflirt':
wf.connect(node, out, motionxfm2itk, 'transform_file')
subwf.connect(node, out, motionxfm2itk, 'transform_file')
elif motion_correct_tool == '3dvolreg':
convert_transform = pe.Node(util.Function(
input_names=['one_d_filename'],
output_names=['transform_directory'],
function=one_d_to_mat,
imports=['import os', 'import numpy as np']),
name=f'convert_transform_{pipe_num}')
wf.connect(node, out, convert_transform, 'one_d_filename')
wf.connect(convert_transform, 'transform_directory',
motionxfm2itk, 'transform_file')
subwf.connect(node, out, convert_transform, 'one_d_filename')
subwf.connect(convert_transform, 'transform_directory',
motionxfm2itk, 'transform_file')

collectxfm = pe.MapNode(util.Merge(4),
name=f'collectxfm_func_to_standard_{pipe_num}',
iterfield=['in4'])

node, out = strat_pool.get_data('from-T1w_to-template_mode-image_xfm')
wf.connect(node, out, collectxfm, 'in1')
wf.connect(bbr2itk, 'itk_transform', collectxfm, 'in2')
subwf.connect(node, out, collectxfm, 'in1')
subwf.connect(bbr2itk, 'itk_transform', collectxfm, 'in2')

collectxfm.inputs.in3 = 'identity'

wf.connect(motionxfm2itk, 'itk_transform',
collectxfm, 'in4')
subwf.connect(motionxfm2itk, 'itk_transform',
collectxfm, 'in4')

applyxfm_func_to_standard = pe.MapNode(interface=ants.ApplyTransforms(),
name=f'applyxfm_func_to_standard_{pipe_num}',
iterfield=['input_image', 'transforms'])
name='applyxfm_func_to_standard_'
f'{pipe_num}',
iterfield=['input_image',
'transforms'])

applyxfm_func_to_standard.inputs.float = True
applyxfm_func_to_standard.inputs.interpolation = 'LanczosWindowedSinc'
guardrail_brain = registration_guardrail_node(
'single-step-resampling-brain_guardrail')

wf.connect(split_func, 'out_files',
applyxfm_func_to_standard, 'input_image')
subwf.connect(split_func, 'out_files',
applyxfm_func_to_standard, 'input_image')

node, out = strat_pool.get_data('T1w-brain-template-funcreg')
wf.connect(node, out, applyxfm_func_to_standard, 'reference_image')
wf.connect(node, out, guardrail_brain, 'reference')
wf.connect(collectxfm, 'out', applyxfm_func_to_standard, 'transforms')
subwf.connect(node, out, applyxfm_func_to_standard, 'reference_image')
subwf.connect(collectxfm, 'out', applyxfm_func_to_standard, 'transforms')

### Loop ends! ###

Expand All @@ -3787,8 +3795,8 @@ def single_step_resample_timeseries_to_T1template(wf, cfg, strat_pool,

merge_func_to_standard.inputs.dimension = 't'

wf.connect(applyxfm_func_to_standard, 'output_image',
merge_func_to_standard, 'in_files')
subwf.connect(applyxfm_func_to_standard, 'output_image',
merge_func_to_standard, 'in_files')

applyxfm_func_mask_to_standard = pe.Node(interface=ants.ApplyTransforms(),
name='applyxfm_func_mask_to_'
Expand All @@ -3797,34 +3805,34 @@ def single_step_resample_timeseries_to_T1template(wf, cfg, strat_pool,
applyxfm_func_mask_to_standard.inputs.interpolation = 'MultiLabel'

node, out = strat_pool.get_data('space-bold_desc-brain_mask')
wf.connect(node, out, applyxfm_func_mask_to_standard, 'input_image')
subwf.connect(node, out, applyxfm_func_mask_to_standard, 'input_image')

node, out = strat_pool.get_data('T1w-brain-template-funcreg')
wf.connect(node, out, applyxfm_func_mask_to_standard, 'reference_image')
subwf.connect(node, out, applyxfm_func_mask_to_standard, 'reference_image')

collectxfm_mask = pe.Node(util.Merge(2),
name=f'collectxfm_func_mask_to_standard_{pipe_num}')
name='collectxfm_func_mask_to_standard_'
f'{pipe_num}')

node, out = strat_pool.get_data('from-T1w_to-template_mode-image_xfm')
wf.connect(node, out, collectxfm_mask, 'in1')
wf.connect(bbr2itk, 'itk_transform', collectxfm_mask, 'in2')
wf.connect(collectxfm_mask, 'out',
applyxfm_func_mask_to_standard, 'transforms')
subwf.connect(node, out, collectxfm_mask, 'in1')
subwf.connect(bbr2itk, 'itk_transform', collectxfm_mask, 'in2')
subwf.connect(collectxfm_mask, 'out',
applyxfm_func_mask_to_standard, 'transforms')

apply_mask = pe.Node(interface=fsl.maths.ApplyMask(),
name=f'get_func_brain_to_standard_{pipe_num}')

wf.connect(merge_func_to_standard, 'merged_file',
apply_mask, 'in_file')
wf.connect(applyxfm_func_mask_to_standard, 'output_image',
apply_mask, 'mask_file')
wf.connect(merge_func_to_standard, 'merged_file',
guardrail_preproc, 'registered')
wf.connect(apply_mask, 'out_file', guardrail_brain, 'registered')
subwf.connect(merge_func_to_standard, 'merged_file',
apply_mask, 'in_file')
subwf.connect(applyxfm_func_mask_to_standard, 'output_image',
apply_mask, 'mask_file')

outputs = {
'space-template_desc-preproc_bold': (guardrail_preproc, 'registered'),
'space-template_desc-brain_bold': (guardrail_brain, 'registered'),
'space-template_desc-preproc_bold': (guardrail_preproc,
'outputspec.merged_file'),
'space-template_desc-brain_bold': (guardrail_brain,
'outputspec.out_file'),
'space-template_desc-bold_mask': (applyxfm_func_mask_to_standard,
'output_image'),
}
Expand Down

0 comments on commit 3f530ff

Please sign in to comment.