-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain.py
31 lines (28 loc) · 1.03 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# 모듈 import
from mmcv import Config
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.apis import train_detector, set_random_seed
import argparse
# set a argument parser
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--configs',
type=str,
help='The config file which train model',
default='swin_dyhead_baseline_lr_config_cosinerestart.py'
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
print('/opt/ml/level2_objectdetection_cv-level2-cv-13/configs/' + args.configs)
cfg = Config.fromfile('/opt/ml/level2_objectdetection_cv-level2-cv-13/configs/' + args.configs)
# build_dataset
datasets = [build_dataset(cfg.data.train)]
# 모델 build 및 pretrained network 불러오기
model = build_detector(cfg.model)
model.init_weights()
set_random_seed(2022, deterministic= True)
train_detector(model, datasets, cfg, distributed=False, validate=True)