From 30b9ff3c9da09014519e6159dab577b99fb322fe Mon Sep 17 00:00:00 2001 From: Jacob Silterra Date: Tue, 28 May 2024 11:47:14 -0400 Subject: [PATCH] Fix demo to use new predict location. Also tweak imports. --- scripts/run_inference_demo.sh | 3 ++- sybil/__init__.py | 2 +- sybil/predict.py | 7 ++++--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/scripts/run_inference_demo.sh b/scripts/run_inference_demo.sh index 6f1088d..8ac20b2 100755 --- a/scripts/run_inference_demo.sh +++ b/scripts/run_inference_demo.sh @@ -13,7 +13,8 @@ if [ ! -d "$demo_scan_dir" ]; then unzip -q sybil_example.zip fi -python3 scripts/inference.py \ +# Either python3 sybil/predict.py or sybil-predict (if installed via pip) +python3 sybil/predict.py \ --loglevel DEBUG \ --output-dir demo_prediction \ --return-attentions \ diff --git a/sybil/__init__.py b/sybil/__init__.py index a37f041..d69c2d6 100644 --- a/sybil/__init__.py +++ b/sybil/__init__.py @@ -22,4 +22,4 @@ from sybil.utils.visualization import visualize_attentions import sybil.utils.logging_utils -__all__ = ["Sybil", "Serie", "visualize_attentions"] +__all__ = ["Sybil", "Serie", "visualize_attentions", "__version__"] diff --git a/sybil/predict.py b/sybil/predict.py index 2596129..703fb3c 100644 --- a/sybil/predict.py +++ b/sybil/predict.py @@ -13,11 +13,11 @@ import sybil.utils.logging_utils import sybil.datasets.utils -from sybil import Serie, Sybil, visualize_attentions +from sybil import Serie, Sybil, visualize_attentions, __version__ def _get_parser(): - description = __doc__ + f"\nVersion: {sybil.__version__}\n" + description = __doc__ + f"\nVersion: {__version__}\n" parser = argparse.ArgumentParser(description=description) parser.add_argument( @@ -69,7 +69,7 @@ def _get_parser(): parser.add_argument("-l", "--log", "--loglevel", "--log-level", default="INFO", dest="loglevel") - parser.add_argument("-v", "--version", action="version", version=sybil.__version__) + parser.add_argument("-v", "--version", action="version", version=__version__) return parser @@ -101,6 +101,7 @@ def predict( if extension.lower() in {".png", "png"}: file_type = "png" voxel_spacing = sybil.datasets.utils.VOXEL_SPACING + logger.debug(f"Using default voxel spacing: {voxel_spacing}") assert file_type in {"dicom", "png"} file_type = typing.cast(Literal["dicom", "png"], file_type)