diff --git a/docs/2.0/_images/RReLU.png b/docs/2.0/_images/RReLU.png index 75fa1b850701..04f0001d9bc7 100644 Binary files a/docs/2.0/_images/RReLU.png and b/docs/2.0/_images/RReLU.png differ diff --git a/docs/2.0/_modules/torch/cuda/amp/autocast_mode.html b/docs/2.0/_modules/torch/cuda/amp/autocast_mode.html index c233fe991454..dd4a059ce9c2 100644 --- a/docs/2.0/_modules/torch/cuda/amp/autocast_mode.html +++ b/docs/2.0/_modules/torch/cuda/amp/autocast_mode.html @@ -504,7 +504,7 @@

Source code for torch.cuda.amp.autocast_mode

if isinstance(value, torch.Tensor):
         is_eligible = (value.is_floating_point() and value.is_cuda and (value.dtype is not torch.float64))
         return value.to(dtype) if is_eligible else value
-    elif isinstance(value, str):
+    elif isinstance(value, (str, bytes)):
         return value
     elif HAS_NUMPY and isinstance(value, np.ndarray):
         return value
diff --git a/docs/2.0/_modules/torch/distributed/distributed_c10d.html b/docs/2.0/_modules/torch/distributed/distributed_c10d.html
index 5e0538df744b..7694541ae08a 100644
--- a/docs/2.0/_modules/torch/distributed/distributed_c10d.html
+++ b/docs/2.0/_modules/torch/distributed/distributed_c10d.html
@@ -683,7 +683,6 @@ 

Source code for torch.distributed.distributed_c10d

def __init__(self, backend: Union[str, Backend]): self.device_backend_map: Dict[torch.device, Backend] = {} - # error check to make sure the config string is valid # Cases for when backend is a single string (without device types) if backend == Backend.UNDEFINED: @@ -700,13 +699,24 @@

Source code for torch.distributed.distributed_c10d

"cuda": backend_val, } else: - # custom backend string in format of "{device_type1}:{backend1},{device_type2}:{backend2}" - # TODO - pass - - required_devices = ["cpu", "cuda"] - for device in required_devices: - assert device in self.device_backend_map + # make sure the backend string is in the correct format + # "{device_type1}:{backend1},{device_type2}:{backend2}" + # e.g. "cpu:gloo,cuda:nccl" + backend_str_error_message = f"""The custom backend string argument is invalid: {backend}. + Custom backend string is an experimental feature where the backend string must be in the format: + "<device_type1>:<backend1>,<device_type2>:<backend2>...". e.g. 'cpu:gloo,cuda:nccl'""" + + # parse the backend string and populate the device_backend_map + for device_backend_pair_str in backend.lower().split(","): + device_backend_pair = device_backend_pair_str.split(":") + if len(device_backend_pair) != 2: + raise ValueError(f"Invalid device:backend pairing: \ + {device_backend_pair_str}. {backend_str_error_message}") + device, backend = device_backend_pair + if device in self.device_backend_map: + raise ValueError(f"Duplicate device type {device} \ + in backend string: {backend}. {backend_str_error_message}") + self.device_backend_map[device] = Backend(backend) def __repr__(self): # string with all the device:backend pairs separared by commas @@ -1293,7 +1303,9 @@

Source code for torch.distributed.distributed_c10d

