diff --git a/docs/master/_dynamo.html b/docs/master/_dynamo.html index 32a3404f115d..5ae4189aca3e 100644 --- a/docs/master/_dynamo.html +++ b/docs/master/_dynamo.html @@ -237,7 +237,7 @@
if len(state) == 0:
state['step'] = (
torch.zeros((1,), dtype=torch.float, device=p.device)
- if self.defaults['capturable'] or self.defaults['fused']
+ if group['capturable'] or group['fused']
else torch.tensor(0.)
)
# Exponential moving average of gradient values
@@ -579,8 +579,6 @@ Source code for torch.optim.adam
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
- grad_scaler (:class:`torch.cuda.amp.GradScaler`, optional): A GradScaler which is
- supplied from ``grad_scaler.step(optimizer)``.
"""
self._cuda_graph_capture_health_check()
@@ -607,25 +605,27 @@ Source code for torch.optim.adam
max_exp_avg_sqs,
state_steps)
- adam(params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- max_exp_avg_sqs,
- state_steps,
- amsgrad=group['amsgrad'],
- beta1=beta1,
- beta2=beta2,
- lr=group['lr'],
- weight_decay=group['weight_decay'],
- eps=group['eps'],
- maximize=group['maximize'],
- foreach=group['foreach'],
- capturable=group['capturable'],
- differentiable=group['differentiable'],
- fused=group['fused'],
- grad_scale=getattr(self, "grad_scale", None),
- found_inf=getattr(self, "found_inf", None))
+ adam(
+ params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ max_exp_avg_sqs,
+ state_steps,
+ amsgrad=group['amsgrad'],
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ maximize=group['maximize'],
+ foreach=group['foreach'],
+ capturable=group['capturable'],
+ differentiable=group['differentiable'],
+ fused=group['fused'],
+ grad_scale=getattr(self, "grad_scale", None),
+ found_inf=getattr(self, "found_inf", None),
+ )
return loss
diff --git a/docs/master/_modules/torch/optim/adamax.html b/docs/master/_modules/torch/optim/adamax.html
index d76f09455ac3..523cdc55981f 100644
--- a/docs/master/_modules/torch/optim/adamax.html
+++ b/docs/master/_modules/torch/optim/adamax.html
@@ -235,7 +235,7 @@
diff --git a/docs/master/_modules/torch/optim/adamw.html b/docs/master/_modules/torch/optim/adamw.html
index aea2d816b2ba..ea1e4d317d84 100644
--- a/docs/master/_modules/torch/optim/adamw.html
+++ b/docs/master/_modules/torch/optim/adamw.html
@@ -235,7 +235,7 @@
@@ -574,7 +574,7 @@ Source code for torch.optim.adamw
if len(state) == 0:
state["step"] = (
torch.zeros((1,), dtype=torch.float, device=p.device)
- if self.defaults["capturable"] or self.defaults["fused"]
+ if group["capturable"] or group["fused"]
else torch.tensor(0.0)
)
# Exponential moving average of gradient values
diff --git a/docs/master/_modules/torch/optim/asgd.html b/docs/master/_modules/torch/optim/asgd.html
index a742f3848f6c..fe398a82473f 100644
--- a/docs/master/_modules/torch/optim/asgd.html
+++ b/docs/master/_modules/torch/optim/asgd.html
@@ -235,7 +235,7 @@
diff --git a/docs/master/_modules/torch/optim/lbfgs.html b/docs/master/_modules/torch/optim/lbfgs.html
index fbc926bf2caf..ea8f10d9c189 100644
--- a/docs/master/_modules/torch/optim/lbfgs.html
+++ b/docs/master/_modules/torch/optim/lbfgs.html
@@ -235,7 +235,7 @@
diff --git a/docs/master/_modules/torch/optim/lr_scheduler.html b/docs/master/_modules/torch/optim/lr_scheduler.html
index f87dae8c13d9..18c01fe2af4e 100644
--- a/docs/master/_modules/torch/optim/lr_scheduler.html
+++ b/docs/master/_modules/torch/optim/lr_scheduler.html
@@ -235,7 +235,7 @@
diff --git a/docs/master/_modules/torch/optim/nadam.html b/docs/master/_modules/torch/optim/nadam.html
index bd25a94c12da..e1f58ee486f6 100644
--- a/docs/master/_modules/torch/optim/nadam.html
+++ b/docs/master/_modules/torch/optim/nadam.html
@@ -235,7 +235,7 @@
diff --git a/docs/master/_modules/torch/optim/optimizer.html b/docs/master/_modules/torch/optim/optimizer.html
index 64d491a30222..98a11f9d7997 100644
--- a/docs/master/_modules/torch/optim/optimizer.html
+++ b/docs/master/_modules/torch/optim/optimizer.html
@@ -235,7 +235,7 @@
@@ -681,19 +681,21 @@ Source code for torch.optim.optimizer
if torch.has_cuda and torch.cuda.is_available():
capturing = torch.cuda.is_current_stream_capturing()
- if capturing and not self.defaults['capturable']:
+ if capturing and not all(group['capturable'] for group in self.param_groups):
raise RuntimeError("Attempting CUDA graph capture of step() for an instance of " +
self.__class__.__name__ +
- " but this instance was constructed with capturable=False.")
+ " but param_groups' capturable is False.")
if (
(not getattr(self, "_warned_capturable_if_run_uncaptured", False))
- and self.defaults["capturable"]
+ and all(group['capturable'] for group in self.param_groups)
and (not capturing)
):
- print("Warning: This instance was constructed with capturable=True, but step() " +
- "is running without CUDA graph capture. If you never intend to graph-capture this " +
- "instance, capturable=True can impair performance, and you should set capturable=False.")
+ warnings.warn(
+ "This instance was constructed with capturable=True or some of all the param_groups came with capturable=True, "
+ "but step() is running without CUDA graph capture. If you never intend to graph-capture this "
+ "instance, capturable=True can impair performance, and you should set capturable=False."
+ )
self._warned_capturable_if_run_uncaptured = True
def _optimizer_step_code(self):
diff --git a/docs/master/_modules/torch/optim/radam.html b/docs/master/_modules/torch/optim/radam.html
index 7fede1eae0d0..acb3b47d8bdc 100644
--- a/docs/master/_modules/torch/optim/radam.html
+++ b/docs/master/_modules/torch/optim/radam.html
@@ -235,7 +235,7 @@
diff --git a/docs/master/_modules/torch/optim/rmsprop.html b/docs/master/_modules/torch/optim/rmsprop.html
index d55d287fb034..2ef92ea11296 100644
--- a/docs/master/_modules/torch/optim/rmsprop.html
+++ b/docs/master/_modules/torch/optim/rmsprop.html
@@ -235,7 +235,7 @@
-