-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Perform gradient clipping on global batch when using gradient accumulation #9
base: main
Are you sure you want to change the base?
Conversation
…n using ShardedStaticAccumulator
@ashors1 sorry for the late review, could rebase to head? i want to import it and run some internal CI, thanks! |
There's quite a few redundant whitespaces. Could you run some python linter to remove those? |
if optimizer_name is None: | ||
optimizer_name = '' | ||
else: | ||
optimizer_name = optimizer_name + '/' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think you are missing the following code block from the original scale_gradient?
if clip_gradient_norm_to_value is None:
clip_gradient_norm_to_value = p.optimizer.clip_gradient_norm_to_value
if clip_gradient_single_norm_to_value is None:
clip_gradient_single_norm_to_value = (
p.optimizer.clip_gradient_single_norm_to_value
)
paxml/learners.py
Outdated
else: | ||
optimizer_name = optimizer_name + '/' | ||
self.get_individual_grad_norms(raw_grads, | ||
optimizer_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's not line break here, optimizer_name
can be on previous line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually can we move get_individual_grad_norms back inline? it's not used anywhere else, and it seems more consistent with the inlined global grad norm below
paxml/learners.py
Outdated
if p.check_valid_step: | ||
# Mark the step as invalid if any gradient anomaly is detected (e.g. Nan | ||
# or Inf, or excessively big gradient norm). | ||
valid_step = self.keep_step(raw_grad_norm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's move keep_step
back as a free function inside get_grad_norm_valid_step
rather than a new instance method?
the original code is a bit complicated; let's avoid refactoring too much because it might make it harder to spot whether the existing logic still holds
paxml/learners.py
Outdated
grads, valid_step = self.scale_gradients(grads) | ||
grad_norm, valid_step = self.get_grad_norm_valid_step(grads) | ||
|
||
using_ga = hasattr(p.optimizer, 'num_sub_batches') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's use using_grad_accum
most readers might not know what ga
means
paxml/learners.py
Outdated
@@ -588,8 +631,16 @@ def scale_gradients_by_optimizer( | |||
) -> Tuple[NestedMap, JTensor]: | |||
optimizer_mask, default_mask = self.get_masks(var_weight_hparams) | |||
|
|||
all_grads, all_valid_step = self.scale_gradients( | |||
jax.tree_map(lambda x, y: x * y, raw_grads, default_mask), | |||
raw_grads = jax.tree_map(lambda x, y: x * y, raw_grads, default_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's not reuse raw_grads
, let's call this grads_after_mask
because you've introduced a subtle bug here if you look at line line 659 inside the auxiliary_optimizers loop, you are now combining this outer mask with inner mask
i would not overwrite raw_grads
variable, just
grads_after_mask = jax.tree_map(lambda x, y: x * y, raw_grads, default_mask)
grad_norm, all_valid_step = self.get_grad_norm_valid_step(
grads_after_mask,
optimizer_name='main',
)
so that inside auxiliary_optimizers loop, raw_grads
is only added to each auxiliary optimizer mask
@zhangqiaorjc is there a reason this has been approved by not merged yet? |
…kage/tensorflow-2.11.1 PiperOrigin-RevId: 524892551
Refactoring to allow gradient clipping to be performed on full batch rather than subbatches when using
ShardedStaticAccumulator
. Note that this refactor allows us to maintain support forenable_skip_step_on_gradient_anomalies
and requiresx+1
grad norm calculations per global batch when usingShardedStaticAccumulator
withx
subbatches (once per subbatch to determine whether step should be skipped, once when applying gradient clipping in base optimizer update) and requires one grad clip per global batch.This PR should be taken together with the corresponding Praxis PR.