.. note:: Support for multiple backends is experimental. Currently when no backend is specified, both ``gloo`` and ``nccl`` backends will be created. The ``gloo`` backend will be used for collectives with CPU tensors and the ``nccl`` backend will be used - for collectives with CUDA tensors. + for collectives with CUDA tensors. A custom backend can be specified by passing in + a string with format "<device_type>:<backend_name>,<device_type>:<backend_name>", e.g. + "cpu:gloo,cuda:custom_backend". """ global _world @@ -1444,6 +1456,9 @@

Source code for torch.distributed.distributed_c10d

backend_type = ProcessGroup.BackendType.MPI if not backend_class: return GroupMember.NON_GROUP_MEMBER + # create new process group with accurate rank and size + if pg.rank() == -1 and pg.size() == -1: + pg = ProcessGroup(backend_prefix_store, backend_class.rank(), backend_class.size(), base_pg_options) elif backend_str == Backend.GLOO: # TODO: remove this check after lazy initialization is supported # if pg_options is not None: @@ -1527,15 +1542,15 @@

Source code for torch.distributed.distributed_c10d

timeout=timeout, ) - # only create single backend pg when backend is set to gloo, nccl, mpi, and ucc - if backend in [Backend.GLOO, Backend.NCCL, Backend.UCC, Backend.MPI]: + # register only a single backend when all get_device_backend_map values are the same + if len(set(backend_config.get_device_backend_map().values())) == 1: for device in backend_config.get_device_backend_map().keys(): pg._register_backend(torch.device(device), backend_type, backend_class) # break out of outer loop to not create any more backends break - else: - pg._register_backend(torch.device(device), backend_type, backend_class) + + pg._register_backend(torch.device(device), backend_type, backend_class) # update global state _world.pg_map[pg] = (backend, prefix_store) diff --git a/docs/2.0/_modules/torch/nn/functional.html b/docs/2.0/_modules/torch/nn/functional.html index 0fdcfaeb1ba0..1737bcf7a48c 100644 --- a/docs/2.0/_modules/torch/nn/functional.html +++ b/docs/2.0/_modules/torch/nn/functional.html @@ -5513,11 +5513,21 @@

Source code for torch.nn.functional

             be ignored by the attention. This is an binary mask. When the value is True,
             the corresponding value on the attention layer will be filled with -inf.
         need_weights: output attn_output_weights.
+            Default: `True`
+            Note: `needs_weight` defaults to `True`, but should be set to `False`
+            For best performance when attention weights are not nedeeded.
+            *Setting needs_weights to `True`
+            leads to a significant performance degradation.*
         attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
             the batches while a 3D mask allows to specify a different mask for the entries of each batch.
         is_causal: If specified, applies a causal mask as attention mask, and ignores
             attn_mask for computing scaled dot product attention.
             Default: ``False``.
+            .. warning::
+                is_causal is provides a hint that the attn_mask is the
+                causal mask.Providing incorrect hints can result in
+                incorrect execution, including forward and backward
+                compatibility.
         use_separate_proj_weight: the function accept the proj. weights for query, key,
             and value in different forms. If false, in_proj_weight will be used, which is
             a combination of q_proj_weight, k_proj_weight, v_proj_weight.
@@ -5617,8 +5627,33 @@ 

Source code for torch.nn.functional

         target_type=query.dtype
     )
 
-    if is_causal:
+    if is_causal and attn_mask is None:
+        raise RuntimeError(
+            "Need attn_mask if specifying the is_causal hint. "
+            "You may use the Transformer module method "
+            "`generate_square_subsequent_mask` to create this mask."
+        )
+
+    if is_causal and key_padding_mask is None and not need_weights:
+        # when we have a kpm or need weights, we need attn_mask
+        # Otherwise, we use the is_causal hint go as is_causal
+        # indicator to SDPA.
         attn_mask = None
+    else:
+        attn_mask = _canonical_mask(
+            mask=attn_mask,
+            mask_name="attn_mask",
+            other_type=None,
+            other_name="",
+            target_type=query.dtype,
+            check_other=False,
+        )
+
+        if key_padding_mask is not None:
+            # We have the attn_mask, and use that to merge kpm into it.
+            # Turn off use of is_causal hint, as the merged mask is no
+            # longer causal.
+            is_causal = False
 
     assert embed_dim == embed_dim_to_check, \
         f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
@@ -5748,6 +5783,9 @@ 

Source code for torch.nn.functional

     if need_weights:
         B, Nt, E = q.shape
         q_scaled = q / math.sqrt(E)
+
+        assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
+
         if attn_mask is not None:
             attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
         else:
diff --git a/docs/2.0/_modules/torch/nn/modules/activation.html b/docs/2.0/_modules/torch/nn/modules/activation.html
index dac76c8328b2..d152507f86df 100644
--- a/docs/2.0/_modules/torch/nn/modules/activation.html
+++ b/docs/2.0/_modules/torch/nn/modules/activation.html
@@ -1353,9 +1353,8 @@ 

Source code for torch.nn.modules.activation

 
     where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
 
-    ``forward()`` will use the optimized implementation described in
-    `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following
-    conditions are met:
+    ``nn.MultiHeadAttention`` will use the optimized implementations of
+    ``scaled_dot_product_attention()`` when possible.
 
     - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
       restriction will be loosened in the future.)
@@ -1510,8 +1509,13 @@ 

Source code for torch.nn.modules.activation

             corresponding position is not allowed to attend. For a float mask, the mask values will be added to
             the attention weight.
             If both attn_mask and key_padding_mask are supplied, their types should match.
-        is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask.
+        is_causal: If specified, applies a causal mask as attention mask.
             Default: ``False``.
+            Warning:
+            ``is_causal`` provides a hint that ``attn_mask`` is the
+            causal mask. Providing incorrect hints can result in
+            incorrect execution, including forward and backward
+            compatibility.
         average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
             heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
             effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
@@ -1530,8 +1534,6 @@ 

Source code for torch.nn.modules.activation

         .. note::
             `batch_first` argument is ignored for unbatched inputs.
         """
