-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathplot_flat_difficulites.py
114 lines (96 loc) · 4 KB
/
plot_flat_difficulites.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from __future__ import print_function, division
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
def parseOutput( path ):
data, costs = [], []
dim, ortho, handle = None, None, None
maxiter = 200
file = open( path, "r" )
for line in file:
if line.startswith( "ambient dimension: " ):
newdim = int( line[ len( "ambient dimension: " ): ] )
assert dim is None or dim == newdim
dim = newdim
if line.startswith( "given flat orthogonal dimension: " ):
ortho = int( line[ len( "given flat orthogonal dimension: " ): ] )
if line.startswith( "affine subspace dimension: " ):
handle = int( line[ len( "affine subspace dimension: " ): ] )
if line.startswith( "Terminated - " ):
assert dim is not None
assert ortho is not None
assert handle is not None
s = line[ len( "Terminated - " ): ]
if s.startswith( "max iterations reached" ):
## max iterations reached
data.append( [ dim-ortho, handle, maxiter ] )
else:
words = s.split( " " )
if words[0] == "max":
## max time reached
data.append( [ dim-ortho, handle, maxiter ] )
# data.append( [ dim-ortho, handle, int( words[4] ) ] )
else:
## min grad norm reached
data.append( [ dim-ortho, handle, int( words[5] ) ] )
dim = ortho = handle = None
if line.startswith( "Final cost: " ):
costs.append( [ float( line[ len( "Final cost: " ): ] ) ] )
data = np.array( data )
costs = np.array( costs )
assert len( data ) == len( costs )
return data, costs
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser( description='plot flat intersection difficulies' )
parser.add_argument( 'path', type=str, help='path of the data file' )
parser.add_argument( 'which', type=str, default = "error", choices = ['iterations', 'error'], help='Whether to plot "iterations" or "error"' )
parser.add_argument( '--out', type=str, help='path to save the plot' )
def str2bool(s): return {'true': True, 'yes': True, 'false': False, 'no': False}[s.lower()]
parser.add_argument( '--show', type=str2bool, default = True, help='Whether to show the result.' )
print( """Example: python3 plot_flat_difficulites.py test_difficulties/test_flat_difficulites.out-v3 error --show no --out error.pdf
Example: python3 plot_flat_difficulites.py test_difficulties/test_flat_difficulites.out-v3 iterations --show no --out iterations.pdf
""" )
args = parser.parse_args()
data, costs = parseOutput( args.path )
names = ["given flats dimension $d$", "unknown flat dimension $k$", "number of iterations"]
df = pd.DataFrame( data, columns=names )
df = df.pivot( names[0], names[1], names[2] )
## The order of magnitude is more important:
# costs = costs.round(3)
## The numbers between -10 and -30 are distracting in a log plot:
# costs = np.log10( costs)
costs = np.log10( costs).clip( -10, None )
# costs[ costs <= -10 ] = -10
data2 = np.hstack( ( data[:, :2], costs ) )
df2 = pd.DataFrame( data2, columns=names )
df2 = df2.pivot( names[0], names[1], names[2] )
df2.index = df2.index.astype( int )
df2.columns = df2.columns.astype( int )
# Draw a heatmap with the numeric values in each cell
## The default dimensions (9,6) are OK for 12-dimensional data
## but should be scaled for larger.
## Scale vertically less; it's only cramped by two-digit labels.
width = max(9, 9*df.shape[0]/12)
height = max(6, 6*(df.shape[0]-12)/8)
make_tiny = True
if make_tiny:
width *= 0.5
height *= 0.5
annot_kws={"size": 5}
print( "width:", width )
print( "height:", height )
if args.which == "iterations":
f, ax = plt.subplots( figsize=(width, height) )
sns.heatmap(df, annot=True, fmt="d", linewidths=.5, ax=ax, annot_kws=annot_kws)
ax.invert_yaxis()
elif args.which == "error":
f, ax = plt.subplots( figsize=(width, height) )
sns.heatmap(df2, annot=True, fmt=".2g", linewidths=.5, ax=ax, annot_kws=annot_kws)
ax.invert_yaxis()
if args.out:
print( "Saving", args.out, "..." )
plt.savefig(args.out)
print( "Saved:", args.out )
if args.show: plt.show()