diff --git a/CPAC/pipeline/nipype_pipeline_engine/utils.py b/CPAC/pipeline/nipype_pipeline_engine/utils.py index 3542b4f70c..90f3984d32 100644 --- a/CPAC/pipeline/nipype_pipeline_engine/utils.py +++ b/CPAC/pipeline/nipype_pipeline_engine/utils.py @@ -18,7 +18,22 @@ def connect_from_spec(wf, spec, original_spec, exclude=None): - """Function to connect all original inputs to a new spec""" + """Function to connect all original inputs to a new spec + + Parameters + ---------- + wf : Workflow + + spec : dict + + original_spec : dict + + exclude : list, tuple, or dict, optional + + Returns + ------- + Workflow + """ for _item, _value in original_spec.items(): if isinstance(exclude, (list, tuple)): if _item not in exclude: diff --git a/CPAC/registration/guardrails.py b/CPAC/registration/guardrails.py index e53e7bca1f..69663aaf5e 100644 --- a/CPAC/registration/guardrails.py +++ b/CPAC/registration/guardrails.py @@ -120,36 +120,46 @@ def registration_guardrail_node(name=None): function=registration_guardrail), name=name) -def registration_guardrail_workflow(registration_node, retry=True): +def registration_guardrail_workflow(registration_node, retry=True, spec=None): """A workflow to handle hitting a registration guardrail Parameters ---------- name : str - registration_node : Node + registration_node : Node or Workflow retry : bool, optional + spec : dict, required for guardrailing function nodes + Resource pool keys for reference and registered resources, in + the format ``{'reference': str, 'registered': str}`` + Returns ------- Workflow + + See Also + -------- + spec_key """ name = f'{registration_node.name}_guardrail' wf = Workflow(name=f'{name}_wf') outputspec = deepcopy(registration_node.outputs) guardrail = registration_guardrail_node(name) - outkey = spec_key(registration_node, 'registered') + if spec is None: + spec = {key: spec_key(registration_node, key) for + key in ['reference', 'registered']} wf.connect([ - (registration_node, guardrail, [ - (spec_key(registration_node, 'reference'), 'reference')]), - (registration_node, guardrail, [(outkey, 'registered')])]) + (registration_node, guardrail, [(spec['reference'], 'reference')]), + (registration_node, guardrail, [(spec['registered'], 'registered')])]) if retry: wf = retry_registration(wf, registration_node, guardrail.outputs.registered) else: - wf.connect(guardrail, 'registered', outputspec, outkey) - wf = connect_from_spec(wf, outputspec, registration_node, outkey) + wf.connect(guardrail, 'registered', outputspec, spec['registered']) + wf = connect_from_spec(wf, outputspec, + registration_node, spec['registered']) return wf diff --git a/CPAC/registration/registration.py b/CPAC/registration/registration.py index e49b3f93ba..afa724fcb3 100644 --- a/CPAC/registration/registration.py +++ b/CPAC/registration/registration.py @@ -24,7 +24,8 @@ from CPAC.anat_preproc.lesion_preproc import create_lesion_preproc from CPAC.func_preproc.utils import chunk_ts, split_ts_chunks -from CPAC.registration.guardrails import registration_guardrail_node +from CPAC.registration.guardrails import registration_guardrail_node, \ + registration_guardrail_workflow from CPAC.registration.utils import seperate_warps_list, \ check_transforms, \ generate_inverse_transform_flags, \ @@ -3686,14 +3687,23 @@ def single_step_resample_timeseries_to_T1template(wf, cfg, strat_pool, "space-template_desc-brain_bold", "space-template_desc-bold_mask"]} """ # noqa: 501 + subwf = pe.Workflow('single_step_resample_timeseries_to_T1template_' + f'{pipe_num}') + guardrail_preproc = registration_guardrail_workflow( + subwf, + spec={'reference': f'convert_bbr2itk_{pipe_num}.reference_file', + 'registered': f'merge_func_to_standard_{pipe_num}.merged_file'}) + guardrail_brain = registration_guardrail_workflow( + subwf, + spec={'reference': f'applyxfm_func_to_standard_{pipe_num}.' + 'reference_image', + 'registered': f'get_func_brain_to_standard_{pipe_num}.out_file'}) bbr2itk = pe.Node(util.Function(input_names=['reference_file', 'source_file', 'transform_file'], output_names=['itk_transform'], function=run_c3d), name=f'convert_bbr2itk_{pipe_num}') - guardrail_preproc = registration_guardrail_node( - 'single-step-resampling-preproc_guardrail') if cfg.registration_workflows['functional_registration'][ 'coregistration']['boundary_based_registration'][ 'reference'] == 'whole-head': @@ -3702,15 +3712,14 @@ def single_step_resample_timeseries_to_T1template(wf, cfg, strat_pool, 'coregistration']['boundary_based_registration'][ 'reference'] == 'brain': node, out = strat_pool.get_data('desc-brain_T1w') - wf.connect(node, out, bbr2itk, 'reference_file') - wf.connect(node, out, guardrail_preproc, 'reference') + subwf.connect(node, out, bbr2itk, 'reference_file') node, out = strat_pool.get_data(['desc-reginput_bold', 'desc-mean_bold']) - wf.connect(node, out, bbr2itk, 'source_file') + subwf.connect(node, out, bbr2itk, 'source_file') node, out = strat_pool.get_data('from-bold_to-T1w_mode-image_desc-linear_' 'xfm') - wf.connect(node, out, bbr2itk, 'transform_file') + subwf.connect(node, out, bbr2itk, 'transform_file') split_func = pe.Node(interface=fsl.Split(), name=f'split_func_{pipe_num}') @@ -3718,7 +3727,7 @@ def single_step_resample_timeseries_to_T1template(wf, cfg, strat_pool, split_func.inputs.dimension = 't' node, out = strat_pool.get_data('desc-stc_bold') - wf.connect(node, out, split_func, 'in_file') + subwf.connect(node, out, split_func, 'in_file') ### Loop starts! ### motionxfm2itk = pe.MapNode(util.Function( @@ -3731,14 +3740,14 @@ def single_step_resample_timeseries_to_T1template(wf, cfg, strat_pool, iterfield=['transform_file']) node, out = strat_pool.get_data('motion-basefile') - wf.connect(node, out, motionxfm2itk, 'reference_file') - wf.connect(node, out, motionxfm2itk, 'source_file') + subwf.connect(node, out, motionxfm2itk, 'reference_file') + subwf.connect(node, out, motionxfm2itk, 'source_file') node, out = strat_pool.get_data('coordinate-transformation') motion_correct_tool = check_prov_for_motion_tool( strat_pool.get_cpac_provenance('coordinate-transformation')) if motion_correct_tool == 'mcflirt': - wf.connect(node, out, motionxfm2itk, 'transform_file') + subwf.connect(node, out, motionxfm2itk, 'transform_file') elif motion_correct_tool == '3dvolreg': convert_transform = pe.Node(util.Function( input_names=['one_d_filename'], @@ -3746,39 +3755,38 @@ def single_step_resample_timeseries_to_T1template(wf, cfg, strat_pool, function=one_d_to_mat, imports=['import os', 'import numpy as np']), name=f'convert_transform_{pipe_num}') - wf.connect(node, out, convert_transform, 'one_d_filename') - wf.connect(convert_transform, 'transform_directory', - motionxfm2itk, 'transform_file') + subwf.connect(node, out, convert_transform, 'one_d_filename') + subwf.connect(convert_transform, 'transform_directory', + motionxfm2itk, 'transform_file') collectxfm = pe.MapNode(util.Merge(4), name=f'collectxfm_func_to_standard_{pipe_num}', iterfield=['in4']) node, out = strat_pool.get_data('from-T1w_to-template_mode-image_xfm') - wf.connect(node, out, collectxfm, 'in1') - wf.connect(bbr2itk, 'itk_transform', collectxfm, 'in2') + subwf.connect(node, out, collectxfm, 'in1') + subwf.connect(bbr2itk, 'itk_transform', collectxfm, 'in2') collectxfm.inputs.in3 = 'identity' - wf.connect(motionxfm2itk, 'itk_transform', - collectxfm, 'in4') + subwf.connect(motionxfm2itk, 'itk_transform', + collectxfm, 'in4') applyxfm_func_to_standard = pe.MapNode(interface=ants.ApplyTransforms(), - name=f'applyxfm_func_to_standard_{pipe_num}', - iterfield=['input_image', 'transforms']) + name='applyxfm_func_to_standard_' + f'{pipe_num}', + iterfield=['input_image', + 'transforms']) applyxfm_func_to_standard.inputs.float = True applyxfm_func_to_standard.inputs.interpolation = 'LanczosWindowedSinc' - guardrail_brain = registration_guardrail_node( - 'single-step-resampling-brain_guardrail') - wf.connect(split_func, 'out_files', - applyxfm_func_to_standard, 'input_image') + subwf.connect(split_func, 'out_files', + applyxfm_func_to_standard, 'input_image') node, out = strat_pool.get_data('T1w-brain-template-funcreg') - wf.connect(node, out, applyxfm_func_to_standard, 'reference_image') - wf.connect(node, out, guardrail_brain, 'reference') - wf.connect(collectxfm, 'out', applyxfm_func_to_standard, 'transforms') + subwf.connect(node, out, applyxfm_func_to_standard, 'reference_image') + subwf.connect(collectxfm, 'out', applyxfm_func_to_standard, 'transforms') ### Loop ends! ### @@ -3787,8 +3795,8 @@ def single_step_resample_timeseries_to_T1template(wf, cfg, strat_pool, merge_func_to_standard.inputs.dimension = 't' - wf.connect(applyxfm_func_to_standard, 'output_image', - merge_func_to_standard, 'in_files') + subwf.connect(applyxfm_func_to_standard, 'output_image', + merge_func_to_standard, 'in_files') applyxfm_func_mask_to_standard = pe.Node(interface=ants.ApplyTransforms(), name='applyxfm_func_mask_to_' @@ -3797,34 +3805,34 @@ def single_step_resample_timeseries_to_T1template(wf, cfg, strat_pool, applyxfm_func_mask_to_standard.inputs.interpolation = 'MultiLabel' node, out = strat_pool.get_data('space-bold_desc-brain_mask') - wf.connect(node, out, applyxfm_func_mask_to_standard, 'input_image') + subwf.connect(node, out, applyxfm_func_mask_to_standard, 'input_image') node, out = strat_pool.get_data('T1w-brain-template-funcreg') - wf.connect(node, out, applyxfm_func_mask_to_standard, 'reference_image') + subwf.connect(node, out, applyxfm_func_mask_to_standard, 'reference_image') collectxfm_mask = pe.Node(util.Merge(2), - name=f'collectxfm_func_mask_to_standard_{pipe_num}') + name='collectxfm_func_mask_to_standard_' + f'{pipe_num}') node, out = strat_pool.get_data('from-T1w_to-template_mode-image_xfm') - wf.connect(node, out, collectxfm_mask, 'in1') - wf.connect(bbr2itk, 'itk_transform', collectxfm_mask, 'in2') - wf.connect(collectxfm_mask, 'out', - applyxfm_func_mask_to_standard, 'transforms') + subwf.connect(node, out, collectxfm_mask, 'in1') + subwf.connect(bbr2itk, 'itk_transform', collectxfm_mask, 'in2') + subwf.connect(collectxfm_mask, 'out', + applyxfm_func_mask_to_standard, 'transforms') apply_mask = pe.Node(interface=fsl.maths.ApplyMask(), name=f'get_func_brain_to_standard_{pipe_num}') - wf.connect(merge_func_to_standard, 'merged_file', - apply_mask, 'in_file') - wf.connect(applyxfm_func_mask_to_standard, 'output_image', - apply_mask, 'mask_file') - wf.connect(merge_func_to_standard, 'merged_file', - guardrail_preproc, 'registered') - wf.connect(apply_mask, 'out_file', guardrail_brain, 'registered') + subwf.connect(merge_func_to_standard, 'merged_file', + apply_mask, 'in_file') + subwf.connect(applyxfm_func_mask_to_standard, 'output_image', + apply_mask, 'mask_file') outputs = { - 'space-template_desc-preproc_bold': (guardrail_preproc, 'registered'), - 'space-template_desc-brain_bold': (guardrail_brain, 'registered'), + 'space-template_desc-preproc_bold': (guardrail_preproc, + 'outputspec.merged_file'), + 'space-template_desc-brain_bold': (guardrail_brain, + 'outputspec.out_file'), 'space-template_desc-bold_mask': (applyxfm_func_mask_to_standard, 'output_image'), }