-        if attn_mask is not None and is_causal:
-            raise AssertionError("Only allow causal mask or attn_mask")
 
         is_batched = query.dim() == 3
 
@@ -1684,20 +1686,27 @@ 

Source code for torch.nn.modules.activation

             check_other=False,
         )
 
-        if attn_mask is not None:
-            mask_type = 0
-            merged_mask = attn_mask
         if key_padding_mask is not None:
             mask_type = 1
             merged_mask = key_padding_mask
-        if (attn_mask is not None) and (key_padding_mask is not None):
+
+        if attn_mask is not None:
             # In this branch query can't be a nested tensor, so it has a shape
             batch_size, seq_len, _ = query.shape
             mask_type = 2
-            key_padding_mask_expanded = key_padding_mask.view(batch_size, 1, 1, seq_len) \
-                                                        .expand(-1, self.num_heads, -1, -1)
-            attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(batch_size, self.num_heads, -1, -1)
-            merged_mask = attn_mask_expanded + key_padding_mask_expanded
+
+            # Always expands attn_mask to 4D
+            if attn_mask.dim() == 3:
+                attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len)
+            else:  # attn_mask.dim() == 2:
+                attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(batch_size, self.num_heads, -1, -1)
+            merged_mask = attn_mask_expanded
+
+            if key_padding_mask is not None:
+                key_padding_mask_expanded = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(-1, self.num_heads, -1, -1)
+                merged_mask = attn_mask_expanded + key_padding_mask_expanded
+
+        # no attn_mask and no key_padding_mask, returns None, None
         return merged_mask, mask_type
diff --git a/docs/2.0/_modules/torch/nn/modules/transformer.html b/docs/2.0/_modules/torch/nn/modules/transformer.html index 4ea7205f7fff..02cb886a2e77 100644 --- a/docs/2.0/_modules/torch/nn/modules/transformer.html +++ b/docs/2.0/_modules/torch/nn/modules/transformer.html @@ -1022,21 +1022,21 @@

Source code for torch.nn.modules.transformer

x = src
         if self.norm_first:
-            x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
+            x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal)
             x = x + self._ff_block(self.norm2(x))
         else:
-            x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
+            x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal))
             x = self.norm2(x + self._ff_block(x))
 
         return x
# self-attention block def _sa_block(self, x: Tensor, - attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor: + attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor: x = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, - need_weights=False)[0] + need_weights=False, is_causal=is_causal)[0] return self.dropout1(x) # feed forward block diff --git a/docs/2.0/_modules/torch/optim/adam.html b/docs/2.0/_modules/torch/optim/adam.html index d9c8680f64e7..885d39ff2447 100644 --- a/docs/2.0/_modules/torch/optim/adam.html +++ b/docs/2.0/_modules/torch/optim/adam.html @@ -984,7 +984,6 @@

Source code for torch.optim.adam

     capturable: bool,  # Needed for consistency.
     differentiable: bool,
 ) -> None:
-    grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
     grad_scale_dict = {grad_scale.device: grad_scale} if grad_scale is not None else None
     found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None
     grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
@@ -997,16 +996,15 @@ 

Source code for torch.optim.adam

             device_max_exp_avg_sqs,
             device_state_steps,
         ) = grouped_tensors[(device, dtype)]
-        if grad_scale is not None and found_inf is not None:
+        device_grad_scale, device_found_inf = None, None
+        if grad_scale is not None:
             if device not in grad_scale_dict:
                 grad_scale_dict[device] = grad_scale.to(device, non_blocking=True)
+            device_grad_scale = grad_scale_dict[device]
+        if found_inf is not None:
             if found_inf not in found_inf_dict:
                 found_inf_dict[device] = found_inf.to(device, non_blocking=True)
-            device_grad_scale = grad_scale_dict[device]
             device_found_inf = found_inf_dict[device]
