From e5f9a9f22cd93925c035f28a8556f1de5e32d2d8 Mon Sep 17 00:00:00 2001 From: lext Date: Thu, 13 Sep 2018 16:17:12 +0300 Subject: [PATCH] improved an inference wrapper for gradcam. Now ready for deployment everywhere --- .gitignore | 2 +- .idea/vcs.xml | 6 - .idea/workspace.xml | 532 ----------------------------------- create_conda_env.sh | 6 +- own_codes/predict.py | 1 - own_codes/produce_gradcam.py | 204 +++++++++----- 6 files changed, 130 insertions(+), 621 deletions(-) delete mode 100644 .idea/vcs.xml delete mode 100644 .idea/workspace.xml diff --git a/.gitignore b/.gitignore index 8eac2b9..6872782 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ __pycache__/ *~ # C extensions *.so - +.idea/* # Distribution / packaging .Python env/ diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index 94a25f7..0000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/workspace.xml b/.idea/workspace.xml deleted file mode 100644 index 7e7b127..0000000 --- a/.idea/workspace.xml +++ /dev/null @@ -1,532 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - cuda( - cuda - load_p - .D - test_files - transforms - Variable - load_img - is_cda - maybe_ - compute_gradcam - - - $PROJECT_DIR$/own_codes - - - - - - - - - - - true - DEFINITION_ORDER - - - - - - - - - - Python - - - - - PyCompatibilityInspection - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 1533717985514 - - - 1533723349354 - - - 1533725197867 - - - 1533725908704 - - - 1535797488065 - - - 1535797504981 - - - 1535798202447 - - - 1535828738284 - - - 1535966131000 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - inp = load_img(fname, CenterCrop(300), patch_transform) - Python - CODE_FRAGMENT - - - load_img(fname) - Python - CODE_FRAGMENT - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/create_conda_env.sh b/create_conda_env.sh index 25fb569..f93cc8a 100644 --- a/create_conda_env.sh +++ b/create_conda_env.sh @@ -3,9 +3,9 @@ conda create -y -n deep_knee python=3.6 source activate deep_knee -conda install -y numpy opencv scipy pyyaml cython matplotlib scikit-learn -conda install -y pytorch==0.3.1 torchvision -c soumith -conda install -y git-lfs -c conda-forge +conda install -y -n deep_knee numpy opencv scipy pyyaml cython matplotlib scikit-learn +conda install -y -n deep_knee pytorch==0.3.1 torchvision -c soumith +conda install -y -n deep_knee git-lfs -c conda-forge pip install pip -U pip install pydicom diff --git a/own_codes/predict.py b/own_codes/predict.py index 1376301..ea568cf 100644 --- a/own_codes/predict.py +++ b/own_codes/predict.py @@ -5,7 +5,6 @@ (c) Aleksei Tiulpin, University of Oulu, 2017 """ -import sys import os import argparse import numpy as np diff --git a/own_codes/produce_gradcam.py b/own_codes/produce_gradcam.py index 214ffce..79c6bee 100644 --- a/own_codes/produce_gradcam.py +++ b/own_codes/produce_gradcam.py @@ -19,7 +19,7 @@ from dataset import get_pair from augmentation import CenterCrop - +from tqdm import tqdm if torch.cuda.is_available(): maybe_cuda = 'cuda' @@ -27,28 +27,6 @@ maybe_cuda = 'cpu' -def load_picture16bit(fname): - patch_transform = transforms.Compose([ - transforms.ToTensor(), - lambda x: x.float(), - normTransform, - ]) - - img = Image.open(fname) - - tmp = np.array(img, dtype=float) - img = Image.fromarray(np.uint8(255 * (tmp / 65535.))) - - cropper = CenterCrop(300) - - l, m = get_pair(cropper(img)) - - l = patch_transform(l) - m = patch_transform(m) - - return cropper(img), l.view(1, 1, 128, 128), m.view(1, 1, 128, 128) - - def smooth_edge_mask(s, w): res = np.zeros((s + w * 2, s + w * 2)) res[w:w + s, w:w + s] = 1 @@ -74,8 +52,45 @@ def inverse_pair_mapping(l, m, s, ps=128, smoothing=7): class KneeNetEnsemble(nn.Module): - def __init__(self, nets): + def __init__(self, snapshots_paths, mean_std_path): super().__init__() + self.states = [] + for snap_path in snapshots_paths: + self.states.append(torch.load(snap_path, map_location=maybe_cuda)) + + self.net1 = None + self.net2 = None + self.net3 = None + + self.grads_l1 = None + self.grads_m1 = None + + self.grads_l2 = None + self.grads_m2 = None + + self.grads_l3 = None + self.grads_m3 = None + self.sm = torch.nn.Softmax(1) + self.mean_std_path = mean_std_path + + def init_networks_from_states(self): + mean_vector, std_vector = np.load(self.mean_std_path) + normTransform = transforms.Normalize(mean_vector, std_vector) + self.patch_transform = transforms.Compose([ + transforms.ToTensor(), + lambda x: x.float(), + normTransform, + ]) + + nets = [] + for state in self.states: + if torch.cuda.is_available(): + net = nn.DataParallel(KneeNet(64, 0.2, True)).cuda() + else: + net = nn.DataParallel(KneeNet(64, 0.2, True)) + net.load_state_dict(state) + nets.append(net.module) + net1 = nets[0] net1.final = nets[0].final[1] @@ -98,15 +113,49 @@ def __init__(self, nets): self.grads_l3 = [] self.grads_m3 = [] - def decopmpose_forward_avg(self, net, l, m): + def load_picture(self, fname, nbits=16): + """ + + :param fname: str or numpy.ndarray + Takes either full path to the image or the numpy array + :return: + """ + + if isinstance(fname, str): + img = Image.open(fname) + elif isinstance(fname, np.ndarray): + img = fname + if nbits == 16: + img = Image.fromarray(np.uint8(255 * (img / 65535.))) + elif nbits == 8: + if img.dtype != np.uint8: + raise TypeError + img = Image.fromarray(img) + else: + raise TypeError + else: + raise TypeError + + cropper = CenterCrop(300) + + l, m = get_pair(cropper(img)) + + l = self.patch_transform(l) + m = self.patch_transform(m) + + return cropper(img), l.view(1, 1, 128, 128), m.view(1, 1, 128, 128) + + @staticmethod + def decompose_forward_avg(net, l, m): l_o = net.branch(l) m_o = net.branch(m) concat = torch.cat([l_o, m_o], 1) - o = net.final(concat.view(l.size(0), 512)) + o = net.final(concat.view(l.size(0), net.final.in_features)) return l_o, m_o, o - def extract_features_branch(self, net, l, m, wl, wm): + @staticmethod + def extract_features_branch(net, l, m, wl, wm): def weigh_maps(weights, maps): maps = Variable(maps.squeeze()) weights = weights.squeeze() @@ -166,20 +215,60 @@ def forward(self, l, m): # Producing the branch outputs and registering the corresponding hooks for attention maps # Net 1 - l_o1, m_o1, o1 = self.decopmpose_forward_avg(self.net1, l, m) + l_o1, m_o1, o1 = self.decompose_forward_avg(self.net1, l, m) l_o1.register_hook(lambda grad: self.grads_l1.append(grad)) m_o1.register_hook(lambda grad: self.grads_m1.append(grad)) # Net 2 - l_o2, m_o2, o2 = self.decopmpose_forward_avg(self.net2, l, m) + l_o2, m_o2, o2 = self.decompose_forward_avg(self.net2, l, m) l_o2.register_hook(lambda grad: self.grads_l2.append(grad)) m_o2.register_hook(lambda grad: self.grads_m2.append(grad)) # Net 3 - l_o3, m_o3, o3 = self.decopmpose_forward_avg(self.net3, l, m) + l_o3, m_o3, o3 = self.decompose_forward_avg(self.net3, l, m) l_o3.register_hook(lambda grad: self.grads_l3.append(grad)) m_o3.register_hook(lambda grad: self.grads_m3.append(grad)) return o1 + o2 + o3 + def predict(self, x, nbits=16): + """Makes a prediction from file or a pre-loaded image + + :param x: str or numpy.array + :param nbits: int + By default we load 16 bit images produced by CropROI Object and convert them to 8bit. + :return: tuple + Image, Heatmap, probabilities + """ + self.init_networks_from_states() + img, l, m = self.load_picture(x, nbits=nbits) + self.train(True) + self.zero_grad() + + if torch.cuda.is_available(): + out = self.forward(Variable(l.cuda()), Variable(m.cuda())) + else: + out = self.forward(Variable(l), Variable(m)) + + probs = self.sm(out).data.cpu().numpy() + + ohe = OneHotEncoder(sparse=False, n_values=5) + index = np.argmax(out.cpu().data.numpy(), axis=1).reshape(-1, 1) + + if torch.cuda.is_available(): + out.backward(torch.from_numpy(ohe.fit_transform(index)).float().cuda()) + else: + out.backward(torch.from_numpy(ohe.fit_transform(index)).float()) + + if torch.cuda.is_available(): + heatmap = self.compute_gradcam( + Variable(l.cuda()), Variable(m.cuda()), 300, 128, 7) + else: + heatmap = self.compute_gradcam( + Variable(l), Variable(m), 300, 128, 7) + + return img, heatmap, probs.squeeze() + + + def parse_args(): parser = argparse.ArgumentParser() @@ -196,64 +285,24 @@ def parse_args(): if __name__ == '__main__': config = parse_args() - mean_vector, std_vector = np.load('../snapshots_knee_grading/mean_std.npy') - normTransform = transforms.Normalize(mean_vector, std_vector) - patch_transform = transforms.Compose([ - transforms.ToTensor(), - lambda x: x.float(), - normTransform, - ]) - avg_preds = {} labels = {} - nets = [] - - for fold in config.snapshots: - for snap_path in glob(os.path.join(config.path_folds, fold, '*.pth')): - - if torch.cuda.is_available(): - net = nn.DataParallel(KneeNet(64, 0.2, True)).cuda() - else: - net = nn.DataParallel(KneeNet(64, 0.2, True)) + nets_snapshots_names = [] - net.load_state_dict(torch.load(snap_path, map_location=maybe_cuda)) - nets.append(deepcopy(net.module)) + for snp in config.snapshots: + nets_snapshots_names.extend(glob(os.path.join(config.path_folds, snp, '*.pth'))) - net = nn.DataParallel(KneeNetEnsemble(nets)) + net = KneeNetEnsemble(nets_snapshots_names, mean_std_path='../snapshots_knee_grading/mean_std.npy') if torch.cuda.is_available(): net.cuda() - # Producing the GradCAM output using the equations provided in the article - paths_test_files = glob(os.path.join(config.path_input, '**', '*.png')) + paths_test_files = glob(os.path.join(config.path_input, '*', '*.png')) if not os.path.exists(config.path_output): os.makedirs(config.path_output) - for path_test_file in paths_test_files: - img, l, m = load_picture16bit(path_test_file) - - net.train(True) - net.zero_grad() - - if torch.cuda.is_available(): - out = net.module(Variable(l.cuda()), Variable(m.cuda())) - else: - out = net.module(Variable(l), Variable(m)) - - ohe = OneHotEncoder(sparse=False, n_values=5) - index = np.argmax(out.cpu().data.numpy(), axis=1).reshape(-1, 1) - - if torch.cuda.is_available(): - out.backward(torch.from_numpy(ohe.fit_transform(index)).float().cuda()) - else: - out.backward(torch.from_numpy(ohe.fit_transform(index)).float()) - - if torch.cuda.is_available(): - heatmap = net.module.compute_gradcam( - Variable(l.cuda()), Variable(m.cuda()), 300, 128, 7) - else: - heatmap = net.module.compute_gradcam( - Variable(l), Variable(m), 300, 128, 7) + for path_test_file in tqdm(paths_test_files, total=len(paths_test_files)): + img, heatmap, probs = net.predict(path_test_file, 16) plt.figure(figsize=(7, 7)) plt.imshow(np.asarray(img), cmap=plt.cm.Greys_r) @@ -266,7 +315,6 @@ def parse_args(): plt.close() plt.figure(figsize=(7, 1)) - probs = F.softmax(out).cpu().data[0].numpy() for kl in range(5): plt.text(kl - 0.2, 0.35, "%.2f" % np.round(probs[kl], 2), fontsize=15) plt.bar(np.array([0, 1, 2, 3, 4]), probs, color='red', align='center',