diff --git a/src/core.jl b/src/core.jl index 29b70c12..59452f2b 100644 --- a/src/core.jl +++ b/src/core.jl @@ -179,7 +179,19 @@ output and `ȳ` the reverse-mode sensitivity of `y`. ∇(y::Node, ȳ) = propagate(tape(y), reverse_tape(y, ȳ)) @inline ∇(y::Node{<:∇Scalar}) = ∇(y, one(unbox(y))) -@inline ∇(x̄, f, ::Type{Arg{N}}, args...) where N = x̄ + ∇(f, Arg{N}, args...) +# This is a fallback method where we don't necessarily know what we'll be adding and whether +# we can update the value in-place, so we'll try to be clever and dispatch. +@inline ∇(x̄, f, ::Type{Arg{N}}, args...) where {N} = update!(x̄, ∇(f, Arg{N}, args...)) + +# Update regular arrays in-place. Structured array types should not be updated in-place, +# even though it technically "works" (https://github.com/JuliaLang/julia/issues/31674), +# so we'll only permit mutating addition for `Array`s, e.g. `Vector` and `Matrix`. +# Mixed array and scalar adds should not occur, as sensitivities should always have the +# same shape, so we won't bother allowing e.g. updating an array with a scalar on the RHS. +update!(x̄::Array{T,N}, y::AbstractArray{S,N}) where {T,S,N} = x̄ .+= y + +# Fall back to using regular addition +update!(x̄, y) = x̄ + y """ ∇(f; get_output::Bool=false) diff --git a/test/core.jl b/test/core.jl index 3de763c3..8f8eaed2 100644 --- a/test/core.jl +++ b/test/core.jl @@ -190,4 +190,26 @@ let @test oned_container(Dict("a"=>5.0, "b"=>randn(3))) == Dict("a"=>1.0, "b"=>ones(3)) end +# To ensure we end up using the fallback machinery for ∇(x̄, f, ...) we'll define a new +# function and setup for it to use in the testset below +quad(A::Matrix, B::Matrix) = B'A*B +@explicit_intercepts quad Tuple{Matrix, Matrix} +Nabla.∇(::typeof(quad), ::Type{Arg{1}}, p, Y, Ȳ, A::Matrix, B::Matrix) = B*Ȳ*B' +Nabla.∇(::typeof(quad), ::Type{Arg{2}}, p, Y, Ȳ, A::Matrix, B::Matrix) = A*B*Ȳ' + A'B*Ȳ + +@testset "Mutating values in the tape" begin + rng = MersenneTwister(123456) + n = 5 + A = Leaf(Tape(), randn(rng, n, n)) + B = randn(rng, n, n) + Q = quad(A, B) + QQ = quad(Q, B) + rt = ∇(QQ, Matrix(1.0I, n, n)) + oldvals = map(deepcopy∘unbox, getfield(rt, :tape)) + Nabla.propagate(Q, rt) # This triggers a mutating addition + newvals = map(unbox, getfield(rt, :tape)) + @test !(oldvals[1] ≈ newvals[1]) + @test oldvals[2:end] ≈ newvals[2:end] +end + end # testset "core"