Skip to content

Commit

Permalink
Merge pull request #44 from reginabarzilaygroup/v1.5.0_dev
Browse files Browse the repository at this point in the history
V1.5.0 dev
  • Loading branch information
pgmikhael authored Aug 6, 2024
2 parents a8c4369 + 552eda0 commit 2b9c5fc
Show file tree
Hide file tree
Showing 7 changed files with 564 additions and 64 deletions.
3 changes: 2 additions & 1 deletion scripts/data/create_nlst_metadata_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def make_metadata_dict(dataframe, pid, timepoint, series_id, use_timepoint = Fal
'slice_number': [slicenumber],
'pixel_spacing': pixel_spacing,
'slice_thickness': slice_thickness,
'img_position': img_posn,
'img_position': [img_posn],
'series_data': make_metadata_dict(image_data, pid, timepoint, series_id, use_timepoint_and_studyinstance = True)
}

Expand All @@ -135,6 +135,7 @@ def make_metadata_dict(dataframe, pid, timepoint, series_id, use_timepoint = Fal
json_dataset[pt_idx]['accessions'][exam_idx]['image_series'][series_id]['paths'].append(path)
json_dataset[pt_idx]['accessions'][exam_idx]['image_series'][series_id]['slice_location'].append(slicelocation)
json_dataset[pt_idx]['accessions'][exam_idx]['image_series'][series_id]['slice_number'].append(slicenumber)
json_dataset[pt_idx]['accessions'][exam_idx]['image_series'][series_id]['img_position'].append(img_posn)
else:
exam_dict['image_series'] = {series_id: img_series_dict}
json_dataset[pt_idx]['accessions'].append(exam_dict)
Expand Down
14 changes: 5 additions & 9 deletions sybil/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

import torch
import numpy as np
import pickle

from sybil.serie import Serie
from sybil.models.sybil import SybilNet
from sybil.models.calibrator import SimpleClassifierGroup
from sybil.utils.logging_utils import get_logger
from sybil.utils.device_utils import get_default_device, get_most_free_gpu, get_device_mem_info
from sybil.utils.metrics import get_survival_metrics
Expand Down Expand Up @@ -67,7 +67,7 @@
},
}

CHECKPOINT_URL = "https://github.com/reginabarzilaygroup/Sybil/releases/download/v1.0.3/sybil_checkpoints.zip"
CHECKPOINT_URL = os.getenv("SYBIL_CHECKPOINT_URL", "https://www.dropbox.com/scl/fi/45rtadfdci0bj8dbpotmr/sybil_checkpoints_v1.5.0.zip?rlkey=n8n7pvhb89pjoxgvm90mtbtuk&dl=1")


class Prediction(NamedTuple):
Expand All @@ -91,7 +91,7 @@ def download_sybil(name, cache) -> Tuple[List[str], str]:
# Download models
model_files = NAME_TO_FILE[name]
checkpoints = model_files["checkpoint"]
download_calib_path = os.path.join(cache, f"{name}.p")
download_calib_path = os.path.join(cache, f"{name}_simple_calibrator.json")
have_all_files = os.path.exists(download_calib_path)

