Skip to content

Commit

Permalink
🚧 WIP 🥅 Iterate guardrail installation
Browse files Browse the repository at this point in the history
  • Loading branch information
shnizzedy committed Oct 19, 2022
1 parent b33c472 commit 7c77f4b
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 68 deletions.
3 changes: 1 addition & 2 deletions CPAC/pipeline/nipype_pipeline_engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@
"""Custom nipype utilities"""


def connect_from_spec(wf, spec, original_spec, exclude=None):
def connect_from_spec(spec, original_spec, exclude=None):
"""Function to connect all original inputs to a new spec"""
for _item, _value in original_spec.items():
if isinstance(exclude, (list, tuple)):
if _item not in exclude:
setattr(spec.inputs, _item, _value)
elif _item != exclude:
setattr(spec.inputs, _item, _value)
return wf
20 changes: 19 additions & 1 deletion CPAC/pipeline/random_state/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def random_seed():
-------
seed : int or None
'''
if _seed['seed'] == 'random':
if _seed['seed'] in ['random', None]:
_seed['seed'] = random_random_seed()
return _seed['seed']

Expand Down Expand Up @@ -153,6 +153,24 @@ def _reusable_flags():
}


def seed_plus_1(seed=None):
'''Increment seed, looping back to 1 at MAX_SEED
Parameters
----------
seed : int, optional
Uses configured seed if not specified
Returns
-------
int
'''
seed = random_seed() if seed is None else int(seed)
if seed < MAX_SEED: # increment random seed
return seed + 1
return 1 # loop back to 1


def set_up_random_state(seed):
'''Set global random seed
Expand Down
13 changes: 7 additions & 6 deletions CPAC/pipeline/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
# pylint: disable=too-many-lines
import re
from itertools import chain, permutations
import numpy as np
from pathvalidate import sanitize_filename
from voluptuous import All, ALLOW_EXTRA, Any, Capitalize, Coerce, \
from voluptuous import All, ALLOW_EXTRA, Any, Capitalize, Coerce, Equal, \
ExactSequence, ExclusiveInvalid, In, Length, Lower, \
Match, Maybe, Optional, Range, Required, Schema
from CPAC import docs_prefix
Expand Down Expand Up @@ -492,7 +491,6 @@ def sanitize(filename):
'interpolation': In({'trilinear', 'sinc', 'spline'}),
'using': str,
'input': str,
'interpolation': str,
'cost': str,
'dof': int,
'arguments': Maybe(str),
Expand All @@ -510,11 +508,14 @@ def sanitize(filename):
},
},
'boundary_based_registration': {
'run': forkable,
'run': All(Coerce(ListFromItem),
[Any(bool, All(Lower, Equal('fallback')))],
Length(max=3)),
'bbr_schedule': str,
'bbr_wm_map': In({'probability_map', 'partial_volume_map'}),
'bbr_wm_map': In(('probability_map',
'partial_volume_map')),
'bbr_wm_mask_args': str,
'reference': In({'whole-head', 'brain'})
'reference': In(('whole-head', 'brain'))
},
},
'EPI_registration': {
Expand Down
2 changes: 0 additions & 2 deletions CPAC/registration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
create_fsl_fnirt_nonlinear_reg_nhp, \
create_register_func_to_anat, \
create_register_func_to_anat_use_T2, \
create_bbregister_func_to_anat, \
create_wf_calculate_ants_warp

from .output_func_to_standard import output_func_to_standard
Expand All @@ -13,6 +12,5 @@
'create_fsl_fnirt_nonlinear_reg_nhp',
'create_register_func_to_anat',
'create_register_func_to_anat_use_T2',
'create_bbregister_func_to_anat',
'create_wf_calculate_ants_warp',
'output_func_to_standard']
42 changes: 24 additions & 18 deletions CPAC/registration/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# You should have received a copy of the GNU Lesser General Public
# License along with C-PAC. If not, see <https://www.gnu.org/licenses/>.
"""Guardrails to protect against bad registrations"""
import logging
from copy import deepcopy
from nipype.interfaces.ants import Registration
from nipype.interfaces.fsl import FLIRT
Expand All @@ -23,7 +24,7 @@
from CPAC.pipeline.nipype_pipeline_engine.utils import connect_from_spec
from CPAC.qc import qc_masks, REGISTRATION_GUARDRAIL_THRESHOLDS


