forked from amazon-science/tubelet-transformer
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval_tuber_ava.py
59 lines (47 loc) · 2.06 KB
/
eval_tuber_ava.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import argparse
import datetime
import time
import torch
import torch.optim
from tensorboardX import SummaryWriter
from models.tuber_ava import build_model
from utils.model_utils import deploy_model, load_model, save_checkpoint
from utils.video_action_recognition import validate_tuber_detection
from pipelines.video_action_recognition_config import get_cfg_defaults
from pipelines.launch import spawn_workers
from utils.utils import build_log_dir
from datasets.ava_frame import build_dataloader
def main_worker(cfg):
# create tensorboard and logs
if cfg.DDP_CONFIG.GPU_WORLD_RANK == 0:
tb_logdir = build_log_dir(cfg)
writer = SummaryWriter(log_dir=tb_logdir)
else:
writer = None
# cfg.freeze()
# create model
print('Creating TubeR model: %s' % cfg.CONFIG.MODEL.NAME)
model, criterion, postprocessors = build_model(cfg)
model = deploy_model(model, cfg, is_tuber=True)
num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Number of parameters in the model: %6.2fM' % (num_parameters / 1000000))
# create dataset and dataloader
_, test_loader, _, test_sampler,_ = build_dataloader(cfg)
# docs: add resume option
if not cfg.CONFIG.MODEL.LOAD: raise ("model dir not found")
model, _ = load_model(model, cfg, load_fc=cfg.CONFIG.MODEL.LOAD_FC)
print('Start training...')
start_time = time.time()
validate_tuber_detection(cfg, model, criterion, postprocessors, test_loader, 0, writer)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('testing time {}'.format(total_time_str))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train video action recognition transformer models.')
parser.add_argument('--config-file',
default='/xxx/TubeR_AVA_v2.2_CSN-152.yaml',
help='path to config file.')
args = parser.parse_args()
cfg = get_cfg_defaults()
cfg.merge_from_file(args.config_file)
spawn_workers(main_worker, cfg)