Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Commit

Permalink
For now keep diffrules for special nosensitivity scalar case so as no…
Browse files Browse the repository at this point in the history
…t to break higher order functions

Handle Higher Order functions, including changing to ForwardDiff
  • Loading branch information
oxinabox committed Oct 21, 2020
1 parent aafa66b commit 82caafc
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 85 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ version = "0.12.1"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
DualNumbers = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
DiffRules = "^0.0"
DualNumbers = ">=0.6.0"
FDM = "^0.6"
ForwardDiff = "0.10.12"
SpecialFunctions = ">=0.5.0"
julia = "^1.0"

Expand Down
1 change: 1 addition & 0 deletions src/Nabla.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ module Nabla
using ChainRules
using ChainRulesCore
using ExprTools: ExprTools
using ForwardDiff: ForwardDiff
using LinearAlgebra
using Random
using SpecialFunctions
Expand Down
18 changes: 9 additions & 9 deletions src/core.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using DualNumbers

import Base: push!, length, show, getindex, setindex!, eachindex, isassigned,
isapprox, zero, one, lastindex

Expand All @@ -26,7 +24,7 @@ function show(io::IO, t::Tape)
end
end
end
@inline getindex(t::Tape, n::Int) = getindex(tape(t), n)
@inline getindex(t::Tape, n::Int) = unthunk(getindex(tape(t), n))
@inline getindex(t::Tape, node::Node) = getindex(t, pos(node))
@inline lastindex(t::Tape) = length(t)
@inline setindex!(t::Tape, x, n::Int) = (tape(t)[n] = x; t)
Expand Down Expand Up @@ -278,18 +276,20 @@ for T in (:Diagonal, :UpperTriangular, :LowerTriangular)
@eval @inline randned_container(x::$T{<:Real}) = $T(randn(eltype(x), size(x)...))
end

# Bare-bones FMAD implementation based on DualNumbers. Accepts a Tuple of args and returns
# a Tuple of gradients. Currently scales almost exactly linearly with the number of inputs.
# The coefficient of this scaling could be improved by implementing a version of DualNumbers
# which computes from multiple seeds at the same time.
# Bare-bones FMAD implementation based on internals of ForwardDiff.
# Accepts a Tuple of args and returns a Tuple of gradients.
# Currently scales almost exactly linearly with the number of inputs.
# The coefficient of this scaling could be improved by fully utilizing ForwardDiff
# and computing from multiple seeds at the same time.
function dual_call_expr(f, x::Type{<:Tuple}, ::Type{Type{Val{n}}}) where n
dual_call = Expr(:call, :f)
for m in 1:Base.length(x.parameters)
push!(dual_call.args, n == m ? :(Dual(x[$m], 1)) : :(x[$m]))
push!(dual_call.args, n == m ? :(ForwardDiff.Dual(x[$m], 1)) : :(x[$m]))
end
return :(dualpart($dual_call))
return :(first(ForwardDiff.partials($dual_call)))
end
@generated fmad(f, x, n) = dual_call_expr(f, x, n)

function fmad_expr(f, x::Type{<:Tuple})
body = Expr(:tuple)
for n in 1:Base.length(x.parameters)
Expand Down
20 changes: 16 additions & 4 deletions src/sensitivities/functional/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ end
(::typeof(map), ::Type{Arg{N}}, p, y, ȳ, f::Function, A::∇Array...) where N =
_∇(map, Arg{N-1}, p, y, ȳ, f, A...)
_∇(::typeof(map), arg::Type{Arg{N}}, p, y, ȳ, f::Function, A::∇Array...) where N =
hasmethod(∇, Tuple{typeof(f), Type{Arg{N}}, Any, Any, Any, map(eltype, A)...}) ?
map((yn, ȳn, An...)->(f, Arg{N}, p, yn, ȳn, An...), y, ȳ, A...) :
map((ȳn, An...)->ȳn * fmad(f, An, Val{N}), ȳ, A...)

