-
Notifications
You must be signed in to change notification settings - Fork 5
Add an internal helper function to do more in-place updating #145
Conversation
Codecov Report
@@ Coverage Diff @@
## master #145 +/- ##
==========================================
+ Coverage 97.04% 97.05% +<.01%
==========================================
Files 20 20
Lines 744 746 +2
==========================================
+ Hits 722 724 +2
Misses 22 22
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable. Can you add tests for the in-place behaviour?
I was hoping no one would notice that. I couldn't come up with a way to test it that would be sensible from a real-world-code perspective, but I can do something targeted. |
Sounds good to me. I'd also like to a Nabla expert review this change. |
src/core.jl
Outdated
@inline ∇(x̄, f, ::Type{Arg{N}}, args...) where {N} = update!(x̄, ∇(f, Arg{N}, args...)) | ||
|
||
# Use broadcast for mixed array/scalar operations | ||
@inline update!(x̄::∇Scalar, y::∇Array) = x̄ .+ y |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this method isn't covered...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on a conversation with @wesselb, our thinking is that the mixed scalar/array case should not occur and it doing so is indicative of a bug. I'll adjust the definitions here appropriately.
src/core.jl
Outdated
|
||
# Update arrays in-place. Mixed array and scalar adds should not occur, so don't bother | ||
# accepting scalars on the RHS. | ||
@inline update!(x̄::∇Array, y::∇Array) = x̄ .+= y |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized that this is not legit to do in the general case; we'll have to be even more restrictive. Say we have an UpperTriangular
stored on the tape. When we go to update this, if we do .+=
, the resulting value is still an UpperTriangular
, which is wrong. I think updating the value on the tape in-place is only safe when the existing and new values are both Array
s, not AbstractArray
s.
Related to/fixes #61? @willtebbutt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a really nice optimisation. I've left one particular comment that I think needs addressing, but other than that I'm happy. The only other thing we might wish to consider is whether or not to stop force-inlining stuff? I'm honestly not sure why I thought that was a good idea when I first wrote the code.
Currently the `∇(x̄, f, Arg{N}, args...)` method updates `x̄` with the result of `∇(f, Arg{N}, args...)`. This is done in-place for some functions `f` but not all. In the case of the fallback method, we can use dispatch to determine whether it's safe to do this in-place, thereby hopefully saving some allocations.
I've opted to remove the |
Currently the
∇(x̄, f, Arg{N}, args...)
method, which updatesx̄
with the result of∇(f, Arg{N}, args...)
, is done in-place for some functionsf
but not all. In the case of the fallback method, we can use dispatch to determine whether it's safe to do this in-place.