Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup Flux.Losses documentation #1930

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
using Documenter, Flux, NNlib, Functors, MLUtils

DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true)
DocMeta.setdocmeta!(Flux.Losses, :DocTestSetup, :(using Flux.Losses); recursive = true)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems this line is not doing its job and docs CI is failing because of this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe these lines are useless since makedocs have the tests disabled. In the GH CI.yml, there is a separate version of these lines, and that's the one that matters.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

turns out that there is an additional problem, DocMeta.setdocmeta! doesn't accept the DocTestFilters that I also needed, so I had to revert this change entirely


# In the Losses module, doctests which differ in the printed Float32 values won't fail
DocMeta.setdocmeta!(Flux.Losses, :DocTestFilters, :(r"[0-9\.]+f0"); recursive = true)

makedocs(modules = [Flux, NNlib, Functors, MLUtils],
doctest = false,
sitename = "Flux",
Expand Down
3 changes: 2 additions & 1 deletion docs/src/models/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,14 @@ model(gpu(rand(10)))
A custom loss function for the multiple outputs may look like this:
```julia
using Statistics
using Flux.Losses: mse

# assuming model returns the output of a Split
# x is a single input
# ys is a tuple of outputs
function loss(x, ys, model)
# rms over all the mse
ŷs = model(x)
return sqrt(mean(Flux.mse(y, ŷ) for (y, ŷ) in zip(ys, ŷs)))
return sqrt(mean(mse(y, ŷ) for (y, ŷ) in zip(ys, ŷs)))
end
```
45 changes: 19 additions & 26 deletions docs/src/models/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,36 @@
Flux provides a large number of common loss functions used for training machine learning models.
They are grouped together in the `Flux.Losses` module.

Loss functions for supervised learning typically expect as inputs a target `y`, and a prediction `ŷ`.
In Flux's convention, the order of the arguments is the following
As an example, the crossentropy function for multi-class classification that takes logit predictions (i.e. not [`softmax`](@ref)ed)
can be imported with

```julia
using Flux.Losses: logitcrossentropy
```

Loss functions for supervised learning typically expect as inputs a true target `y` and a prediction `ŷ`.
In Flux's convention, the order of the arguments is the following:

```julia
loss(ŷ, y)
```

They are commonly passed as arrays of size `num_target_features x num_examples_in_batch`.

Most loss functions in Flux have an optional argument `agg`, denoting the type of aggregation performed over the
batch:
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

```julia
loss(ŷ, y) # defaults to `mean`
loss(ŷ, y, agg=sum) # use `sum` for reduction
loss(ŷ, y, agg=x->sum(x, dims=2)) # partial reduction
loss(ŷ, y, agg=x->mean(w .* x)) # weighted mean
loss(ŷ, y, agg=identity) # no aggregation.
loss(ŷ, y) # defaults to `mean`
loss(ŷ, y, agg = sum) # use `sum` for reduction
loss(ŷ, y, agg = x -> sum(x, dims=2)) # partial reduction
loss(ŷ, y, agg = x -> mean(w .* x)) # weighted mean
loss(ŷ, y, agg = identity) # no aggregation.
```

## Losses Reference

```@docs
Flux.Losses.mae
Flux.Losses.mse
Flux.Losses.msle
Flux.Losses.huber_loss
Flux.Losses.label_smoothing
Flux.Losses.crossentropy
Flux.Losses.logitcrossentropy
Flux.Losses.binarycrossentropy
Flux.Losses.logitbinarycrossentropy
Flux.Losses.kldivergence
Flux.Losses.poisson_loss
Flux.Losses.hinge_loss
Flux.Losses.squared_hinge_loss
Flux.Losses.dice_coeff_loss
Flux.Losses.tversky_loss
Flux.Losses.binary_focal_loss
Flux.Losses.focal_loss
Flux.Losses.siamese_contrastive_loss
```@autodocs
Modules = [Flux.Losses]
Pages = ["functions.jl"]
```
3 changes: 1 addition & 2 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ include("outputsize.jl")
include("data/Data.jl")
using .Data


include("losses/Losses.jl")
using .Losses # TODO: stop importing Losses in Flux's namespace in v0.12
using .Losses # TODO: stop importing Losses in Flux's namespace in v0.14?

include("deprecations.jl")

Expand Down
Loading