-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
87 lines (66 loc) · 2.66 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
""" 추론 코드
TODO:
NOTES:
REFERENCE:
* MNC 코드 템플릿 predict.py
UPDATED:
"""
import os
import random
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from modules.dataset import TestDataset
from modules.trainer import Trainer
from modules.utils import load_yaml, save_csv
import torch
from model.model import PestClassifier
import cv2
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from modules.metrics import get_metric_fn
import torch.nn as nn
os.environ["CUDA_VISIBLE_DEVICES"]="1"
# CONFIG
PROJECT_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_PROJECT_DIR = os.path.dirname(PROJECT_DIR)
DATA_DIR = '../shared/test'
PREDICT_CONFIG_PATH = os.path.join(PROJECT_DIR, 'config/predict_config.yml')
config = load_yaml(PREDICT_CONFIG_PATH)
# SEED
RANDOM_SEED = config['SEED']['random_seed']
# PREDICT
BATCH_SIZE = config['PREDICT']['batch_size']
########################################### Input Shape 설정 필요! ###########################################
INPUT_SHAPE = (528, 528)
if __name__ == '__main__':
# Set random seed
torch.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
########################################### Best Model, predict 결과 저장 할 위치 설정 필요! ###########################################
TRAINED_MODEL_PATH = 'results/train/Efficientb6_20211102180641/best.pt'
SAVE_PATH = 'results/csv/[0.9953]Efficientnetb6-layer(1280-500-250-10)-ES(50)-IS(528)_Aug(NoColor).csv'
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
test_dataset = TestDataset(data_dir=DATA_DIR, input_shape=INPUT_SHAPE)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
print('Test set samples:',len(test_dataset))
criterion = nn.CrossEntropyLoss()
metric_fn = get_metric_fn
# Load Model
model = PestClassifier(num_class=10).to(device)
model.load_state_dict(torch.load(TRAINED_MODEL_PATH, map_location=torch.device('cpu'))['model'])
# model.load_state_dict(torch.load(TRAINED_MODEL_PATH, map_location=torch.device('cpu'))['model'])
pred_lst = []
file_name_lst = []
with torch.no_grad():
for batch_index, (img, file_name) in enumerate(test_dataloader):
img = img.to(device)
pred = model(img)
pred_lst.extend(pred.argmax(dim=1).cpu().tolist())
file_name_lst.extend(file_name)
df = pd.DataFrame({'file_name':file_name_lst, 'answer':pred_lst})
df.to_csv(SAVE_PATH, index=None)