diff --git a/main.py b/main.py index 3970641..566eaaf 100644 --- a/main.py +++ b/main.py @@ -14,7 +14,7 @@ from termcolor import cprint from engine import evaluate, train_one_epoch -from qtcls import build_criterion, build_dataset, build_model, build_optimizer, build_scheduler +from qtcls import __version__, build_criterion, build_dataset, build_model, build_optimizer, build_scheduler from qtcls.utils.io import checkpoint_saver, checkpoint_loader, variables_loader, variables_saver from qtcls.utils.misc import makedirs, init_distributed_mode, init_seeds, is_main_process @@ -44,7 +44,7 @@ def get_args_parser(): parser.add_argument('--print_freq', type=int, default=50) parser.add_argument('--need_targets', action='store_true', help='need targets for training') parser.add_argument('--drop_lr_now', action='store_true') - parser.add_argument('--drop_last', action='store_true') + parser.add_argument('--drop_last', type=bool, default=True) parser.add_argument('--amp', action='store_true', help='automatic mixed precision training') parser.add_argument('--no_dist', action='store_true', help='forcibly disable distributed mode') @@ -108,6 +108,7 @@ def get_args_parser(): def main(args): init_seeds(args.seed) init_distributed_mode(args) + cprint(f'QTClassification v{__version__}', 'light_green', attrs=['bold']) device = torch.device(args.device if torch.cuda.is_available() else 'cpu') if device.type == 'cpu' or args.eval: args.amp = False