Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Commit

Permalink
Handle Higher Order functions, including changing to ForwardDiff
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Oct 21, 2020
1 parent 8a3464d commit 0d63462
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 86 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ version = "0.12.1"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
DualNumbers = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
DiffRules = "^0.0"
DualNumbers = ">=0.6.0"
FDM = "^0.6"
ForwardDiff = "0.10.12"
SpecialFunctions = ">=0.5.0"
julia = "^1.0"

Expand Down
1 change: 1 addition & 0 deletions src/Nabla.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ module Nabla
using ChainRules
using ChainRulesCore
using ExprTools: ExprTools
using ForwardDiff: ForwardDiff
using LinearAlgebra
using Random
using SpecialFunctions
Expand Down
18 changes: 9 additions & 9 deletions src/core.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using DualNumbers

import Base: push!, length, show, getindex, setindex!, eachindex, isassigned,
isapprox, zero, one, lastindex

Expand All @@ -26,7 +24,7 @@ function show(io::IO, t::Tape)
end
end
end
@inline getindex(t::Tape, n::Int) = getindex(tape(t), n)
@inline getindex(t::Tape, n::Int) = unthunk(getindex(tape(t), n))
@inline getindex(t::Tape, node::Node) = getindex(t, pos(node))
@inline lastindex(t::Tape) = length(t)
@inline setindex!(t::Tape, x, n::Int) = (tape(t)[n] = x; t)
Expand Down Expand Up @@ -291,18 +289,20 @@ for T in (:Diagonal, :UpperTriangular, :LowerTriangular)
@eval @inline randned_container(x::$T{<:Real}) = $T(randn(eltype(x), size(x)...))
end

# Bare-bones FMAD implementation based on DualNumbers. Accepts a Tuple of args and returns
# a Tuple of gradients. Currently scales almost exactly linearly with the number of inputs.
# The coefficient of this scaling could be improved by implementing a version of DualNumbers
# which computes from multiple seeds at the same time.
# Bare-bones FMAD implementation based on internals of ForwardDiff.
# Accepts a Tuple of args and returns a Tuple of gradients.
# Currently scales almost exactly linearly with the number of inputs.
# The coefficient of this scaling could be improved by fully utilizing ForwardDiff
# and computing from multiple seeds at the same time.
function dual_call_expr(f, x::Type{<:Tuple}, ::Type{Type{Val{n}}}) where n
dual_call = Expr(:call, :f)
for m in 1:Base.length(x.parameters)
push!(dual_call.args, n == m ? :(Dual(x[$m], 1)) : :(x[$m]))
push!(dual_call.args, n == m ? :(ForwardDiff.Dual(x[$m], 1)) : :(x[$m]))
end
return :(dualpart($dual_call))
return :(first(ForwardDiff.partials($dual_call)))
end
@generated fmad(f, x, n) = dual_call_expr(f, x, n)

function fmad_expr(f, x::Type{<:Tuple})
body = Expr(:tuple)
for n in 1:Base.length(x.parameters)
Expand Down
20 changes: 16 additions & 4 deletions src/sensitivities/functional/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ end
(::typeof(map), ::Type{Arg{N}}, p, y, ȳ, f::Function, A::∇Array...) where N =
_∇(map, Arg{N-1}, p, y, ȳ, f, A...)
_∇(::typeof(map), arg::Type{Arg{N}}, p, y, ȳ, f::Function, A::∇Array...) where N =
hasmethod(∇, Tuple{typeof(f), Type{Arg{N}}, Any, Any, Any, map(eltype, A)...}) ?
map((yn, ȳn, An...)->(f, Arg{N}, p, yn, ȳn, An...), y, ȳ, A...) :
map((ȳn, An...)->ȳn * fmad(f, An, Val{N}), ȳ, A...)