logger = logging.getLogger('nipype.workflow')
_SPEC_KEYS = {
FLIRT: {'reference': 'reference', 'registered': 'out_file'},
Registration: {'reference': 'reference', 'registered': 'out_file'}}
Expand Down Expand Up @@ -54,7 +55,7 @@ def __init__(self, *args, metric=None, value=None, threshold=None,


def registration_guardrail(registered: str, reference: str, retry: bool = False
) -> str:
):
"""Check QC metrics post-registration and throw an exception if
metrics are below given thresholds.
Expand All @@ -78,23 +79,29 @@ def registration_guardrail(registered: str, reference: str, retry: bool = False
-------
registered_mask : str
path to mask
failed_qc : int
metrics met specified thresholds?, used as index for selecting
outputs
"""
qc_metrics = qc_masks(registered, reference)
failed_qc = 0
for metric, threshold in REGISTRATION_GUARDRAIL_THRESHOLDS.items():
if threshold is not None:
value = qc_metrics.get(metric)
if isinstance(value, list):
value = value[0]
if value < threshold:
failed_qc = 1
with open(f'{registered}.failed_qc', 'w',
encoding='utf-8') as _f:
_f.write(f'{metric}: {value} < {threshold}')
if retry:
registered = f'{registered}-failed'
else:
raise BadRegistrationError(metric=metric, value=value,
threshold=threshold)
return registered
logger.error(str(BadRegistrationError(
metric=metric, value=value, threshold=threshold)))
return registered, failed_qc


def registration_guardrail_node(name=None):
Expand All @@ -112,7 +119,8 @@ def registration_guardrail_node(name=None):
name = 'registration_guardrail'
return Node(Function(input_names=['registered',
'reference'],
output_names=['registered'],
output_names=['registered',
'failed_qc'],
imports=['from CPAC.qc import qc_masks, '
'REGISTRATION_GUARDRAIL_THRESHOLDS',
'from CPAC.registration.guardrails '
Expand Down Expand Up @@ -146,10 +154,10 @@ def registration_guardrail_workflow(registration_node, retry=True):
(registration_node, guardrail, [(outkey, 'registered')])])
if retry:
wf = retry_registration(wf, registration_node,
guardrail.outputs.registered)
guardrail.outputs.registered)[0]
else:
wf.connect(guardrail, 'registered', outputspec, outkey)
wf = connect_from_spec(wf, outputspec, registration_node, outkey)
connect_from_spec(outputspec, registration_node, outkey)
return wf


Expand All @@ -167,6 +175,8 @@ def retry_registration(wf, registration_node, registered):
Returns
-------
Workflow
Node
"""
name = f'retry_{registration_node.name}'
retry_node = Node(Function(function=retry_registration_node,
Expand All @@ -177,14 +187,14 @@ def retry_registration(wf, registration_node, registered):
outputspec = registration_node.outputs
outkey = spec_key(registration_node, 'registered')
guardrail = registration_guardrail_node(f'{name}_guardrail')
wf = connect_from_spec(wf, inputspec, retry_node)
connect_from_spec(inputspec, retry_node)
wf.connect([
(inputspec, guardrail, [
(spec_key(retry_node, 'reference'), 'reference')]),
(retry_node, guardrail, [(outkey, 'registered')]),
(guardrail, outputspec, [('registered', outkey)])])
wf = connect_from_spec(wf, retry_node, outputspec, registered)
return wf
connect_from_spec(retry_node, outputspec, registered)
return wf, retry_node


def retry_registration_node(registered, registration_node):
Expand All @@ -200,16 +210,12 @@ def retry_registration_node(registered, registration_node):
-------
Node
"""
from CPAC.pipeline.random_state.seed import MAX_SEED, random_seed
seed = random_seed()
from CPAC.pipeline.random_state.seed import seed_plus_1
if registered.endswith('-failed'):
retry_node = registration_node.clone(
name=f'{registration_node.name}-retry')
if isinstance(seed, int):
if seed < MAX_SEED: # increment random seed
retry_node.seed = seed + 1
else: # loop back to minumum seed
retry_node.seed = 1
if isinstance(retry_node.seed, int):
retry_node.seed = seed_plus_1()
return retry_node
return registration_node

Expand Down
Loading

0 comments on commit 7c77f4b

Please sign in to comment.