Skip to content

Commit

Permalink
Merge pull request #39 from reginabarzilaygroup/v1.3.0_dev
Browse files Browse the repository at this point in the history
V1.3.0 dev
  • Loading branch information
pgmikhael authored Jun 4, 2024
2 parents 1fd3d38 + bc5c599 commit 3a5604a
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 102 deletions.
3 changes: 2 additions & 1 deletion scripts/run_inference_demo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ if [ ! -d "$demo_scan_dir" ]; then
unzip -q sybil_example.zip
fi

python3 scripts/inference.py \
# Either python3 sybil/predict.py or sybil-predict (if installed via pip)
python3 sybil/predict.py \
--loglevel DEBUG \
--output-dir demo_prediction \
--return-attentions \
Expand Down
27 changes: 13 additions & 14 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ author_email =
license_file = LICENSE.txt
long_description = file: README.md
long_description_content_type = text/markdown; charset=UTF-8; variant=GFM
version = 1.2.2
version = 1.3.0
# url =
project_urls =
; Documentation = https://.../docs
Expand All @@ -31,31 +31,27 @@ find_links =
zip_safe = False
packages = find:
include_package_data = True
python_requires = >=3.8
python_requires = >=3.8,<3.11
# Add here dependencies of your project (line-separated), e.g. requests>=2.2,<3.0.
# Version specifiers like >=2.2,<3.0 avoid problems due to API changes in
# new major versions. This works if the required packages follow Semantic Versioning.
# For more information, check out https://semver.org/.
install_requires =
importlib-metadata; python_version>="3.8"
albumentations==1.1.0
numpy==1.24.1
torch==1.13.1+cu117; sys_platform != "darwin"
torch==1.13.1; sys_platform == "darwin"
torchvision==0.14.1+cu117; sys_platform != "darwin"
torchvision==0.14.1; sys_platform == "darwin"
pytorch_lightning==1.6.0
scikit-learn==1.0.2
tqdm==4.62.3
lifelines==0.26.4
opencv-python==4.5.4.60
opencv-python-headless==4.5.4.60
albumentations==1.1.0
pillow>=10.2.0
pydicom==2.3.0
pylibjpeg[all]==2.0.0
scikit-learn==1.0.2
torch==1.13.1+cu117; platform_machine == "x86_64"
torch==1.13.1; platform_machine != "x86_64"
torchio==0.18.74
gdown==4.6.0

torchvision==0.14.1+cu117; platform_machine == "x86_64"
torchvision==0.14.1; platform_machine != "x86_64"
tqdm==4.62.3

[options.packages.find]
exclude =
Expand All @@ -71,10 +67,13 @@ testing =
flake8
mypy
black
train =
lifelines==0.26.4
pytorch_lightning==1.6.0

[options.entry_points]
console_scripts =
sybil = sybil.main:main
sybil-predict = sybil.predict:main


[bdist_wheel]
Expand Down
2 changes: 1 addition & 1 deletion sybil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
from sybil.utils.visualization import visualize_attentions
import sybil.utils.logging_utils

__all__ = ["Sybil", "Serie", "visualize_attentions"]
__all__ = ["Sybil", "Serie", "visualize_attentions", "__version__"]
14 changes: 0 additions & 14 deletions sybil/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import math
from lifelines import KaplanMeierFitter

# Error Messages
METAFILE_NOTFOUND_ERR = "Metadata file {} could not be parsed! Exception: {}!"
Expand Down Expand Up @@ -104,16 +103,3 @@ def get_scaled_annotation_area(sample, args):
areas.append(mask.sum() / (mask.shape[0] * mask.shape[1]))
return np.array(areas)


def get_censoring_dist(train_dataset):
_dataset = train_dataset.dataset
times, event_observed = (
[d["time_at_event"] for d in _dataset],
[d["y"] for d in _dataset],
)
all_observed_times = set(times)
kmf = KaplanMeierFitter()
kmf.fit(times, event_observed)

censoring_dist = {str(time): kmf.predict(time) for time in all_observed_times}
return censoring_dist
47 changes: 3 additions & 44 deletions sybil/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import NamedTuple, Union, Dict, List, Optional, Tuple
from urllib.request import urlopen
from zipfile import ZipFile
# import gdown

