-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdomain_adaptation.py
60 lines (45 loc) · 2.14 KB
/
domain_adaptation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import collections
import os
# from adatime.da.da import get_backbone_class
from models.adatime.da.models import get_backbone_class
from models.adatime.da.algorithms import get_algorithm_class
from models.adatime.load_data import load_data
from models.adatime.log_utils.log_experiment import starting_logs
from models.adatime.utils import AverageMeter, save_checkpoint
from utils.model_size import get_model_size
from utils.path_utils import project_root
from models.adatime.configs.get_configs import Config
if __name__ == '__main__':
# Args
config = Config()
# Source and target dataset paths
source_path = os.path.join(project_root(), 'data', 'tl_datasets', 'pretrain', 'pretrain.pt')
target_path = os.path.join(project_root(), 'data', 'tl_datasets', 'finetune', 'finetune.pt')
source_dataloader = load_data(source_path, config)
target_dataloader = load_data(target_path, config)
# Domain Adaptation Algorithm
da_algorithm_name = 'AdvSKM'
da_algorithm = get_algorithm_class(da_algorithm_name)
# Setting up backbone
backbone_name = 'GTN'
da_backbone = get_backbone_class(backbone_name)
# Initializing algorithm
device = 'cuda'
algorithm = da_algorithm(backbone=da_backbone, configs=config, device=device)
algorithm.to(device)
# Model Size
get_model_size(algorithm)
# Training DA algorithm
home_path = os.path.join(project_root(), 'results', 'adatime')
save_dir = f'{da_algorithm_name}'
dataset_name = 'pretrain_finetune'
experiment_description = dataset_name
exp_name = f'{da_algorithm_name}_gtn'
run_description = f"{da_algorithm_name}_{exp_name}"
exp_log_dir = os.path.join(home_path, save_dir, experiment_description, f"{run_description}")
logger, scenario_log_dir = starting_logs(dataset_name, da_algorithm_name,
exp_log_dir, '0', '1', '0')
loss_avg_meters = collections.defaultdict(lambda: AverageMeter())
last_model, best_model = algorithm.update(source_dataloader, target_dataloader, loss_avg_meters, logger)
# Save checkpoint
save_checkpoint(home_path, scenario_log_dir, last_model, best_model)