Skip to content

Commit

Permalink
Change automatic upsample to interpolate to match skimage preprocessi…
Browse files Browse the repository at this point in the history
…ng (#161)

* change to interpolate and refactor
  • Loading branch information
ieee8023 authored Jan 3, 2025
1 parent 5a8984c commit 09bafae
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 88 deletions.
8 changes: 4 additions & 4 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ def test_normalization_check():
# so here the first 2 pixels are set to the limits
test_x[0][0][0] = ra[0]
test_x[0][0][1] = ra[1]
xrv.models.warning_log = {}
xrv.utils.warning_log = {}
model(test_x)
assert xrv.models.warning_log['norm_correct'] == False, ra
assert xrv.utils.warning_log['norm_correct'] == False, ra

for ra in correct_ranges:
test_x = torch.zeros([1,1,224,224])
test_x.uniform_(ra[0], ra[1])
xrv.models.warning_log = {}
xrv.utils.warning_log = {}
model(test_x)
assert xrv.models.warning_log['norm_correct'] == True, ra
assert xrv.utils.warning_log['norm_correct'] == True, ra

10 changes: 4 additions & 6 deletions torchxrayvision/baseline_models/chestx_det/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,13 @@ def __init__(self, cache_dir:str = None):

model.eval()
self.model = model
self.upsample = nn.Upsample(
size=(512, 512),
mode='bilinear',
align_corners=False,
)

def forward(self, x):

x = x.repeat(1, 3, 1, 1)
x = self.upsample(x)

x = utils.fix_resolution(x, 512, self)
utils.warn_normalization(x)

# expecting values between [-1024,1024]
x = (x + 1024) / 2048
Expand Down
6 changes: 3 additions & 3 deletions torchxrayvision/baseline_models/chexpert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.nn as nn
from .model import Tasks2Models
from ... import utils


class DenseNet(nn.Module):
Expand Down Expand Up @@ -52,8 +53,6 @@ def __init__(self, weights_zip="", num_models=30):
dynamic=False,
use_gpu=self.use_gpu)

self.upsample = nn.Upsample(size=(320, 320), mode='bilinear', align_corners=False)

self.pathologies = self.targets

def forward(self, x):
Expand All @@ -80,7 +79,8 @@ def forward(self, x):

def features(self, x):
x = x.repeat(1, 3, 1, 1)
x = self.upsample(x)
x = utils.fix_resolution(x, 320, self)
utils.warn_normalization(x)

# expecting values between [-1024,1024]
x = x / 512
Expand Down
12 changes: 4 additions & 8 deletions torchxrayvision/baseline_models/emory_hiti/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn
import torchvision
import torchxrayvision as xrv

from ... import utils

