-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsinglehead_rotated_tasks.py
72 lines (60 loc) · 2.32 KB
/
singlehead_rotated_tasks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import argparse
from math import exp
import random
import numpy as np
import torch
import pandas as pd
import seaborn as sns
sns.set_theme()
from utils.config import fetch_configs
from datasets.cifar import RotatedCIFAR10Handler
from net.smallconv import SingleHeadNet
from utils.run_net import train, evaluate
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
def run_experiment(exp_conf):
n = exp_conf['n']
hp = exp_conf['hp']
df = pd.DataFrame()
for angle in exp_conf['angles']:
print("Doing angle = {}".format(angle))
dataset = RotatedCIFAR10Handler(exp_conf['task'], angle)
i = 0
for m in exp_conf['m']:
print("m = {}".format(m))
for r, rep in enumerate(range(exp_conf['reps'])):
print("Doing rep...{}".format(rep))
df.at[i, "m"] = m
df.at[i, "r"] = r
dataset.sample_data(n=n, m=m)
train_loader = dataset.get_data_loader(hp['batch_size'], train=True)
net = SingleHeadNet(
num_task=1,
num_cls=2,
channels=3,
avg_pool=2,
lin_size=320
)
optimizer = torch.optim.SGD(net.parameters(), lr=hp['lr'],
momentum=0.9, nesterov=True,
weight_decay=hp['l2_reg'])
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, hp['epochs'] * len(train_loader))
net = train(net, hp, train_loader, optimizer, lr_scheduler, verbose=False, task_id_flag=False)
risk = evaluate(net, dataset, task_id_flag=False)
df.at[i, str(angle)] = risk
i+=1
df.to_csv('./experiments/results/rotated_tasks_exp_{}.csv'.format(str(exp_conf['task'])))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--exp_config', type=str,
default="./experiments/config/singlehead_rotated_tasks.yaml",
help="Experiment configuration")
args = parser.parse_args()
exp_conf = fetch_configs(args.exp_config)
run_experiment(exp_conf)
if __name__ == '__main__':
main()