-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoriginal_weight_matching.txt
41 lines (33 loc) · 1.38 KB
/
original_weight_matching.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def weight_matching(rng,
ps: PermutationSpec,
params_a,
params_b,
max_iter=100,
init_perm=None,
silent=False):
"""Find a permutation of `params_b` to make them match `params_a`."""
perm_sizes = {p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()}
perm = {p: jnp.arange(n) for p, n in perm_sizes.items()} if init_perm is None else init_perm
perm_names = list(perm.keys())
for iteration in range(max_iter):
progress = False
for p_ix in random.permutation(rngmix(rng, iteration), len(perm_names)):
p = perm_names[p_ix]
n = perm_sizes[p]
A = jnp.zeros((n, n))
for wk, axis in ps.perm_to_axes[p]:
w_a = params_a[wk]
w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
w_a = jnp.moveaxis(w_a, axis, 0).reshape((n, -1))
w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1))
A += w_a @ w_b.T
ri, ci = linear_sum_assignment(A, maximize=True)
assert (ri == jnp.arange(len(ri))).all()
oldL = jnp.vdot(A, jnp.eye(n)[perm[p]])
newL = jnp.vdot(A, jnp.eye(n)[ci, :])
if not silent: print(f"{iteration}/{p}: {newL - oldL}")
progress = progress or newL > oldL + 1e-12
perm[p] = jnp.array(ci)
if not progress:
break
return perm