diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 0de25f5b72..8e36e23da5 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -1096,9 +1096,6 @@ def _div_prim_grad(a: Number | TensorProxy, b: Number | TensorProxy, /) -> Numbe register_grad(pids.GT, prims.gt) register_grad(pids.LE, prims.le) register_grad(pids.LT, prims.lt) -register_grad(pids.NE, prims.ne) -register_grad(pids.GT, prims.gt) -register_grad(pids.LE, prims.le) @torchctx diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index f625f041ca..e1fc0a33a7 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -7858,12 +7858,12 @@ def cross_entropy_error_generator(op, device, dtype=torch.float32, **kwargs): "Expected the input tensor to have (.*?) dimensions, but it has (.*?) dimensions.", ) - # target shape is input shape except channels dimension + # target shape is input shape except class dimension incorrect_batch_target = make((10,), low=0, high=C, dtype=torch.long, requires_grad=False) yield ( SampleInput(valid_input, incorrect_batch_target), RuntimeError, - "Expected the target tensor to have the same shape as the input tensor except for the channels dimension \ + "Expected the target tensor to have the same shape as the input tensor except for the class dimension \ (.*?), but it has shape (.*?).", ) @@ -8013,12 +8013,12 @@ def nll_loss_error_generator(op, device, dtype=torch.float32, **kwargs): "Expected the input tensor to have (.*?) dimensions, but it has (.*?) dimensions.", ) - # target shape is input shape except channels dimension + # target shape is input shape except class dimension incorrect_batch_target = make((10,), low=0, high=C, dtype=torch.long, requires_grad=False) yield ( SampleInput(valid_input, incorrect_batch_target), RuntimeError, - "Expected the target tensor to have the same shape as the input tensor except for the channels dimension \ + "Expected the target tensor to have the same shape as the input tensor except for the class dimension \ (.*?), but it has shape (.*?).", ) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 9469f58d47..c8f69f262a 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -3578,9 +3578,9 @@ def cross_entropy( _cross_entropy_input_checks(a, target, weight, ignore_index, reduction, label_smoothing) - # channels dimension is either the first one if no batch dim present (i.e. a.shape[0]), + # class dimension is either the first one if no batch dim present (i.e. a.shape[0]), # or right next to it (i.e. a.shape[1]). - channels_dim = 1 if a.ndim >= 2 else 0 + class_dim = 1 if a.ndim >= 2 else 0 # NOTE This short-circuit is subject to change and is placed ahead of other input checks to match PyTorch behavior. # The expected behavior when the target and input have zero elements: @@ -3591,7 +3591,7 @@ def cross_entropy( if a.numel() == 0: if reduction == "none": output_shape = list(a.shape) - output_shape.pop(channels_dim) + output_shape.pop(class_dim) return full(output_shape, 0.0, device=a.device, dtype=a.dtype) elif reduction == "sum": return full(result_shape := [], fill_value := 0.0, device=a.device, dtype=a.dtype) @@ -3603,7 +3603,7 @@ def cross_entropy( elif label_smoothing != 0.0: return _cross_entropy_loss_label_smoothing(a, target, weight, ignore_index, reduction, label_smoothing) else: - log_softmax_input = log_softmax(a, dim=channels_dim) + log_softmax_input = log_softmax(a, dim=class_dim) return nll_loss(log_softmax_input, target, weight, ignore_index, reduction) @@ -3632,14 +3632,14 @@ def _cross_entropy_input_checks( lambda: f"Expected label_smoothing to be in [0, 1] range but got {label_smoothing}.", ) - # channels dimension is either the first one if no batch dim present (i.e. a.shape[0]), + # class dimension is either the first one if no batch dim present (i.e. a.shape[0]), # or right next to it (i.e. a.shape[1]). - channels_dim = 1 if a.ndim >= 2 else 0 - num_channels = a.shape[channels_dim] + class_dim = 1 if a.ndim >= 2 else 0 + num_class = a.shape[class_dim] utils.check( - weight is None or (weight.ndim == 1 and weight.shape[0] == num_channels), - lambda: f"Expected a 1D tensor with {num_channels} elements for weight argument, \ + weight is None or (weight.ndim == 1 and weight.shape[0] == num_class), + lambda: f"Expected a 1D tensor with {num_class} elements for weight argument, \ but found a tensor with {weight.ndim} dimensions and {weight.shape[0]} elements.", ) @@ -3654,13 +3654,13 @@ def _cross_entropy_input_checks( lambda: f"Expected the input tensor to have {(target.ndim + 1)=} dimensions, but it has {a.ndim} dimensions.", ) - # target should match input in dims which do not correspond to the channels dim, i.e. - # (input.shape[:channels_dim] + input.shape[channels_dim + 1:]) == target.shape <=> True - expected_target_shape = a.shape[:channels_dim] + a.shape[channels_dim + 1 :] + # target should match input in dims which do not correspond to the class dim, i.e. + # (input.shape[:class_dim] + input.shape[class_dim + 1:]) == target.shape <=> True + expected_target_shape = a.shape[:class_dim] + a.shape[class_dim + 1 :] utils.check( expected_target_shape == target.shape, - lambda: f"Expected the target tensor to have the same shape as the input tensor except for the channels dimension \ + lambda: f"Expected the target tensor to have the same shape as the input tensor except for the class dimension \ {expected_target_shape}, but it has shape {target.shape}.", ) else: @@ -3685,28 +3685,28 @@ def _cross_entropy_loss_probability_target( reduction: str, label_smoothing: float, ) -> TensorLike: - # channels dimension is either the first one if no batch dim present (i.e. a.shape[0]), + # class dimension is either the first one if no batch dim present (i.e. a.shape[0]), # or right next to it (i.e. a.shape[1]). - channels_dim = 1 if a.ndim >= 2 else 0 - num_channels = a.shape[channels_dim] + class_dim = 1 if a.ndim >= 2 else 0 + num_class = a.shape[class_dim] if label_smoothing > 0.0: - target = (target * (1 - label_smoothing)) + (label_smoothing / num_channels) + target = (target * (1 - label_smoothing)) + (label_smoothing / num_class) - out = log_softmax(a, dim=channels_dim) * target + out = log_softmax(a, dim=class_dim) * target if weight is not None: - bcast_weight = reshape(weight, [num_channels] + [1 for _ in range(2, a.ndim)]) + bcast_weight = reshape(weight, [num_class] + [1 for _ in range(2, a.ndim)]) out = out * bcast_weight out = -out if reduction == "none": - return sum(out, dim=channels_dim) + return sum(out, dim=class_dim) elif reduction == "sum": return sum(out) elif reduction == "mean": - return sum(out) / (a.numel() // num_channels) + return sum(out) / (a.numel() // num_class) def _cross_entropy_loss_label_smoothing( @@ -3718,20 +3718,20 @@ def _cross_entropy_loss_label_smoothing( reduction: str, label_smoothing: int, ) -> TensorLike: - # channels dimension is either the first one if no batch dim present (i.e. a.shape[0]), + # class dimension is either the first one if no batch dim present (i.e. a.shape[0]), # or right next to it (i.e. a.shape[1]). - channels_dim = 1 if a.ndim >= 2 else 0 - num_channels = a.shape[channels_dim] + class_dim = 1 if a.ndim >= 2 else 0 + num_class = a.shape[class_dim] - log_softmax_value = log_softmax(a, dim=channels_dim) + log_softmax_value = log_softmax(a, dim=class_dim) if weight is not None: - bcast_weight = reshape(weight, [num_channels] + [1 for _ in range(2, len(a.shape))]) + bcast_weight = reshape(weight, [num_class] + [1 for _ in range(2, len(a.shape))]) out = -(log_softmax_value * bcast_weight) else: out = -log_softmax_value - smooth_loss = sum(out, dim=channels_dim) + smooth_loss = sum(out, dim=class_dim) # Make target broadcastable with output, which has same shape as input tensor. selected_target_mask = target != ignore_index @@ -3749,8 +3749,8 @@ def _cross_entropy_loss_label_smoothing( # Sum together all target weights. # Make target broadcastable with output, which has same shape as input tensor. expanded_weight = expand(bcast_weight, a.shape) - bcast_target = unsqueeze(target, channels_dim) - selected_weight = take_along_dim(expanded_weight, bcast_target, channels_dim) + bcast_target = unsqueeze(target, class_dim) + selected_weight = take_along_dim(expanded_weight, bcast_target, class_dim) selected_weight = where(selected_target_mask, squeeze(selected_weight), 0) ret = reduced_sum / sum(selected_weight) else: @@ -3760,7 +3760,7 @@ def _cross_entropy_loss_label_smoothing( nll_loss_value = nll_loss(log_softmax_value, target, weight, ignore_index, reduction) - return (nll_loss_value * (1.0 - label_smoothing)) + (ret * (label_smoothing / num_channels)) + return (nll_loss_value * (1.0 - label_smoothing)) + (ret * (label_smoothing / num_class)) # TODO Is this a method? @@ -4128,30 +4128,30 @@ def _nll_loss_helper( lambda: f"Expected the input tensor to have {(target.ndim + 1)=} dimensions, but it has {a.ndim} dimensions.", ) - # channels dimension is either the first one if no batch dim present (i.e. a.shape[0]), + # class dimension is either the first one if no batch dim present (i.e. a.shape[0]), # or right next to it (i.e. a.shape[1]). - channels_dim = 1 if a.ndim >= 2 else 0 - num_channels = a.shape[channels_dim] - # target should match input in dims which do not correspond to the channels dim, i.e. - # (input.shape[:channels_dim] + input.shape[channels_dim + 1:]) == target.shape <=> True - expected_target_shape = a.shape[:channels_dim] + a.shape[channels_dim + 1 :] + class_dim = 1 if a.ndim >= 2 else 0 + num_class = a.shape[class_dim] + # target should match input in dims which do not correspond to the class dim, i.e. + # (input.shape[:class_dim] + input.shape[class_dim + 1:]) == target.shape <=> True + expected_target_shape = a.shape[:class_dim] + a.shape[class_dim + 1 :] utils.check( expected_target_shape == target.shape, - lambda: f"Expected the target tensor to have the same shape as the input tensor except for the channels dimension \ + lambda: f"Expected the target tensor to have the same shape as the input tensor except for the class dimension \ {expected_target_shape}, but it has shape {target.shape}.", ) utils.check( - weight is None or (weight.ndim == 1 and weight.shape[0] == num_channels), - lambda: f"Expected a 1D tensor with {num_channels} elements for weight argument, \ + weight is None or (weight.ndim == 1 and weight.shape[0] == num_class), + lambda: f"Expected a 1D tensor with {num_class} elements for weight argument, \ but found a tensor with {weight.ndim} dimensions and {weight.shape[0]} elements.", ) # NOTE: [Handling of 'ignore_index' parameter] # What does it mean to ignore an index? # The 'ignore_index' parameter specifies a target value that does not contribute to input gradient. - # 'ignore_index' can be outside of the [0, num_channels) range, which can cause out-of-bounds errors when gathering + # 'ignore_index' can be outside of the [0, num_class) range, which can cause out-of-bounds errors when gathering # values from input tensor. # # What does ATen do? @@ -4160,12 +4160,12 @@ def _nll_loss_helper( # # What do we do? # We mask the ignore_index entries on the output tensor from take_along_axis because we expect the targets to be - # within [0, num_channels) range. + # within [0, num_class) range. # # Why do we like our approach better? # Mimicking Aten behavior requires masking the target tensor before calling take_along_axis, which would add more # operations to the fusion. We should follow this approach until we see real examples where ignore_index is - # out-of-bounds of [0, num_channels) range. + # out-of-bounds of [0, num_class) range. # # What are the alternative options? # We can add a `mode` parameter to take_along_axis that controls how to handle out-of-bounds indices. @@ -4174,20 +4174,20 @@ def _nll_loss_helper( out = -a if weight is not None: - bcast_weight = reshape(weight, [num_channels] + [1 for _ in range(2, a.ndim)]) + bcast_weight = reshape(weight, [num_class] + [1 for _ in range(2, a.ndim)]) out = out * bcast_weight # Make target broadcastable with output, which has same shape as input tensor. - bcast_target = unsqueeze(target, channels_dim) + bcast_target = unsqueeze(target, class_dim) - out = take_along_dim(out, bcast_target, channels_dim) + out = take_along_dim(out, bcast_target, class_dim) selected_target_mask = bcast_target != ignore_index out = where(selected_target_mask, out, 0) # This section handles applying the reduction parameter to the output. # We return None for the total_weight when reduction is "none" or "sum" since it is unused in the backwards pass. if reduction == "none": - return squeeze(out, channels_dim), None + return squeeze(out, class_dim), None elif reduction == "sum": return sum(out), None elif reduction == "mean": @@ -4197,7 +4197,7 @@ def _nll_loss_helper( # Mask the ignored target classes. # Sum together all target weights. expanded_weight = expand(bcast_weight, a.shape) - selected_weight = take_along_dim(expanded_weight, bcast_target, channels_dim) + selected_weight = take_along_dim(expanded_weight, bcast_target, class_dim) selected_weight = where(selected_target_mask, selected_weight, 0) bcast_weight_sum = sum(selected_weight) return (reduced_sum / bcast_weight_sum), bcast_weight_sum