-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmake_table.py
139 lines (119 loc) · 5.17 KB
/
make_table.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
""" make_table.py
For generating pivot tables
Collaboratively developed
by Avi Schwarzschild, Eitan Borgnia,
Arpit Bansal, and Zeyad Emam.
Developed for DeepThinking project
October 2021
"""
import argparse
import copy
import glob
import json
import pandas as pd
def get_little_df_from_one_run(data_dict):
num_new_dicts = len(data_dict["test_iters"])
test_iters = data_dict["test_iters"]
out_dict = {}
for i in range(num_new_dicts):
new_dict_i = copy.deepcopy(data_dict)
new_dict_i["test_acc"] = data_dict["test_acc"][str(test_iters[i])] \
if data_dict["test_acc"] else 0
new_dict_i["val_acc"] = data_dict["val_acc"][str(test_iters[i])] \
if data_dict["val_acc"] else 0
new_dict_i["train_acc"] = data_dict["train_acc"][str(test_iters[i])] \
if data_dict["train_acc"] else 0
new_dict_i["test_iter"] = test_iters[i]
out_dict[i] = new_dict_i
little_df = pd.DataFrame.from_dict(out_dict, orient="index")
return little_df
def get_df(filepath, acc_filter):
pd.set_option("display.max_rows", None)
df = pd.DataFrame()
num_checkpoints = 0
for f_name in glob.iglob(f"{filepath}/**/*testing*/stats.json", recursive=True):
num_checkpoints += 1
with open(f_name, "r") as fp:
data = json.load(fp)
if acc_filter is not None:
m = data["max_iters"]
if data["train_acc"][str(m)] > acc_filter:
df = df.append(get_little_df_from_one_run(data))
else:
df = df.append(get_little_df_from_one_run(data))
num_trained = len(df)
return df, num_checkpoints, num_trained
def get_table(filepath, disp_max, disp_min, filter_at=None, max_iters_list=None,
alpha_list=None, width_list=None, model_list=None):
pd.set_option("display.max_rows", None)
df, num_checkpoints, num_trained = get_df(filepath, filter_at)
df["count"] = 1
if max_iters_list:
frames = []
for max_iters in max_iters_list:
frames.append(df[df["max_iters"] == int(max_iters)])
df = pd.concat(frames)
if alpha_list:
frames = []
for alpha in alpha_list:
frames.append(df[df["alpha"] == float(alpha)])
df = pd.concat(frames)
if width_list:
frames = []
for width in width_list:
frames.append(df[df["model"].str.contains(str(width))])
df = pd.concat(frames)
if model_list:
frames = []
for model in model_list:
frames.append(df[df["model"].str.contains(model)])
df = pd.concat(frames)
# df = df[df.model_path.str.contains("best")]
index = ["model", "test_data", "max_iters", "alpha", "test_mode", "test_iter"]
values = ["mean", "sem"]
if disp_max:
values.append("max")
if disp_min:
values.append("min")
df.drop_duplicates(subset=["model_path", "test_iter", "test_data", "test_mode"], inplace=True)
try:
table = pd.pivot_table(df, index=index, aggfunc={"train_acc": values,
"val_acc": values,
"test_acc": values,
"count": "count"})
except:
raise KeyError("No data in the table. Check the path to test results. "
"Possibly, there are no results from models that trained "
"with accuracy above the filter value.")
return table
def main():
parser = argparse.ArgumentParser(description="Analysis parser")
parser.add_argument("filepath", type=str)
parser.add_argument("--filter", type=float, default=None,
help="cutoff for filtering by training acc?")
parser.add_argument("--max_iters_list", type=int, nargs="+", default=None,
help="only plot models with max iters in given list")
parser.add_argument("--model_list", type=str, nargs="+", default=None,
help="only plot models with model name in given list")
parser.add_argument("--width_list", type=str, nargs="+", default=None,
help="only plot models with widths in given list")
parser.add_argument("--alpha_list", type=float, nargs="+", default=None,
help="only plot models with alphas in given list")
parser.add_argument("--max", action="store_true", help="add max values to table?")
parser.add_argument("--min", action="store_true", help="add min values too table?")
parser.add_argument("--save_name", type=str, default=None, help="add min values too table?")
args = parser.parse_args()
table = get_table(args.filepath,
args.max,
args.min,
filter_at=args.filter,
max_iters_list=args.max_iters_list,
alpha_list=args.alpha_list,
width_list=args.width_list,
model_list=args.model_list)
table = table.round(3)
print(table.to_markdown())
if args.save_name is not None:
table.to_csv(args.save_name)
if __name__ == "__main__":
main()