Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Logit focal loss #2138

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft

Conversation

josephsdavid
Copy link

It would be quite nice to have focal loss from logits, for numerical stability. This PR implements that! We have logit versions of crossentropy et al, so this i think has precedence!

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

src/losses/functions.jl Outdated Show resolved Hide resolved
See also: [`Losses.focal_loss`](@ref)

"""
function logit_focal_loss(ŷ, y; γ=2.0f0, agg=mean, dims=1, ϵ=epseltype(ŷ))
Copy link
Member

@mcabbott mcabbott Dec 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some have crept in & need fixing, but there should not be greek-letter keywords. These can be gamma and eps?

Also, as written, γ=1.5 will cause Float32 input to be promoted to Float64. Can you avoid this somehow? Perhaps there should be a line like γ = gamma isa Integer ? gamma : convert(eltype(logpt), gamma). (Integer powers are faster.)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice idea, can do!

0.665241 0.665241 0.665241 0.665241 0.665241

julia> Flux.logit_focal_loss(ŷ, y) ≈ 1.1277571935622628
true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example output doesn't match what's written.

More importantly, the example is an opportunity to show exactly how this relates to focal_loss, i.e. where the softmax goes. And perhaps (if you can think of a compact & clear way) the relation to crossentropy (or rather logitcrossentropy?) too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah i still need to work through these tests, did not realize about the docstring tests until after already putting tests elsewhere :) Can do !

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants