Skip to content

Commit

Permalink
Merge pull request #51 from reginabarzilaygroup/v1.6.0_dev
Browse files Browse the repository at this point in the history
V1.6.0 dev
  • Loading branch information
jsilter authored Sep 9, 2024
2 parents 1d33c64 + 9b277e8 commit 479de0a
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 48 deletions.
41 changes: 41 additions & 0 deletions examples/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python

__doc__ = """
Simple example script showing how to use the Sybil library locally to predict risk scores for a set of DICOM files.
"""

import sybil
from sybil import visualize_attentions

from utils import get_demo_data


def main():
# Load a trained model
model = sybil.Sybil("sybil_ensemble")

dicom_files = get_demo_data()

# Get risk scores
serie = sybil.Serie(dicom_files)
print(f"Processing {len(dicom_files)} DICOM files")
prediction = model.predict([serie], return_attentions=True)
scores = prediction.scores

print(f"Risk scores: {scores}")

# Visualize attention maps
output_dir = "sybil_attention_output"

print(f"Writing attention images to {output_dir}")
series_with_attention = visualize_attentions(
serie,
attentions=prediction.attentions,
save_directory=output_dir,
gain=3,
)

print(f"Finished writing attention images to {output_dir}")

if __name__ == "__main__":
main()
77 changes: 77 additions & 0 deletions examples/remote_ark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python

__doc__ = """
This example shows how to use a client to access a
remote Sybil server (running Ark) to predict risk scores for a set of DICOM files.
The server must be started separately.
https://github.com/reginabarzilaygroup/Sybil/wiki
https://github.com/reginabarzilaygroup/ark/wiki
"""
import json
import os

import numpy as np
import requests

import sybil.utils.visualization

from utils import get_demo_data

if __name__ == "__main__":

dicom_files = get_demo_data()
serie = sybil.Serie(dicom_files)

# Set the URL of the remote Sybil server
ark_hostname = "localhost"
ark_port = 5000

# Set the URL of the remote Sybil server
ark_host = f"http://{ark_hostname}:{ark_port}"

data_dict = {"return_attentions": True}
payload = {"data": json.dumps(data_dict)}

# Check if the server is running and reachable
resp = requests.get(f"{ark_host}/info")
if resp.status_code != 200:
raise ValueError(f"Failed to connect to ARK server. Status code: {resp.status_code}")

info_data = resp.json()["data"]
assert info_data["modelName"].lower() == "sybil", "The ARK server is not running Sybil"
print(f"ARK server info: {info_data}")

# Submit prediction to ARK server.
files = [('dicom', open(file_path, 'rb')) for file_path in dicom_files]
r = requests.post(f"{ark_host}/dicom/files", files=files, data=payload)
_ = [f[1].close() for f in files]
if r.status_code != 200:
raise ValueError(f"Error occurred processing DICOM files. Status code: {r.status_code}.\n{r.text}")

r_json = r.json()
predictions = r_json["data"]["predictions"]

scores = predictions[0]
print(f"Risk scores: {scores}")

attentions = predictions[1]
attentions = np.array(attentions)
print(f"Ark received attention shape: {attentions.shape}")

# Visualize attention maps
save_directory = "remote_ark_sybil_attention_output"

print(f"Writing attention images to {save_directory}")

images = serie.get_raw_images()
overlayed_images = sybil.utils.visualization.build_overlayed_images(images, attentions, gain=3)

if save_directory is not None:
serie_idx = 0
save_path = os.path.join(save_directory, f"serie_{serie_idx}")
sybil.utils.visualization.save_images(overlayed_images, save_path, f"serie_{serie_idx}")

print(f"Finished writing attention images to {save_directory}")

42 changes: 42 additions & 0 deletions examples/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
from urllib.request import urlopen


def download_file(url, filepath):
response = urlopen(url)

target_dir = os.path.dirname(filepath)
if target_dir and not os.path.exists(target_dir):
os.makedirs(target_dir)

# Check if the request was successful
if response.status == 200:
with open(filepath, 'wb') as f:
f.write(response.read())
else:
print(f"Failed to download file. Status code: {response.status_code}")

return filepath

def get_demo_data():
demo_data_url = "https://www.dropbox.com/scl/fi/covbvo6f547kak4em3cjd/sybil_example.zip?rlkey=7a13nhlc9uwga9x7pmtk1cf1c&st=dqi0cf9k&dl=1"

zip_file_name = "sybil_example.zip"
cache_dir = os.path.expanduser("~/.sybil")
zip_file_path = os.path.join(cache_dir, zip_file_name)
os.makedirs(cache_dir, exist_ok=True)
if not os.path.exists(zip_file_path):
print(f"Downloading demo data to {zip_file_path}")
download_file(demo_data_url, zip_file_path)

demo_data_dir = os.path.join(cache_dir, "sybil_example")
image_data_dir = os.path.join(demo_data_dir, "sybil_demo_data")
if not os.path.exists(demo_data_dir):
print(f"Extracting demo data to {demo_data_dir}")
import zipfile
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall(demo_data_dir)

dicom_files = os.listdir(image_data_dir)
dicom_files = [os.path.join(image_data_dir, x) for x in dicom_files]
return dicom_files
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,13 @@ python_requires = >=3.8,<3.11
# For more information, check out https://semver.org/.
install_requires =
importlib-metadata; python_version>="3.8"
albumentations==1.1.0
imageio==2.34.1
numpy==1.24.1
opencv-python==4.5.4.60
opencv-python-headless==4.5.4.60
pillow>=10.2.0
pydicom==2.3.0
pylibjpeg[all]==2.0.0
scikit-learn==1.0.2
torch==1.13.1+cu117; platform_machine == "x86_64"
torch==1.13.1; platform_machine != "x86_64"
torchio==0.18.74
Expand All @@ -68,8 +67,10 @@ testing =
mypy
black
train =
albumentations==1.1.0
lifelines==0.26.4
pytorch_lightning==1.6.0
scikit-learn==1.0.2

[options.entry_points]
console_scripts =
Expand Down
4 changes: 2 additions & 2 deletions sybil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from sybil.model import Sybil
from sybil.serie import Serie
from sybil.utils.visualization import visualize_attentions
from sybil.utils.visualization import visualize_attentions, collate_attentions
import sybil.utils.logging_utils

__all__ = ["Sybil", "Serie", "visualize_attentions", "__version__"]
__all__ = ["Sybil", "Serie", "visualize_attentions", "collate_attentions", "__version__"]
28 changes: 24 additions & 4 deletions sybil/augmentations.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import cv2
import torch
import torchvision
import albumentations as A
from albumentations.pytorch import ToTensorV2

from typing import Literal
from abc import ABCMeta, abstractmethod
import numpy as np
import random

try:
import albumentations as A
except ImportError:
# albumentations is not installed, training with augmentations will not be possible
A = None


def get_augmentations(split: Literal["train", "dev", "test"], args):
if split == "train":
Expand Down Expand Up @@ -94,7 +100,6 @@ class ToTensor(Abstract_augmentation):

def __init__(self):
super(ToTensor, self).__init__()
self.transform = ToTensorV2()
self.name = "totensor"

def __call__(self, input_dict, sample=None):
Expand All @@ -104,6 +109,20 @@ def __call__(self, input_dict, sample=None):
return input_dict


class ResizeTransform:
def __init__(self, width, height):
self.width = width
self.height = height

def __call__(self, image=None, mask=None):
out = {"image": None, "mask": None}
if image is not None:
out["image"] = cv2.resize(image, dsize=(self.width, self.height), interpolation=cv2.INTER_LINEAR)
if mask is not None:
out["mask"] = cv2.resize(mask, dsize=(self.width, self.height), interpolation=cv2.INTER_NEAREST)
return out


class Scale_2d(Abstract_augmentation):
"""
Given PIL image, enforce its some set size
Expand All @@ -115,7 +134,7 @@ def __init__(self, args, kwargs):
assert len(kwargs.keys()) == 0
width, height = args.img_size
self.set_cachable(width, height)
self.transform = A.Resize(height, width)
self.transform = ResizeTransform(width, height)

def __call__(self, input_dict, sample=None):
out = self.transform(
Expand All @@ -138,6 +157,7 @@ def __init__(self, args, kwargs):
super(Rotate_Range, self).__init__()
assert len(kwargs.keys()) == 1
self.max_angle = int(kwargs["deg"])
assert A is not None, "albumentations is not installed"
self.transform = A.Rotate(limit=self.max_angle, p=0.5)

def __call__(self, input_dict, sample=None):
Expand Down
12 changes: 6 additions & 6 deletions sybil/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from sybil.models.calibrator import SimpleClassifierGroup
from sybil.utils.logging_utils import get_logger
from sybil.utils.device_utils import get_default_device, get_most_free_gpu, get_device_mem_info
from sybil.utils.metrics import get_survival_metrics


# Leaving this here for a bit; these are IDs to download the models from Google Drive
Expand Down Expand Up @@ -67,7 +66,7 @@
},
}

CHECKPOINT_URL = os.getenv("SYBIL_CHECKPOINT_URL", "https://www.dropbox.com/scl/fi/45rtadfdci0bj8dbpotmr/sybil_checkpoints_v1.5.0.zip?rlkey=n8n7pvhb89pjoxgvm90mtbtuk&dl=1")
CHECKPOINT_URL = os.getenv("SYBIL_CHECKPOINT_URL", "https://github.com/reginabarzilaygroup/Sybil/releases/download/v1.5.0/sybil_checkpoints.zip")


class Prediction(NamedTuple):
Expand Down Expand Up @@ -107,12 +106,12 @@ def download_sybil(name, cache) -> Tuple[List[str], str]:
return download_model_paths, download_calib_path


def download_and_extract(remote_model_url: str, local_model_dir) -> List[str]:
resp = urlopen(remote_model_url)
os.makedirs(local_model_dir, exist_ok=True)
def download_and_extract(remote_url: str, local_dir: str) -> List[str]:
os.makedirs(local_dir, exist_ok=True)
resp = urlopen(remote_url)
with ZipFile(BytesIO(resp.read())) as zip_file:
all_files_and_dirs = zip_file.namelist()
zip_file.extractall(local_model_dir)
zip_file.extractall(local_dir)
return all_files_and_dirs


Expand Down Expand Up @@ -379,6 +378,7 @@ def evaluate(
Output evaluation. See details for :class:`~sybil.model.Evaluation`.
"""
from sybil.utils.metrics import get_survival_metrics
if isinstance(series, Serie):
series = [series]
elif not isinstance(series, list):
Expand Down
Loading

0 comments on commit 479de0a

Please sign in to comment.