# Implementation of sensitivities w.r.t. `broadcast`.
Expand Down Expand Up @@ -94,6 +92,20 @@ broadcastsum(f, add::Bool, z::Ref{<:Number}, As...) = broadcastsum(f, add, z[],
(::typeof(broadcast), ::Type{Arg{N}}, p, y, ȳ, f, A...) where N =
_∇(broadcast, Arg{N-1}, p, y, ȳ, f, A...)
_∇(::typeof(broadcast), ::Type{Arg{N}}, p, y, ȳ, f, A...) where N =
hasmethod(∇, Tuple{typeof(f), Type{Arg{N}}, Any, Any, Any, map(eltype, A)...}) ?
broadcastsum((yn, ȳn, xn...)->(f, Arg{N}, p, yn, ȳn, xn...), false, A[N], y, ȳ, A...) :
broadcastsum((ȳn, xn...)->ȳn * fmad(f, xn, Val{N}), false, A[N], ȳ, A...)

# Division from the right by a scalar.
import Base: /
@eval @explicit_intercepts $(Symbol("/")) Tuple{∇Array, ∇Scalar}
@inline (::typeof(/), ::Type{Arg{1}}, p, z, z̄, x::∇Scalar, y::∇Array) =
(broadcast, Arg{2}, p, z, z̄, /, x, y)
@inline (::typeof(/), ::Type{Arg{2}}, p, z, z̄, x::∇Scalar, y::∇Array) =
(broadcast, Arg{3}, p, z, z̄, /, x, y)

# Division from the left by a scalar.
import Base: \
@eval @explicit_intercepts $(Symbol("\\")) Tuple{∇Scalar, ∇Array}
@inline (::typeof(\), ::Type{Arg{1}}, p, z, z̄, x::∇Array, y::∇Scalar) =
(broadcast, Arg{2}, p, z, z̄, \, x, y)
@inline (::typeof(\), ::Type{Arg{2}}, p, z, z̄, x::∇Array, y::∇Scalar) =
(broadcast, Arg{3}, p, z, z̄, \, x, y)
7 changes: 3 additions & 4 deletions src/sensitivities/functional/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ for f in (:mapfoldl, :mapfoldr)
@eval begin
import Base: $f
@explicit_intercepts $f $type_tuple [false, false, true] #(init=0,)
(::typeof($f), ::Type{Arg{3}}, p, y, ȳ, f, ::$plustype, A::∇ArrayOrScalar) =
hasmethod(∇, Tuple{typeof(f), Type{Arg{1}}, Real}) ?
broadcast(An->* (f, Arg{1}, An), A) :
broadcast(An->* fmad(f, (An,), Val{1}), A)
function (::typeof($f), ::Type{Arg{3}}, p, y, ȳ, f, ::$plustype, A::∇ArrayOrScalar)
return.* ForwardDiff.derivative.(f, A)
end
end
end
44 changes: 6 additions & 38 deletions src/sensitivities/functional/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ import Statistics: mean
A::AbstractArray{<:∇Scalar};
dims=:,
)
hasmethod(∇, Tuple{typeof(f), Type{Arg{1}}, ∇Scalar}) ?
broadcast((An, ȳn)->ȳn * (f, Arg{1}, An), A, ȳ) :
broadcast((An, ȳn)->ȳn * fmad(f, (An,), Val{1}), A, ȳ)
return broadcast((An, ȳn)->ȳn * ForwardDiff.derivative(f, An), A, ȳ)
end
end