import torch
import numpy as np
Expand Down Expand Up @@ -83,49 +82,6 @@ class Evaluation(NamedTuple):
attentions: List[Dict[str, np.ndarray]] = None


def download_sybil_gdrive(name, cache):
"""Download trained models and calibrator from Google Drive
Parameters
----------
name (str): name of model to use. A key in NAME_TO_FILE
cache (str): path to directory where files are downloaded
Returns
-------
download_model_paths (list): paths to .ckpt models
download_calib_path (str): path to calibrator
"""
# Create cache folder if not exists
cache = os.path.expanduser(cache)
os.makedirs(cache, exist_ok=True)

# Download if neded
model_files = NAME_TO_FILE[name]

# Download models
download_model_paths = []
for model_name, google_id in zip(
model_files["checkpoint"], model_files["google_checkpoint_id"]
):
model_path = os.path.join(cache, f"{model_name}.ckpt")
if not os.path.exists(model_path):
print(f"Downloading model to {cache}")
gdown.download(id=google_id, output=model_path, quiet=False)
download_model_paths.append(model_path)

# download calibrator
download_calib_path = os.path.join(cache, f"{name}.p")
if not os.path.exists(download_calib_path):
gdown.download(
id=model_files["google_calibrator_id"],
output=download_calib_path,
quiet=False,
)

return download_model_paths, download_calib_path


def download_sybil(name, cache) -> Tuple[List[str], str]:
"""Download trained models and calibrator"""
# Create cache folder if not exists
Expand Down Expand Up @@ -329,6 +285,9 @@ def _predict(
"volume_attention_1": out["volume_attention_1"]
.detach()
.cpu(),
"hidden": out["hidden"]
.detach()
.cpu(),
}
)

Expand Down
61 changes: 42 additions & 19 deletions scripts/inference.py → sybil/predict.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,55 @@
#!/usr/bin/env python

__doc__ = """
Use Sybil to run inference on a single exam.
"""

import argparse
import datetime
import json
import logging
import os
import pickle
import typing
from typing import Literal

import sybil.utils.logging_utils
from sybil import Serie, Sybil, visualize_attentions

script_directory = os.path.dirname(os.path.abspath(__file__))
project_directory = os.path.dirname(script_directory)
import sybil.datasets.utils
from sybil import Serie, Sybil, visualize_attentions, __version__


def _get_parser():
parser = argparse.ArgumentParser(description=__doc__)
description = __doc__ + f"\nVersion: {__version__}\n"
parser = argparse.ArgumentParser(description=description)

parser.add_argument(
"image_dir",
default=None,
help="Path to directory containing DICOM/PNG files (from a single exam) to run inference on."
help="Path to directory containing DICOM/PNG files (from a single exam) to run inference on. "
"Every file in the directory will be included.",
)

parser.add_argument(
"--output-dir",
default="sybil_result",
dest="output_dir",
help="Output directory in which to save prediction results."
help="Output directory in which to save prediction results. "
"Prediction will be printed to stdout as well.",
)

parser.add_argument(
"--return-attentions",
default=False,
action="store_true",
help="Generate an image which overlaps attention scores.",
help="Return hidden vectors and attention scores, write them to a pickle file.",
)

parser.add_argument(
"--write-attention-images",
default=False,
action="store_true",
help="Generate images with attention overlap. Sets --return-attentions (if not already set).",
)


