-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAA_score.py
159 lines (135 loc) · 6.25 KB
/
AA_score.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
Code to produce AA scores for DT-models, changing parameters for the models is done in the exact same way as for training
"""
import torch
import numpy as np
import deepthinking.models as models
import deepthinking as dt
from omegaconf import DictConfig, OmegaConf
import hydra
import sys
import os
from tqdm import tqdm
import nvidia_smi
import argparse
def get_mazes_net(device, path):
"""
Returns the DT recall (progressive) network NOT in evaluation mode to solve maze problems
Note: We are not in eval mode as we want to access interim_thought
Args:
path (str): The relative file path to the network to be loaded
device (str): the device to store the network on
Returns:
torch.nn: the neural net
"""
state_dict = torch.load(path, map_location=device)
net = getattr(models, "dt_net_recall_2d")(width=128, in_channels=3, max_iters=50)
net = net.to(device)
net = torch.nn.DataParallel(net)
net.load_state_dict(state_dict["net"])
return net
def get_sums_net(device, path):
"""
Returns the DT recall (progressive) network NOT in evaluation mode to solve prefix sums problems
Note: We are not in eval mode as we want to access interim_thought
Args:
path (str): The relative file path to the network to be loaded
device (str): the device to store the network on
Returns:
torch.nn: the neural net
"""
net = getattr(models, "dt_net_recall_1d")(width=400, in_channels=3, max_iters=300)
state_dict = torch.load(path, map_location=device)
net = net.to(device)
net = torch.nn.DataParallel(net)
net.load_state_dict(state_dict["net"])
return net
def get_chess_net(device, path):
"""
Returns the DT recall (progressive) network NOT in evaluation mode to solve chess problems
Note: We are not in eval mode as we want to access interim_thought
Args:
path (str): The relative file path to the network to be loaded
device (str): the device to store the network on
Returns:
torch.nn: the neural net
"""
net = getattr(models, "dt_net_recall_2d")(width=512, in_channels=12, max_iters=30)
state_dict = torch.load(path, map_location=device)
net = net.to(device)
net = torch.nn.DataParallel(net)
net.load_state_dict(state_dict["net"])
return net
def compute_cross_pi(net, testloader, iters, problem, device):
"""
Computes the Asymptotic Alignment Score (AA score) for a given network
Code provided by Cem Anil and Ashwini Pokle, a prerelease of their code base which is currently being cleaned ready for release
Code modified by Authors, minor chnaged such as the if statement to make the code work with my models and codebase
Args:
net (torch.nn): the network to be tested
testloader (torch.utils.data.Dataloader): The test data to run the tests on
iters (int): The number of iterations to run each part of the test for
device (str): The device we are working on
alpha (int): The alpha value of the network
Returns:
float: the AA score
"""
total = 0
idx = 0
path_indep_val = 0
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
with torch.no_grad():
for inputs, targets in tqdm(testloader, leave=False):
torch.cuda.empty_cache()
inputs, targets = inputs.to(device), targets.to(device)
init_outputs, fp_val1 = net(inputs, return_fp=True)
if problem.name == "prefix_sums": #prefix sums need a different shape input, to other two problems
tiled_inputs = torch.tile(inputs, (inputs.shape[0], 1, 1))
tiled_targets = torch.tile(targets, (targets.shape[0], 1))
else:
tiled_inputs = torch.tile(inputs, (inputs.shape[0], 1, 1, 1))
tiled_targets = torch.tile(targets, (targets.shape[0], 1, 1))
repeated_fp = torch.repeat_interleave(fp_val1, repeats=inputs.shape[0], dim=0)
next_outputs, fp_val2 = net(tiled_inputs, interim_thought=repeated_fp, return_fp=True)
total += fp_val2.size(0)
idx = np.arange(0, tiled_inputs.shape[0], inputs.shape[0])
fp1 = repeated_fp.view(repeated_fp.shape[0], -1)
fp2 = fp_val2.view(fp_val2.shape[0], -1)
bsz = inputs.shape[0]
for i in range(inputs.shape[0]):
cur_idx = idx + i
conseq_idx = np.arange(i*bsz, i*bsz + inputs.shape[0])
path_indep_val += cos(fp1[cur_idx], fp2[conseq_idx]).sum()
return path_indep_val/total
@hydra.main(config_path="config", config_name="test_model_config")
def main(cfg: DictConfig):
"""
Computes the Asymptotic Alignment score (AA score) for each model in the filepaths list and writes this to a file
Args:
cfg (DictConfig): Uses the hydra framework like the rest of the codebase from Bansal et al to take command line arguments
"""
problem = cfg.problem
loaders = dt.utils.get_dataloaders(problem)
os.chdir('../../..') # as we have created an instace we are in the file for that instance
testloader = loaders["test"] # uses the test input paramater data
device = "cuda" if torch.cuda.is_available() else "cpu"
iters = 300
aa = []
filepaths = ["outputs/mazes_13x13/training-smuggest-Bo/model_best.pth", "outputs/mazes_13x13/training-repand-Natilee/model_best.pth"] # Add your file paths here, remember they must all be for the same problem, e.g. prefix sums, this must match the command line argument given.
# EXAMPLE: filepaths = ["outputs/prefix_sums_ablation/training-unbalanced-Nick/model_best.pth"]
for path in filepaths:
if problem.name == "prefix_sums":
net = get_sums_net(device, path)
elif problem.name == "mazes":
net = get_mazes_net(device, path)
else:
net = get_chess_net(device, path)
aa_score = compute_cross_pi(net, testloader, iters, problem, device).tolist()
aa.append(aa_score)
file_name = f"score_{problem.name}_{problem.train_data}.txt"
with open(file_name, 'w+') as f:
f.write(f"for models in paths: {filepaths} the time array is {aa}")
if __name__ == "__main__":
run_id = dt.utils.generate_run_id()
sys.argv.append(f"+run_id={run_id}")
main()