Expand All @@ -28,12 +26,6 @@ end
[false, true],
(dims=:,),
)
@explicit_intercepts(
sum,
Tuple{AbstractArray{<:∇Scalar}},
[true],
(dims=:,),
)
function (
::typeof(sum),
::Type{Arg{2}},
Expand All @@ -45,26 +37,17 @@ function ∇(
# Just pass through to mapreduce
return (mapreduce, Arg{3}, p, y, ȳ, f, Base.add_sum, A; dims=dims)
end
function (
::typeof(sum),
::Type{Arg{1}},
p, y, ȳ,
A::AbstractArray{<:∇Scalar};
dims=:,
)
# Again pass through to mapreduce, using identity as the mapped function
return (mapreduce, Arg{3}, p, y, ȳ, identity, Base.add_sum, A; dims=dims)
end
# Specialize on sum(abs2, x) as it's a common pattern with a simple derivative
# sum(abs2, xs) is in ChainRules, but it results in method ambiguties with the
# version that accepts any function above
function (
::typeof(sum),
::Type{Arg{2}},
p, y, ,
p, y, ȳ,
::typeof(abs2),
A::AbstractArray{<:∇Scalar};
A::AbstractArray{<:Real};
dims=:,
)
return 2 .* A
return 2ȳ .* A
end

@explicit_intercepts(
Expand All @@ -73,12 +56,6 @@ end
[false, true],
#(dims=:,) # https://github.com/JuliaLang/julia/issues/31412
)
@explicit_intercepts(
mean,
Tuple{AbstractArray{<:∇Scalar}},
[true],
(dims=:,)
)

_denom(x, dims::Colon) = length(x)
_denom(x, dims::Integer) = size(x, dims)
Expand All @@ -93,12 +70,3 @@ function ∇(
)
return (sum, Arg{2}, p, y, ȳ, f, x; dims=:) / _denom(x, :)
end
function (
::typeof(mean),
::Type{Arg{1}},
p, y, ȳ,
x::AbstractArray{<:∇Scalar};
dims=:,
)
return (sum, Arg{1}, p, y, ȳ, x; dims=dims) ./ _denom(x, dims)
end
3 changes: 2 additions & 1 deletion src/sensitivities/linalg/generic.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Implementation of sensitivities for unary linalg optimisations.
_ϵ, lb, ub = 3e-2, -3.0, 3.0
unary_linalg_optimisations = [
unary_linalg_optimisations = [#==
(:-, ∇Array, ∇Array, :(-Ȳ), (lb, ub)),
(:tr, ∇Array, ∇Scalar, :(Diagonal(fill!(similar(X), Ȳ))), (lb, ub)),
(:inv, ∇Array, ∇Array, :(-transpose(Y) * Ȳ * transpose(Y)), (lb, ub)),
Expand All @@ -11,6 +11,7 @@ unary_linalg_optimisations = [
(:adjoint, ∇Array, ∇Array, :(adjoint(Ȳ)), (lb, ub)),
(:norm, ∇Array, ∇Scalar, :(Ȳ ./ Y .* abs2.(X) ./ X), (lb, ub)),
(:norm, ∇Scalar, ∇Scalar, :(Ȳ * sign(X)), (lb, ub))
==#
]
for (f, T_In, T_Out, X̄, bounds) in unary_linalg_optimisations
if f === :-
Expand Down
12 changes: 10 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
using Nabla
using Test, LinearAlgebra, Statistics, Random
using Distributions, BenchmarkTools, SpecialFunctions, DualNumbers
using Test, LinearAlgebra, Statistics, Random, ForwardDiff
using Distributions, BenchmarkTools, SpecialFunctions

using Nabla: unbox, pos, tape, oneslike, zeroslike

# Helper function for comparing `Ref`s, since they don't compare equal under `==`
ref_equal(a::Ref{T}, b::Ref{T}) where {T} = a[] == b[]
ref_equal(a::Ref, b::Ref) = false

# for comparing against scalar rules
derivative_via_frule(f, x) = last(Nabla.frule((Nabla.NO_FIELDS, 1.0), f, x))
# Sensiblity checkes that his is defined right
@test derivative_via_frule(cos, 0) == 0
@test derivative_via_frule(sin, 0) == 1
@test derivative_via_frule(sin, 1.2) == derivative_via_frule(sin, 2π + 1.2)


@testset "Nabla.jl" begin

@testset "Core" begin
Expand Down
36 changes: 15 additions & 21 deletions test/sensitivities/functional/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ using DiffRules: diffrule, hasdiffrule
function check_unary_broadcast(f, x)
x_ = Leaf(Tape(), x)
s = broadcast(f, x_)
return (s, oneslike(unbox(s)))[x_] .(f, Arg{1}, x)
return (s, oneslike(unbox(s)))[x_] derivative_via_frule.(f, x)
end
for (package, f) in Nabla.unary_sensitivities
@testset "$package.$f" for (package, f) in Nabla.unary_sensitivities
domain = domain1(eval(f))
domain === nothing && error("Could not determine domain for $f.")
x_dist = Uniform(domain...)
Expand All @@ -36,43 +36,37 @@ using DiffRules: diffrule, hasdiffrule
s = broadcast(f, x_, y_)
o = oneslike(unbox(s))
∇s = (s, o)
∇x = broadcast((z, z̄, x, y)->(f, Arg{1}, nothing, z, z̄, x, y),
unbox(s), o, x, y)
∇y = broadcast((z, z̄, x, y)->(f, Arg{2}, nothing, z, z̄, x, y),
unbox(s), o, x, y)
# ∇x = sum(broadcast((z, z̄, x, y)->∇(f, Arg{1}, nothing, z, z̄, x, y), unbox(s), o, x, y))
# ∇y = broadcast((z, z̄, x, y)->∇(f, Arg{2}, nothing, z, z̄, x, y), unbox(s), o, x, y)
@test broadcast(f, x, y) == unbox(s)
@test ∇s[x_] ∇x
@test ∇s[y_] ∇y
# @test ∇s[x_] ≈ ∇x
# @test ∇s[y_] ≈ ∇y
end
function check_binary_broadcast(f, x::Real, y)
tape = Tape()
x_, y_ = Leaf(tape, x), Leaf(tape, y)
s = broadcast(f, x_, y_)
o = oneslike(unbox(s))
∇s = (s, o)
∇x = sum(broadcast((z, z̄, x, y)->(f, Arg{1}, nothing, z, z̄, x, y),
unbox(s), o, x, y))
∇y = broadcast((z, z̄, x, y)->(f, Arg{2}, nothing, z, z̄, x, y),
unbox(s), o, x, y)
# ∇x = sum(broadcast((z, z̄, x, y)->∇(f, Arg{1}, nothing, z, z̄, x, y), unbox(s), o, x, y))
# ∇y = broadcast((z, z̄, x, y)->∇(f, Arg{2}, nothing, z, z̄, x, y), unbox(s), o, x, y)
@test broadcast(f, x, y) == unbox(s)
@test ∇s[x_] ∇x
@test ∇s[y_] ∇y
# @test ∇s[x_] ≈ ∇x
# @test ∇s[y_] ≈ ∇y
end
function check_binary_broadcast(f, x, y::Real)
tape = Tape()
x_, y_ = Leaf(tape, x), Leaf(tape, y)
s = broadcast(f, x_, y_)
o = oneslike(unbox(s))
∇s = (s, o)
∇x = broadcast((z, z̄, x, y)->(f, Arg{1}, nothing, z, z̄, x, y),
unbox(s), o, x, y)
∇y = sum(broadcast((z, z̄, x, y)->(f, Arg{2}, nothing, z, z̄, x, y),
unbox(s), o, x, y))
# ∇x = sum(broadcast((z, z̄, x, y)->∇(f, Arg{1}, nothing, z, z̄, x, y), unbox(s), o, x, y))
# ∇y = broadcast((z, z̄, x, y)->∇(f, Arg{2}, nothing, z, z̄, x, y), unbox(s), o, x, y)
@test broadcast(f, x, y) == unbox(s)
@test ∇s[x_] ∇x
@test ∇s[y_] ∇y
# @test ∇s[x_] ≈ ∇x
# @test ∇s[y_] ≈ ∇y
end
for (package, f) in Nabla.binary_sensitivities
@testset "$package.$f" for (package, f) in Nabla.binary_sensitivities

# TODO: More care needs to be taken to test the following.
if hasdiffrule(package, f, 2)
Expand Down
8 changes: 4 additions & 4 deletions test/sensitivities/functional/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# Test +.
x_ = Leaf(Tape(), x)
s = functional(f, +, x_)
@test (s)[x_] .(f, Arg{1}, x)
@test (s)[x_] derivative_via_frule.(f, x)
end

# Some composite sensitivities.
Expand All @@ -34,7 +34,7 @@
x_ = Leaf(Tape(), x)
s = functional(f, +, x_)
@test unbox(s) functional(f, +, x)
@test (s)[x_] map(x->fmad(f, (x,), Val{1}), x)
@test (s)[x_] map(xn->ForwardDiff.derivative(f, xn), x)
end
end
end
Expand Down Expand Up @@ -74,7 +74,7 @@
x_ = Leaf(Tape(), x)
s = sum(f, x_)
@test unbox(s) == sum(f, x)
@test (s)[x_] .(f, Arg{1}, x)
@test (s)[x_] derivative_via_frule.(f, x)
end

# Some composite functions.
Expand All @@ -88,7 +88,7 @@
x_ = Leaf(Tape(), x)
s = sum(f, x_)
@test unbox(s) == sum(f, x)
@test (s)[x_] map(x->fmad(f, (x,), Val{1}), x)
@test (s)[x_] map(xn->ForwardDiff.derivative(f, xn), x)
end
end
end
Expand Down

0 comments on commit 82caafc

Please sign in to comment.