class RaceModel(nn.Module):
"""This model is from the work below and is trained to predict the
Expand Down Expand Up @@ -78,12 +78,6 @@ def __init__(self):
print("Loading failure. Check weights file:", self.weights_filename_local)
raise e

self.upsample = nn.Upsample(
size=(320, 320),
mode='bilinear',
align_corners=False,
)

self.targets = ["Asian", "Black", "White"]

self.mean = np.array([0.485, 0.456, 0.406])
Expand All @@ -93,7 +87,9 @@ def __init__(self):

def forward(self, x):
x = x.repeat(1, 3, 1, 1)
x = self.upsample(x)

x = utils.fix_resolution(x, 320, self)
utils.warn_normalization(x)

# Expecting values between [-1024,1024]
x = (x + 1024) / 2048
Expand Down
6 changes: 4 additions & 2 deletions torchxrayvision/baseline_models/jfhealthcare/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pathlib
import torch
import torch.nn as nn
from ... import utils


class DenseNet(nn.Module):
Expand Down Expand Up @@ -76,13 +77,14 @@ def __init__(self, **entries):
raise (e)

self.model = model
self.upsample = nn.Upsample(size=(512, 512), mode='bilinear', align_corners=False)

self.pathologies = self.targets

def forward(self, x):
x = x.repeat(1, 3, 1, 1)
x = self.upsample(x)

x = utils.fix_resolution(x, 512, self)
utils.warn_normalization(x)

# expecting values between [-1024,1024]
x = x / 512
Expand Down
11 changes: 4 additions & 7 deletions torchxrayvision/baseline_models/riken/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torchvision
import pathlib
import torchxrayvision as xrv
from ... import utils


class AgeModel(nn.Module):
Expand Down Expand Up @@ -71,20 +72,16 @@ def __init__(self):
print("Loading failure. Check weights file:", self.weights_filename_local)
raise e

self.upsample = nn.Upsample(
size=(320, 320),
mode='bilinear',
align_corners=False,
)

self.norm = torchvision.transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225],
)

def forward(self, x):
x = x.repeat(1, 3, 1, 1)
x = self.upsample(x)

x = utils.fix_resolution(x, 320, self)
utils.warn_normalization(x)

# expecting values between [-1024,1024]
x = (x + 1024) / 2048
Expand Down
11 changes: 4 additions & 7 deletions torchxrayvision/baseline_models/xinario/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torchvision
import pathlib
import torchxrayvision as xrv
from ... import utils


class ViewModel(nn.Module):
Expand Down Expand Up @@ -63,20 +64,16 @@ def __init__(self):
print("Loading failure. Check weights file:", self.weights_filename_local)
raise e

self.upsample = nn.Upsample(
size=(224, 224),
mode='bilinear',
align_corners=False,
)

self.norm = torchvision.transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225],
)

def forward(self, x):
x = x.repeat(1, 3, 1, 1)
x = self.upsample(x)

x = utils.fix_resolution(x, 224, self)
utils.warn_normalization(x)

# expecting values between [-1024,1024]
x = (x + 1024) / 2048
Expand Down
61 changes: 10 additions & 51 deletions torchxrayvision/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def __init__(self,
self.weights_filename_local = get_weights(weights, cache_dir)

try:
savedmodel = torch.load(self.weights_filename_local, map_location='cpu')
savedmodel = torch.load(self.weights_filename_local, map_location='cpu', weights_only=False)
# patch to load old models https://github.com/pytorch/pytorch/issues/42242
for mod in savedmodel.modules():
if not hasattr(mod, "_non_persistent_buffers_set"):
Expand All @@ -313,25 +313,24 @@ def __init__(self,
if "op_threshs" in model_urls[weights]:
self.op_threshs = torch.tensor(model_urls[weights]["op_threshs"])

self.upsample = nn.Upsample(size=(224, 224), mode='bilinear', align_corners=False)

def __repr__(self):
if self.weights is not None:
return "XRV-DenseNet121-{}".format(self.weights)
else:
return "XRV-DenseNet"

def features2(self, x):
x = fix_resolution(x, 224, self)
warn_normalization(x)
x = utils.fix_resolution(x, 224, self)
utils.warn_normalization(x)

features = self.features(x)
out = F.relu(features, inplace=True)
out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1)
return out

def forward(self, x):
x = fix_resolution(x, 224, self)
x = utils.fix_resolution(x, 224, self)
utils.warn_normalization(x)

features = self.features2(x)
out = self.classifier(features)
Expand Down Expand Up @@ -412,16 +411,14 @@ def __init__(self, weights: str = None, apply_sigmoid: bool = False, cache_dir:
self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

try:
self.model.load_state_dict(torch.load(self.weights_filename_local))
self.model.load_state_dict(torch.load(self.weights_filename_local, map_location='cpu', weights_only=False))
except Exception as e:
print("Loading failure. Check weights file:", self.weights_filename_local)
raise e

if "op_threshs" in model_urls[weights]:
self.register_buffer('op_threshs', torch.tensor(model_urls[weights]["op_threshs"]))

self.upsample = nn.Upsample(size=(512, 512), mode='bilinear', align_corners=False)

self.eval()

def __repr__(self):
Expand All @@ -431,8 +428,8 @@ def __repr__(self):
return "XRV-ResNet"

def features(self, x):
x = fix_resolution(x, 512, self)
warn_normalization(x)
x = utils.fix_resolution(x, 512, self)
utils.warn_normalization(x)

x = self.model.conv1(x)
x = self.model.bn1(x)
Expand All @@ -449,8 +446,8 @@ def features(self, x):
return x

def forward(self, x):
x = fix_resolution(x, 512, self)
warn_normalization(x)
x = utils.fix_resolution(x, 512, self)
utils.warn_normalization(x)

out = self.model(x)

Expand All @@ -463,44 +460,6 @@ def forward(self, x):
return out


warning_log = {}


def fix_resolution(x, resolution: int, model: nn.Module):
"""Check resolution of input and resize to match requested."""

# just skip it if upsample was removed somehow
if not hasattr(model, 'upsample') or (model.upsample == None):
return x

if (x.shape[2] != resolution) | (x.shape[3] != resolution):
if not hash(model) in warning_log:
print("Warning: Input size ({}x{}) is not the native resolution ({}x{}) for this model. A resize will be performed but this could impact performance.".format(x.shape[2], x.shape[3], resolution, resolution))
warning_log[hash(model)] = True
return model.upsample(x)
return x


def warn_normalization(x):
"""Check normalization of input and warn if possibly wrong. When
processing an image that may likely not have the correct
normalization we can issue a warning. But running min and max on
every image/batch is costly so we only do it on the first image/batch.
"""

# Only run this check on the first image so we don't hurt performance.
if not "norm_check" in warning_log:
x_min = x.min()
x_max = x.max()
if torch.logical_or(-255 < x_min, x_max < 255) or torch.logical_or(x_min < -1024, 1024 < x_max):
print(f'Warning: Input image does not appear to be normalized correctly. The input image has the range [{x_min:.2f},{x_max:.2f}] which doesn\'t seem to be in the [-1024,1024] range. This warning may be wrong though. Only the first image is tested and we are only using a heuristic in an attempt to save a user from using the wrong normalization.')
warning_log["norm_correct"] = False
else:
warning_log["norm_correct"] = True

warning_log["norm_check"] = True


def op_norm(outputs, op_threshs):
"""Normalize outputs according to operating points for a given model.
Args:
Expand Down
40 changes: 40 additions & 0 deletions torchxrayvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,43 @@ def infer(model: torch.nn.Module, dataset: torch.utils.data.Dataset, threads=4,
preds.append(output)

return np.concatenate(preds)


warning_log = {}

def fix_resolution(x, resolution: int, model):
"""Check resolution of input and resize to match requested."""

if len(x.shape) == 3:
# Extend to be 4D
x = x[None,...]

if x.shape[2] != x.shape[3]:
raise Exception(f"Height and width of the image must be the same. Input: {x.shape[2]} != {x.shape[3]}. Perform a center crop first.")

if (x.shape[2] != resolution) | (x.shape[3] != resolution):
if not hash(model) in warning_log:
print("Warning: Input size ({}x{}) is not the native resolution ({}x{}) for this model. A resize will be performed but this could impact performance.".format(x.shape[2], x.shape[3], resolution, resolution))
warning_log[hash(model)] = True
return torch.nn.functional.interpolate(x, size=(resolution, resolution), mode='bilinear', antialias=True)
return x


def warn_normalization(x):
"""Check normalization of input and warn if possibly wrong. When
processing an image that may likely not have the correct
normalization we can issue a warning. But running min and max on
every image/batch is costly so we only do it on the first image/batch.
"""

# Only run this check on the first image so we don't hurt performance.
if not "norm_check" in warning_log:
x_min = x.min()
x_max = x.max()
if torch.logical_or(-255 < x_min, x_max < 255) or torch.logical_or(x_min < -1025, 1025 < x_max):
print(f'Warning: Input image does not appear to be normalized correctly. The input image has the range [{x_min:.2f},{x_max:.2f}] which doesn\'t seem to be in the [-1024,1024] range. This warning may be wrong though. Only the first image is tested and we are only using a heuristic in an attempt to save a user from using the wrong normalization.')
warning_log["norm_correct"] = False
else:
warning_log["norm_correct"] = True

warning_log["norm_check"] = True

0 comments on commit 09bafae

Please sign in to comment.