diff --git a/CPAC/registration/guardrails.py b/CPAC/registration/guardrails.py index b869227c70..ee3dff397e 100644 --- a/CPAC/registration/guardrails.py +++ b/CPAC/registration/guardrails.py @@ -33,6 +33,36 @@ Registration: {'reference': 'reference', 'registered': 'out_file'}} +def connect_retries(wf, nodes, connections): + """Function to generalize making the same connections to try and + retry nodes. + + For each 3-tuple (``conn``) in ``connections``, will do + + .. code-block:: Python + + wf.connect(conn[0], node, conn[1], conn[2]) + + for each node in nodes + + Parameters + ---------- + wf : Workflow + + nodes : iterable of Nodes + + connections : iterable of 3-tuples of (Node, str or tuple, str) + + Returns + ------- + Workflow + """ + for node in nodes: + for conn in connections: + wf.connect(conn[0], node, conn[1], conn[2]) + return wf + + def guardrail_selection(wf: 'Workflow', node1: 'Node', node2: 'Node', output_key: str = 'registered', guardrail_node: 'Node' = None) -> Node: diff --git a/CPAC/registration/registration.py b/CPAC/registration/registration.py index 32da26a2ff..e339874460 100644 --- a/CPAC/registration/registration.py +++ b/CPAC/registration/registration.py @@ -17,6 +17,7 @@ """Registration functions""" # pylint: disable=too-many-lines,ungrouped-imports,wrong-import-order # TODO: replace Tuple with tuple, Union with |, once Python >= 3.9, 3.10 +from sqlite3 import connect from typing import Optional, Tuple, Union from CPAC.pipeline import nipype_pipeline_engine as pe from nipype.interfaces import afni, ants, c3, fsl, utility as util @@ -24,7 +25,8 @@ from CPAC.anat_preproc.lesion_preproc import create_lesion_preproc from CPAC.func_preproc.utils import chunk_ts, split_ts_chunks from CPAC.pipeline.random_state.seed import increment_seed -from CPAC.registration.guardrails import guardrail_selection, \ +from CPAC.registration.guardrails import connect_retries, \ + guardrail_selection, \ registration_guardrail_node from CPAC.registration.utils import seperate_warps_list, \ check_transforms, \ @@ -974,7 +976,6 @@ def create_bbregister_func_to_anat(phase_diff_distcor=False, outputspec.anat_func : string (nifti file) Functional data in anatomical space """ - from CPAC.pipeline.random_state.seed import seed_plus_1 register_bbregister_func_to_anat = pe.Workflow(name=name) inputspec = pe.Node(util.IdentityInterface(fields=['func', @@ -999,14 +1000,10 @@ def create_bbregister_func_to_anat(phase_diff_distcor=False, wm_bb_mask = pe.Node(interface=fsl.ImageMaths(), name='wm_bb_mask') - if retry: - seed = seed_plus_1() - wm_bb_mask.seed = seed register_bbregister_func_to_anat.connect( inputspec, 'bbr_wm_mask_args', wm_bb_mask, 'op_string') - register_bbregister_func_to_anat.connect(inputspec, 'anat_wm_segmentation', wm_bb_mask, 'in_file') @@ -1017,50 +1014,52 @@ def bbreg_args(bbreg_target): bbreg_func_to_anat = pe.Node(interface=fsl.FLIRT(), name='bbreg_func_to_anat') bbreg_func_to_anat.inputs.dof = 6 + guardrail_bbreg_func_to_anat = registration_guardrail_node( + f'{bbreg_func_to_anat.name}_guardrail') + nodes = [bbreg_func_to_anat] + guardrails = [guardrail_bbreg_func_to_anat] if retry: - bbreg_func_to_anat.seed = seed - - register_bbregister_func_to_anat.connect( - inputspec, 'bbr_schedule', - bbreg_func_to_anat, 'schedule') - register_bbregister_func_to_anat.connect( - wm_bb_mask, ('out_file', bbreg_args), - bbreg_func_to_anat, 'args') - register_bbregister_func_to_anat.connect( - inputspec, 'func', - bbreg_func_to_anat, 'in_file') - register_bbregister_func_to_anat.connect( - inputspec, 'anat', - bbreg_func_to_anat, 'reference') - register_bbregister_func_to_anat.connect( - inputspec, 'linear_reg_matrix', - bbreg_func_to_anat, 'in_matrix_file') - + retry_bbreg_func_to_anat = increment_seed(bbreg_func_to_anat.clone( + f'retry_{bbreg_func_to_anat.name}')) + guardrail_retry_bbreg_func_to_anat = registration_guardrail_node( + f'{retry_bbreg_func_to_anat.name}_guardrail') + nodes += [retry_bbreg_func_to_anat] + guardrails += [guardrail_retry_bbreg_func_to_anat] + register_bbregister_func_to_anat = connect_retries( + register_bbregister_func_to_anat, nodes, [ + (inputspec, 'bbr_schedule', 'schedule'), + (wm_bb_mask, ('out_file', bbreg_args), 'args'), + (inputspec, 'func', 'in_file'), + (inputspec, 'anat', 'reference'), + (inputspec, 'linear_reg_matrix', 'in_matrix_file')]) if phase_diff_distcor: + register_bbregister_func_to_anat = connect_retries( + register_bbregister_func_to_anat, nodes, [ + (inputNode_pedir, ('pedir', convert_pedir), 'pedir'), + (inputspec, 'fieldmap', 'fieldmap'), + (inputspec, 'fieldmapmask', 'fieldmapmask'), + (inputNode_echospacing, 'echospacing', 'echospacing')]) + for i, node in enumerate(nodes): + register_bbregister_func_to_anat.connect(inputspec, 'anat', + guardrails[i], 'reference') + register_bbregister_func_to_anat.connect(node, 'out_file', + guardrails[i], 'registered') + if retry: + # pylint: disable=no-value-for-parameter + outfile = guardrail_selection(register_bbregister_func_to_anat, + *guardrails) + matrix = guardrail_selection(register_bbregister_func_to_anat, *nodes, + 'out_matrix_file', guardrails[0]) register_bbregister_func_to_anat.connect( - inputNode_pedir, ('pedir', convert_pedir), - bbreg_func_to_anat, 'pedir') - register_bbregister_func_to_anat.connect( - inputspec, 'fieldmap', - bbreg_func_to_anat, 'fieldmap') - register_bbregister_func_to_anat.connect( - inputspec, 'fieldmapmask', - bbreg_func_to_anat, 'fieldmapmask') + matrix, 'out', outputspec, 'func_to_anat_linear_xfm') + register_bbregister_func_to_anat.connect(outfile, 'out', + outputspec, 'anat_func') + else: register_bbregister_func_to_anat.connect( - inputNode_echospacing, 'echospacing', - bbreg_func_to_anat, 'echospacing') - - guardrail = registration_guardrail_node(name=f'{name}_guardrail') - register_bbregister_func_to_anat.connect(inputspec, 'anat', - guardrail, 'reference') - register_bbregister_func_to_anat.connect( - bbreg_func_to_anat, 'out_matrix_file', - outputspec, 'func_to_anat_linear_xfm') - register_bbregister_func_to_anat.connect(bbreg_func_to_anat, 'out_file', - guardrail, 'registered') - register_bbregister_func_to_anat.connect(guardrail, 'registered', - outputspec, 'anat_func') - + bbreg_func_to_anat, 'out_matrix_file', + outputspec, 'func_to_anat_linear_xfm') + register_bbregister_func_to_anat.connect(guardrails[0], 'registered', + outputspec, 'anat_func') return register_bbregister_func_to_anat @@ -2871,9 +2870,11 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): (func_to_anat, 'outputspec.func_to_anat_linear_xfm_nobbreg') } - if opt in [True, "fallback"]: + if opt in [True, 'fallback']: + fallback = opt == 'fallback' func_to_anat_bbreg = create_bbregister_func_to_anat( - diff_complete, f'func_to_anat_bbreg{bbreg_status}_{pipe_num}') + diff_complete, f'func_to_anat_bbreg{bbreg_status}_{pipe_num}', + opt is True) func_to_anat_bbreg.inputs.inputspec.bbr_schedule = \ cfg.registration_workflows['functional_registration'][ 'coregistration']['boundary_based_registration'][ @@ -2882,56 +2883,31 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): cfg.registration_workflows['functional_registration'][ 'coregistration']['boundary_based_registration'][ 'bbr_wm_mask_args'] - bbreg_guardrail = registration_guardrail_node( - f'bbreg{bbreg_status}_guardrail_{pipe_num}') - if opt is True: - # Retry once on failure - retry_node = create_bbregister_func_to_anat(diff_complete, - f'retry_func_to_anat_' - f'bbreg_{pipe_num}', - retry=True) - retry_node.inputs.inputspec.bbr_schedule = cfg[ - 'registration_workflows', 'functional_registration', - 'coregistration', 'boundary_based_registration', - 'bbr_schedule'] - retry_node.inputs.inputspec.bbr_wm_mask_args = cfg[ - 'registration_workflows', 'functional_registration', - 'coregistration', 'boundary_based_registration', - 'bbr_wm_mask_args'] - retry_guardrail = registration_guardrail_node( - f'retry_bbreg_guardrail_{pipe_num}') + if fallback: + bbreg_guardrail = registration_guardrail_node( + f'bbreg{bbreg_status}_guardrail_{pipe_num}') node, out = strat_pool.get_data('desc-reginput_bold') wf.connect(node, out, func_to_anat_bbreg, 'inputspec.func') - if opt is True: - wf.connect(node, out, retry_node, 'inputspec.func') if cfg.registration_workflows['functional_registration'][ 'coregistration']['boundary_based_registration'][ 'reference'] == 'whole-head': node, out = strat_pool.get_data('T1w') wf.connect(node, out, func_to_anat_bbreg, 'inputspec.anat') - wf.connect(node, out, bbreg_guardrail, 'reference') - if opt is True: - wf.connect(node, out, retry_node, 'inputspec.anat') - wf.connect(node, out, retry_guardrail, 'reference') + if fallback: + wf.connect(node, out, bbreg_guardrail, 'reference') elif cfg.registration_workflows['functional_registration'][ 'coregistration']['boundary_based_registration'][ 'reference'] == 'brain': node, out = strat_pool.get_data('desc-brain_T1w') wf.connect(node, out, func_to_anat_bbreg, 'inputspec.anat') - wf.connect(node, out, bbreg_guardrail, 'reference') - if opt is True: - wf.connect(node, out, retry_node, 'inputspec.anat') - wf.connect(node, out, retry_guardrail, 'reference') + if fallback: + wf.connect(node, out, bbreg_guardrail, 'reference') wf.connect(func_to_anat, 'outputspec.func_to_anat_linear_xfm_nobbreg', func_to_anat_bbreg, 'inputspec.linear_reg_matrix') - if opt is True: - wf.connect(func_to_anat, - 'outputspec.func_to_anat_linear_xfm_nobbreg', - retry_node, 'inputspec.linear_reg_matrix') if strat_pool.check_rpool('space-bold_label-WM_mask'): node, out = strat_pool.get_data(["space-bold_label-WM_mask"]) @@ -2948,76 +2924,59 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): "label-WM_mask"]) wf.connect(node, out, func_to_anat_bbreg, 'inputspec.anat_wm_segmentation') - if opt is True: - wf.connect(node, out, retry_node, 'inputspec.anat_wm_segmentation') if diff_complete: node, out = strat_pool.get_data('effectiveEchoSpacing') wf.connect(node, out, func_to_anat_bbreg, 'echospacing_input.echospacing') - if opt is True: - wf.connect(node, out, - retry_node, 'echospacing_input.echospacing') node, out = strat_pool.get_data('diffphase-pedir') wf.connect(node, out, func_to_anat_bbreg, 'pedir_input.pedir') - if opt is True: - wf.connect(node, out, retry_node, 'pedir_input.pedir') node, out = strat_pool.get_data("despiked-fieldmap") wf.connect(node, out, func_to_anat_bbreg, 'inputspec.fieldmap') - if opt is True: - wf.connect(node, out, retry_node, 'inputspec.fieldmap') node, out = strat_pool.get_data("fieldmap-mask") wf.connect(node, out, func_to_anat_bbreg, 'inputspec.fieldmapmask') - if opt is True: - wf.connect(node, out, retry_node, 'inputspec.fieldmapmask') - - wf.connect(func_to_anat_bbreg, 'outputspec.anat_func', - bbreg_guardrail, 'registered') - if opt is True: - wf.connect(func_to_anat_bbreg, 'outputspec.anat_func', - retry_guardrail, 'registered') - - mean_bolds = pe.Node(util.Merge(2), run_without_submitting=True, - name=f'bbreg_mean_bold_choices_{pipe_num}') - xfms = pe.Node(util.Merge(2), run_without_submitting=True, - name=f'bbreg_xfm_choices_{pipe_num}') - fallback_mean_bolds = pe.Node(util.Select(), - run_without_submitting=True, - name=f'bbreg_choose_mean_bold_{pipe_num}' - ) - fallback_xfms = pe.Node(util.Select(), run_without_submitting=True, - name=f'bbreg_choose_xfm_{pipe_num}') - if opt is True: - wf.connect([ - (bbreg_guardrail, mean_bolds, [('registered', 'in1')]), - (retry_guardrail, mean_bolds, [('registered', 'in2')]), - (func_to_anat_bbreg, xfms, [ - ('outputspec.func_to_anat_linear_xfm', 'in1')]), - (retry_node, xfms, [ - ('outputspec.func_to_anat_linear_xfm', 'in2')])]) - else: + if fallback: # Fall back to no-BBReg + mean_bolds = pe.Node(util.Merge(2), run_without_submitting=True, + name=f'bbreg_mean_bold_choices_{pipe_num}') + xfms = pe.Node(util.Merge(2), run_without_submitting=True, + name=f'bbreg_xfm_choices_{pipe_num}') + fallback_mean_bolds = pe.Node(util.Select(), + run_without_submitting=True, + name='bbreg_choose_mean_bold_' + f'{pipe_num}') + fallback_xfms = pe.Node(util.Select(), run_without_submitting=True, + name=f'bbreg_choose_xfm_{pipe_num}') wf.connect([ + (func_to_anat_bbreg, bbreg_guardrail, [ + ('outputspec.anat_func', 'registered')]), (bbreg_guardrail, mean_bolds, [('registered', 'in1')]), (func_to_anat, mean_bolds, [('outputspec.anat_func_nobbreg', 'in2')]), (func_to_anat_bbreg, xfms, [ ('outputspec.func_to_anat_linear_xfm', 'in1')]), (func_to_anat, xfms, [ - ('outputspec.func_to_anat_linear_xfm_nobbreg', 'in2')])]) - wf.connect([ - (mean_bolds, fallback_mean_bolds, [('out', 'inlist')]), - (xfms, fallback_xfms, [('out', 'inlist')]), - (bbreg_guardrail, fallback_mean_bolds, [('failed_qc', 'index')]), - (bbreg_guardrail, fallback_xfms, [('failed_qc', 'index')])]) - outputs = { - 'space-T1w_desc-mean_bold': (fallback_mean_bolds, 'out'), - 'from-bold_to-T1w_mode-image_desc-linear_xfm': (fallback_xfms, - 'out')} + ('outputspec.func_to_anat_linear_xfm_nobbreg', 'in2')]), + (mean_bolds, fallback_mean_bolds, [('out', 'inlist')]), + (xfms, fallback_xfms, [('out', 'inlist')]), + (bbreg_guardrail, fallback_mean_bolds, [ + ('failed_qc', 'index')]), + (bbreg_guardrail, fallback_xfms, [('failed_qc', 'index')])]) + outputs = { + 'space-T1w_desc-mean_bold': (fallback_mean_bolds, 'out'), + 'from-bold_to-T1w_mode-image_desc-linear_xfm': (fallback_xfms, + 'out')} + else: + outputs = { + 'space-T1w_desc-mean_bold': (func_to_anat_bbreg, + 'outputspec.anat_func'), + 'from-bold_to-T1w_mode-image_desc-linear_xfm': ( + func_to_anat_bbreg, + 'outputspec.func_to_anat_linear_xfm')} return wf, outputs