-
Hello, More in detail I'm actually using a multi_transform optimizer, where each group of parameters is "active" (i.e., the gradient is computed with respect of those parameters and is not None) only at specific steps. Currently I have created my own multi_transform to deal with such cases. Is it a good implementation? Is it possible to improve or use something else already existing in the library? Here's the code: # I've edited only the update_fn of multi_transform
def update_fn(updates, state, params=None):
labels = param_labels(updates) if callable(param_labels) else param_labels
new_inner_state = {}
for group, tx in transforms.items():
group_mask = make_mask(labels, group)
updates_mask, _ = jtu.tree_flatten(jtu.tree_map(lambda m, v: m == True and v is not None, group_mask, updates))
if np.any(updates_mask):
assert updates_mask == jtu.tree_flatten(group_mask)[0], "[TODO] If updates_mask as any True value, it means that the whole group should have valid gradients. If you see this error message it means you are voiding only some of the gradients of a group. This should never happen in normal circumstances. Please report this."
masked_tx = optax.masked(tx, group_mask)
updates, new_inner_state[group] = masked_tx.update(
updates, state.inner_states[group], params)
else:
updates, new_inner_state[group] = updates, state.inner_states[group]
return updates, optax.MultiTransformState(new_inner_state) Thank you for any advice |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
Did you try wrapping your optimiser with Then |
Beta Was this translation helpful? Give feedback.
-
I see, one potential issue is that I don't think your proposed solution with work with jitting due to the if condition not being based on static values |
Beta Was this translation helpful? Give feedback.
I see, one potential issue is that I don't think your proposed solution with work with jitting due to the if condition not being based on static values