From d4f5af602c17543c96ee8db08a4af14e691ef127 Mon Sep 17 00:00:00 2001 From: Jon Clucas Date: Mon, 31 Oct 2022 12:15:35 -0400 Subject: [PATCH] :goal_net: Guardrail BBR only --- .../nipype_pipeline_engine/__init__.py | 35 ++- .../pipeline/nipype_pipeline_engine/engine.py | 138 +++++++++--- CPAC/pipeline/schema.py | 9 +- CPAC/qc/__init__.py | 24 +- CPAC/qc/globals.py | 42 ++++ CPAC/qc/qcmetrics.py | 163 +++++++++++--- CPAC/registration/exceptions.py | 41 ++++ CPAC/registration/guardrails.py | 208 ++++++++++++++++++ CPAC/registration/registration.py | 201 +++++++++-------- .../configs/pipeline_config_default.yml | 3 +- .../configs/pipeline_config_rbc-options.yml | 3 + CPAC/utils/docs.py | 37 ++++ 12 files changed, 725 insertions(+), 179 deletions(-) create mode 100644 CPAC/qc/globals.py create mode 100644 CPAC/registration/exceptions.py create mode 100644 CPAC/registration/guardrails.py diff --git a/CPAC/pipeline/nipype_pipeline_engine/__init__.py b/CPAC/pipeline/nipype_pipeline_engine/__init__.py index 48b445241b..66f4111cce 100644 --- a/CPAC/pipeline/nipype_pipeline_engine/__init__.py +++ b/CPAC/pipeline/nipype_pipeline_engine/__init__.py @@ -1,25 +1,24 @@ -'''Module to import Nipype Pipeline engine and override some Classes. -See https://fcp-indi.github.io/docs/developer/nodes -for C-PAC-specific documentation. -See https://nipype.readthedocs.io/en/latest/api/generated/nipype.pipeline.engine.html -for Nipype's documentation. - -Copyright (C) 2022 C-PAC Developers +# Copyright (C) 2022 C-PAC Developers -This file is part of C-PAC. +# This file is part of C-PAC. -C-PAC is free software: you can redistribute it and/or modify it under -the terms of the GNU Lesser General Public License as published by the -Free Software Foundation, either version 3 of the License, or (at your -option) any later version. +# C-PAC is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the +# Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. -C-PAC is distributed in the hope that it will be useful, but WITHOUT -ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or -FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public -License for more details. +# C-PAC is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. -You should have received a copy of the GNU Lesser General Public -License along with C-PAC. If not, see .''' # noqa: E501 +# You should have received a copy of the GNU Lesser General Public +# License along with C-PAC. If not, see . +'''Module to import Nipype Pipeline engine and override some Classes. +See https://fcp-indi.github.io/docs/developer/nodes +for C-PAC-specific documentation. +See https://nipype.readthedocs.io/en/latest/api/generated/nipype.pipeline.engine.html +for Nipype's documentation.''' # noqa: E501 # pylint: disable=line-too-long from nipype.pipeline import engine as pe # import everything in nipype.pipeline.engine.__all__ from nipype.pipeline.engine import * # noqa: F401,F403 diff --git a/CPAC/pipeline/nipype_pipeline_engine/engine.py b/CPAC/pipeline/nipype_pipeline_engine/engine.py index 12e8808f1f..8695b1e536 100644 --- a/CPAC/pipeline/nipype_pipeline_engine/engine.py +++ b/CPAC/pipeline/nipype_pipeline_engine/engine.py @@ -1,43 +1,57 @@ -'''Module to import Nipype Pipeline engine and override some Classes. -See https://fcp-indi.github.io/docs/developer/nodes -for C-PAC-specific documentation. -See https://nipype.readthedocs.io/en/latest/api/generated/nipype.pipeline.engine.html -for Nipype's documentation. +# STATEMENT OF CHANGES: +# This file is derived from sources licensed under the Apache-2.0 terms, +# and this file has been changed. -STATEMENT OF CHANGES: - This file is derived from sources licensed under the Apache-2.0 terms, - and this file has been changed. +# CHANGES: +# * Supports just-in-time dynamic memory allocation +# * Skips doctests that require files that we haven't copied over +# * Applies a random seed +# * Supports overriding memory estimates via a log file and a buffer +# * Adds quotation marks around strings in dotfiles -CHANGES: - * Supports just-in-time dynamic memory allocation - * Skips doctests that require files that we haven't copied over - * Applies a random seed - * Supports overriding memory estimates via a log file and a buffer +# ORIGINAL WORK'S ATTRIBUTION NOTICE: +# Copyright (c) 2009-2016, Nipype developers -ORIGINAL WORK'S ATTRIBUTION NOTICE: - Copyright (c) 2009-2016, Nipype developers +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 - http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +# Prior to release 0.12, Nipype was licensed under a BSD license. - Prior to release 0.12, Nipype was licensed under a BSD license. +# Modifications Copyright (C) 2022 C-PAC Developers -Modifications Copyright (C) 2022 C-PAC Developers +# This file is part of C-PAC. -This file is part of C-PAC.''' # noqa: E501 +# C-PAC is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the +# Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. + +# C-PAC is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +# You should have received a copy of the GNU Lesser General Public +# License along with C-PAC. If not, see . +'''Module to import Nipype Pipeline engine and override some Classes. +See https://fcp-indi.github.io/docs/developer/nodes +for C-PAC-specific documentation. +See https://nipype.readthedocs.io/en/latest/api/generated/nipype.pipeline.engine.html +for Nipype's documentation.''' # noqa: E501 # pylint: disable=line-too-long import os import re -from logging import getLogger from inspect import Parameter, Signature, signature +from logging import getLogger +from typing import Iterable, Tuple, Union from nibabel import load from nipype import logging from nipype.interfaces.utility import Function @@ -53,6 +67,7 @@ UNDEFINED_SIZE = (42, 42, 42, 1200) random_state_logger = getLogger('random') +logger = getLogger("nipype.workflow") def _check_mem_x_path(mem_x_path): @@ -399,10 +414,9 @@ def run(self, updatehash=False): if self.seed is not None: self._apply_random_seed() if self.seed_applied: - random_state_logger.info('%s', - '%s # (Atropos constant)' % - self.name if 'atropos' in - self.name else self.name) + random_state_logger.info('%s\t%s', '# (Atropos constant)' if + 'atropos' in self.name else + str(self.seed), self.name) return super().run(updatehash) @@ -483,6 +497,40 @@ def _configure_exec_nodes(self, graph): TypeError): self._handle_just_in_time_exception(node) + def connect_retries(self, nodes: Iterable['Node'], + connections: Iterable[Tuple['Node', Union[str, tuple], + str]]) -> None: + """Method to generalize making the same connections to try and + retry nodes. + + For each 3-tuple (``conn``) in ``connections``, will do + ``wf.connect(conn[0], conn[1], node, conn[2])`` for each ``node`` + in ``nodes`` + + Parameters + ---------- + nodes : iterable of Nodes + + connections : iterable of 3-tuples of (Node, str or tuple, str) + """ + wrong_conn_type_msg = (r'connect_retries `connections` argument ' + 'must be an iterable of (Node, str or ' + 'tuple, str) tuples.') + if not isinstance(connections, (list, tuple)): + raise TypeError(f'{wrong_conn_type_msg}: Given {connections}') + for node in nodes: + if not isinstance(node, Node): + raise TypeError('connect_retries requires an iterable ' + r'of nodes for the `nodes` parameter: ' + f'Given {node}') + for conn in connections: + if not all((isinstance(conn, (list, tuple)), len(conn) == 3, + isinstance(conn[0], Node), + isinstance(conn[1], (tuple, str)), + isinstance(conn[2], str))): + raise TypeError(f'{wrong_conn_type_msg}: Given {conn}') + self.connect(*conn[:2], node, conn[2]) + def _handle_just_in_time_exception(self, node): # pylint: disable=protected-access if hasattr(self, '_local_func_scans'): @@ -492,6 +540,32 @@ def _handle_just_in_time_exception(self, node): # TODO: handle S3 files node._apply_mem_x(UNDEFINED_SIZE) # noqa: W0212 + def nodes_and_guardrails(self, *nodes, registered, add_clones=True): + """Returns a two tuples of Nodes: (try, retry) and their + respective guardrails + + Parameters + ---------- + nodes : any number of Nodes + + Returns + ------- + nodes : tuple of Nodes + + guardrails : tuple of Nodes + """ + from CPAC.registration.guardrails import registration_guardrail_node, \ + retry_clone + nodes = list(nodes) + if add_clones is True: + nodes.extend([retry_clone(node) for node in nodes]) + guardrails = [None] * len(nodes) + for i, node in enumerate(nodes): + guardrails[i] = registration_guardrail_node( + f'guardrail_{node.name}', i) + self.connect(node, registered, guardrails[i], 'registered') + return tuple(nodes), tuple(guardrails) + def get_data_size(filepath, mode='xyzt'): """Function to return the size of a functional image (x * y * z * t) diff --git a/CPAC/pipeline/schema.py b/CPAC/pipeline/schema.py index 19764b6cfb..dfbcc92fd1 100644 --- a/CPAC/pipeline/schema.py +++ b/CPAC/pipeline/schema.py @@ -21,7 +21,7 @@ from itertools import chain, permutations import numpy as np from pathvalidate import sanitize_filename -from voluptuous import All, ALLOW_EXTRA, Any, Capitalize, Coerce, \ +from voluptuous import All, ALLOW_EXTRA, Any, Capitalize, Coerce, Equal, \ ExactSequence, ExclusiveInvalid, In, Length, Lower, \ Match, Maybe, Optional, Range, Required, Schema from CPAC import docs_prefix @@ -526,9 +526,12 @@ def sanitize(filename): }, }, 'boundary_based_registration': { - 'run': forkable, + 'run': All(Coerce(ListFromItem), + [Any(bool1_1, All(Lower, Equal('fallback')))], + Length(max=3)), 'bbr_schedule': str, - 'bbr_wm_map': In({'probability_map', 'partial_volume_map'}), + 'bbr_wm_map': In({'probability_map', + 'partial_volume_map'}), 'bbr_wm_mask_args': str, 'reference': In({'whole-head', 'brain'}) }, diff --git a/CPAC/qc/__init__.py b/CPAC/qc/__init__.py index 75ee654fec..810d06aedc 100644 --- a/CPAC/qc/__init__.py +++ b/CPAC/qc/__init__.py @@ -1,2 +1,22 @@ -from .utils import * -from .qc import * +# Copyright (C) 2013-2022 C-PAC Developers + +# This file is part of C-PAC. + +# C-PAC is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the +# Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. + +# C-PAC is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +# You should have received a copy of the GNU Lesser General Public +# License along with C-PAC. If not, see . +"""Quality control utilities for C-PAC""" +from CPAC.qc.globals import registration_guardrail_thresholds, \ + update_thresholds +from CPAC.qc.qcmetrics import qc_masks +__all__ = ['qc_masks', 'registration_guardrail_thresholds', + 'update_thresholds'] diff --git a/CPAC/qc/globals.py b/CPAC/qc/globals.py new file mode 100644 index 0000000000..e4a05d8d9d --- /dev/null +++ b/CPAC/qc/globals.py @@ -0,0 +1,42 @@ +# Copyright (C) 2022 C-PAC Developers + +# This file is part of C-PAC. + +# C-PAC is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the +# Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. + +# C-PAC is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +# You should have received a copy of the GNU Lesser General Public +# License along with C-PAC. If not, see . +"""Global QC values""" +_REGISTRATION_GUARDRAIL_THRESHOLDS = {'thresholds': {}} + + +def registration_guardrail_thresholds() -> dict: + """Get registration guardrail thresholds + + Returns + ------- + dict + """ + return _REGISTRATION_GUARDRAIL_THRESHOLDS['thresholds'] + + +def update_thresholds(thresholds) -> None: + """Set a registration guardrail threshold + + Parameters + ---------- + thresholds : dict of {str: float or int} + + Returns + ------- + None + """ + _REGISTRATION_GUARDRAIL_THRESHOLDS['thresholds'].update(thresholds) diff --git a/CPAC/qc/qcmetrics.py b/CPAC/qc/qcmetrics.py index 6db977c495..b45430020c 100644 --- a/CPAC/qc/qcmetrics.py +++ b/CPAC/qc/qcmetrics.py @@ -1,24 +1,88 @@ +# Modifications: Copyright (C) 2022 C-PAC Developers + +# This file is part of C-PAC. + +# C-PAC is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the +# Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. + +# C-PAC is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +# You should have received a copy of the GNU Lesser General Public +# License along with C-PAC. If not, see . + +# Original code: BSD 3-Clause License + +# Copyright (c) 2020, Lifespan Informatics and Neuroimaging Center + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. """QC metrics from XCP-D v0.0.9 Ref: https://github.com/PennLINC/xcp_d/tree/0.0.9 """ +# LGPL-3.0-or-later: Module docstring and lint exclusions # pylint: disable=invalid-name, redefined-outer-name +# BSD-3-Clause: imports and unspecified sections import nibabel as nb import numpy as np -def regisQ(bold2t1w_mask, t1w_mask, bold2template_mask, template_mask): - reg_qc = {'coregDice': [dc(bold2t1w_mask, t1w_mask)], - 'coregJaccard': [jc(bold2t1w_mask, t1w_mask)], - 'coregCrossCorr': [crosscorr(bold2t1w_mask, t1w_mask)], - 'coregCoverage': [coverage(bold2t1w_mask, t1w_mask)], - 'normDice': [dc(bold2template_mask, template_mask)], - 'normJaccard': [jc(bold2template_mask, template_mask)], - 'normCrossCorr': [crosscorr(bold2template_mask, template_mask)], - 'normCoverage': [coverage(bold2template_mask, template_mask)]} - return reg_qc +# BSD-3-Clause +def coverage(input1, input2): + """Estimate the coverage between two masks.""" + input1 = nb.load(input1).get_fdata() + input2 = nb.load(input2).get_fdata() + input1 = np.atleast_1d(input1.astype(np.bool)) + input2 = np.atleast_1d(input2.astype(np.bool)) + intsec = np.count_nonzero(input1 & input2) + if np.sum(input1) > np.sum(input2): + smallv = np.sum(input2) + else: + smallv = np.sum(input1) + cov = float(intsec)/float(smallv) + return cov + + +# BSD-3-Clause +def crosscorr(input1, input2): + r"""cross correlation: compute cross correction bewteen input masks""" + input1 = nb.load(input1).get_fdata() + input2 = nb.load(input2).get_fdata() + input1 = np.atleast_1d(input1.astype(np.bool)).flatten() + input2 = np.atleast_1d(input2.astype(np.bool)).flatten() + cc = np.corrcoef(input1, input2)[0][1] + return cc +# BSD-3-Clause def dc(input1, input2): r""" Dice coefficient @@ -71,6 +135,7 @@ def dc(input1, input2): return dc +# BSD-3-Clause def jc(input1, input2): r""" Jaccard coefficient @@ -106,26 +171,62 @@ def jc(input1, input2): return jc -def crosscorr(input1, input2): - r"""cross correlation: compute cross correction bewteen input masks""" - input1 = nb.load(input1).get_fdata() - input2 = nb.load(input2).get_fdata() - input1 = np.atleast_1d(input1.astype(np.bool)).flatten() - input2 = np.atleast_1d(input2.astype(np.bool)).flatten() - cc = np.corrcoef(input1, input2)[0][1] - return cc +# LGPL-3.0-or-later +def _prefix_regqc_keys(qc_dict: dict, prefix: str) -> str: + """Prepend string to each key in a qc dict + Parameters + ---------- + qc_dict : dict + output of ``qc_masks`` -def coverage(input1, input2): - """Estimate the coverage between two masks.""" - input1 = nb.load(input1).get_fdata() - input2 = nb.load(input2).get_fdata() - input1 = np.atleast_1d(input1.astype(np.bool)) - input2 = np.atleast_1d(input2.astype(np.bool)) - intsec = np.count_nonzero(input1 & input2) - if np.sum(input1) > np.sum(input2): - smallv = np.sum(input2) - else: - smallv = np.sum(input1) - cov = float(intsec)/float(smallv) - return cov + prefix : str + string to prepend + + Returns + ------- + dict + """ + return {f'{prefix}{_key}': _value for _key, _value in qc_dict.items()} + + +# BSD-3-Clause: logic +# LGPL-3.0-or-later: docstring and refactored function +def qc_masks(registered_mask: str, native_mask: str) -> dict: + """Return QC measures for coregistration + + Parameters + ---------- + registered_mask : str + path to registered mask + + native_mask : str + path to native-space mask + + Returns + ------- + dict + """ + return {'Dice': [dc(registered_mask, native_mask)], + 'Jaccard': [jc(registered_mask, native_mask)], + 'CrossCorr': [crosscorr(registered_mask, native_mask)], + 'Coverage': [coverage(registered_mask, native_mask)]} + + +# BSD-3-Clause: name and signature +# LGPL-3.0-or-later: docstring and refactored function +def regisQ(bold2t1w_mask: str, t1w_mask: str, bold2template_mask: str, + template_mask: str) -> dict: + """Collect coregistration QC measures + + Parameters + ---------- + bold2t1w_mask, t1w_mask, bold2template_mask, template_mask : str + + Returns + ------- + dict + """ + return {**_prefix_regqc_keys(qc_masks(bold2t1w_mask, t1w_mask), 'coreg'), + **_prefix_regqc_keys(qc_masks(bold2template_mask, template_mask), + 'norm')} diff --git a/CPAC/registration/exceptions.py b/CPAC/registration/exceptions.py new file mode 100644 index 0000000000..d962ddfa30 --- /dev/null +++ b/CPAC/registration/exceptions.py @@ -0,0 +1,41 @@ +# Copyright (C) 2022 C-PAC Developers + +# This file is part of C-PAC. + +# C-PAC is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the +# Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. + +# C-PAC is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +# You should have received a copy of the GNU Lesser General Public +# License along with C-PAC. If not, see . +"""Custom registration exceptions""" + + +class BadRegistrationError(ValueError): + """Exception for when a QC measure for a registration falls below a + specified threshold""" + def __init__(self, *args, metric=None, value=None, threshold=None, + **kwargs): + """ + Parameters + ---------- + metric : str + QC metric + + value : float + calculated QC value + + threshold : float + specified threshold + """ + msg = "Registration failed quality control" + if all(arg is not None for arg in (metric, value, threshold)): + msg += f" ({metric}: {value} < {threshold})" + msg += "." + super().__init__(msg, *args, **kwargs) diff --git a/CPAC/registration/guardrails.py b/CPAC/registration/guardrails.py new file mode 100644 index 0000000000..1329cad97f --- /dev/null +++ b/CPAC/registration/guardrails.py @@ -0,0 +1,208 @@ +# Copyright (C) 2022 C-PAC Developers + +# This file is part of C-PAC. + +# C-PAC is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the +# Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. + +# C-PAC is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +# You should have received a copy of the GNU Lesser General Public +# License along with C-PAC. If not, see . +"""Guardrails to protect against bad registrations""" +import logging +from typing import Tuple +from nipype.interfaces.utility import Function, Merge, Select +# pylint: disable=unused-import +from CPAC.pipeline.nipype_pipeline_engine import Node, Workflow +from CPAC.pipeline.random_state.seed import increment_seed +from CPAC.qc import qc_masks, registration_guardrail_thresholds +from CPAC.registration.exceptions import BadRegistrationError +from CPAC.registration.utils import hardcoded_reg +from CPAC.utils.docs import retry_docstring + + +# noqa: F401 +def guardrail_selection(wf: 'Workflow', node1: 'Node', node2: 'Node', + output_key: str = 'registered', + guardrail_node: 'Node' = None) -> Node: + """Generate requisite Nodes for choosing a path through the graph + with retries. + + Takes two nodes to choose an output from. These nodes are assumed + to be guardrail nodes if `output_key` and `guardrail_node` are not + specified. + + A ``nipype.interfaces.utility.Merge`` is generated, connecting + ``output_key`` from ``node1`` and ``node2`` in that order. + + A ``nipype.interfaces.utility.Select`` node is generated taking the + output from the generated ``Merge`` and using the ``failed_qc`` + output of ``guardrail_node`` (``node1`` if ``guardrail_node`` is + unspecified). + + All relevant connections are made in the given Workflow. + + The ``Select`` node is returned; its output is keyed ``out`` and + contains the value of the given ``output_key`` (``registered`` if + unspecified). + + Parameters + ---------- + wf : Workflow + + node1, node2 : Node + first try, retry + + output_key : str + field to choose + + guardrail_node : Node + guardrail to collect 'failed_qc' from if not node1 + + Returns + ------- + select : Node + """ + # pylint: disable=redefined-outer-name,reimported,unused-import + from CPAC.pipeline.nipype_pipeline_engine import Node, Workflow + if guardrail_node is None: + guardrail_node = node1 + name = node1.name + if output_key != 'registered': + name = f'{name}_{output_key}' + choices = Node(Merge(2), run_without_submitting=True, + name=f'{name}_choices') + select = Node(Select(), run_without_submitting=True, + name=f'choose_{name}') + wf.connect([(node1, choices, [(output_key, 'in1')]), + (node2, choices, [(output_key, 'in2')]), + (choices, select, [('out', 'inlist')]), + (guardrail_node, select, [('failed_qc', 'index')])]) + return select + + +def registration_guardrail(registered: str, reference: str, + retry: bool = False, retry_num: int = 0 + ) -> Tuple[str, int]: + """Check QC metrics post-registration and throw an exception if + metrics are below given thresholds. + + If inputs point to images that are not masks, images will be + binarized before being compared. + + .. seealso:: + + :py:mod:`CPAC.qc.qcmetrics` + Documentation of the :py:mod:`CPAC.qc.qcmetrics` module. + + Parameters + ---------- + registered, reference : str + path to mask + + retry : bool, optional + can retry? + + retry_num : int, optional + how many previous tries? + + Returns + ------- + registered_mask : str + path to mask + + failed_qc : int + metrics met specified thresholds?, used as index for selecting + outputs + .. seealso:: + + :py:mod:`guardrail_selection` + """ + logger = logging.getLogger('nipype.workflow') + qc_metrics = qc_masks(registered, reference) + failed_qc = 0 + for metric, threshold in registration_guardrail_thresholds().items(): + if threshold is not None: + value = qc_metrics.get(metric) + if isinstance(value, list): + value = value[0] + if value < threshold: + failed_qc = 1 + with open(f'{registered}.failed_qc', 'w', + encoding='utf-8') as _f: + _f.write(f'{metric}: {value} < {threshold}') + if retry: + registered = f'{registered}-failed' + else: + bad_registration = BadRegistrationError( + metric=metric, value=value, threshold=threshold) + logger.error(str(bad_registration)) + if retry_num: + # if we've already retried, raise the error + raise bad_registration + return registered, failed_qc + + +def registration_guardrail_node(name=None, retry_num=0): + """Convenience method to get a new registration_guardrail Node + + Parameters + ---------- + name : str, optional + + retry_num : int, optional + how many previous tries? + + Returns + ------- + Node + """ + if name is None: + name = 'registration_guardrail' + node = Node(Function(input_names=['registered', 'reference', 'retry_num'], + output_names=['registered', 'failed_qc'], + imports=['import logging', + 'from typing import Tuple', + 'from CPAC.qc import qc_masks, ' + 'registration_guardrail_thresholds', + 'from CPAC.registration.guardrails ' + 'import BadRegistrationError'], + function=registration_guardrail), name=name) + if retry_num: + node.inputs.retry_num = retry_num + return node + + +def retry_clone(node: 'Node') -> 'Node': + """Function to clone a node, name the clone, and increment its + random seed + + Parameters + ---------- + node : Node + + Returns + ------- + Node + """ + return increment_seed(node.clone(f'retry_{node.name}')) + + +# pylint: disable=missing-function-docstring,too-many-arguments +@retry_docstring(hardcoded_reg) +def retry_hardcoded_reg(moving_brain, reference_brain, moving_skull, + reference_skull, ants_para, moving_mask=None, + reference_mask=None, fixed_image_mask=None, + interp=None, reg_with_skull=0, previous_failure=False): + if not previous_failure: + return [], None + return hardcoded_reg(moving_brain, reference_brain, moving_skull, + reference_skull, ants_para, moving_mask, + reference_mask, fixed_image_mask, interp, + reg_with_skull) diff --git a/CPAC/registration/registration.py b/CPAC/registration/registration.py index b2260a9641..3603c70b2e 100644 --- a/CPAC/registration/registration.py +++ b/CPAC/registration/registration.py @@ -23,6 +23,8 @@ from CPAC.anat_preproc.lesion_preproc import create_lesion_preproc from CPAC.func_preproc.utils import chunk_ts, split_ts_chunks +from CPAC.registration.guardrails import guardrail_selection, \ + registration_guardrail_node from CPAC.registration.utils import seperate_warps_list, \ check_transforms, \ generate_inverse_transform_flags, \ @@ -739,7 +741,7 @@ def create_register_func_to_anat(config, phase_diff_distcor=False, return register_func_to_anat -def create_register_func_to_anat_use_T2(config, name='register_func_to_anat_use_T2'): +def create_register_func_to_anat_use_T2(name='register_func_to_anat_use_T2'): # for monkey data # ref: https://github.com/DCAN-Labs/dcan-macaque-pipeline/blob/master/fMRIVolume/GenericfMRIVolumeProcessingPipeline.sh#L287-L295 # https://github.com/HechengJin0/dcan-macaque-pipeline/blob/master/fMRIVolume/GenericfMRIVolumeProcessingPipeline.sh#L524-L535 @@ -776,8 +778,6 @@ def create_register_func_to_anat_use_T2(config, name='register_func_to_anat_use_ outputspec.anat_func_nobbreg : string (nifti file) Functional scan registered to anatomical space """ - - register_func_to_anat_use_T2 = pe.Workflow(name=name) inputspec = pe.Node(util.IdentityInterface(fields=['func', @@ -877,13 +877,12 @@ def create_register_func_to_anat_use_T2(config, name='register_func_to_anat_use_ def create_bbregister_func_to_anat(phase_diff_distcor=False, - name='bbregister_func_to_anat'): - + name='bbregister_func_to_anat', + retry=False): """ Registers a functional scan in native space to structural. This is meant to be used after create_nonlinear_register() has been run and relies on some of its outputs. - Parameters ---------- fieldmap_distortion : bool, optional @@ -891,6 +890,8 @@ def create_bbregister_func_to_anat(phase_diff_distcor=False, take in the appropriate field map-related inputs. name : string, optional Name of the workflow. + retry : bool + Try twice? Returns ------- @@ -919,7 +920,6 @@ def create_bbregister_func_to_anat(phase_diff_distcor=False, outputspec.anat_func : string (nifti file) Functional data in anatomical space """ - register_bbregister_func_to_anat = pe.Workflow(name=name) inputspec = pe.Node(util.IdentityInterface(fields=['func', @@ -948,7 +948,6 @@ def create_bbregister_func_to_anat(phase_diff_distcor=False, register_bbregister_func_to_anat.connect( inputspec, 'bbr_wm_mask_args', wm_bb_mask, 'op_string') - register_bbregister_func_to_anat.connect(inputspec, 'anat_wm_segmentation', wm_bb_mask, 'in_file') @@ -959,49 +958,38 @@ def bbreg_args(bbreg_target): bbreg_func_to_anat = pe.Node(interface=fsl.FLIRT(), name='bbreg_func_to_anat') bbreg_func_to_anat.inputs.dof = 6 - - register_bbregister_func_to_anat.connect( - inputspec, 'bbr_schedule', - bbreg_func_to_anat, 'schedule') - - register_bbregister_func_to_anat.connect( - wm_bb_mask, ('out_file', bbreg_args), - bbreg_func_to_anat, 'args') - - register_bbregister_func_to_anat.connect( - inputspec, 'func', - bbreg_func_to_anat, 'in_file') - - register_bbregister_func_to_anat.connect( - inputspec, 'anat', - bbreg_func_to_anat, 'reference') - - register_bbregister_func_to_anat.connect( - inputspec, 'linear_reg_matrix', - bbreg_func_to_anat, 'in_matrix_file') - + nodes, guardrails = register_bbregister_func_to_anat.nodes_and_guardrails( + bbreg_func_to_anat, registered='out_file', add_clones=bool(retry)) + register_bbregister_func_to_anat.connect_retries(nodes, [ + (inputspec, 'bbr_schedule', 'schedule'), + (wm_bb_mask, ('out_file', bbreg_args), 'args'), + (inputspec, 'func', 'in_file'), + (inputspec, 'anat', 'reference'), + (inputspec, 'linear_reg_matrix', 'in_matrix_file')]) if phase_diff_distcor: + register_bbregister_func_to_anat.connect_retries(nodes, [ + (inputNode_pedir, ('pedir', convert_pedir), 'pedir'), + (inputspec, 'fieldmap', 'fieldmap'), + (inputspec, 'fieldmapmask', 'fieldmapmask'), + (inputNode_echospacing, 'echospacing', 'echospacing')]) + register_bbregister_func_to_anat.connect_retries(guardrails, [ + (inputspec, 'anat', 'reference')]) + if retry: + # pylint: disable=no-value-for-parameter + outfile = guardrail_selection(register_bbregister_func_to_anat, + *guardrails) + matrix = guardrail_selection(register_bbregister_func_to_anat, *nodes, + 'out_matrix_file', guardrails[0]) register_bbregister_func_to_anat.connect( - inputNode_pedir, ('pedir', convert_pedir), - bbreg_func_to_anat, 'pedir') - register_bbregister_func_to_anat.connect( - inputspec, 'fieldmap', - bbreg_func_to_anat, 'fieldmap') - register_bbregister_func_to_anat.connect( - inputspec, 'fieldmapmask', - bbreg_func_to_anat, 'fieldmapmask') + matrix, 'out', outputspec, 'func_to_anat_linear_xfm') + register_bbregister_func_to_anat.connect(outfile, 'out', + outputspec, 'anat_func') + else: register_bbregister_func_to_anat.connect( - inputNode_echospacing, 'echospacing', - bbreg_func_to_anat, 'echospacing') - - register_bbregister_func_to_anat.connect( - bbreg_func_to_anat, 'out_matrix_file', - outputspec, 'func_to_anat_linear_xfm') - - register_bbregister_func_to_anat.connect( - bbreg_func_to_anat, 'out_file', - outputspec, 'anat_func') - + bbreg_func_to_anat, 'out_matrix_file', + outputspec, 'func_to_anat_linear_xfm') + register_bbregister_func_to_anat.connect(guardrails[0], 'registered', + outputspec, 'anat_func') return register_bbregister_func_to_anat @@ -2754,8 +2742,8 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): "config": ["registration_workflows", "functional_registration", "coregistration"], "switch": ["run"], - "option_key": "None", - "option_val": "None", + "option_key": ["boundary_based_registration", "run"], + "option_val": [True, False, "fallback"], "inputs": [("sbref", "desc-motion_bold", "space-bold_label-WM_mask", @@ -2766,7 +2754,6 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): ("desc-preproc_T1w", "desc-restore-brain_T1w", "desc-preproc_T2w", - "desc-preproc_T2w", "T2w", ["label-WM_probseg", "label-WM_mask"], ["label-WM_pveseg", "label-WM_mask"], @@ -2775,22 +2762,19 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): "from-bold_to-T1w_mode-image_desc-linear_xfm", "from-bold_to-T1w_mode-image_desc-linear_warp"]} ''' - - diff_complete = False - if strat_pool.check_rpool("despiked-fieldmap") and \ - strat_pool.check_rpool("fieldmap-mask"): - diff_complete = True - + diff_complete = (strat_pool.check_rpool("despiked-fieldmap") and + strat_pool.check_rpool("fieldmap-mask")) + bbreg_status = "On" if opt is True else "Off" if isinstance( + opt, bool) else opt.title() + subwfname = f'func_to_anat_FLIRT_bbreg{bbreg_status}_{pipe_num}' if strat_pool.check_rpool('T2w') and cfg.anatomical_preproc['run_t2']: # monkey data - func_to_anat = create_register_func_to_anat_use_T2(cfg, - f'func_to_anat_FLIRT_' - f'{pipe_num}') + func_to_anat = create_register_func_to_anat_use_T2(subwfname) # https://github.com/DCAN-Labs/dcan-macaque-pipeline/blob/master/fMRIVolume/GenericfMRIVolumeProcessingPipeline.sh#L177 # fslmaths "$fMRIFolder"/"$NameOffMRI"_mc -Tmean "$fMRIFolder"/"$ScoutName"_gdc func_mc_mean = pe.Node(interface=afni_utils.TStat(), - name=f'func_motion_corrected_mean_{pipe_num}') + name=f'func_motion_corrected_mean_{pipe_num}') func_mc_mean.inputs.options = '-mean' func_mc_mean.inputs.outputtype = 'NIFTI_GZ' @@ -2813,24 +2797,23 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): # if field map-based distortion correction is on, but BBR is off, # send in the distortion correction files here func_to_anat = create_register_func_to_anat(cfg, diff_complete, - f'func_to_anat_FLIRT_' - f'{pipe_num}') + subwfname) func_to_anat.inputs.inputspec.dof = cfg.registration_workflows[ - 'functional_registration']['coregistration']['dof'] + 'functional_registration']['coregistration']['dof'] func_to_anat.inputs.inputspec.interp = cfg.registration_workflows[ - 'functional_registration']['coregistration']['interpolation'] + 'functional_registration']['coregistration']['interpolation'] node, out = strat_pool.get_data('sbref') wf.connect(node, out, func_to_anat, 'inputspec.func') if cfg.registration_workflows['functional_registration'][ - 'coregistration']['reference'] == 'brain': + 'coregistration']['reference'] == 'brain': # TODO: use JSON meta-data to confirm node, out = strat_pool.get_data('desc-preproc_T1w') elif cfg.registration_workflows['functional_registration'][ - 'coregistration']['reference'] == 'restore-brain': + 'coregistration']['reference'] == 'restore-brain': node, out = strat_pool.get_data('desc-restore-brain_T1w') wf.connect(node, out, func_to_anat, 'inputspec.anat') @@ -2864,22 +2847,22 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): (func_to_anat, 'outputspec.func_to_anat_linear_xfm_nobbreg') } - if True in cfg.registration_workflows['functional_registration'][ - 'coregistration']["boundary_based_registration"]["run"]: - - func_to_anat_bbreg = create_bbregister_func_to_anat(diff_complete, - f'func_to_anat_' - f'bbreg_' - f'{pipe_num}') + if opt in [True, 'fallback']: + fallback = opt == 'fallback' + func_to_anat_bbreg = create_bbregister_func_to_anat( + diff_complete, f'func_to_anat_bbreg{bbreg_status}_{pipe_num}', + opt is True) func_to_anat_bbreg.inputs.inputspec.bbr_schedule = \ cfg.registration_workflows['functional_registration'][ 'coregistration']['boundary_based_registration'][ 'bbr_schedule'] - func_to_anat_bbreg.inputs.inputspec.bbr_wm_mask_args = \ cfg.registration_workflows['functional_registration'][ 'coregistration']['boundary_based_registration'][ 'bbr_wm_mask_args'] + if fallback: + bbreg_guardrail = registration_guardrail_node( + f'bbreg{bbreg_status}_guardrail_{pipe_num}', 1) node, out = strat_pool.get_data('sbref') wf.connect(node, out, func_to_anat_bbreg, 'inputspec.func') @@ -2889,31 +2872,35 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): 'reference'] == 'whole-head': node, out = strat_pool.get_data('desc-head_T1w') wf.connect(node, out, func_to_anat_bbreg, 'inputspec.anat') + if fallback: + wf.connect(node, out, bbreg_guardrail, 'reference') elif cfg.registration_workflows['functional_registration'][ 'coregistration']['boundary_based_registration'][ 'reference'] == 'brain': node, out = strat_pool.get_data('desc-preproc_T1w') wf.connect(node, out, func_to_anat_bbreg, 'inputspec.anat') + if fallback: + wf.connect(node, out, bbreg_guardrail, 'reference') wf.connect(func_to_anat, 'outputspec.func_to_anat_linear_xfm_nobbreg', func_to_anat_bbreg, 'inputspec.linear_reg_matrix') if strat_pool.check_rpool('space-bold_label-WM_mask'): node, out = strat_pool.get_data(["space-bold_label-WM_mask"]) - wf.connect(node, out, - func_to_anat_bbreg, 'inputspec.anat_wm_segmentation') else: - if cfg.registration_workflows['functional_registration'][ - 'coregistration']['boundary_based_registration']['bbr_wm_map'] == 'probability_map': + if cfg['registration_workflows', 'functional_registration', + 'coregistration', 'boundary_based_registration', + 'bbr_wm_map'] == 'probability_map': node, out = strat_pool.get_data(["label-WM_probseg", "label-WM_mask"]) - elif cfg.registration_workflows['functional_registration'][ - 'coregistration']['boundary_based_registration']['bbr_wm_map'] == 'partial_volume_map': + elif cfg['registration_workflows', 'functional_registration', + 'coregistration', 'boundary_based_registration', + 'bbr_wm_map'] == 'partial_volume_map': node, out = strat_pool.get_data(["label-WM_pveseg", "label-WM_mask"]) - wf.connect(node, out, - func_to_anat_bbreg, 'inputspec.anat_wm_segmentation') + wf.connect(node, out, + func_to_anat_bbreg, 'inputspec.anat_wm_segmentation') if diff_complete: node, out = strat_pool.get_data('effectiveEchoSpacing') @@ -2929,15 +2916,45 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): node, out = strat_pool.get_data("fieldmap-mask") wf.connect(node, out, func_to_anat_bbreg, 'inputspec.fieldmapmask') - - outputs = { - 'space-T1w_sbref': - (func_to_anat_bbreg, 'outputspec.anat_func'), - 'from-bold_to-T1w_mode-image_desc-linear_xfm': - (func_to_anat_bbreg, 'outputspec.func_to_anat_linear_xfm') - } - - return (wf, outputs) + if fallback: + # Fall back to no-BBReg + mean_bolds = pe.Node(util.Merge(2), run_without_submitting=True, + name=f'bbreg_mean_bold_choices_{pipe_num}') + xfms = pe.Node(util.Merge(2), run_without_submitting=True, + name=f'bbreg_xfm_choices_{pipe_num}') + fallback_mean_bolds = pe.Node(util.Select(), + run_without_submitting=True, + name='bbreg_choose_mean_bold_' + f'{pipe_num}') + fallback_xfms = pe.Node(util.Select(), run_without_submitting=True, + name=f'bbreg_choose_xfm_{pipe_num}') + wf.connect([ + (func_to_anat_bbreg, bbreg_guardrail, [ + ('outputspec.anat_func', 'registered')]), + (bbreg_guardrail, mean_bolds, [('registered', 'in1')]), + (func_to_anat, mean_bolds, [('outputspec.anat_func_nobbreg', + 'in2')]), + (func_to_anat_bbreg, xfms, [ + ('outputspec.func_to_anat_linear_xfm', 'in1')]), + (func_to_anat, xfms, [ + ('outputspec.func_to_anat_linear_xfm_nobbreg', 'in2')]), + (mean_bolds, fallback_mean_bolds, [('out', 'inlist')]), + (xfms, fallback_xfms, [('out', 'inlist')]), + (bbreg_guardrail, fallback_mean_bolds, [ + ('failed_qc', 'index')]), + (bbreg_guardrail, fallback_xfms, [('failed_qc', 'index')])]) + outputs = { + 'space-T1w_sbref': (fallback_mean_bolds, 'out'), + 'from-bold_to-T1w_mode-image_desc-linear_xfm': (fallback_xfms, + 'out')} + else: + outputs = { + 'space-T1w_sbref': (func_to_anat_bbreg, + 'outputspec.anat_func'), + 'from-bold_to-T1w_mode-image_desc-linear_xfm': ( + func_to_anat_bbreg, + 'outputspec.func_to_anat_linear_xfm')} + return wf, outputs def create_func_to_T1template_xfm(wf, cfg, strat_pool, pipe_num, opt=None): diff --git a/CPAC/resources/configs/pipeline_config_default.yml b/CPAC/resources/configs/pipeline_config_default.yml index be84b46008..e05bb3778f 100644 --- a/CPAC/resources/configs/pipeline_config_default.yml +++ b/CPAC/resources/configs/pipeline_config_default.yml @@ -784,7 +784,8 @@ registration_workflows: boundary_based_registration: # this is a fork point - # run: [On, Off] - this will run both and fork the pipeline + # run: [On, Off, fallback] - this will run both and fork the pipeline + # if 'fallback' is one of the selected options, BBR will run and, if its output fails quality_thresholds, the pipeline will fallback to BBR's input image run: [On] # Standard FSL 5.0 Scheduler used for Boundary Based Registration. diff --git a/CPAC/resources/configs/pipeline_config_rbc-options.yml b/CPAC/resources/configs/pipeline_config_rbc-options.yml index b79a016de3..5b5d89f83d 100644 --- a/CPAC/resources/configs/pipeline_config_rbc-options.yml +++ b/CPAC/resources/configs/pipeline_config_rbc-options.yml @@ -46,6 +46,9 @@ registration_workflows: T1w_brain_template_mask: $FSLDIR/data/standard/MNI152_T1_${resolution_for_anat}_brain_mask.nii.gz functional_registration: + coregistration: + boundary_based_registration: + run: [fallback] func_registration_to_template: output_resolution: diff --git a/CPAC/utils/docs.py b/CPAC/utils/docs.py index b1ee23df0b..181df9aa98 100644 --- a/CPAC/utils/docs.py +++ b/CPAC/utils/docs.py @@ -71,4 +71,41 @@ def grab_docstring_dct(fn): return dct +def retry_docstring(orig): + """Decorator to autodocument retries. + + Examples + -------- + >>> @retry_docstring(grab_docstring_dct) + ... def do_nothing(): + ... '''Does this do anything?''' + ... pass + >>> print(do_nothing.__doc__) + Does this do anything? + Retries the following after a failed QC check: + Function to grab a NodeBlock dictionary from a docstring. + + Parameters + ---------- + fn : function + The NodeBlock function with the docstring to be parsed. + + Returns + ------- + dct : dict + A NodeBlock configuration dictionary. + + """ + def retry(obj): + if obj.__doc__ is None: + obj.__doc__ = '' + origdoc = (f'{orig.__module__}.{orig.__name__}' if + orig.__doc__ is None else orig.__doc__) + obj.__doc__ = '\n'.join([ + obj.__doc__, 'Retries the following after a failed QC check:', + origdoc]) + return obj + return retry + + DOCS_URL_PREFIX = _docs_url_prefix()