parser.add_argument(
"--file-type",
default="auto",
Expand All @@ -56,33 +69,41 @@ def _get_parser():
parser.add_argument("-l", "--log", "--loglevel", "--log-level",
default="INFO", dest="loglevel")

parser.add_argument("-v", "--version", action="version", version=__version__)

return parser


def inference(
def predict(
image_dir,
output_dir,
model_name="sybil_ensemble",
return_attentions=False,
file_type="auto",
write_attention_images=False,
file_type: Literal["auto", "dicom", "png"] = "auto",
):
logger = sybil.utils.logging_utils.get_logger()

input_files = os.listdir(image_dir)
input_files = [os.path.join(image_dir, x) for x in input_files if not x.startswith(".")]
input_files = [x for x in input_files if os.path.isfile(x)]
extensions = {os.path.splitext(x)[1] for x in input_files}
if len(extensions) > 1:
raise ValueError(
f"Multiple file types found in {image_dir}: {','.join(extensions)}"
)

voxel_spacing = None
if file_type == "auto":
extensions = {os.path.splitext(x)[1] for x in input_files}
extension = extensions.pop()
if len(extensions) > 1:
raise ValueError(
f"Multiple file types found in {image_dir}: {','.join(extensions)}"
)

file_type = "dicom"
if extension.lower() in {".png", "png"}:
file_type = "png"
voxel_spacing = sybil.datasets.utils.VOXEL_SPACING
logger.debug(f"Using default voxel spacing: {voxel_spacing}")
assert file_type in {"dicom", "png"}
file_type = typing.cast(Literal["dicom", "png"], file_type)

num_files = len(input_files)

Expand All @@ -92,7 +113,7 @@ def inference(
model = Sybil(model_name)

# Get risk scores
serie = Serie(input_files, file_type=file_type)
serie = Serie(input_files, voxel_spacing=voxel_spacing, file_type=file_type)
series = [serie]
prediction = model.predict(series, return_attentions=return_attentions)
prediction_scores = prediction.scores[0]
Expand All @@ -110,6 +131,7 @@ def inference(
with open(attention_path, "wb") as f:
pickle.dump(prediction, f)

if write_attention_images:
series_with_attention = visualize_attentions(
series,
attentions=prediction.attentions,
Expand All @@ -126,11 +148,12 @@ def main():

os.makedirs(args.output_dir, exist_ok=True)

pred_dict, series_with_attention = inference(
pred_dict, series_with_attention = predict(
args.image_dir,
args.output_dir,
args.model_name,
args.return_attentions,
args.write_attention_images,
file_type=args.file_type,
)

Expand Down
12 changes: 6 additions & 6 deletions sybil/serie.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def __init__(
"""
if label is not None and censor_time is None:
raise ValueError("censor_time should also provided with label.")
if file_type == "png" and voxel_spacing is None:
raise ValueError("voxel_spacing should be provided for PNG files.")

self._censor_time = censor_time
self._label = label
Expand Down Expand Up @@ -263,13 +265,11 @@ def _check_valid(self, args):
- serie doesn't have a label, OR
- slice thickness is too big
"""
if (self._meta.thickness is None) or (
self._meta.thickness > args.slice_thickness_filter
):
if self._meta.thickness is None:
raise ValueError("slice thickness not found")
if self._meta.thickness > args.slice_thickness_filter:
raise ValueError(
"slice thickness is greater than {}.".format(
args.slice_thickness_filter
)
f"slice thickness {self._meta.thickness} is greater than {args.slice_thickness_filter}."
)
if self._meta.voxel_spacing is None:
raise ValueError("voxel spacing either not set or not found in DICOM")
2 changes: 1 addition & 1 deletion sybil/utils/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def _get_formatter(loglevel="INFO"):
warn_fmt = "[%(asctime)s] %(levelname)s -%(message)s"
warn_fmt = "[%(asctime)s] %(levelname)s - %(message)s"
debug_fmt = "[%(asctime)s] [%(filename)s:%(lineno)d] %(levelname)s - %(message)s"
fmt = debug_fmt if loglevel.upper() in {"DEBUG"} else warn_fmt
return logging.Formatter(
Expand Down
4 changes: 2 additions & 2 deletions sybil/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
average_precision_score,
)
import numpy as np
from lifelines.utils.btree import _BTree
from lifelines import KaplanMeierFitter
import warnings

EPSILON = 1e-6
Expand Down Expand Up @@ -154,6 +152,7 @@ def include_exam_and_determine_label(prob_arr, censor_time, gold):


def get_censoring_dist(train_dataset):
from lifelines import KaplanMeierFitter
_dataset = train_dataset.dataset
times, event_observed = (
[d["time_at_event"] for d in _dataset],
Expand Down Expand Up @@ -309,6 +308,7 @@ def _concordance_summary_statistics(
censored_truth = censored_truth[ix]
censored_pred = predicted_event_times[~died_mask][ix]

from lifelines.utils.btree import _BTree
censored_ix = 0
died_ix = 0
times_to_compare = {}
Expand Down

0 comments on commit 3a5604a

Please sign in to comment.