-        else:
-            device_grad_scale = None
-            device_found_inf = None
         torch._foreach_add_(device_state_steps, 1)
         torch._fused_adam_(
             device_params,
diff --git a/docs/2.0/_modules/torch/optim/adamw.html b/docs/2.0/_modules/torch/optim/adamw.html
index 204bf9a06a32..5506ca655aa2 100644
--- a/docs/2.0/_modules/torch/optim/adamw.html
+++ b/docs/2.0/_modules/torch/optim/adamw.html
@@ -1057,16 +1057,15 @@ 

Source code for torch.optim.adamw

             device_max_exp_avg_sqs,
             device_state_steps,
         ) = grouped_tensors[(device, dtype)]
-        if grad_scale is not None and found_inf is not None:
+        device_grad_scale, device_found_inf = None, None
+        if grad_scale is not None:
             if device not in grad_scale_dict:
                 grad_scale_dict[device] = grad_scale.to(device, non_blocking=True)
+            device_grad_scale = grad_scale_dict[device]
+        if found_inf is not None:
             if found_inf not in found_inf_dict:
                 found_inf_dict[device] = found_inf.to(device, non_blocking=True)
-            device_grad_scale = grad_scale_dict[device]
             device_found_inf = found_inf_dict[device]
-        else:
-            device_grad_scale = None
-            device_found_inf = None
         torch._foreach_add_(device_state_steps, 1)
         torch._fused_adamw_(
             device_params,
diff --git a/docs/2.0/_modules/torch/serialization.html b/docs/2.0/_modules/torch/serialization.html
index 828eefb3877e..37b75a1bf0a2 100644
--- a/docs/2.0/_modules/torch/serialization.html
+++ b/docs/2.0/_modules/torch/serialization.html
@@ -1533,7 +1533,7 @@ 

Source code for torch.serialization

         def restore_location(storage, location):
             location = map_location.get(location, location)
             return default_restore_location(storage, location)
-    elif isinstance(map_location, str):
+    elif isinstance(map_location, (str, bytes)):
         def restore_location(storage, location):
             return default_restore_location(storage, map_location)
     elif isinstance(map_location, torch.device):
diff --git a/docs/2.0/_modules/torch/storage.html b/docs/2.0/_modules/torch/storage.html
index 808613bfc3fe..9f1ee809d93e 100644
--- a/docs/2.0/_modules/torch/storage.html
+++ b/docs/2.0/_modules/torch/storage.html
@@ -761,14 +761,37 @@ 

Source code for torch.storage

     else:
         return isinstance(x, int)
 
+_always_warn_typed_storage_removal = False
+
+def _get_always_warn_typed_storage_removal():
+    return _always_warn_typed_storage_removal
+
+def _set_always_warn_typed_storage_removal(always_warn):
+    global _always_warn_typed_storage_removal
+    assert isinstance(always_warn, bool)
+    _always_warn_typed_storage_removal = always_warn
+
 def _warn_typed_storage_removal(stacklevel=2):
-    message = (
-        "TypedStorage is deprecated. It will be removed in the future and "
-        "UntypedStorage will be the only storage class. This should only matter "
-        "to you if you are using storages directly.  To access UntypedStorage "
-        "directly, use tensor.untyped_storage() instead of tensor.storage()"
-    )
-    warnings.warn(message, UserWarning, stacklevel=stacklevel + 1)
+    global _always_warn_typed_storage_removal
+
+    def is_first_time():
+        if not hasattr(_warn_typed_storage_removal, 'has_warned'):
+            return True
+        else:
+            return not _warn_typed_storage_removal.__dict__['has_warned']
+
+    if _get_always_warn_typed_storage_removal() or is_first_time():
+        message = (
+            "TypedStorage is deprecated. It will be removed in the future and "
+            "UntypedStorage will be the only storage class. This should only matter "
+            "to you if you are using storages directly.  To access UntypedStorage "
+            "directly, use tensor.untyped_storage() instead of tensor.storage()"
+        )
+        warnings.warn(message, UserWarning, stacklevel=stacklevel + 1)
+        _warn_typed_storage_removal.__dict__['has_warned'] = True
+
+def _reset_warn_typed_storage_removal():
+    _warn_typed_storage_removal.__dict__['has_warned'] = False
 
 
[docs]class TypedStorage: is_sparse = False diff --git a/docs/2.0/_modules/torch/utils/data/_utils/collate.html b/docs/2.0/_modules/torch/utils/data/_utils/collate.html index 87f898a5af4e..e3907f7d8d88 100644 --- a/docs/2.0/_modules/torch/utils/data/_utils/collate.html +++ b/docs/2.0/_modules/torch/utils/data/_utils/collate.html @@ -524,7 +524,7 @@

Source code for torch.utils.data._utils.collate

< return elem_type(*(default_convert(d) for d in data)) elif isinstance(data, tuple): return [default_convert(d) for d in data] # Backwards compatibility. - elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str): + elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): try: return elem_type([default_convert(d) for d in data]) except TypeError: @@ -653,6 +653,7 @@

Source code for torch.utils.data._utils.collate

< default_collate_fn_map[float] = collate_float_fn default_collate_fn_map[int] = collate_int_fn default_collate_fn_map[str] = collate_str_fn +default_collate_fn_map[bytes] = collate_str_fn
[docs]def default_collate(batch): diff --git a/docs/2.0/_modules/torch/utils/data/dataloader.html b/docs/2.0/_modules/torch/utils/data/dataloader.html index 8faf4e767bf3..a9ec570aff46 100644 --- a/docs/2.0/_modules/torch/utils/data/dataloader.html +++ b/docs/2.0/_modules/torch/utils/data/dataloader.html @@ -489,7 +489,6 @@

Source code for torch.utils.data.dataloader

     Dataset,)
 
 from torch.utils.data.datapipes.datapipe import _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper
