forked from zhangzjn/ADer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathruns_single_class.py
52 lines (47 loc) · 2.34 KB
/
runs_single_class.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import subprocess
from multiprocessing import Process
import argparse
def runcmd(command):
# ret = subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8")
ret = subprocess.run(command, shell=True)
return ret
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--dataset', type=str, default='mvtec')
parser.add_argument('-c', '--config', type=str, default='configs/rd/rd_mvtec.py')
parser.add_argument('-p', '--prefix', type=str, default='single_cls')
parser.add_argument('-n', '--nproc_per_node', type=int, default=8)
parser.add_argument('-m', '--multiprocessing', type=int, default=-1)
parser.add_argument('-g', '--gpu_num', type=int, default=0)
cfg = parser.parse_args()
if cfg.dataset == 'mvtec':
names = ['carpet', 'grid', 'leather', 'tile', 'wood',
'bottle', 'cable', 'capsule', 'hazelnut', 'metal_nut',
'pill', 'screw', 'toothbrush', 'transistor', 'zipper', ]
elif cfg.dataset == 'mvtec3d':
names = ['bagel', 'cable_gland', 'carrot', 'cookie', 'dowel',
'foam', 'peach', 'potato', 'rope', 'tire', ]
elif cfg.dataset == 'visa':
names = ['candle', 'capsules', 'cashew', 'chewinggum', 'fryum',
'macaroni1', 'macaroni2', 'pcb1', 'pcb2', 'pcb3',
'pcb4', 'pipe_fryum', ]
config = cfg.config
suffix = config.split('/')[-1].split('.')[0]
nproc_per_node = cfg.nproc_per_node
use_multiprocessing = True if cfg.multiprocessing > 0 else False
if use_multiprocessing:
process_list = []
for i, name in enumerate(names):
command = f'CUDA_VISIBLE_DEVICES={i % nproc_per_node} '
command += f'python3 run.py -c {cfg.config} -m train data.cls_names={name} trainer.checkpoint=runs/{cfg.prefix}/{suffix}/{name}'
p = Process(target=runcmd, args=(command,))
p.start()
process_list.append(p)
for p in process_list:
p.join()
else:
for i, name in enumerate(names):
command = f'CUDA_VISIBLE_DEVICES={cfg.gpu_num} '
command += f'python3 run.py -c {cfg.config} -m train data.cls_names={name} trainer.checkpoint=runs/{cfg.prefix}/{suffix}/{name}'
ret = runcmd(command)
print(command, ret)