# Implementation of sensitivities w.r.t. `broadcast`.
Expand Down Expand Up @@ -94,6 +92,20 @@ broadcastsum(f, add::Bool, z::Ref{<:Number}, As...) = broadcastsum(f, add, z[],
(::typeof(broadcast), ::Type{Arg{N}}, p, y, ȳ, f, A...) where N =
_∇(broadcast, Arg{N-1}, p, y, ȳ, f, A...)
_∇(::typeof(broadcast), ::Type{Arg{N}}, p, y, ȳ, f, A...) where N =
hasmethod(∇, Tuple{typeof(f), Type{Arg{N}}, Any, Any, Any, map(eltype, A)...}) ?
broadcastsum((yn, ȳn, xn...)->(f, Arg{N}, p, yn, ȳn, xn...), false, A[N], y, ȳ, A...) :
broadcastsum((ȳn, xn...)->ȳn * fmad(f, xn, Val{N}), false, A[N], ȳ, A...)

# Division from the right by a scalar.
import Base: /
@eval @explicit_intercepts $(Symbol("/")) Tuple{∇Array, ∇Scalar}
@inline (::typeof(/), ::Type{Arg{1}}, p, z, z̄, x::∇Scalar, y::∇Array) =
(broadcast, Arg{2}, p, z, z̄, /, x, y)
@inline (::typeof(/), ::Type{Arg{2}}, p, z, z̄, x::∇Scalar, y::∇Array) =
(broadcast, Arg{3}, p, z, z̄, /, x, y)

# Division from the left by a scalar.
import Base: \
@eval @explicit_intercepts $(Symbol("\\")) Tuple{∇Scalar, ∇Array}
@inline (::typeof(\), ::Type{Arg{1}}, p, z, z̄, x::∇Array, y::∇Scalar) =
(broadcast, Arg{2}, p, z, z̄, \, x, y)
@inline (::typeof(\), ::Type{Arg{2}}, p, z, z̄, x::∇Array, y::∇Scalar) =
(broadcast, Arg{3}, p, z, z̄, \, x, y)
7 changes: 3 additions & 4 deletions src/sensitivities/functional/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ for f in (:mapfoldl, :mapfoldr)
@eval begin
import Base: $f
@explicit_intercepts $f $type_tuple [false, false, true] #(init=0,)
(::typeof($f), ::Type{Arg{3}}, p, y, ȳ, f, ::$plustype, A::∇ArrayOrScalar) =
hasmethod(∇, Tuple{typeof(f), Type{Arg{1}}, Real}) ?
broadcast(An->* (f, Arg{1}, An), A) :
broadcast(An->* fmad(f, (An,), Val{1}), A)
function (::typeof($f), ::Type{Arg{3}}, p, y, ȳ, f, ::$plustype, A::∇ArrayOrScalar)
return.* ForwardDiff.derivative.(f, A)
end
end
end
44 changes: 6 additions & 38 deletions src/sensitivities/functional/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ import Statistics: mean
A::AbstractArray{<:∇Scalar};
dims=:,
)
hasmethod(∇, Tuple{typeof(f), Type{Arg{1}}, ∇Scalar}) ?
broadcast((An, ȳn)->ȳn * (f, Arg{1}, An), A, ȳ) :
broadcast((An, ȳn)->ȳn * fmad(f, (An,), Val{1}), A, ȳ)
return broadcast((An, ȳn)->ȳn * ForwardDiff.derivative(f, An), A, ȳ)
end
end