-from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
 
 from . import _utils
 
@@ -559,6 +558,7 @@ 

Source code for torch.utils.data.dataloader

 
 
 def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id):
+    global_worker_id = worker_id
     info = torch.utils.data.get_worker_info()
     assert info is not None
     total_workers = info.num_workers
@@ -566,10 +566,10 @@ 

Source code for torch.utils.data.dataloader

     assert isinstance(datapipe, (IterDataPipe, MapDataPipe))
     # To distribute elements across distributed process evenly, we should shard data on distributed
     # processes first then shard on worker processes
-    torch.utils.data.graph_settings.apply_sharding(
-        datapipe, world_size, rank_id, sharding_group=SHARDING_PRIORITIES.DISTRIBUTED)
-    torch.utils.data.graph_settings.apply_sharding(
-        datapipe, total_workers, worker_id, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING)
+    total_workers *= world_size
+    global_worker_id = global_worker_id * world_size + rank_id
+    # For BC, use default SHARDING_PRIORITIES
+    torch.utils.data.graph_settings.apply_sharding(datapipe, total_workers, global_worker_id)
     if worker_init_fn is not None:
         worker_init_fn(worker_id)
 
@@ -857,8 +857,7 @@ 

Source code for torch.utils.data.dataloader

                             ('multiprocessing_context option '
                              'should specify a valid start method in {!r}, but got '
                              'multiprocessing_context={!r}').format(valid_start_methods, multiprocessing_context))