download_model_paths = []
Expand Down Expand Up @@ -187,7 +187,7 @@ def __init__(
self.to(self.device)

if calibrator_path is not None:
self.calibrator = pickle.load(open(calibrator_path, "rb"))
self.calibrator = SimpleClassifierGroup.from_json_grouped(calibrator_path)
else:
self.calibrator = None

Expand Down Expand Up @@ -227,8 +227,6 @@ def _calibrate(self, scores: np.ndarray) -> np.ndarray:
Parameters
----------
calibrator: Optional[dict]
Dictionary of sklearn.calibration.CalibratedClassifierCV for each year, otherwise None.
scores: np.ndarray
risk scores as numpy array
Expand All @@ -242,9 +240,7 @@ def _calibrate(self, scores: np.ndarray) -> np.ndarray:
calibrated_scores = []
for YEAR in range(scores.shape[1]):
probs = scores[:, YEAR].reshape(-1, 1)
probs = self.calibrator["Year{}".format(YEAR + 1)].predict_proba(probs)[
:, 1
]
probs = self.calibrator["Year{}".format(YEAR + 1)].predict_proba(probs)[:, -1]
calibrated_scores.append(probs)

return np.stack(calibrated_scores, axis=1)
Expand Down
168 changes: 168 additions & 0 deletions sybil/models/calibrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import json
import os
from typing import List

import numpy as np

"""
Calibrator for Sybil prediction models.
We calibrate probabilities using isotonic regression.
Previously this was done with scikit-learn, here we use a custom implementation to avoid versioning issues.
"""


class SimpleClassifierGroup:
"""
A class to represent a calibrator for prediction models.
Behavior and coefficients are taken from the sklearn.calibration.CalibratedClassifierCV class.
Make a custom class to avoid sklearn versioning issues.
"""

def __init__(self, calibrators: List["SimpleIsotonicRegressor"]):
self.calibrators = calibrators

def predict_proba(self, X, expand=False):
"""
Predict class probabilities for X.
Parameters
----------
X : array-like of shape (n_probabilities,)
The input probabilities to recalibrate.
expand : bool, default=False
Whether to return the probabilities for each class separately.
This is intended for binary classification which can be done in 1D,
expand=True will return a 2D array with shape (n_probabilities, 2).
Returns
-------
proba : ndarray of shape (n_samples, n_classes)
The class probabilities of the input samples. Classes are ordered by
lexicographic order.
"""
proba = np.array([calibrator.transform(X) for calibrator in self.calibrators])
pos_prob = np.mean(proba, axis=0)
if expand and len(self.calibrators) == 1:
return np.array([1.-pos_prob, pos_prob])
else:
return pos_prob

def to_json(self):
return [calibrator.to_json() for calibrator in self.calibrators]

@classmethod
def from_json(cls, json_list):
return cls([SimpleIsotonicRegressor.from_json(json_dict) for json_dict in json_list])

@classmethod
def from_json_grouped(cls, json_path):
"""
We store calibrators in a diction of {year (str): [calibrators]}.
This is a convenience method to load that dictionary from a file path.
"""
json_dict = json.load(open(json_path, "r"))
output_dict = {key: cls.from_json(json_list) for key, json_list in json_dict.items()}
return output_dict


class SimpleIsotonicRegressor:
def __init__(self, coef, intercept, x0, y0, x_min=-np.inf, x_max=np.inf):
self.coef = coef
self.intercept = intercept
self.x0 = x0
self.y0 = y0
self.x_min = x_min
self.x_max = x_max

def transform(self, X):
T = X
T = T @ self.coef + self.intercept
T = np.clip(T, self.x_min, self.x_max)
return np.interp(T, self.x0, self.y0)

@classmethod
def from_classifier(cls, classifer: "_CalibratedClassifier"):
assert len(classifer.calibrators) == 1, "Only one calibrator per classifier is supported."
calibrator = classifer.calibrators[0]
return cls(classifer.base_estimator.coef_, classifer.base_estimator.intercept_,
calibrator.f_.x, calibrator.f_.y, calibrator.X_min_, calibrator.X_max_)

def to_json(self):
return {
"coef": self.coef.tolist(),
"intercept": self.intercept.tolist(),
"x0": self.x0.tolist(),
"y0": self.y0.tolist(),
"x_min": self.x_min,
"x_max": self.x_max
}

@classmethod
def from_json(cls, json_dict):
return cls(
np.array(json_dict["coef"]),
np.array(json_dict["intercept"]),
np.array(json_dict["x0"]),
np.array(json_dict["y0"]),
json_dict["x_min"],
json_dict["x_max"]
)

def __repr__(self):
return f"SimpleIsotonicRegressor(x={self.x0}, y={self.y0})"


def export_calibrator(input_path, output_path):
import pickle
import sklearn
sk_cal_dict = pickle.load(open(input_path, "rb"))
simple_cal_dict = dict()
for key, cal in sk_cal_dict.items():
calibrators = [SimpleIsotonicRegressor.from_classifier(classifier) for classifier in cal.calibrated_classifiers_]
simple_cal_dict[key] = SimpleClassifierGroup(calibrators).to_json()

json.dump(simple_cal_dict, open(output_path, "w"), indent=2)


def export_by_name(base_dir, model_name, overwrite=False):
sk_input_path = os.path.expanduser(f"{base_dir}/{model_name}.p")
simple_output_path = os.path.expanduser(f"{base_dir}/{model_name}_simple_calibrator.json")

version = "1.4.0"
scores_output_path = f"{base_dir}/{model_name}_v{version}_calibrations.json"

if overwrite or not os.path.exists(simple_output_path):
run_test_calibrations(sk_input_path, scores_output_path)

if overwrite or not os.path.exists(simple_output_path):
export_calibrator(sk_input_path, simple_output_path)


def export_all_default_calibrators(base_dir="~/.sybil", overwrite=False):
base_dir = os.path.expanduser(base_dir)
model_names = ["sybil_1", "sybil_2", "sybil_3", "sybil_4", "sybil_5", "sybil_ensemble"]
for model_name in model_names:
export_by_name(base_dir, model_name, overwrite=overwrite)


def run_test_calibrations(sk_input_path, scores_output_path, overwrite=False):
"""
For regression testing. Output calibrated probabilities for a range of input probabilities.
"""
import pickle
sk_cal_dict = pickle.load(open(sk_input_path, "rb"))

test_probs = np.arange(0, 1, 0.001).reshape(-1, 1)

output_dict = {"x": test_probs.flatten().tolist()}
for key, model in sk_cal_dict.items():
output_dict[key] = model.predict_proba(test_probs)[:, -1].flatten().tolist()

if overwrite or not os.path.exists(scores_output_path):
with open(scores_output_path, "w") as f:
json.dump(output_dict, f, indent=2)


if __name__ == "__main__":
export_all_default_calibrators(overwrite=False)
13 changes: 13 additions & 0 deletions sybil/models/sybil.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn as nn
import torchvision
from sybil.models.cumulative_probability_layer import Cumulative_Probability_Layer
Expand Down Expand Up @@ -29,6 +30,7 @@ def forward(self, x, batch=None):
pool_output = self.aggregate_and_classify(x)
output["activ"] = x
output.update(pool_output)
output["prob"] = pool_output["logit"].sigmoid()

return output

Expand All @@ -41,6 +43,17 @@ def aggregate_and_classify(self, x):

return pool_output

@staticmethod
def load(path):
checkpoint = torch.load(path, map_location="cpu")
args = checkpoint["args"]
model = SybilNet(args)

# Remove 'model' from param names
state_dict = {k[6:]: v for k, v in checkpoint["state_dict"].items()}
model.load_state_dict(state_dict) # type: ignore
return model


class RiskFactorPredictor(SybilNet):
def __init__(self, args):
Expand Down
3 changes: 2 additions & 1 deletion sybil/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def _get_parser():
help="Generate images with attention overlap. Sets --return-attentions (if not already set).",
)


parser.add_argument(
"--file-type",
default="auto",
Expand Down Expand Up @@ -90,6 +89,8 @@ def predict(
):
logger = sybil.utils.logging_utils.get_logger()

return_attentions |= write_attention_images

input_files = os.listdir(image_dir)
input_files = [os.path.join(image_dir, x) for x in input_files if not x.startswith(".")]
input_files = [x for x in input_files if os.path.isfile(x)]
Expand Down
2 changes: 2 additions & 0 deletions sybil/serie.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from typing import List, Optional, NamedTuple, Literal
from argparse import Namespace

Expand Down Expand Up @@ -137,6 +138,7 @@ def get_raw_images(self) -> List[np.ndarray]:
images = [i["input"] for i in input_dicts]
return images

@functools.lru_cache
def get_volume(self) -> torch.Tensor:
"""
Load loaded 3D CT volume
Expand Down
Loading

0 comments on commit 2b9c5fc

Please sign in to comment.