This repository has been archived by the owner on May 19, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathverify.py
71 lines (59 loc) · 2.36 KB
/
verify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import argparse
from collections import Counter
import torch
from PIL import Image
from tqdm import tqdm
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from resnet import get_resnet, name_to_params
class ImagenetValidationDataset(Dataset):
def __init__(self, val_path):
super().__init__()
self.val_path = val_path
self.transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()])
with open(os.path.join(val_path, 'ILSVRC2012_validation_ground_truth.txt')) as f:
self.labels = [int(l) - 1 for l in f.readlines()]
def __len__(self):
return len(self.labels)
def __getitem__(self, item):
img = Image.open(os.path.join(self.val_path, f'ILSVRC2012_val_{item + 1:08d}.JPEG')).convert('RGB')
return self.transform(img), self.labels[item]
def accuracy(output, target, topk=(1,)):
maxk = max(topk)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t().cpu()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum().item()
res.append(correct_k)
return res
@torch.no_grad()
def run(pth_path):
device = 'cuda'
dataset = ImagenetValidationDataset('./val/')
data_loader = DataLoader(dataset, batch_size=64, shuffle=False, pin_memory=True, num_workers=8)
model, _ = get_resnet(*name_to_params(pth_path))
model.load_state_dict(torch.load(pth_path)['resnet'])
model = model.to(device).eval()
preds = []
target = []
for images, labels in tqdm(data_loader):
_, pred = model(images.to(device), apply_fc=True).topk(1, dim=1)
preds.append(pred.squeeze(1).cpu())
target.append(labels)
p = torch.cat(preds).numpy()
t = torch.cat(target).numpy()
all_counters = [Counter() for i in range(1000)]
for i in range(50000):
all_counters[t[i]][p[i]] += 1
total_correct = 0
for i in range(1000):
total_correct += all_counters[i].most_common(1)[0][1]
print(f'ACC: {total_correct / 50000 * 100}')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='SimCLR verifier')
parser.add_argument('pth_path', type=str, help='path of the input checkpoint file')
args = parser.parse_args()
run(args.pth_path)