-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathpymanopt_test_both.py
67 lines (56 loc) · 1.96 KB
/
pymanopt_test_both.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import autograd.numpy as np
from pymanopt.manifolds import Stiefel, Grassmann, Euclidean, Product
from pymanopt import Problem
from pymanopt.solvers import SteepestDescent, TrustRegions, ConjugateGradient
# (1) Instantiate a manifold
poses = 10
handles = 10
N = 100
Q = 3*poses
# p, B
manifold = Product( [ Euclidean(12*poses), Grassmann(12*poses, handles) ] + [Euclidean(handles)]*N + [Euclidean(12*poses-Q)]*N )
## (1b) Generate data
## TODO: Zero energy test data.
np.random.seed(0)
## Create a bunch of orthonormal rows and a point (rhs)
flats = [ ( np.random.random(( Q, 12*poses )), np.random.random(12*poses) ) for i in range(N) ]
## Orthonormalize the rows
flats = [ ( np.linalg.svd( A, full_matrices=False )[2][:Q], a ) for A, a in flats ]
## Generate null spaces
nullspaces = [ np.linalg.svd( A, full_matrices=True )[2][Q:].T for A, a in flats ]
# (2) Define the cost function (here using autograd.numpy)
def cost(X):
p,B = X[:2]
zys = X[2:]
sum = 0.
for i in range(N):
A, a = flats[i]
nullA = nullspaces[i]
z = zys[i]
y = zys[i+N]
diff = p + np.dot( B, z ) - ( a + np.dot( nullA, y ) )
e = np.dot( diff, diff )
sum += e
return sum
problem = Problem(manifold=manifold, cost=cost)
# (3) Instantiate a Pymanopt solver
solver_args = {}
# solver = SteepestDescent()
solver = ConjugateGradient( maxiter = 10000 )
solver = TrustRegions( maxtime = 2000 )
## Delta_bar = 100 made a huge difference (running without it printed a suggestion to do it).
solver_args = { 'Delta_bar': 1000. }
# let Pymanopt do the rest
Xopt = solver.solve(problem, **solver_args)
# print(Xopt)
print( "Final cost:", cost( Xopt ) )
# Is zero in the solution flat?
p, B = Xopt[:2]
print( 'p:' )
print( p )
print( 'B:' )
print( B )
import flat_metrics
p_closest_to_origin = flat_metrics.canonical_point( p, B )
dist_to_origin = np.linalg.norm( p_closest_to_origin )
print( "Distance to the flat from the origin:", dist_to_origin )