-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
122 lines (96 loc) · 3.37 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import pickle
import torch
from PIL import Image
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
def mean_and_std(train_dataset, batch_size, num_workers):
loader = DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=False
)
num_samples = 0.
channel_mean = torch.Tensor([0., 0., 0.])
channel_std = torch.Tensor([0., 0., 0.])
for samples in tqdm(loader):
X = samples
channel_mean += X.mean((2, 3)).sum(0)
num_samples += X.size(0)
channel_mean /= num_samples
for samples in tqdm(loader):
X = samples
batch_samples = X.size(0)
X = X.permute(0, 2, 3, 1).reshape(-1, 3)
channel_std += ((X - channel_mean) ** 2).mean(0) * batch_samples
channel_std = torch.sqrt(channel_std / num_samples)
mean, std = channel_mean.tolist(), channel_std.tolist()
print('mean: {}'.format(mean))
print('std: {}'.format(std))
return mean, std
def save_weights(model, save_path):
if isinstance(model, torch.nn.DataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(state_dict, save_path)
def print_msg(msg, appendixs=[]):
max_len = len(max([msg, *appendixs], key=len))
print('=' * max_len)
print(msg)
for appendix in appendixs:
print(appendix)
print('=' * max_len)
def print_config(configs):
for name, config in configs.items():
print('====={}====='.format(name))
_print_config(config)
print('=' * (len(name) + 10))
print()
def _print_config(config, indentation=''):
for key, value in config.items():
if isinstance(value, dict):
print('{}{}:'.format(indentation, key))
_print_config(value, indentation + ' ')
else:
print('{}{}: {}'.format(indentation, key, value))
def print_dataset_info(datasets):
train_dataset, test_dataset, val_dataset = datasets
print('=========================')
print('Dataset Loaded.')
print('Training:\t{}'.format(len(train_dataset)))
print('Validation:\t{}'.format(len(val_dataset)))
print('Test:\t\t{}'.format(len(test_dataset)))
print('=========================')
def pil_loader(path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
# unnormalize image for visualization
def inverse_normalize(tensor, mean, std):
for t, m, s in zip(tensor, mean, std):
t.mul_(s).add_(m)
return tensor
# convert labels to onehot
def one_hot(labels, num_classes, device, dtype):
y = torch.eye(num_classes, device=device, dtype=dtype)
return y[labels]
# convert type of target according to criterion
def select_target_type(y, criterion):
if criterion in ['cross_entropy', 'kappa_loss']:
y = y.long()
elif criterion in ['mean_square_root', 'L1', 'smooth_L1']:
y = y.float()
elif criterion in ['focal_loss']:
y = y.to(dtype=torch.int64)
else:
raise NotImplementedError('Not implemented criterion.')
return y
# convert output dimension of network according to criterion
def select_out_features(num_classes, criterion):
out_features = num_classes
if criterion in ['mean_square_root', 'L1', 'smooth_L1']:
out_features = 1
return out_features