Skip to content

Commit

Permalink
Merge pull request #29 from reginabarzilaygroup/GH28_individual_infer…
Browse files Browse the repository at this point in the history
…ence

GH28 Individual inference
  • Loading branch information
pgmikhael authored Mar 18, 2024
2 parents 864c9d5 + e6f87a5 commit faf596b
Show file tree
Hide file tree
Showing 10 changed files with 194 additions and 25 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/pgmikhael/Sybil/blob/main/LICENSE.txt) ![version](https://img.shields.io/badge/version-1.0.2-success)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/pgmikhael/Sybil/blob/main/LICENSE.txt) ![version](https://img.shields.io/badge/version-1.2.0-success)

# Sybil

Lung Cancer Risk Prediction
Lung Cancer Risk Prediction.

Additional documentation can be found on the [GitHub Wiki](https://github.com/reginabarzilaygroup/Sybil/wiki).

## Run a regression test

Expand All @@ -21,7 +23,7 @@ You can load our pretrained model trained on the NLST dataset, and score a given
from sybil import Serie, Sybil

# Load a trained model
model = Sybil("sybil_base")
model = Sybil("sybil_ensemble")

# Get risk scores
serie = Serie([dicom_path_1, dicom_path_2, ...])
Expand All @@ -32,9 +34,9 @@ serie = Serie([dicom_path_1, dicom_path_2, ...], label=1)
results = model.evaluate([serie])
```

Models available include: `sybil_base` and `sybil_ensemble`.
Models available include: `sybil_1`, `sybil_2`, `sybil_3`, `sybil_4`, `sybil_5` and `sybil_ensemble`.

All model files are available [here](https://drive.google.com/drive/folders/1nBp05VV9mf5CfEO6W5RY4ZpcpxmPDEeR?usp=sharing).
All model files are available on [GitHub releases](https://github.com/reginabarzilaygroup/Sybil/releases) as well as [here](https://drive.google.com/drive/folders/1nBp05VV9mf5CfEO6W5RY4ZpcpxmPDEeR?usp=sharing).

## Replicating results

Expand Down
148 changes: 148 additions & 0 deletions scripts/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import argparse
import datetime
import json
import logging
import os
import pickle

from sybil import Serie, Sybil, visualize_attentions

script_directory = os.path.dirname(os.path.abspath(__file__))
project_directory = os.path.dirname(script_directory)


def _get_parser():
parser = argparse.ArgumentParser(description=__doc__)

parser.add_argument(
"image_dir",
default=None,
help="Path to directory containing DICOM/PNG files (from a single exam) to run inference on."
"Every file in the directory will be included.",
)

parser.add_argument(
"--output-dir",
default="sybil_result",
dest="output_dir",
help="Output directory in which to save prediction results."
"Prediction will be printed to stdout as well.",
)

parser.add_argument(
"--return-attentions",
default=False,
action="store_true",
help="Generate an image which overlaps attention scores.",
)

parser.add_argument(
"--file-type",
default="auto",
dest="file_type",
choices={"dicom", "png", "auto"},
help="File type of input images. "
"If not provided, the file type will be inferred from input extensions.",
)

parser.add_argument(
"--model-name",
default="sybil_ensemble",
dest="model_name",
help="Name of the model to use for prediction. Default: sybil_ensemble",
)

parser.add_argument("-l", "--log", "--loglevel", default="INFO", dest="loglevel")

return parser


def logging_basic_config(args):
info_fmt = "[%(asctime)s] - %(message)s"
debug_fmt = "[%(asctime)s] [%(filename)s:%(lineno)d] %(levelname)s - %(message)s"
fmt = debug_fmt if args.loglevel.upper() == "DEBUG" else info_fmt

logging.basicConfig(
format=fmt, datefmt="%Y-%m-%d %H:%M:%S", level=args.loglevel.upper()
)


def inference(
image_dir,
output_dir,
model_name="sybil_ensemble",
return_attentions=False,
file_type="auto",
):
logger = logging.getLogger("inference")

input_files = os.listdir(image_dir)
input_files = [os.path.join(image_dir, x) for x in input_files if not x.startswith(".")]
input_files = [x for x in input_files if os.path.isfile(x)]
extensions = {os.path.splitext(x)[1] for x in input_files}
if len(extensions) > 1:
raise ValueError(
f"Multiple file types found in {image_dir}: {','.join(extensions)}"
)

if file_type == "auto":
extension = extensions.pop()
file_type = "dicom"
if extension.lower() in {".png", "png"}:
file_type = "png"
assert file_type in {"dicom", "png"}

num_files = len(input_files)

# Load a trained model
model = Sybil(model_name)

logger.debug(f"Beginning prediction using {num_files} {file_type} files from {image_dir}")

# Get risk scores
serie = Serie(input_files, file_type=file_type)
series = [serie]
prediction = model.predict(series, return_attentions=return_attentions)
prediction_scores = prediction.scores[0]

logger.debug(f"Prediction finished. Results:\n{prediction_scores}")

prediction_path = os.path.join(output_dir, "prediction_scores.json")
pred_dict = {"predictions": prediction.scores}
with open(prediction_path, "w") as f:
json.dump(pred_dict, f, indent=2)

if return_attentions:
attention_path = os.path.join(output_dir, "attention_scores.pkl")
with open(attention_path, "wb") as f:
pickle.dump(prediction, f)

series_with_attention = visualize_attentions(
series,
attentions=prediction.attentions,
save_directory=output_dir,
gain=3,
)

return pred_dict


def main():
args = _get_parser().parse_args()
logging_basic_config(args)

os.makedirs(args.output_dir, exist_ok=True)

pred_dict = inference(
args.image_dir,
args.output_dir,
args.model_name,
args.return_attentions,
file_type=args.file_type,
)

print(json.dumps(pred_dict, indent=2))


if __name__ == "__main__":
main()
20 changes: 20 additions & 0 deletions scripts/run_inference_demo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash

# Run inference on the demo data
# The output will be printed to the console


demo_scan_dir=sybil_demo_data

# Download the demo data if it doesn't exist
if [ ! -d "$demo_scan_dir" ]; then
# Download example data
curl -L -o sybil_example.zip "https://www.dropbox.com/scl/fi/covbvo6f547kak4em3cjd/sybil_example.zip?rlkey=7a13nhlc9uwga9x7pmtk1cf1c&dl=1"
tar -xf sybil_example.zip
fi

python3 scripts/inference.py \
--loglevel DEBUG \
--output-dir demo_prediction \
--return-attentions \
$demo_scan_dir
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ author_email =
license_file = LICENSE.txt
long_description = file: README.md
long_description_content_type = text/markdown; charset=UTF-8; variant=GFM
version = 1.0.3
version = 1.2.0
# url =
project_urls =
; Documentation = https://.../docs
Expand Down
4 changes: 2 additions & 2 deletions sybil/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self):
)

@abstractmethod
def __call__(self, img, mask=None, additional=None):
def __call__(self, input_dict):
pass

def set_seed(self, seed):
Expand Down Expand Up @@ -141,7 +141,7 @@ def __init__(self, args, kwargs):
self.transform = A.Rotate(limit=self.max_angle, p=0.5)

def __call__(self, input_dict, sample=None):
if "seed" in sample:
if sample and "seed" in sample:
self.set_seed(sample["seed"])
out = self.transform(
image=input_dict["input"], mask=input_dict.get("mask", None)
Expand Down
14 changes: 7 additions & 7 deletions sybil/loaders/abstract_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,18 @@ def __init__(self, cache_path, augmentations, args, apply_augmentations=True):
self.composed_all_augmentations = ComposeAug(augmentations)

@abstractmethod
def load_input(self, path, sample):
def load_input(self, path):
pass

@property
@abstractmethod
def cached_extension(self):
pass

def configure_path(self, path, sample):
def configure_path(self, path, sample=None):
return path

def get_image(self, path, sample):
def get_image(self, path, sample=None):
"""
Returns a transformed image by its absolute path.
If cache is used - transformed image will be loaded if available,
Expand All @@ -166,18 +166,18 @@ def get_image(self, path, sample):
input_path = self.configure_path(path, sample)

if input_path == self.pad_token:
return self.load_input(input_path, sample)
return self.load_input(input_path)

if not self.use_cache:
input_dict = self.load_input(input_path, sample)
input_dict = self.load_input(input_path)
# hidden loaders typically do not use augmentation
if self.apply_augmentations:
input_dict = self.composed_all_augmentations(input_dict, sample)
return input_dict

if self.args.use_annotations:
input_dict["mask"] = get_scaled_annotation_mask(
sample["annotations"], self.args
input_dict["annotations"], self.args
)

for key, post_augmentations in self.split_augmentations:
Expand Down Expand Up @@ -207,7 +207,7 @@ def get_image(self, path, sample):
warnings.warn(CORUPTED_FILE_ERR.format(sys.exc_info()[0]))
self.cache.rem(input_path, key)
all_augmentations = self.split_augmentations[-1][1]
input_dict = self.load_input(input_path, sample)
input_dict = self.load_input(input_path)
if self.apply_augmentations:
input_dict = apply_augmentations_and_cache(
input_dict,
Expand Down
8 changes: 4 additions & 4 deletions sybil/loaders/image_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

class OpenCVLoader(abstract_loader):

def load_input(self, path, sample):
def load_input(self, path):
"""
loads as grayscale image
"""
return {"input": cv2.imread(path, 0) }
return {"input": cv2.imread(path, 0)}

@property
def cached_extension(self):
Expand All @@ -27,12 +27,12 @@ def __init__(self, cache_path, augmentations, args, apply_augmentations=True):
self.window_center = -600
self.window_width = 1500

def load_input(self, path, sample):
def load_input(self, path):
try:
dcm = pydicom.dcmread(path)
dcm = apply_modality_lut(dcm.pixel_array, dcm)
arr = apply_windowing(dcm, self.window_center, self.window_width)
arr = arr//256 # parity with images loaded as 8 bit
arr = arr//256 # parity with images loaded as 8 bit
except Exception:
raise Exception(LOADING_ERROR.format("COULD NOT LOAD DICOM."))
return {"input": arr}
Expand Down
2 changes: 1 addition & 1 deletion sybil/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def download_and_extract(remote_model_url: str, local_model_dir) -> List[str]:
class Sybil:
def __init__(
self,
name_or_path: Union[List[str], str] = "sybil_base",
name_or_path: Union[List[str], str] = "sybil_ensemble",
cache: str = "~/.sybil/",
calibrator_path: Optional[str] = None,
device: Optional[str] = None,
Expand Down
6 changes: 2 additions & 4 deletions sybil/serie.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def get_raw_images(self) -> List[np.ndarray]:
"""

loader = get_sample_loader("test", self._args, apply_augmentations=False)
input_dicts = [loader.get_image(path, {}) for path in self._meta.paths]
input_dicts = [loader.get_image(path) for path in self._meta.paths]
images = [i["input"] for i in input_dicts]
return images

Expand All @@ -145,10 +145,8 @@ def get_volume(self) -> torch.Tensor:
CT volume of shape (1, C, N, H, W)
"""

sample = {"seed": np.random.randint(0, 2**32 - 1)}

input_dicts = [
self._loader.get_image(path, sample) for path in self._meta.paths
self._loader.get_image(path) for path in self._meta.paths
]

x = torch.cat([i["input"].unsqueeze(0) for i in input_dicts], dim=0)
Expand Down
3 changes: 2 additions & 1 deletion tests/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def main():
num_files = len(dicom_files)

# Load a trained model
model = Sybil("sybil_ensemble")
# model = Sybil("sybil_ensemble")
model = Sybil()

myprint(f"Beginning prediction using {num_files} files from {image_data_dir}")

Expand Down

0 comments on commit faf596b

Please sign in to comment.