Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Commit

Permalink
Add an internal helper function to do more in-place updating
Browse files Browse the repository at this point in the history
Currently the `∇(x̄, f, Arg{N}, args...)` method updates `x̄` with the
result of `∇(f, Arg{N}, args...)`. This is done in-place for some
functions `f` but not all. In the case of the fallback method, we can
use dispatch to determine whether it's safe to do this in-place, thereby
hopefully saving some allocations.
  • Loading branch information
ararslan committed Apr 12, 2019
1 parent 85b27f7 commit 9c5a84f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,19 @@ output and `ȳ` the reverse-mode sensitivity of `y`.
(y::Node, ȳ) = propagate(tape(y), reverse_tape(y, ȳ))
@inline (y::Node{<:∇Scalar}) = (y, one(unbox(y)))

@inline (x̄, f, ::Type{Arg{N}}, args...) where N =+ (f, Arg{N}, args...)
# This is a fallback method where we don't necessarily know what we'll be adding and whether
# we can update the value in-place, so we'll try to be clever and dispatch.
@inline (x̄, f, ::Type{Arg{N}}, args...) where {N} = update!(x̄, (f, Arg{N}, args...))

# Update regular arrays in-place. Structured array types should not be updated in-place,
# even though it technically "works" (https://github.com/JuliaLang/julia/issues/31674),
# so we'll only permit mutating addition for `Array`s, e.g. `Vector` and `Matrix`.
# Mixed array and scalar adds should not occur, as sensitivities should always have the
# same shape, so we won't bother allowing e.g. updating an array with a scalar on the RHS.
update!(x̄::Array{T,N}, y::AbstractArray{S,N}) where {T,S,N} =.+= y

# Fall back to using regular addition
update!(x̄, y) =+ y

"""
∇(f; get_output::Bool=false)
Expand Down
22 changes: 22 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,26 @@ let
@test oned_container(Dict("a"=>5.0, "b"=>randn(3))) == Dict("a"=>1.0, "b"=>ones(3))
end

# To ensure we end up using the fallback machinery for ∇(x̄, f, ...) we'll define a new
# function and setup for it to use in the testset below
quad(A::Matrix, B::Matrix) = B'A*B
@explicit_intercepts quad Tuple{Matrix, Matrix}
Nabla.(::typeof(quad), ::Type{Arg{1}}, p, Y, Ȳ, A::Matrix, B::Matrix) = B**B'
Nabla.(::typeof(quad), ::Type{Arg{2}}, p, Y, Ȳ, A::Matrix, B::Matrix) = A*B*' + A'B*

@testset "Mutating values in the tape" begin
rng = MersenneTwister(123456)
n = 5
A = Leaf(Tape(), randn(rng, n, n))
B = randn(rng, n, n)
Q = quad(A, B)
QQ = quad(Q, B)
rt = (QQ, Matrix(1.0I, n, n))
oldvals = map(deepcopyunbox, getfield(rt, :tape))
Nabla.propagate(Q, rt) # This triggers a mutating addition
newvals = map(unbox, getfield(rt, :tape))
@test !(oldvals[1] newvals[1])
@test oldvals[2:end] newvals[2:end]
end

end # testset "core"

0 comments on commit 9c5a84f

Please sign in to comment.