diff --git a/examples/local.py b/examples/local.py new file mode 100644 index 0000000..e4a752c --- /dev/null +++ b/examples/local.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python + +__doc__ = """ +Simple example script showing how to use the Sybil library locally to predict risk scores for a set of DICOM files. +""" + +import sybil +from sybil import visualize_attentions + +from utils import get_demo_data + + +def main(): + # Load a trained model + model = sybil.Sybil("sybil_ensemble") + + dicom_files = get_demo_data() + + # Get risk scores + serie = sybil.Serie(dicom_files) + print(f"Processing {len(dicom_files)} DICOM files") + prediction = model.predict([serie], return_attentions=True) + scores = prediction.scores + + print(f"Risk scores: {scores}") + + # Visualize attention maps + output_dir = "sybil_attention_output" + + print(f"Writing attention images to {output_dir}") + series_with_attention = visualize_attentions( + serie, + attentions=prediction.attentions, + save_directory=output_dir, + gain=3, + ) + + print(f"Finished writing attention images to {output_dir}") + +if __name__ == "__main__": + main() diff --git a/examples/remote_ark.py b/examples/remote_ark.py new file mode 100644 index 0000000..de484d1 --- /dev/null +++ b/examples/remote_ark.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python + +__doc__ = """ +This example shows how to use a client to access a +remote Sybil server (running Ark) to predict risk scores for a set of DICOM files. + +The server must be started separately. + +https://github.com/reginabarzilaygroup/Sybil/wiki +https://github.com/reginabarzilaygroup/ark/wiki +""" +import json +import os + +import numpy as np +import requests + +import sybil.utils.visualization + +from utils import get_demo_data + +if __name__ == "__main__": + + dicom_files = get_demo_data() + serie = sybil.Serie(dicom_files) + + # Set the URL of the remote Sybil server + ark_hostname = "localhost" + ark_port = 5000 + + # Set the URL of the remote Sybil server + ark_host = f"http://{ark_hostname}:{ark_port}" + + data_dict = {"return_attentions": True} + payload = {"data": json.dumps(data_dict)} + + # Check if the server is running and reachable + resp = requests.get(f"{ark_host}/info") + if resp.status_code != 200: + raise ValueError(f"Failed to connect to ARK server. Status code: {resp.status_code}") + + info_data = resp.json()["data"] + assert info_data["modelName"].lower() == "sybil", "The ARK server is not running Sybil" + print(f"ARK server info: {info_data}") + + # Submit prediction to ARK server. + files = [('dicom', open(file_path, 'rb')) for file_path in dicom_files] + r = requests.post(f"{ark_host}/dicom/files", files=files, data=payload) + _ = [f[1].close() for f in files] + if r.status_code != 200: + raise ValueError(f"Error occurred processing DICOM files. Status code: {r.status_code}.\n{r.text}") + + r_json = r.json() + predictions = r_json["data"]["predictions"] + + scores = predictions[0] + print(f"Risk scores: {scores}") + + attentions = predictions[1] + attentions = np.array(attentions) + print(f"Ark received attention shape: {attentions.shape}") + + # Visualize attention maps + save_directory = "remote_ark_sybil_attention_output" + + print(f"Writing attention images to {save_directory}") + + images = serie.get_raw_images() + overlayed_images = sybil.utils.visualization.build_overlayed_images(images, attentions, gain=3) + + if save_directory is not None: + serie_idx = 0 + save_path = os.path.join(save_directory, f"serie_{serie_idx}") + sybil.utils.visualization.save_images(overlayed_images, save_path, f"serie_{serie_idx}") + + print(f"Finished writing attention images to {save_directory}") + diff --git a/examples/utils.py b/examples/utils.py new file mode 100644 index 0000000..c30417c --- /dev/null +++ b/examples/utils.py @@ -0,0 +1,42 @@ +import os +from urllib.request import urlopen + + +def download_file(url, filepath): + response = urlopen(url) + + target_dir = os.path.dirname(filepath) + if target_dir and not os.path.exists(target_dir): + os.makedirs(target_dir) + + # Check if the request was successful + if response.status == 200: + with open(filepath, 'wb') as f: + f.write(response.read()) + else: + print(f"Failed to download file. Status code: {response.status_code}") + + return filepath + +def get_demo_data(): + demo_data_url = "https://www.dropbox.com/scl/fi/covbvo6f547kak4em3cjd/sybil_example.zip?rlkey=7a13nhlc9uwga9x7pmtk1cf1c&st=dqi0cf9k&dl=1" + + zip_file_name = "sybil_example.zip" + cache_dir = os.path.expanduser("~/.sybil") + zip_file_path = os.path.join(cache_dir, zip_file_name) + os.makedirs(cache_dir, exist_ok=True) + if not os.path.exists(zip_file_path): + print(f"Downloading demo data to {zip_file_path}") + download_file(demo_data_url, zip_file_path) + + demo_data_dir = os.path.join(cache_dir, "sybil_example") + image_data_dir = os.path.join(demo_data_dir, "sybil_demo_data") + if not os.path.exists(demo_data_dir): + print(f"Extracting demo data to {demo_data_dir}") + import zipfile + with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: + zip_ref.extractall(demo_data_dir) + + dicom_files = os.listdir(image_data_dir) + dicom_files = [os.path.join(image_data_dir, x) for x in dicom_files] + return dicom_files diff --git a/setup.cfg b/setup.cfg index fc57e6f..10c0545 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,14 +38,13 @@ python_requires = >=3.8,<3.11 # For more information, check out https://semver.org/. install_requires = importlib-metadata; python_version>="3.8" - albumentations==1.1.0 + imageio==2.34.1 numpy==1.24.1 opencv-python==4.5.4.60 opencv-python-headless==4.5.4.60 pillow>=10.2.0 pydicom==2.3.0 pylibjpeg[all]==2.0.0 - scikit-learn==1.0.2 torch==1.13.1+cu117; platform_machine == "x86_64" torch==1.13.1; platform_machine != "x86_64" torchio==0.18.74 @@ -68,8 +67,10 @@ testing = mypy black train = + albumentations==1.1.0 lifelines==0.26.4 pytorch_lightning==1.6.0 + scikit-learn==1.0.2 [options.entry_points] console_scripts = diff --git a/sybil/__init__.py b/sybil/__init__.py index d69c2d6..6bb2698 100644 --- a/sybil/__init__.py +++ b/sybil/__init__.py @@ -19,7 +19,7 @@ from sybil.model import Sybil from sybil.serie import Serie -from sybil.utils.visualization import visualize_attentions +from sybil.utils.visualization import visualize_attentions, collate_attentions import sybil.utils.logging_utils -__all__ = ["Sybil", "Serie", "visualize_attentions", "__version__"] +__all__ = ["Sybil", "Serie", "visualize_attentions", "collate_attentions", "__version__"] diff --git a/sybil/augmentations.py b/sybil/augmentations.py index 464c88d..075cc20 100644 --- a/sybil/augmentations.py +++ b/sybil/augmentations.py @@ -1,12 +1,18 @@ +import cv2 import torch import torchvision -import albumentations as A -from albumentations.pytorch import ToTensorV2 + from typing import Literal from abc import ABCMeta, abstractmethod import numpy as np import random +try: + import albumentations as A +except ImportError: + # albumentations is not installed, training with augmentations will not be possible + A = None + def get_augmentations(split: Literal["train", "dev", "test"], args): if split == "train": @@ -94,7 +100,6 @@ class ToTensor(Abstract_augmentation): def __init__(self): super(ToTensor, self).__init__() - self.transform = ToTensorV2() self.name = "totensor" def __call__(self, input_dict, sample=None): @@ -104,6 +109,20 @@ def __call__(self, input_dict, sample=None): return input_dict +class ResizeTransform: + def __init__(self, width, height): + self.width = width + self.height = height + + def __call__(self, image=None, mask=None): + out = {"image": None, "mask": None} + if image is not None: + out["image"] = cv2.resize(image, dsize=(self.width, self.height), interpolation=cv2.INTER_LINEAR) + if mask is not None: + out["mask"] = cv2.resize(mask, dsize=(self.width, self.height), interpolation=cv2.INTER_NEAREST) + return out + + class Scale_2d(Abstract_augmentation): """ Given PIL image, enforce its some set size @@ -115,7 +134,7 @@ def __init__(self, args, kwargs): assert len(kwargs.keys()) == 0 width, height = args.img_size self.set_cachable(width, height) - self.transform = A.Resize(height, width) + self.transform = ResizeTransform(width, height) def __call__(self, input_dict, sample=None): out = self.transform( @@ -138,6 +157,7 @@ def __init__(self, args, kwargs): super(Rotate_Range, self).__init__() assert len(kwargs.keys()) == 1 self.max_angle = int(kwargs["deg"]) + assert A is not None, "albumentations is not installed" self.transform = A.Rotate(limit=self.max_angle, p=0.5) def __call__(self, input_dict, sample=None): diff --git a/sybil/model.py b/sybil/model.py index 9efa803..407445f 100644 --- a/sybil/model.py +++ b/sybil/model.py @@ -13,7 +13,6 @@ from sybil.models.calibrator import SimpleClassifierGroup from sybil.utils.logging_utils import get_logger from sybil.utils.device_utils import get_default_device, get_most_free_gpu, get_device_mem_info -from sybil.utils.metrics import get_survival_metrics # Leaving this here for a bit; these are IDs to download the models from Google Drive @@ -67,7 +66,7 @@ }, } -CHECKPOINT_URL = os.getenv("SYBIL_CHECKPOINT_URL", "https://www.dropbox.com/scl/fi/45rtadfdci0bj8dbpotmr/sybil_checkpoints_v1.5.0.zip?rlkey=n8n7pvhb89pjoxgvm90mtbtuk&dl=1") +CHECKPOINT_URL = os.getenv("SYBIL_CHECKPOINT_URL", "https://github.com/reginabarzilaygroup/Sybil/releases/download/v1.5.0/sybil_checkpoints.zip") class Prediction(NamedTuple): @@ -107,12 +106,12 @@ def download_sybil(name, cache) -> Tuple[List[str], str]: return download_model_paths, download_calib_path -def download_and_extract(remote_model_url: str, local_model_dir) -> List[str]: - resp = urlopen(remote_model_url) - os.makedirs(local_model_dir, exist_ok=True) +def download_and_extract(remote_url: str, local_dir: str) -> List[str]: + os.makedirs(local_dir, exist_ok=True) + resp = urlopen(remote_url) with ZipFile(BytesIO(resp.read())) as zip_file: all_files_and_dirs = zip_file.namelist() - zip_file.extractall(local_model_dir) + zip_file.extractall(local_dir) return all_files_and_dirs @@ -379,6 +378,7 @@ def evaluate( Output evaluation. See details for :class:`~sybil.model.Evaluation`. """ + from sybil.utils.metrics import get_survival_metrics if isinstance(series, Serie): series = [series] elif not isinstance(series, list): diff --git a/sybil/utils/visualization.py b/sybil/utils/visualization.py index 82cc7cf..2dbee10 100644 --- a/sybil/utils/visualization.py +++ b/sybil/utils/visualization.py @@ -4,7 +4,47 @@ from sybil.serie import Serie from typing import Dict, List, Union import os -import imageio + +def collate_attentions(attention_dict: Dict[str, np.ndarray], N: int, eps=1e-6) -> np.ndarray: + a1 = attention_dict["image_attention_1"] + v1 = attention_dict["volume_attention_1"] + + a1 = torch.Tensor(a1) + v1 = torch.Tensor(v1) + + # take mean attention over ensemble + a1 = torch.exp(a1).mean(0) + v1 = torch.exp(v1).mean(0) + + attention = a1 * v1.unsqueeze(-1) + attention = attention.view(1, 25, 16, 16) + + attention_up = F.interpolate( + attention.unsqueeze(0), (N, 512, 512), mode="trilinear" + ) + attention_up = attention_up.cpu().numpy() + attention_up = attention_up.squeeze() + if eps: + attention_up[attention_up <= eps] = 0.0 + + return attention_up + +def build_overlayed_images(images: List[np.ndarray], attention: np.ndarray, gain: int = 3): + overlayed_images = [] + N = len(images) + for i in range(N): + overlayed = np.zeros((512, 512, 3)) + overlayed[..., 2] = images[i] + overlayed[..., 1] = images[i] + overlayed[..., 0] = np.clip( + (attention[i, ...] * gain * 256) + images[i], + a_min=0, + a_max=255, + ) + + overlayed_images.append(np.uint8(overlayed)) + + return overlayed_images def visualize_attentions( @@ -29,39 +69,11 @@ def visualize_attentions( series_overlays = [] for serie_idx, serie in enumerate(series): - a1 = attentions[serie_idx]["image_attention_1"] - v1 = attentions[serie_idx]["volume_attention_1"] - - a1 = torch.Tensor(a1) - v1 = torch.Tensor(v1) - - # take mean attention over ensemble - a1 = torch.exp(a1).mean(0) - v1 = torch.exp(v1).mean(0) - - attention = a1 * v1.unsqueeze(-1) - attention = attention.view(1, 25, 16, 16) - - # get original image images = serie.get_raw_images() - N = len(images) - attention_up = F.interpolate( - attention.unsqueeze(0), (N, 512, 512), mode="trilinear" - ) - - overlayed_images = [] - for i in range(N): - overlayed = np.zeros((512, 512, 3)) - overlayed[..., 2] = images[i] - overlayed[..., 1] = images[i] - overlayed[..., 0] = np.clip( - (attention_up[0, 0, i] * gain * 256) + images[i], - a_min=0, - a_max=256, - ) + cur_attention = collate_attentions(attentions[serie_idx], N) - overlayed_images.append(np.uint8(overlayed)) + overlayed_images = build_overlayed_images(images, cur_attention, gain) if save_directory is not None: save_path = os.path.join(save_directory, f"serie_{serie_idx}") @@ -83,6 +95,7 @@ def save_images(img_list: List[np.ndarray], directory: str, name: str): Returns: None """ + import imageio os.makedirs(directory, exist_ok=True) path = os.path.join(directory, f"{name}.gif") imageio.mimsave(path, img_list) diff --git a/tests/regression_test.py b/tests/regression_test.py index f13b403..5235b81 100644 --- a/tests/regression_test.py +++ b/tests/regression_test.py @@ -118,7 +118,7 @@ def test_demo_data(self): pytest.skip(f"Skipping long-running test in {type(self)}.") # Download demo data - demo_data_url = "https://www.dropbox.com/sh/addq480zyguxbbg/AACJRVsKDL0gpq-G9o3rfCBQa?dl=1" + demo_data_url = "https://www.dropbox.com/scl/fi/covbvo6f547kak4em3cjd/sybil_example.zip?rlkey=7a13nhlc9uwga9x7pmtk1cf1c&st=dqi0cf9k&dl=1" expected_scores = [ 0.021628819563619374, 0.03857256315036462, @@ -128,9 +128,9 @@ def test_demo_data(self): 0.13568094038444453 ] - zip_file_name = "SYBIL.zip" + zip_file_name = "sybil_example.zip" cache_dir = os.path.expanduser("~/.sybil") - demo_data_dir = os.path.join(cache_dir, "SYBIL") + demo_data_dir = os.path.join(cache_dir, "sybil_example") image_data_dir = os.path.join(demo_data_dir, "sybil_demo_data") os.makedirs(cache_dir, exist_ok=True) download_and_extract_zip(zip_file_name, cache_dir, demo_data_url, demo_data_dir)