Expand All @@ -28,12 +26,6 @@ end
[false, true],
(dims=:,),
)
@explicit_intercepts(
sum,
Tuple{AbstractArray{<:∇Scalar}},
[true],
(dims=:,),
)
function (
::typeof(sum),
::Type{Arg{2}},
Expand All @@ -45,26 +37,17 @@ function ∇(
# Just pass through to mapreduce
return (mapreduce, Arg{3}, p, y, ȳ, f, Base.add_sum, A; dims=dims)
end
function (
::typeof(sum),
::Type{Arg{1}},
p, y, ȳ,
A::AbstractArray{<:∇Scalar};
dims=:,
)
# Again pass through to mapreduce, using identity as the mapped function
return (mapreduce, Arg{3}, p, y, ȳ, identity, Base.add_sum, A; dims=dims)
end
# Specialize on sum(abs2, x) as it's a common pattern with a simple derivative
# sum(abs2, xs) is in ChainRules, but it results in method ambiguties with the
# version that accepts any function above
function (
::typeof(sum),
::Type{Arg{2}},
p, y, ,
p, y, ȳ,
::typeof(abs2),
A::AbstractArray{<:∇Scalar};
A::AbstractArray{<:Real};
dims=:,
)
return 2 .* A
return 2ȳ .* A
end

@explicit_intercepts(
Expand All @@ -73,12 +56,6 @@ end
[false, true],
#(dims=:,) # https://github.com/JuliaLang/julia/issues/31412
)
@explicit_intercepts(
mean,
Tuple{AbstractArray{<:∇Scalar}},
[true],
(dims=:,)
)

_denom(x, dims::Colon) = length(x)
_denom(x, dims::Integer) = size(x, dims)
Expand All @@ -93,12 +70,3 @@ function ∇(
)
return (sum, Arg{2}, p, y, ȳ, f, x; dims=:) / _denom(x, :)
end
function (
::typeof(mean),
::Type{Arg{1}},
p, y, ȳ,
x::AbstractArray{<:∇Scalar};
dims=:,
)
return (sum, Arg{1}, p, y, ȳ, x; dims=dims) ./ _denom(x, dims)
end
3 changes: 2 additions & 1 deletion src/sensitivities/linalg/generic.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Implementation of sensitivities for unary linalg optimisations.
_ϵ, lb, ub = 3e-2, -3.0, 3.0
unary_linalg_optimisations = [
unary_linalg_optimisations = [#==
(:-, ∇Array, ∇Array, :(-Ȳ), (lb, ub)),
(:tr, ∇Array, ∇Scalar, :(Diagonal(fill!(similar(X), Ȳ))), (lb, ub)),
(:inv, ∇Array, ∇Array, :(-transpose(Y) * Ȳ * transpose(Y)), (lb, ub)),
Expand All @@ -11,6 +11,7 @@ unary_linalg_optimisations = [
(:adjoint, ∇Array, ∇Array, :(adjoint(Ȳ)), (lb, ub)),
(:norm, ∇Array, ∇Scalar, :(Ȳ ./ Y .* abs2.(X) ./ X), (lb, ub)),
(:norm, ∇Scalar, ∇Scalar, :(Ȳ * sign(X)), (lb, ub))
==#
]
for (f, T_In, T_Out, X̄, bounds) in unary_linalg_optimisations
if f === :-
Expand Down
2 changes: 1 addition & 1 deletion src/sensitivities/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ for (package, f, arity) in diffrules()
∂f∂x = diffrule(package, f, :x)
#@eval @explicit_intercepts $f Tuple{∇Scalar}
#@eval @inline ∇(::typeof($f), ::Type{Arg{1}}, p, y, ȳ, x::∇Scalar) = ȳ * $∂f∂x
@eval @inline (::typeof($f), ::Type{Arg{1}}, x::∇Scalar) = $∂f∂x
#@eval @inline ∇(::typeof($f), ::Type{Arg{1}}, x::∇Scalar) = $∂f∂x
elseif arity == 2
push!(binary_sensitivities, (package, f))
∂f∂x, ∂f∂y = diffrule(package, f, :x, :y)
Expand Down
12 changes: 10 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
using Nabla
using Test, LinearAlgebra, Statistics, Random
using Distributions, BenchmarkTools, SpecialFunctions, DualNumbers
using Test, LinearAlgebra, Statistics, Random, ForwardDiff
using Distributions, BenchmarkTools, SpecialFunctions

using Nabla: unbox, pos, tape, oneslike, zeroslike

# Helper function for comparing `Ref`s, since they don't compare equal under `==`
ref_equal(a::Ref{T}, b::Ref{T}) where {T} = a[] == b[]
ref_equal(a::Ref, b::Ref) = false

# for comparing against scalar rules
derivative_via_frule(f, x) = last(Nabla.frule((Nabla.NO_FIELDS, 1.0), f, x))
# Sensiblity checkes that his is defined right
@test derivative_via_frule(cos, 0) == 0
@test derivative_via_frule(sin, 0) == 1
@test derivative_via_frule(sin, 1.2) == derivative_via_frule(sin, 2π + 1.2)


@testset "Nabla.jl" begin

@testset "Core" begin
Expand Down
36 changes: 15 additions & 21 deletions test/sensitivities/functional/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ using DiffRules: diffrule, hasdiffrule
function check_unary_broadcast(f, x)
x_ = Leaf(Tape(), x)
s = broadcast(f, x_)
return (s, oneslike(unbox(s)))[x_] .(f, Arg{1}, x)
return (s, oneslike(unbox(s)))[x_] derivative_via_frule.(f, x)
end
for (package, f) in Nabla.unary_sensitivities
@testset "$package.$f" for (package, f) in Nabla.unary_sensitivities
domain = domain1(eval(f))
domain === nothing && error("Could not determine domain for $f.")
x_dist = Uniform(domain...)
Expand All @@ -36,43 +36,37 @@ using DiffRules: diffrule, hasdiffrule
s = broadcast(f, x_, y_)
o = oneslike(unbox(s))
∇s = (s, o)
∇x = broadcast((z, z̄, x, y)->(f, Arg{1}, nothing, z, z̄, x, y),
unbox(s), o, x, y)
∇y = broadcast((z, z̄, x, y)->(f, Arg{2}, nothing, z, z̄, x, y),
unbox(s), o, x, y)
# ∇x = sum(broadcast((z, z̄, x, y)->∇(f, Arg{1}, nothing, z, z̄, x, y), unbox(s), o, x, y))
# ∇y = broadcast((z, z̄, x, y)->∇(f, Arg{2}, nothing, z, z̄, x, y), unbox(s), o, x, y)
@test broadcast(f, x, y) == unbox(s)
@test ∇s[x_] ∇x
@test ∇s[y_] ∇y
# @test ∇s[x_] ≈ ∇x
# @test ∇s[y_] ≈ ∇y
end
function check_binary_broadcast(f, x::Real, y)
tape = Tape()
x_, y_ = Leaf(tape, x), Leaf(tape, y)
s = broadcast(f, x_, y_)
o = oneslike(unbox(s))
∇s = (s, o)
∇x = sum(broadcast((z, z̄, x, y)->(f, Arg{1}, nothing, z, z̄, x, y),
unbox(s), o, x, y))
∇y = broadcast((z, z̄, x, y)->(f, Arg{2}, nothing, z, z̄, x, y),
unbox(s), o, x, y)
# ∇x = sum(broadcast((z, z̄, x, y)->∇(f, Arg{1}, nothing, z, z̄, x, y), unbox(s), o, x, y))
# ∇y = broadcast((z, z̄, x, y)->∇(f, Arg{2}, nothing, z, z̄, x, y), unbox(s), o, x, y)
@test broadcast(f, x, y) == unbox(s)
@test ∇s[x_] ∇x
@test ∇s[y_] ∇y
# @test ∇s[x_] ≈ ∇x
# @test ∇s[y_] ≈ ∇y
end
function check_binary_broadcast(f, x, y::Real)
tape = Tape()
x_, y_ = Leaf(tape, x), Leaf(tape, y)
s = broadcast(f, x_, y_)
o = oneslike(unbox(s))
∇s = (s, o)
∇x = broadcast((z, z̄, x, y)->(f, Arg{1}, nothing, z, z̄, x, y),
unbox(s), o, x, y)
∇y = sum(broadcast((z, z̄, x, y)->(f, Arg{2}, nothing, z, z̄, x, y),
unbox(s), o, x, y))
# ∇x = sum(broadcast((z, z̄, x, y)->∇(f, Arg{1}, nothing, z, z̄, x, y), unbox(s), o, x, y))
# ∇y = broadcast((z, z̄, x, y)->∇(f, Arg{2}, nothing, z, z̄, x, y), unbox(s), o, x, y)
@test broadcast(f, x, y) == unbox(s)
@test ∇s[x_] ∇x
@test ∇s[y_] ∇y
# @test ∇s[x_] ≈ ∇x
# @test ∇s[y_] ≈ ∇y
end
for (package, f) in Nabla.binary_sensitivities
@testset "$package.$f" for (package, f) in Nabla.binary_sensitivities

# TODO: More care needs to be taken to test the following.
if hasdiffrule(package, f, 2)
Expand Down
8 changes: 4 additions & 4 deletions test/sensitivities/functional/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# Test +.
x_ = Leaf(Tape(), x)
s = functional(f, +, x_)
@test (s)[x_] .(f, Arg{1}, x)
@test (s)[x_] derivative_via_frule.(f, x)
end

# Some composite sensitivities.
Expand All @@ -34,7 +34,7 @@
x_ = Leaf(Tape(), x)
s = functional(f, +, x_)
@test unbox(s) functional(f, +, x)
@test (s)[x_] map(x->fmad(f, (x,), Val{1}), x)
@test (s)[x_] map(xn->ForwardDiff.derivative(f, xn), x)
end
end
end
Expand Down Expand Up @@ -74,7 +74,7 @@
x_ = Leaf(Tape(), x)
s = sum(f, x_)
@test unbox(s) == sum(f, x)
@test (s)[x_] .(f, Arg{1}, x)
@test (s)[x_] derivative_via_frule.(f, x)
end

# Some composite functions.
Expand All @@ -88,7 +88,7 @@
x_ = Leaf(Tape(), x)
s = sum(f, x_)
@test unbox(s) == sum(f, x)
@test (s)[x_] map(x->fmad(f, (x,), Val{1}), x)
@test (s)[x_] map(xn->ForwardDiff.derivative(f, xn), x)
end
end
end
Expand Down

0 comments on commit 0d63462

Please sign in to comment.