-                    # error: Argument 1 to "get_context" has incompatible type "Union[str, bytes]"; expected "str"  [arg-type]
-                    multiprocessing_context = multiprocessing.get_context(multiprocessing_context)  # type: ignore[arg-type]
+                    multiprocessing_context = multiprocessing.get_context(multiprocessing_context)
 
                 if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext):
                     raise TypeError(('multiprocessing_context option should be a valid context '
@@ -1122,8 +1121,8 @@ 

Source code for torch.utils.data.dataloader

         # Adds forward compatibilities so classic DataLoader can work with DataPipes:
         #   Taking care of distributed sharding
         if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
-            torch.utils.data.graph_settings.apply_sharding(
-                self._dataset, self._world_size, self._rank, sharding_group=SHARDING_PRIORITIES.DISTRIBUTED)
+            # For BC, use default SHARDING_PRIORITIES
+            torch.utils.data.graph_settings.apply_sharding(self._dataset, self._world_size, self._rank)
 
         self._dataset_fetcher = _DatasetKind.create_fetcher(
             self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
diff --git a/docs/2.0/_sources/notes/extending.func.rst.txt b/docs/2.0/_sources/notes/extending.func.rst.txt
index b3a044f8c2f4..21603ba1becf 100644
--- a/docs/2.0/_sources/notes/extending.func.rst.txt
+++ b/docs/2.0/_sources/notes/extending.func.rst.txt
@@ -37,8 +37,13 @@ Only the latter is supported with function transforms:
   (by calling ``ctx.save_for_backward(*tensors)``), or save non-Tensors
   (by assigning them to the ``ctx`` object).
 
-Any intermediates that need to be saved must be returned as an output from
-:meth:`~Function.forward`.
+Because :meth:`~Function.setup_context` accepts only ``inputs`` and ``output``,
+the only quantities that can be saved are either objects (such as Tensors) in
+the inputs or outputs or quantities (like ``Tensor.shape``) derived from them.
+If you wish to save a non-input intermediate activation from
+:meth:`Function.forward` for backward, then you'll need to return it as an
+output from :meth:`~Function.forward` so that it gets passed to
+:meth:`~Function.setup_context`.
 
 Depending on the transform,
 
diff --git a/docs/2.0/distributed.html b/docs/2.0/distributed.html
index 52d99629322a..76e14e919f3e 100644
--- a/docs/2.0/distributed.html
+++ b/docs/2.0/distributed.html
@@ -800,7 +800,9 @@ 

Initializationgloo and nccl backends will be created. The gloo backend will be used for collectives with CPU tensors and the nccl backend will be used -for collectives with CUDA tensors.

+for collectives with CUDA tensors. A custom backend can be specified by passing in +a string with format “<device_type>:<backend_name>,<device_type>:<backend_name>”, e.g. +“cpu:gloo,cuda:custom_backend”.

diff --git a/docs/2.0/generated/torch.nn.MultiheadAttention.html b/docs/2.0/generated/torch.nn.MultiheadAttention.html index c398bca93d78..33e9c36bad23 100644 --- a/docs/2.0/generated/torch.nn.MultiheadAttention.html +++ b/docs/2.0/generated/torch.nn.MultiheadAttention.html @@ -469,9 +469,8 @@

MultiheadAttentionMultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O

where headi=Attention(QWiQ,KWiK,VWiV)head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V).

-

forward() will use the optimized implementation described in -FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness if all of the following -conditions are met:

+

nn.MultiHeadAttention will use the optimized implementations of +scaled_dot_product_attention() when possible.

  • self attention is being computed (i.e., query, key, and value are the same tensor. This restriction will be loosened in the future.)

  • @@ -549,8 +548,13 @@

    MultiheadAttentionbool) – If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask. -Default: False.

    +
  • is_causal (bool) – If specified, applies a causal mask as attention mask. +Default: False. +Warning: +is_causal provides a hint that attn_mask is the +causal mask. Providing incorrect hints can result in +incorrect execution, including forward and backward +compatibility.

  • average_attn_weights (bool) – If true, indicates that the returned attn_weights should be averaged across heads. Otherwise, attn_weights are provided separately per head. Note that this flag only has an effect when need_weights=True. Default: True (i.e. average weights across heads)

  • diff --git a/docs/2.0/notes/extending.func.html b/docs/2.0/notes/extending.func.html index 2aa252938bed..2d7fca29adb3 100644 --- a/docs/2.0/notes/extending.func.html +++ b/docs/2.0/notes/extending.func.html @@ -483,8 +483,13 @@

    Basic Usagectx.save_for_backward(*tensors)), or save non-Tensors (by assigning them to the ctx object).

-

Any intermediates that need to be saved must be returned as an output from -forward().

+

Because setup_context() accepts only inputs and output, +the only quantities that can be saved are either objects (such as Tensors) in +the inputs or outputs or quantities (like Tensor.shape) derived from them. +If you wish to save a non-input intermediate activation from +Function.forward() for backward, then you’ll need to return it as an +output from forward() so that it gets passed to +setup_context().

Depending on the transform,

  • to support reverse-mode AD (torch.func.grad(), torch.func.vjp()), diff --git a/docs/2.0/quantization-backend-configuration.html b/docs/2.0/quantization-backend-configuration.html index 26a7d09ff41f..4265b51b7c55 100644 --- a/docs/2.0/quantization-backend-configuration.html +++ b/docs/2.0/quantization-backend-configuration.html @@ -479,7 +479,7 @@

    Default values for native configurationsOperator Tags class torch.Tag

    Members:

    -

    nondeterministic_bitwise

    -

    pointwise

    view_copy

    -

    nondeterministic_seeded

    dynamic_output_shape

    +

    nondeterministic_bitwise

    core

    -

    inplace_view

    -

    generated

    data_dependent_output

    +

    generated

    +

    pointwise

    +

    nondeterministic_seeded

    +

    inplace_view

    property name