From 32f951a0d7ab933d8bf743c396f58c6e0f098b31 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 21 Oct 2020 14:52:33 +0100 Subject: [PATCH 01/71] Add ChainRules to Project.toml --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index a73d650b..4b6f9a0f 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,8 @@ uuid = "49c96f43-aa6d-5a04-a506-44c7070ebe78" version = "0.12.3" [deps] +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" DualNumbers = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128" From 83d24c544b82b7c794435e50f635ad276cf4be28 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 3 Sep 2020 19:18:28 +0100 Subject: [PATCH 02/71] Load all rules from ChainRules WIP update to use preprocess to handle pullbacks wip improve generation code generate all the possible methods Make show that scalars mostly work deal with redundant where N in varargs Allow chainrules to do more define and use node_type rather than unionize for defining overload Correctly handing varargs and outer unionalls with same type var symbols as inner union alls Fix code_tranformation utils.jl Note similarities between node_type and unionize type comment vararg --- Project.toml | 2 + dev/ExprTools | 1 + src/Nabla.jl | 13 +- src/code_transformation/util.jl | 34 +++- src/core.jl | 24 ++- src/sensitivities/chainrules.jl | 189 +++++++++++++++++++++ src/sensitivities/functional/functional.jl | 40 ----- src/sensitivities/scalar.jl | 12 +- src/sensitivity.jl | 12 +- test/code_transformation/differentiable.jl | 6 + test/code_transformation/util.jl | 5 + test/scratch.jl | 3 + test/sensitivities/scalar.jl | 4 +- 13 files changed, 286 insertions(+), 59 deletions(-) create mode 160000 dev/ExprTools create mode 100644 src/sensitivities/chainrules.jl create mode 100644 test/scratch.jl diff --git a/Project.toml b/Project.toml index 4b6f9a0f..9759c90e 100644 --- a/Project.toml +++ b/Project.toml @@ -7,8 +7,10 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" DualNumbers = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" +ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/dev/ExprTools b/dev/ExprTools new file mode 160000 index 00000000..cba2e159 --- /dev/null +++ b/dev/ExprTools @@ -0,0 +1 @@ +Subproject commit cba2e15975636c8502e40c346ad2597266403128 diff --git a/src/Nabla.jl b/src/Nabla.jl index fa562d0d..6454faae 100644 --- a/src/Nabla.jl +++ b/src/Nabla.jl @@ -1,9 +1,12 @@ __precompile__() module Nabla - - using SpecialFunctions + using ChainRules + using ChainRulesCore + using ExprTools: ExprTools using LinearAlgebra + using Random + using SpecialFunctions using Statistics # Some aliases used repeatedly throughout the package. @@ -25,6 +28,9 @@ module Nabla end end + # Link up to ChainRulesCore so rules are generated when new rrules are declared. + __init__() = on_new_rule(generate_overload, rrule) + # Meta-programming utilities specific to Nabla. include("code_transformation/util.jl") include("code_transformation/differentiable.jl") @@ -39,6 +45,9 @@ module Nabla # into a separate module at some point. include("finite_differencing.jl") + # Sensitivities via ChainRules + include("sensitivities/chainrules.jl") + # Sensitivities for the basics. include("sensitivities/indexing.jl") include("sensitivities/scalar.jl") diff --git a/src/code_transformation/util.jl b/src/code_transformation/util.jl index 6aa069fd..0b9e6472 100644 --- a/src/code_transformation/util.jl +++ b/src/code_transformation/util.jl @@ -40,6 +40,23 @@ function unionise_type(tp::Union{Symbol, Expr}) return replace_vararg(:(Union{$_tp, Node{<:$tp_clean}}), (_tp, _info)) end +""" + node_type(tp::Union{Symbol, Expr}) + +Returns an expression for the `Node{<:tp}`. e.g. +`node_type(:Real)` returns `:(Node{<:Real}})`. + +Correctly `Varargs{Real}` becomes `:(Varargs{Node{<:Real}})` + +This is a lot like [`unionize_type`](ref) but it doesn't permit the original type anymore. +""" +function node_type(tp::Union{Symbol, Expr}) + (_tp, _info) = remove_vararg(tp) + tp_clean = (isa(_tp, Expr) && _tp.head == Symbol("<:")) ? _tp.args[1] : _tp + return replace_vararg(:(Node{<:$tp_clean}), (_tp, _info)) +end + + """ replace_body(unionall::Union{Symbol, Expr}, replacement::Union{Symbol, Expr}) @@ -91,6 +108,21 @@ function remove_vararg(typ::Expr) if isa_vararg(typ) body = get_body(typ) new_typ = replace_body(typ, body.args[2]) + + # This is a bit ugly: + # handle interally `where N` from `typ = :(Vararg{FOO, N} where N)` which results in + # `body = :(Vararg{FOO, N})` and `new_type = Foo where N`, we don't need to keep it + # at all, the `where N` wasn't doing anything to begin with, so we just strip it out + if Meta.isexpr(new_typ, :where, 2) && Meta.isexpr(body, :curly, 3) + @assert body.args[1] == :Vararg + T = body.args[2] + N = body.args[3] + if new_typ.args == [T, N] + body = :(Vararg{T}) + new_typ = T + end + end + vararg_info = length(body.args) == 3 ? body.args[3] : :Vararg return new_typ, vararg_info else @@ -107,7 +139,7 @@ Convert `typ` to the `Vararg` containing elements of type `typ` specified by replace_vararg(typ::SymOrExpr, vararg_info::Tuple) = vararg_info[2] == :nothing ? typ : - vararg_info[2] == :no_N || vararg_info[2] == :Vararg ? + vararg_info[2] == :no_N || vararg_info[2] == :Vararg ? #TODO: :no_N is impossible now? replace_body(typ, :(Vararg{$(get_body(typ))})) : replace_body(typ, :(Vararg{$(get_body(typ)), $(vararg_info[2])})) diff --git a/src/core.jl b/src/core.jl index 5025c722..44a01604 100644 --- a/src/core.jl +++ b/src/core.jl @@ -77,10 +77,17 @@ struct Branch{T} <: Node{T} kwargs::NamedTuple tape::Tape pos::Int + pullback # if we have a rrule pullback for this it is stored here end function Branch(f, args::Tuple, tape::Tape; kwargs...) unboxed = unbox.(args) - branch = Branch(f(unboxed...; kwargs...), f, args, kwargs.data, tape, length(tape) + 1) + + # We could check for an `rrule` here if we wanted but we don't, + # because we should never reach this point if we have an rrule + primal_val = f(unboxed...; kwargs...) + pullback = nothing + + branch = Branch(primal_val, f, args, kwargs.data, tape, length(tape) + 1, pullback) push!(tape, branch) return branch end @@ -126,16 +133,18 @@ one(n::Node) = one(unbox(n)) @inline propagate(y::Leaf, rvs_tape::Tape) = nothing function propagate(y::Branch, rvs_tape::Tape) tape = Nabla.tape(rvs_tape) - ȳ, f = tape[pos(y)], getfield(y, :f) + ȳ = tape[pos(y)] + f = getfield(y, :f) args = getfield(y, :args) kwargs = getfield(y, :kwargs) - xs, xids = map(unbox, args), map(pos, args) - p = preprocess(f, unbox(y), ȳ, xs...) + xs = map(unbox, args) + xids = map(pos, args) + p = preprocess(f, y, ȳ, args...) # inlining CSE will avoid unboxing twice. for j in eachindex(xs) x, xid = xs[j], xids[j] if xid > 0 tape[xid] = isassigned(tape, xid) ? - ∇(tape[xid], f, Arg{j}, p, unbox(y), ȳ, xs...; kwargs...) : + ∇(tape[xid], f, Arg{j}, p, unbox(y), ȳ, xs...; kwargs...) : # maybe-inplace version ∇(f, Arg{j}, p, unbox(y), ȳ, xs...; kwargs...) end end @@ -172,11 +181,14 @@ computing the gradient of `y` w.r.t. each of the elements in the `Tape`. ∇(f::Function, ::Type{Arg{N}}, p, y, ȳ, x...) - ∇(x̄, f::Function, ::Type{Arg{N}}, p, y, ȳ, x...) To implement a new reverse-mode sensitivity for the `N^{th}` argument of function `f`. p is the output of `preprocess`. `x1`, `x2`,... are the inputs to the function, `y` is its output and `ȳ` the reverse-mode sensitivity of `y`. + +∇(x̄, f::Function, ::Type{Arg{N}}, p, y, ȳ, x...) +This is the optionally inplace version of `∇` that should, if implemented, mutate +x̄ to have the gradient added to it. """ ∇(y::Node, ȳ) = propagate(tape(y), reverse_tape(y, ȳ)) @inline ∇(y::Node{<:∇Scalar}) = ∇(y, one(unbox(y))) diff --git a/src/sensitivities/chainrules.jl b/src/sensitivities/chainrules.jl new file mode 100644 index 00000000..51983bcc --- /dev/null +++ b/src/sensitivities/chainrules.jl @@ -0,0 +1,189 @@ +function generate_overload(sig) + opT, argTs = Iterators.peel(ExprTools.parameters(sig)) + opT <: Core.Builtin && return false # can't do operater overloading for builtins + opT <: Function || return false # not handling non-functions + + fieldcount(opT) == 0 || return false # not handling functors + isempty(argTs) && return false # we are an operator overloading AD, need operands + + + nameof(opT.name.module) == :NaNMath && return false # Don't care about NaNMath + + # Ignore functions that have complex ranges. This may change when Nabla supports complex + # numbers. + opT ∈ typeof.(( + SpecialFunctions.hankelh1, SpecialFunctions.hankelh2, + log1p, rem2pi, mod, atan, rem, + )) && return false + + # Ingore functions because have better Nabla specific version. + opT ∈ typeof.(( + isapprox, size, length, + )) && return false + + + signature_def = build_def(sig) + original_signature_args = signature_def[:args] + signature_def[:args] = unionise_sig.(original_signature_args) + + fdef = quote + @inline $(preprocess_declaration(signature_def)) + @inline $(∇_declaration(signature_def)) + $(overload_declarations!(signature_def, original_signature_args)...) + end + #@show fdef + eval(fdef) + return true +end + +"like `ExprTools.signature` but on a signature type-tuple, not a Method" +function build_def(orig_sig) + sig = _truely_rename_unionall(orig_sig) # TODO ExprTools possibly should do this for `signature(::Method)`` also + def = Dict{Symbol, Any}() + + opT = ExprTools.parameters(sig)[1] + def[:name] = :(op::$opT) + + explicit_tvars = Core.TypeName[]#ExprTools.extract_tvars(sig) + arg_types = ExprTools.name_of_type.(ExprTools.argument_types(sig)) + arg_names = [Symbol(:x, ii) for ii in eachindex(arg_types)] #TODO: should we pass the arg_names in? + def[:args] = Expr.(:(::), arg_names, arg_types) + def[:whereparams] = ExprTools.where_parameters(sig) + + def = Dict{Symbol, Any}(k => v for (k, v) in def if v !== nothing) # filter out nonfields. + + return def +end + +"this overwrites and ruins `signature_def` for others" +function overload_declarations!(signature_def, original_signature_args) + + # Our macro-hygine is not complete here. + # the argument names and `op`, `tape` `args`, `kwargs` etc could conflict with + # where-params. but for sake of outputting readable code we are not gensyming everything + # chance of conflict seems low as where-params are normally upper-case. + @assert(signature_def[:name].head == :(::)) + @assert(signature_def[:name].args[1] == :op) + + + signature_def[:kwargs] = [:(kwargs...)] + signature_def[:body] = quote + #@show op + args = $(_args_tuple(signature_def[:args])) + primal_val, pullback = rrule(op, unbox.(args)...; kwargs...) + tape = get_tape(args) + + branch = Branch(primal_val, op, args, kwargs.data, tape, length(tape) + 1, pullback) + push!(tape, branch) + return branch + end + + # we need to generate a version of this for each place that an arg could be + n_args = length(original_signature_args) + definitions = Expr[] + for swap_mask in Iterators.product(ntuple(_->(true,false), n_args)...) + any(swap_mask) || continue # don't generate if not swapping anything. + signature_def[:args] = map(swap_mask, original_signature_args) do swap, orig_arg + if swap + @assert Meta.isexpr(orig_arg, :(::), 2) + Expr(:(::), orig_arg.args[1], node_type(orig_arg.args[2])) + else + orig_arg + end + end + push!(definitions, ExprTools.combinedef(signature_def)) + end + + return definitions +end + +function preprocess_declaration(signature_def) + # basically want to generate things like: + # `preprocess(f::$opT, y::Branch, ȳ, $((arg_sig)...)) = y.pullback(ȳ)` + # We need the pullback value to use to compute the sensitivies of the inputs + + op = signature_def[:name] + args = signature_def[:args] + y = gensym(:y) + ȳ = gensym(:ȳ) + + # preprocess has a broadly similar definition, signature-wise, to the overload. + # so we copy it to get whereparams etc + preprocess_def = Dict{Symbol, Any}( + :name => :preprocess, + :args => [op, :($y::Branch), ȳ, args...], + :body => quote $y.pullback($ȳ) end, + ) + + where_params = get(signature_def, :whereparams, nothing) + if where_params !== nothing + preprocess_def[:whereparams] = where_params + end + return ExprTools.combinedef(preprocess_def) +end + + +function ∇_declaration(signature_def) + # basically want to generate things like: + # `∇(::$opT, ::Type{Arg{N}}, p, y, ȳ, xs...) where N = p[N+1] # Skip dself` + # We need the pullback value to use to compute the sensitivies of the inputs + + # For readability lets name all the parts, NB: this is being a bit too cute. + op = signature_def[:name] + args = signature_def[:args] + N = gensym(:N) + p = gensym(:p) + y = :(::Any) + ȳ = :(::Any) + + ∇_def = Dict{Symbol, Any}( + :name => :∇, + :args => [op, :(::Type{Arg{$N}}), p, y, ȳ, args...], + :whereparams => [N; get(signature_def, :whereparams, [])], + :body => quote $p[$N+1] end, + ) + return ExprTools.combinedef(∇_def) +end + + +""" + _args_tuple(arg_exprs) + +For `arg_exprs` being a list of arguments expressions from a signature, of a form +such as `[:(x::Int), :(y::Float64), :(z::Vararg)]`, returns a tuple expresion containing all +of them by name; while correctly handling splatting, +e.g for prior example `:((x, y, z...))` +""" +function _args_tuple(arg_exprs) + ret = Expr(:tuple) + ret.args = map(arg_exprs) do arg + @assert Meta.isexpr(arg, :(::), 2) + arg_name, Texpr = arg.args + if Texpr == :Vararg || (Meta.isexpr(Texpr, :curly) && Texpr.args[1] == :Vararg) + return :($arg_name...) + else + return arg_name + end + end + return ret +end + +"like `Base.rename_unionall`, but actually gensyms the name also, not just a new instance" +function _truely_rename_unionall(@nospecialize(u)) + isa(u,UnionAll) || return u + body = _truely_rename_unionall(u.body) + if body === u.body + body = u + else + body = UnionAll(u.var, body) + end + var = u.var::TypeVar + nv = TypeVar(gensym(var.name), var.lb, var.ub) + return UnionAll(nv, body{nv}) +end + + + +# Find a tape, ds might be Nodes or might be something else. +# All nodes should have the same tape, so the first one will do +get_tape(ds) = first(tape(d) for d in ds if d isa Node) diff --git a/src/sensitivities/functional/functional.jl b/src/sensitivities/functional/functional.jl index f72e75db..54af1b47 100644 --- a/src/sensitivities/functional/functional.jl +++ b/src/sensitivities/functional/functional.jl @@ -97,43 +97,3 @@ _∇(::typeof(broadcast), ::Type{Arg{N}}, p, y, ȳ, f, A...) where N = hasmethod(∇, Tuple{typeof(f), Type{Arg{N}}, Any, Any, Any, map(eltype, A)...}) ? broadcastsum((yn, ȳn, xn...)->∇(f, Arg{N}, p, yn, ȳn, xn...), false, A[N], y, ȳ, A...) : broadcastsum((ȳn, xn...)->ȳn * fmad(f, xn, Val{N}), false, A[N], ȳ, A...) - -# Addition. -import Base: + -@eval @explicit_intercepts $(Symbol("+")) Tuple{∇Array, ∇Array} -@inline ∇(::typeof(+), ::Type{Arg{1}}, p, z, z̄, x::∇Array, y::∇Array) = - ∇(broadcast, Arg{2}, p, z, z̄, +, x, y) -@inline ∇(::typeof(+), ::Type{Arg{2}}, p, z, z̄, x::∇Array, y::∇Array) = - ∇(broadcast, Arg{3}, p, z, z̄, +, x, y) - -# Multiplication. -import Base: * -@eval @explicit_intercepts $(Symbol("*")) Tuple{∇ArrayOrScalar, ∇ArrayOrScalar} -@inline ∇(::typeof(*), ::Type{Arg{1}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) = - ∇(broadcast, Arg{2}, p, z, z̄, *, x, y) -@inline ∇(::typeof(*), ::Type{Arg{2}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) = - ∇(broadcast, Arg{3}, p, z, z̄, *, x, y) - -# Subtraction. -import Base: - -@eval @explicit_intercepts $(Symbol("-")) Tuple{∇Array, ∇Array} -@inline ∇(::typeof(-), ::Type{Arg{1}}, p, z, z̄, x::∇Array, y::∇Array) = - ∇(broadcast, Arg{2}, p, z, z̄, -, x, y) -@inline ∇(::typeof(-), ::Type{Arg{2}}, p, z, z̄, x::∇Array, y::∇Array) = - ∇(broadcast, Arg{3}, p, z, z̄, -, x, y) - -# Division from the right by a scalar. -import Base: / -@eval @explicit_intercepts $(Symbol("/")) Tuple{∇Array, ∇Scalar} -@inline ∇(::typeof(/), ::Type{Arg{1}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) = - ∇(broadcast, Arg{2}, p, z, z̄, /, x, y) -@inline ∇(::typeof(/), ::Type{Arg{2}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) = - ∇(broadcast, Arg{3}, p, z, z̄, /, x, y) - -# Division from the left by a scalar. -import Base: \ -@eval @explicit_intercepts $(Symbol("\\")) Tuple{∇Scalar, ∇Array} -@inline ∇(::typeof(\), ::Type{Arg{1}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) = - ∇(broadcast, Arg{2}, p, z, z̄, \, x, y) -@inline ∇(::typeof(\), ::Type{Arg{2}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) = - ∇(broadcast, Arg{3}, p, z, z̄, \, x, y) diff --git a/src/sensitivities/scalar.jl b/src/sensitivities/scalar.jl index 531854b6..cd75b702 100644 --- a/src/sensitivities/scalar.jl +++ b/src/sensitivities/scalar.jl @@ -28,15 +28,15 @@ for (package, f, arity) in diffrules() if arity == 1 push!(unary_sensitivities, (package, f)) ∂f∂x = diffrule(package, f, :x) - @eval @explicit_intercepts $f Tuple{∇Scalar} - @eval @inline ∇(::typeof($f), ::Type{Arg{1}}, p, y, ȳ, x::∇Scalar) = ȳ * $∂f∂x - @eval @inline ∇(::typeof($f), ::Type{Arg{1}}, x::∇Scalar) = $∂f∂x + #@eval @explicit_intercepts $f Tuple{∇Scalar} + #@eval @inline ∇(::typeof($f), ::Type{Arg{1}}, p, y, ȳ, x::∇Scalar) = ȳ * $∂f∂x + #@eval @inline ∇(::typeof($f), ::Type{Arg{1}}, x::∇Scalar) = $∂f∂x elseif arity == 2 push!(binary_sensitivities, (package, f)) ∂f∂x, ∂f∂y = diffrule(package, f, :x, :y) - @eval @explicit_intercepts $f Tuple{∇Scalar, ∇Scalar} - @eval ∇(::typeof($f), ::Type{Arg{1}}, p, z, z̄, x::∇Scalar, y::∇Scalar) = z̄ * $∂f∂x - @eval ∇(::typeof($f), ::Type{Arg{2}}, p, z, z̄, x::∇Scalar, y::∇Scalar) = z̄ * $∂f∂y + #@eval @explicit_intercepts $f Tuple{∇Scalar, ∇Scalar} + #@eval ∇(::typeof($f), ::Type{Arg{1}}, p, z, z̄, x::∇Scalar, y::∇Scalar) = z̄ * $∂f∂x + #@eval ∇(::typeof($f), ::Type{Arg{2}}, p, z, z̄, x::∇Scalar, y::∇Scalar) = z̄ * $∂f∂y else error("Cannot implement sensitivity for $package.$f: arity $arity not supported.") end diff --git a/src/sensitivity.jl b/src/sensitivity.jl index 0652e725..f7bd02fd 100644 --- a/src/sensitivity.jl +++ b/src/sensitivity.jl @@ -204,10 +204,18 @@ function tape_expr(x::Tuple, syms::NTuple{N, Symbol} where N, is_node::Vector{Bo end """ - preprocess(::Function, args...) + preprocess(f, y, ȳ, xs...) = () Default implementation of preprocess returns an empty Tuple. Individual sensitivity implementations should add methods specific to their use case. The output is passed in to `∇` as the 3rd or 4th argument in the new-x̄ and update-x̄ cases respectively. + +`preprocess` is invoked with `y` and `xs` still boxed. +The default implementation just calls `unbox` on them then calls `preprocess` on the unboxed +values. +If for preprocessing you need the boxed values you should overload +`preprocess(f, y::Node, ȳ, xs...)`. +If you need them unboxed, then overloading `preprocess(f, y, ȳ, xs...)` is fine. """ -@inline preprocess(::Any, args...) = () +@inline preprocess(f, y, ȳ, xs...) = () +@inline preprocess(f, y::Node, ȳ, xs...) = preprocess(f, unbox(y), ȳ, map(unbox, xs)...) diff --git a/test/code_transformation/differentiable.jl b/test/code_transformation/differentiable.jl index f899653c..7f298dcb 100644 --- a/test/code_transformation/differentiable.jl +++ b/test/code_transformation/differentiable.jl @@ -63,6 +63,12 @@ skip_line_info(ex) = ex @test unionise_sig(:(foo(x::T))) == :(foo($(unionise_arg(:(x::T))))) @test unionise_sig(:(foo(x::T) where T)) == :(foo($(unionise_arg(:(x::T)))) where T) + @test isequal( # special case for a redudant where N in a Vararg + Nabla.unionise_sig(:(x2::(Vararg{Int64, N} where N))), + :(x2::Vararg{Union{Int64, Node{<:Int64}}}), + ) + + # Test Nabla.unionise_struct. Written in terms of Nabla.unionise_arg. @test unionise_struct(:(struct Foo end)) == :(struct Foo end) @test unionise_struct(:(struct Foo{T} end)) == diff --git a/test/code_transformation/util.jl b/test/code_transformation/util.jl index b88799e2..5d54b06e 100644 --- a/test/code_transformation/util.jl +++ b/test/code_transformation/util.jl @@ -75,4 +75,9 @@ @test Nabla.parse_is_node(:([true, false])) == [true, false] @test_throws ArgumentError Nabla.parse_is_node(:((true, false))) end + + @testset "node_type" begin + @test Nabla.node_type(:(Vararg{Int64, N} where N)) == :(Vararg{Node{<:Int64}}) + @test Nabla.node_type(:Float32) == :(Node{<:Float32}) + end end diff --git a/test/scratch.jl b/test/scratch.jl new file mode 100644 index 00000000..64944c48 --- /dev/null +++ b/test/scratch.jl @@ -0,0 +1,3 @@ +using Nabla, Test, Random, ChainRulesCore, SpecialFunctions, LinearAlgebra +using DiffRules: diffrule, hasdiffrule +using Nabla: unbox diff --git a/test/sensitivities/scalar.jl b/test/sensitivities/scalar.jl index b1f6d0a5..af34f3d0 100644 --- a/test/sensitivities/scalar.jl +++ b/test/sensitivities/scalar.jl @@ -26,7 +26,7 @@ end end unary_check(f, x) = check_errs(eval(f), ȳ, x, v) - for (package, f) in Nabla.unary_sensitivities + @testset "$package.$f" for (package, f) in Nabla.unary_sensitivities domain = domain1(eval(f)) domain === nothing && error("Could not determine domain for $f.") lb, ub = domain @@ -37,7 +37,7 @@ end end end - for (package, f) in Nabla.binary_sensitivities + @testset "$package.$f" for (package, f) in Nabla.binary_sensitivities # This is a hack. Sensitivities added in Nabla don't persist upon reloading the # package, so we can't query them here. It happens to be the case that all such From 563b0a6d7ecc6f90c8e385b748f84ebe709eaaf8 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 14 Oct 2020 14:34:14 +0100 Subject: [PATCH 03/71] Make Branch parametric on the pullback type Remove attempt to remove ambiguities in preprocess --- src/core.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/core.jl b/src/core.jl index 44a01604..155aa024 100644 --- a/src/core.jl +++ b/src/core.jl @@ -61,23 +61,27 @@ show(io::IO, tape::Leaf{T}) where T = print(io, "Leaf{$T} $(unbox(tape))") show(io::IO, tape::Leaf{T}) where T<:AbstractArray = print(io, "Leaf{$T} $(size(unbox(tape)))") """ -A Branch is a Node with parents (args). +A `Branch` is a Node with parents (args). Fields: -val - the value of this node produced in the forward pass. +val::T - the value of this node produced in the forward pass. f - the function used to generate this Node. args - Values indicating which elements in the tape will require updating by this node. tape - The Tape to which this Branch is assigned. pos - the location of this Branch in the tape to which it is assigned. +pullback::B - if there is a custom primate rule (a `ChainRulesCore.rrule`) then this holds + the pullback to propagates gradients back through the operation, if there is not a rule + then this is set to `nothing`. + It also maybe set to `nothing` by legacy Nabla rules that have not moved to ChainRules. """ -struct Branch{T} <: Node{T} +struct Branch{T, B} <: Node{T} val::T f args::Tuple kwargs::NamedTuple tape::Tape pos::Int - pullback # if we have a rrule pullback for this it is stored here + pullback::B # if we have a rrule pullback for this it is stored here end function Branch(f, args::Tuple, tape::Tape; kwargs...) unboxed = unbox.(args) From 5db5bd3738782257a4caa1dec0a9390dc1e6d695 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 14 Oct 2020 15:38:53 +0100 Subject: [PATCH 04/71] Don't by pass the tape abstraction in propagate fix type in comment --- src/core.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/core.jl b/src/core.jl index 155aa024..e3753c2e 100644 --- a/src/core.jl +++ b/src/core.jl @@ -136,8 +136,8 @@ one(n::Node) = one(unbox(n)) # Leafs do nothing, Branches compute their own sensitivities and update others. @inline propagate(y::Leaf, rvs_tape::Tape) = nothing function propagate(y::Branch, rvs_tape::Tape) - tape = Nabla.tape(rvs_tape) - ȳ = tape[pos(y)] + ȳ = rvs_tape[y] # the gradient we are going to propagate through the operation in y + d_tape = Nabla.tape(rvs_tape) # strips off the Tape abstration leaving a plain Vector f = getfield(y, :f) args = getfield(y, :args) kwargs = getfield(y, :kwargs) @@ -147,8 +147,8 @@ function propagate(y::Branch, rvs_tape::Tape) for j in eachindex(xs) x, xid = xs[j], xids[j] if xid > 0 - tape[xid] = isassigned(tape, xid) ? - ∇(tape[xid], f, Arg{j}, p, unbox(y), ȳ, xs...; kwargs...) : # maybe-inplace version + d_tape[xid] = isassigned(d_tape, xid) ? + ∇(d_tape[xid], f, Arg{j}, p, unbox(y), ȳ, xs...; kwargs...) : # maybe-inplace version ∇(f, Arg{j}, p, unbox(y), ȳ, xs...; kwargs...) end end From c1d383a03a9eceed20748deb56edbe548d71f6c5 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 7 Sep 2020 13:31:30 +0100 Subject: [PATCH 05/71] use update! to handle InplaceableThunks use add!! instead of update! --- src/core.jl | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/src/core.jl b/src/core.jl index e3753c2e..51199b54 100644 --- a/src/core.jl +++ b/src/core.jl @@ -197,22 +197,10 @@ x̄ to have the gradient added to it. ∇(y::Node, ȳ) = propagate(tape(y), reverse_tape(y, ȳ)) @inline ∇(y::Node{<:∇Scalar}) = ∇(y, one(unbox(y))) -# This is a fallback method where we don't necessarily know what we'll be adding and whether -# we can update the value in-place, so we'll try to be clever and dispatch. @inline function ∇(x̄, f, ::Type{Arg{N}}, args...; kwargs...) where N - return update!(x̄, ∇(f, Arg{N}, args...; kwargs...)) + return ChainRulesCore.add!!(x̄, ∇(f, Arg{N}, args...; kwargs...)) end -# Update regular arrays in-place. Structured array types should not be updated in-place, -# even though it technically "works" (https://github.com/JuliaLang/julia/issues/31674), -# so we'll only permit mutating addition for `Array`s, e.g. `Vector` and `Matrix`. -# Mixed array and scalar adds should not occur, as sensitivities should always have the -# same shape, so we won't bother allowing e.g. updating an array with a scalar on the RHS. -update!(x̄::Array{T,N}, y::AbstractArray{S,N}) where {T,S,N} = x̄ .+= y - -# Fall back to using regular addition -update!(x̄, y) = x̄ + y - """ ∇(f; get_output::Bool=false) From 23e9a64767db07b68ced83bfcad9cd3a582b866e Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 17 Sep 2020 13:33:43 +0100 Subject: [PATCH 06/71] unthunk public API --- src/core.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/core.jl b/src/core.jl index 51199b54..8c75f63d 100644 --- a/src/core.jl +++ b/src/core.jl @@ -220,7 +220,8 @@ function ∇(f; get_output::Bool=false) else ∇args = zero.(args) end - return get_output ? (y, ∇args) : ∇args + ∇args_public = map(unthunk, ∇args) + return get_output ? (y, ∇args_public) : ∇args_public end end From 69edf480b26a23c83b7d178badbd34e380957144 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 17 Sep 2020 14:59:48 +0100 Subject: [PATCH 07/71] Remove all the array rules that are unneeded. Move special length and size to core.jl --- src/Nabla.jl | 1 - src/core.jl | 4 ++++ src/sensitivities/array.jl | 43 -------------------------------------- 3 files changed, 4 insertions(+), 44 deletions(-) delete mode 100644 src/sensitivities/array.jl diff --git a/src/Nabla.jl b/src/Nabla.jl index 6454faae..043afc34 100644 --- a/src/Nabla.jl +++ b/src/Nabla.jl @@ -51,7 +51,6 @@ module Nabla # Sensitivities for the basics. include("sensitivities/indexing.jl") include("sensitivities/scalar.jl") - include("sensitivities/array.jl") # Sensitivities for functionals. include("sensitivities/functional/functional.jl") diff --git a/src/core.jl b/src/core.jl index 8c75f63d..434d0d7c 100644 --- a/src/core.jl +++ b/src/core.jl @@ -133,6 +133,10 @@ isapprox(n::Node, f::Node) = unbox(n) ≈ unbox(f) zero(n::Node) = zero(unbox(n)) one(n::Node) = one(unbox(n)) +# Let the user get the `size` and `length` of `Node`s. +Base.size(x::Node, dims...) = size(unbox(x), dims...) +Base.length(x::Node) = length(unbox(x)) + # Leafs do nothing, Branches compute their own sensitivities and update others. @inline propagate(y::Leaf, rvs_tape::Tape) = nothing function propagate(y::Branch, rvs_tape::Tape) diff --git a/src/sensitivities/array.jl b/src/sensitivities/array.jl deleted file mode 100644 index 5c68b540..00000000 --- a/src/sensitivities/array.jl +++ /dev/null @@ -1,43 +0,0 @@ -import Base: size, length, reshape, hcat, vcat, fill - -# Let the user get the `size` and `length` of `Node`s. -Base.size(x::Node, dims...) = size(unbox(x), dims...) -Base.length(x::Node) = length(unbox(x)) - -# Sensitivity for the first argument of `reshape`. -@explicit_intercepts reshape Tuple{∇Array, Vararg{Int}} [true, false] -@explicit_intercepts reshape Tuple{∇Array, Tuple{Vararg{Int}}} [true, false] -∇(::typeof(reshape), ::Type{Arg{1}}, _, y, ȳ, A::∇Array, args...) = - reshape(ȳ, size(A)...) - -@union_intercepts hcat Tuple{Vararg{∇Array}} Tuple{Vararg{AbstractArray}} -function Nabla.∇( - ::typeof(hcat), - ::Type{Arg{i}}, - _, - y, - ȳ, - A::AbstractArray... -) where i - l = sum([size(A[j], 2) for j in 1:(i - 1)]) - u = l + size(A[i], 2) - # Using copy materializes the views returned by selectdim - return copy(u > l + 1 ? selectdim(ȳ, 2, (l+1):u) : selectdim(ȳ, 2, u)) -end - -@union_intercepts vcat Tuple{Vararg{∇Array}} Tuple{Vararg{AbstractArray}} -function Nabla.∇( - ::typeof(vcat), - ::Type{Arg{i}}, - _, - y, - ȳ, - A::AbstractArray... -) where i - l = sum([size(A[j], 1) for j in 1:(i - 1)]) - u = l + size(A[i], 1) - return copy(selectdim(ȳ, 1, (l+1):u)) -end - -@explicit_intercepts fill Tuple{Any, Tuple{Vararg{Integer}}} [true, false] -∇(::typeof(fill), ::Type{Arg{1}}, p, y, ȳ, value, dims...) = sum(ȳ) From fb45de7718b7ae43b2fdd9e20a9c2a407b73bbd6 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 18 Sep 2020 13:26:57 +0100 Subject: [PATCH 08/71] For now keep diffrules for special nosensitivity scalar case so as not to break higher order functions Handle Higher Order functions, including changing to ForwardDiff --- Project.toml | 10 ++--- src/Nabla.jl | 1 + src/core.jl | 18 ++++----- src/sensitivities/functional/functional.jl | 20 ++++++++-- src/sensitivities/functional/reduce.jl | 7 ++-- src/sensitivities/functional/reducedim.jl | 44 +++------------------ src/sensitivities/linalg/generic.jl | 3 +- test/runtests.jl | 12 +++++- test/sensitivities/functional/functional.jl | 36 +++++++---------- test/sensitivities/functional/reduce.jl | 8 ++-- 10 files changed, 71 insertions(+), 88 deletions(-) diff --git a/Project.toml b/Project.toml index 9759c90e..44abe880 100644 --- a/Project.toml +++ b/Project.toml @@ -6,20 +6,20 @@ version = "0.12.3" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" -DualNumbers = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -DiffRules = "0.0, 1" -DualNumbers = "0.6" +DiffRules = "^0.0" FDM = "^0.6" -SpecialFunctions = "0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 1" -julia = "^1.3" +ForwardDiff = "0.10.12" +SpecialFunctions = ">=0.5.0" +julia = "^1.0" [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" diff --git a/src/Nabla.jl b/src/Nabla.jl index 043afc34..0deb92a5 100644 --- a/src/Nabla.jl +++ b/src/Nabla.jl @@ -4,6 +4,7 @@ module Nabla using ChainRules using ChainRulesCore using ExprTools: ExprTools + using ForwardDiff: ForwardDiff using LinearAlgebra using Random using SpecialFunctions diff --git a/src/core.jl b/src/core.jl index 434d0d7c..e0724d33 100644 --- a/src/core.jl +++ b/src/core.jl @@ -1,5 +1,3 @@ -using DualNumbers - import Base: push!, length, show, getindex, setindex!, eachindex, isassigned, isapprox, zero, one, lastindex @@ -26,7 +24,7 @@ function show(io::IO, t::Tape) end end end -@inline getindex(t::Tape, n::Int) = getindex(tape(t), n) +@inline getindex(t::Tape, n::Int) = unthunk(getindex(tape(t), n)) @inline getindex(t::Tape, node::Node) = getindex(t, pos(node)) @inline lastindex(t::Tape) = length(t) @inline setindex!(t::Tape, x, n::Int) = (tape(t)[n] = x; t) @@ -278,18 +276,20 @@ for T in (:Diagonal, :UpperTriangular, :LowerTriangular) @eval @inline randned_container(x::$T{<:Real}) = $T(randn(eltype(x), size(x)...)) end -# Bare-bones FMAD implementation based on DualNumbers. Accepts a Tuple of args and returns -# a Tuple of gradients. Currently scales almost exactly linearly with the number of inputs. -# The coefficient of this scaling could be improved by implementing a version of DualNumbers -# which computes from multiple seeds at the same time. +# Bare-bones FMAD implementation based on internals of ForwardDiff. +# Accepts a Tuple of args and returns a Tuple of gradients. +# Currently scales almost exactly linearly with the number of inputs. +# The coefficient of this scaling could be improved by fully utilizing ForwardDiff +# and computing from multiple seeds at the same time. function dual_call_expr(f, x::Type{<:Tuple}, ::Type{Type{Val{n}}}) where n dual_call = Expr(:call, :f) for m in 1:Base.length(x.parameters) - push!(dual_call.args, n == m ? :(Dual(x[$m], 1)) : :(x[$m])) + push!(dual_call.args, n == m ? :(ForwardDiff.Dual(x[$m], 1)) : :(x[$m])) end - return :(dualpart($dual_call)) + return :(first(ForwardDiff.partials($dual_call))) end @generated fmad(f, x, n) = dual_call_expr(f, x, n) + function fmad_expr(f, x::Type{<:Tuple}) body = Expr(:tuple) for n in 1:Base.length(x.parameters) diff --git a/src/sensitivities/functional/functional.jl b/src/sensitivities/functional/functional.jl index 54af1b47..950cc8e5 100644 --- a/src/sensitivities/functional/functional.jl +++ b/src/sensitivities/functional/functional.jl @@ -21,8 +21,6 @@ end ∇(::typeof(map), ::Type{Arg{N}}, p, y, ȳ, f::Function, A::∇Array...) where N = _∇(map, Arg{N-1}, p, y, ȳ, f, A...) _∇(::typeof(map), arg::Type{Arg{N}}, p, y, ȳ, f::Function, A::∇Array...) where N = - hasmethod(∇, Tuple{typeof(f), Type{Arg{N}}, Any, Any, Any, map(eltype, A)...}) ? - map((yn, ȳn, An...)->∇(f, Arg{N}, p, yn, ȳn, An...), y, ȳ, A...) : map((ȳn, An...)->ȳn * fmad(f, An, Val{N}), ȳ, A...) # Implementation of sensitivities w.r.t. `broadcast`. @@ -94,6 +92,20 @@ broadcastsum(f, add::Bool, z::Ref{<:Number}, As...) = broadcastsum(f, add, z[], ∇(::typeof(broadcast), ::Type{Arg{N}}, p, y, ȳ, f, A...) where N = _∇(broadcast, Arg{N-1}, p, y, ȳ, f, A...) _∇(::typeof(broadcast), ::Type{Arg{N}}, p, y, ȳ, f, A...) where N = - hasmethod(∇, Tuple{typeof(f), Type{Arg{N}}, Any, Any, Any, map(eltype, A)...}) ? - broadcastsum((yn, ȳn, xn...)->∇(f, Arg{N}, p, yn, ȳn, xn...), false, A[N], y, ȳ, A...) : broadcastsum((ȳn, xn...)->ȳn * fmad(f, xn, Val{N}), false, A[N], ȳ, A...) + +# Division from the right by a scalar. +import Base: / +@eval @explicit_intercepts $(Symbol("/")) Tuple{∇Array, ∇Scalar} +@inline ∇(::typeof(/), ::Type{Arg{1}}, p, z, z̄, x::∇Scalar, y::∇Array) = + ∇(broadcast, Arg{2}, p, z, z̄, /, x, y) +@inline ∇(::typeof(/), ::Type{Arg{2}}, p, z, z̄, x::∇Scalar, y::∇Array) = + ∇(broadcast, Arg{3}, p, z, z̄, /, x, y) + +# Division from the left by a scalar. +import Base: \ +@eval @explicit_intercepts $(Symbol("\\")) Tuple{∇Scalar, ∇Array} +@inline ∇(::typeof(\), ::Type{Arg{1}}, p, z, z̄, x::∇Array, y::∇Scalar) = + ∇(broadcast, Arg{2}, p, z, z̄, \, x, y) +@inline ∇(::typeof(\), ::Type{Arg{2}}, p, z, z̄, x::∇Array, y::∇Scalar) = + ∇(broadcast, Arg{3}, p, z, z̄, \, x, y) diff --git a/src/sensitivities/functional/reduce.jl b/src/sensitivities/functional/reduce.jl index bc92d393..8942a616 100644 --- a/src/sensitivities/functional/reduce.jl +++ b/src/sensitivities/functional/reduce.jl @@ -5,9 +5,8 @@ for f in (:mapfoldl, :mapfoldr) @eval begin import Base: $f @explicit_intercepts $f $type_tuple [false, false, true] #(init=0,) - ∇(::typeof($f), ::Type{Arg{3}}, p, y, ȳ, f, ::$plustype, A::∇ArrayOrScalar) = - hasmethod(∇, Tuple{typeof(f), Type{Arg{1}}, Real}) ? - broadcast(An->ȳ * ∇(f, Arg{1}, An), A) : - broadcast(An->ȳ * fmad(f, (An,), Val{1}), A) + function ∇(::typeof($f), ::Type{Arg{3}}, p, y, ȳ, f, ::$plustype, A::∇ArrayOrScalar) + return ȳ .* ForwardDiff.derivative.(f, A) + end end end diff --git a/src/sensitivities/functional/reducedim.jl b/src/sensitivities/functional/reducedim.jl index 0c55d3e0..aae932d7 100644 --- a/src/sensitivities/functional/reducedim.jl +++ b/src/sensitivities/functional/reducedim.jl @@ -16,9 +16,7 @@ import Statistics: mean A::AbstractArray{<:∇Scalar}; dims=:, ) - hasmethod(∇, Tuple{typeof(f), Type{Arg{1}}, ∇Scalar}) ? - broadcast((An, ȳn)->ȳn * ∇(f, Arg{1}, An), A, ȳ) : - broadcast((An, ȳn)->ȳn * fmad(f, (An,), Val{1}), A, ȳ) + return broadcast((An, ȳn)->ȳn * ForwardDiff.derivative(f, An), A, ȳ) end end @@ -28,12 +26,6 @@ end [false, true], (dims=:,), ) -@explicit_intercepts( - sum, - Tuple{AbstractArray{<:∇Scalar}}, - [true], - (dims=:,), -) function ∇( ::typeof(sum), ::Type{Arg{2}}, @@ -45,26 +37,17 @@ function ∇( # Just pass through to mapreduce return ∇(mapreduce, Arg{3}, p, y, ȳ, f, Base.add_sum, A; dims=dims) end -function ∇( - ::typeof(sum), - ::Type{Arg{1}}, - p, y, ȳ, - A::AbstractArray{<:∇Scalar}; - dims=:, -) - # Again pass through to mapreduce, using identity as the mapped function - return ∇(mapreduce, Arg{3}, p, y, ȳ, identity, Base.add_sum, A; dims=dims) -end -# Specialize on sum(abs2, x) as it's a common pattern with a simple derivative +# sum(abs2, xs) is in ChainRules, but it results in method ambiguties with the +# version that accepts any function above function ∇( ::typeof(sum), ::Type{Arg{2}}, - p, y, ȳ, + p, y, ȳ, ::typeof(abs2), - A::AbstractArray{<:∇Scalar}; + A::AbstractArray{<:Real}; dims=:, ) - return 2ȳ .* A + return 2ȳ .* A end @explicit_intercepts( @@ -73,12 +56,6 @@ end [false, true], #(dims=:,) # https://github.com/JuliaLang/julia/issues/31412 ) -@explicit_intercepts( - mean, - Tuple{AbstractArray{<:∇Scalar}}, - [true], - (dims=:,) -) _denom(x, dims::Colon) = length(x) _denom(x, dims::Integer) = size(x, dims) @@ -93,12 +70,3 @@ function ∇( ) return ∇(sum, Arg{2}, p, y, ȳ, f, x; dims=:) / _denom(x, :) end -function ∇( - ::typeof(mean), - ::Type{Arg{1}}, - p, y, ȳ, - x::AbstractArray{<:∇Scalar}; - dims=:, -) - return ∇(sum, Arg{1}, p, y, ȳ, x; dims=dims) ./ _denom(x, dims) -end diff --git a/src/sensitivities/linalg/generic.jl b/src/sensitivities/linalg/generic.jl index a600a0d2..59c90bd5 100644 --- a/src/sensitivities/linalg/generic.jl +++ b/src/sensitivities/linalg/generic.jl @@ -1,6 +1,6 @@ # Implementation of sensitivities for unary linalg optimisations. _ϵ, lb, ub = 3e-2, -3.0, 3.0 -unary_linalg_optimisations = [ +unary_linalg_optimisations = [#== (:-, ∇Array, ∇Array, :(-Ȳ), (lb, ub)), (:tr, ∇Array, ∇Scalar, :(Diagonal(fill!(similar(X), Ȳ))), (lb, ub)), (:inv, ∇Array, ∇Array, :(-transpose(Y) * Ȳ * transpose(Y)), (lb, ub)), @@ -11,6 +11,7 @@ unary_linalg_optimisations = [ (:adjoint, ∇Array, ∇Array, :(adjoint(Ȳ)), (lb, ub)), (:norm, ∇Array, ∇Scalar, :(Ȳ ./ Y .* abs2.(X) ./ X), (lb, ub)), (:norm, ∇Scalar, ∇Scalar, :(Ȳ * sign(X)), (lb, ub)) + ==# ] for (f, T_In, T_Out, X̄, bounds) in unary_linalg_optimisations if f === :- diff --git a/test/runtests.jl b/test/runtests.jl index 92e57405..c79880ce 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ using Nabla -using Test, LinearAlgebra, Statistics, Random -using Distributions, BenchmarkTools, SpecialFunctions, DualNumbers +using Test, LinearAlgebra, Statistics, Random, ForwardDiff +using Distributions, BenchmarkTools, SpecialFunctions using Nabla: unbox, pos, tape, oneslike, zeroslike @@ -8,6 +8,14 @@ using Nabla: unbox, pos, tape, oneslike, zeroslike ref_equal(a::Ref{T}, b::Ref{T}) where {T} = a[] == b[] ref_equal(a::Ref, b::Ref) = false +# for comparing against scalar rules +derivative_via_frule(f, x) = last(Nabla.frule((Nabla.NO_FIELDS, 1.0), f, x)) +# Sensiblity checkes that his is defined right +@test derivative_via_frule(cos, 0) == 0 +@test derivative_via_frule(sin, 0) == 1 +@test derivative_via_frule(sin, 1.2) == derivative_via_frule(sin, 2π + 1.2) + + @testset "Nabla.jl" begin @testset "Core" begin diff --git a/test/sensitivities/functional/functional.jl b/test/sensitivities/functional/functional.jl index af4bf95b..4465cdb9 100644 --- a/test/sensitivities/functional/functional.jl +++ b/test/sensitivities/functional/functional.jl @@ -18,9 +18,9 @@ using DiffRules: diffrule, hasdiffrule function check_unary_broadcast(f, x) x_ = Leaf(Tape(), x) s = broadcast(f, x_) - return ∇(s, oneslike(unbox(s)))[x_] ≈ ∇.(f, Arg{1}, x) + return ∇(s, oneslike(unbox(s)))[x_] ≈ derivative_via_frule.(f, x) end - for (package, f) in Nabla.unary_sensitivities + @testset "$package.$f" for (package, f) in Nabla.unary_sensitivities domain = domain1(eval(f)) domain === nothing && error("Could not determine domain for $f.") x_dist = Uniform(domain...) @@ -36,13 +36,11 @@ using DiffRules: diffrule, hasdiffrule s = broadcast(f, x_, y_) o = oneslike(unbox(s)) ∇s = ∇(s, o) - ∇x = broadcast((z, z̄, x, y)->∇(f, Arg{1}, nothing, z, z̄, x, y), - unbox(s), o, x, y) - ∇y = broadcast((z, z̄, x, y)->∇(f, Arg{2}, nothing, z, z̄, x, y), - unbox(s), o, x, y) +# ∇x = sum(broadcast((z, z̄, x, y)->∇(f, Arg{1}, nothing, z, z̄, x, y), unbox(s), o, x, y)) +# ∇y = broadcast((z, z̄, x, y)->∇(f, Arg{2}, nothing, z, z̄, x, y), unbox(s), o, x, y) @test broadcast(f, x, y) == unbox(s) - @test ∇s[x_] ≈ ∇x - @test ∇s[y_] ≈ ∇y +# @test ∇s[x_] ≈ ∇x +# @test ∇s[y_] ≈ ∇y end function check_binary_broadcast(f, x::Real, y) tape = Tape() @@ -50,13 +48,11 @@ using DiffRules: diffrule, hasdiffrule s = broadcast(f, x_, y_) o = oneslike(unbox(s)) ∇s = ∇(s, o) - ∇x = sum(broadcast((z, z̄, x, y)->∇(f, Arg{1}, nothing, z, z̄, x, y), - unbox(s), o, x, y)) - ∇y = broadcast((z, z̄, x, y)->∇(f, Arg{2}, nothing, z, z̄, x, y), - unbox(s), o, x, y) +# ∇x = sum(broadcast((z, z̄, x, y)->∇(f, Arg{1}, nothing, z, z̄, x, y), unbox(s), o, x, y)) +# ∇y = broadcast((z, z̄, x, y)->∇(f, Arg{2}, nothing, z, z̄, x, y), unbox(s), o, x, y) @test broadcast(f, x, y) == unbox(s) - @test ∇s[x_] ≈ ∇x - @test ∇s[y_] ≈ ∇y +# @test ∇s[x_] ≈ ∇x +# @test ∇s[y_] ≈ ∇y end function check_binary_broadcast(f, x, y::Real) tape = Tape() @@ -64,15 +60,13 @@ using DiffRules: diffrule, hasdiffrule s = broadcast(f, x_, y_) o = oneslike(unbox(s)) ∇s = ∇(s, o) - ∇x = broadcast((z, z̄, x, y)->∇(f, Arg{1}, nothing, z, z̄, x, y), - unbox(s), o, x, y) - ∇y = sum(broadcast((z, z̄, x, y)->∇(f, Arg{2}, nothing, z, z̄, x, y), - unbox(s), o, x, y)) +# ∇x = sum(broadcast((z, z̄, x, y)->∇(f, Arg{1}, nothing, z, z̄, x, y), unbox(s), o, x, y)) +# ∇y = broadcast((z, z̄, x, y)->∇(f, Arg{2}, nothing, z, z̄, x, y), unbox(s), o, x, y) @test broadcast(f, x, y) == unbox(s) - @test ∇s[x_] ≈ ∇x - @test ∇s[y_] ≈ ∇y +# @test ∇s[x_] ≈ ∇x +# @test ∇s[y_] ≈ ∇y end - for (package, f) in Nabla.binary_sensitivities + @testset "$package.$f" for (package, f) in Nabla.binary_sensitivities # TODO: More care needs to be taken to test the following. if hasdiffrule(package, f, 2) diff --git a/test/sensitivities/functional/reduce.jl b/test/sensitivities/functional/reduce.jl index 6c8a55a0..291ea8ee 100644 --- a/test/sensitivities/functional/reduce.jl +++ b/test/sensitivities/functional/reduce.jl @@ -20,7 +20,7 @@ # Test +. x_ = Leaf(Tape(), x) s = functional(f, +, x_) - @test ∇(s)[x_] ≈ ∇.(f, Arg{1}, x) + @test ∇(s)[x_] ≈ derivative_via_frule.(f, x) end # Some composite sensitivities. @@ -34,7 +34,7 @@ x_ = Leaf(Tape(), x) s = functional(f, +, x_) @test unbox(s) ≈ functional(f, +, x) - @test ∇(s)[x_] ≈ map(x->fmad(f, (x,), Val{1}), x) + @test ∇(s)[x_] ≈ map(xn->ForwardDiff.derivative(f, xn), x) end end end @@ -74,7 +74,7 @@ x_ = Leaf(Tape(), x) s = sum(f, x_) @test unbox(s) == sum(f, x) - @test ∇(s)[x_] ≈ ∇.(f, Arg{1}, x) + @test ∇(s)[x_] ≈ derivative_via_frule.(f, x) end # Some composite functions. @@ -88,7 +88,7 @@ x_ = Leaf(Tape(), x) s = sum(f, x_) @test unbox(s) == sum(f, x) - @test ∇(s)[x_] ≈ map(x->fmad(f, (x,), Val{1}), x) + @test ∇(s)[x_] ≈ map(xn->ForwardDiff.derivative(f, xn), x) end end end From 5369e2d7482b30a639e51897b911195f1e5e7a61 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 18 Sep 2020 19:14:04 +0100 Subject: [PATCH 09/71] Correct tests re #191 --- test/sensitivities/functional/reducedim.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/sensitivities/functional/reducedim.jl b/test/sensitivities/functional/reducedim.jl index bf9cffc4..a0c1f919 100644 --- a/test/sensitivities/functional/reducedim.jl +++ b/test/sensitivities/functional/reducedim.jl @@ -54,8 +54,8 @@ randn(rng, 10, 10, 10), randn(rng, 10, 10, 10)) # Issue #123 - x6_ = collect(1:10) - tens = (fill(10.0, (10,)), fill(10.0, (10, 1))) + x6_ = float.(1:10) + tens = (fill(10.0, (10,)), fill(10.0, (10,))) @test ∇(x->sum(sum(x, dims=2)))(x6_) == (oneslike(x6_),) @test ∇((x, y)->sum(sum(x, dims=2) .+ sum(y, dims=2)'))(x6_, x6_) == tens @test ∇((x, y)->sum(x .+ y'))(x6_, x6_) == tens From 7b30b1c95f93a7aed1401d47ad2c3486fa4e2974 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 30 Sep 2020 16:27:22 +0100 Subject: [PATCH 10/71] fix scalar-array \ and / tests --- test/sensitivities/functional/functional.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/sensitivities/functional/functional.jl b/test/sensitivities/functional/functional.jl index 4465cdb9..6fac7646 100644 --- a/test/sensitivities/functional/functional.jl +++ b/test/sensitivities/functional/functional.jl @@ -134,7 +134,7 @@ using DiffRules: diffrule, hasdiffrule z2_ = broadcast(/, x_, y_) @test unbox(z_) == x ./ y @test ∇(z_, oneslike(unbox(z_)))[x_] == ∇(z2_, oneslike(unbox(z2_)))[x_] - @test ∇(z_, oneslike(unbox(z_)))[y_] == ∇(z2_, oneslike(unbox(z2_)))[y_] + @test ∇(z_, oneslike(unbox(z_)))[y_] ≈ ∇(z2_, oneslike(unbox(z2_)))[y_] end let x, y, tape = 5.0, randn(rng, 5), Tape() @@ -142,7 +142,7 @@ using DiffRules: diffrule, hasdiffrule z_ = x_ \ y_ z2_ = broadcast(\, x_, y_) @test unbox(z_) == x .\ y - @test ∇(z_, oneslike(unbox(z_)))[x_] == ∇(z2_, oneslike(unbox(z2_)))[x_] + @test ∇(z_, oneslike(unbox(z_)))[x_] ≈ ∇(z2_, oneslike(unbox(z2_)))[x_] @test ∇(z_, oneslike(unbox(z_)))[y_] == ∇(z2_, oneslike(unbox(z2_)))[y_] end From 29ff45dbe900582c81b74c08e97d00c29de5da48 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 30 Sep 2020 16:30:47 +0100 Subject: [PATCH 11/71] Delete generic linear algebra that has moved to ChainRules --- src/sensitivities/linalg/generic.jl | 61 +---------------------------- 1 file changed, 1 insertion(+), 60 deletions(-) diff --git a/src/sensitivities/linalg/generic.jl b/src/sensitivities/linalg/generic.jl index 59c90bd5..433c06a3 100644 --- a/src/sensitivities/linalg/generic.jl +++ b/src/sensitivities/linalg/generic.jl @@ -1,61 +1,9 @@ # Implementation of sensitivities for unary linalg optimisations. -_ϵ, lb, ub = 3e-2, -3.0, 3.0 -unary_linalg_optimisations = [#== - (:-, ∇Array, ∇Array, :(-Ȳ), (lb, ub)), - (:tr, ∇Array, ∇Scalar, :(Diagonal(fill!(similar(X), Ȳ))), (lb, ub)), - (:inv, ∇Array, ∇Array, :(-transpose(Y) * Ȳ * transpose(Y)), (lb, ub)), - (:det, ∇Array, ∇Scalar, :(Y * Ȳ * transpose(inv(X))), (_ϵ, ub)), - (:logdet, ∇Array, ∇Scalar, :(Ȳ * transpose(inv(X))), (_ϵ, ub)), - (:transpose, ∇Array, ∇Array, :(transpose(Ȳ)), (lb, ub)), - (:adjoint, ∇Scalar, ∇Scalar, :(adjoint(Ȳ)), (_ϵ, ub)), - (:adjoint, ∇Array, ∇Array, :(adjoint(Ȳ)), (lb, ub)), - (:norm, ∇Array, ∇Scalar, :(Ȳ ./ Y .* abs2.(X) ./ X), (lb, ub)), - (:norm, ∇Scalar, ∇Scalar, :(Ȳ * sign(X)), (lb, ub)) - ==# -] -for (f, T_In, T_Out, X̄, bounds) in unary_linalg_optimisations - if f === :- - @eval import Base: - - else - @eval import LinearAlgebra: $f - end - @eval begin - @explicit_intercepts $f Tuple{$T_In} - ∇(::typeof($f), ::Type{Arg{1}}, p, Y::$T_Out, Ȳ::$T_Out, X::$T_In) = $X̄ - end -end - # Implementation of sensitivities for binary linalg optimisations. const A = ∇Array const S = ∇Scalar const AS = Union{∇Scalar, ∇Array} -δ = 1e-5 -binary_linalg_optimisations = [ - (:*, A, A, AS, - :(Ȳ * B'), - :(A' * Ȳ)), - (:/, A, A, AS, - :(Ȳ / transpose(B)), - :(-transpose(Y) * (Ȳ / transpose(B)))), - (:\, A, A, AS, - :(-(transpose(A) \ Ȳ) * transpose(Y)), - :(transpose(A) \ Ȳ)), - (:norm, A, S, S, - :(Ȳ .* Y^(1 - B) .* abs.(A).^B ./ A), - :(Ȳ * (Y^(1 - B) * sum(abs.(A).^B .* log.(abs.(A))) - Y * log(Y)) / B)), - (:norm, S, S, S, - :(Ȳ * sign(A)), - :(0)), -] -import Base: *, /, \ -import LinearAlgebra: norm -for (f, T_A, T_B, T_Y, Ā, B̄) in binary_linalg_optimisations - @eval begin - @explicit_intercepts $f Tuple{$T_A, $T_B} - ∇(::typeof($f), ::Type{Arg{1}}, p, Y::$T_Y, Ȳ::$T_Y, A::$T_A, B::$T_B) = $Ā - ∇(::typeof($f), ::Type{Arg{2}}, p, Y::$T_Y, Ȳ::$T_Y, A::$T_A, B::$T_B) = $B̄ - end -end + # Sensitivities for the Kronecker product: import LinearAlgebra: kron @@ -97,13 +45,6 @@ end @explicit_intercepts Base.:+ Tuple{UniformScaling, A} ∇(::typeof(+), ::Type{Arg{2}}, p, Y::∇Array, Ȳ::∇Array, A::UniformScaling, B::∇Array) = Ȳ -# Short-form `dot`. -@explicit_intercepts LinearAlgebra.dot Tuple{∇Array, ∇Array} -∇(::typeof(LinearAlgebra.dot), ::Type{Arg{1}}, p, z, z̄, x::A, y::A) = z̄ .* y -∇(::typeof(LinearAlgebra.dot), ::Type{Arg{2}}, p, z, z̄, x::A, y::A) = z̄ .* x -∇(x̄, ::typeof(LinearAlgebra.dot), ::Type{Arg{1}}, p, z, z̄, x::A, y::A) = (x̄ .= x̄ .+ z̄ .* y) -∇(ȳ, ::typeof(LinearAlgebra.dot), ::Type{Arg{2}}, p, z, z̄, x::A, y::A) = (ȳ .= ȳ .+ z̄ .* x) - # `copy` materializes `Adjoint` and `Transpose` wrappers but can be called on anything import Base: copy @explicit_intercepts copy Tuple{Any} From 04b45827edeab7d3d2e999128ff6ad956a34b489 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 30 Sep 2020 17:09:40 +0100 Subject: [PATCH 12/71] Delete Diagonal methods that moved to ChainRules --- src/sensitivities/linalg/diagonal.jl | 86 ---------------------------- 1 file changed, 86 deletions(-) diff --git a/src/sensitivities/linalg/diagonal.jl b/src/sensitivities/linalg/diagonal.jl index 2c3e83bc..a01f4f01 100644 --- a/src/sensitivities/linalg/diagonal.jl +++ b/src/sensitivities/linalg/diagonal.jl @@ -2,61 +2,6 @@ import LinearAlgebra: det, logdet, diagm, Diagonal, diag const ∇ScalarDiag = Diagonal{<:∇Scalar} -@explicit_intercepts diag Tuple{∇AbstractMatrix} -function ∇( - ::typeof(diag), - ::Type{Arg{1}}, - p, - y::∇AbstractVector, - ȳ::∇AbstractVector, - x::∇AbstractMatrix, -) - x̄ = zeroslike(x) - x̄[diagind(x̄)] = ȳ - return x̄ -end -function ∇( - x̄::∇AbstractMatrix, - ::typeof(diag), - ::Type{Arg{1}}, - p, - y::∇AbstractVector, - ȳ::∇AbstractVector, - x::∇AbstractMatrix, -) - x̄_diag = view(x̄, diagind(x̄)) - x̄_diag .+= ȳ - return x̄ -end - -@explicit_intercepts diag Tuple{∇AbstractMatrix, Integer} [true, false] -function ∇( - ::typeof(diag), - ::Type{Arg{1}}, - p, - y::∇AbstractVector, - ȳ::∇AbstractVector, - x::∇AbstractMatrix, - k::Integer, -) - x̄ = zeroslike(x) - x̄[diagind(x̄, k)] = ȳ - return x̄ -end -function ∇( - x̄::∇AbstractMatrix, - ::typeof(diag), - ::Type{Arg{1}}, - p, - y::∇AbstractVector, - ȳ::∇AbstractVector, - x::∇AbstractMatrix, - k::Integer, -) - x̄_diag = view(x̄, diagind(x̄, k)) - x̄_diag .+= ȳ - return x̄ -end @explicit_intercepts Diagonal Tuple{∇AbstractVector} function ∇( @@ -108,37 +53,6 @@ function ∇( return X̄ end -@explicit_intercepts det Tuple{Diagonal{<:∇Scalar}} -∇(::typeof(det), ::Type{Arg{1}}, p, y::∇Scalar, ȳ::∇Scalar, X::∇ScalarDiag) = - Diagonal(ȳ .* y ./ X.diag) -function ∇( - X̄::∇ScalarDiag, - ::typeof(det), - ::Type{Arg{1}}, - p, - y::∇Scalar, - ȳ::∇Scalar, - X::∇ScalarDiag, -) - broadcast!((x̄, x, y, ȳ)->x̄ + ȳ * y / x, X̄.diag, X̄.diag, X.diag, y, ȳ) - return X̄ -end - -@explicit_intercepts logdet Tuple{Diagonal{<:∇Scalar}} -∇(::typeof(logdet), ::Type{Arg{1}}, p, y::∇Scalar, ȳ::∇Scalar, X::∇ScalarDiag) = - Diagonal(ȳ ./ X.diag) -function ∇( - X̄::∇ScalarDiag, - ::typeof(logdet), - ::Type{Arg{1}}, - p, - y::∇Scalar, - ȳ::∇Scalar, - X::∇ScalarDiag, -) - broadcast!((x̄, x, ȳ)->x̄ + ȳ / x, X̄.diag, X̄.diag, X.diag, ȳ) - return X̄ -end # NOTE: diagm can't go through the @explicit_intercepts machinery directly because as of # Julia 0.7, its methods are not sufficiently straightforward; we need to dispatch on one From 0709260a908e1978669200715067cd56a83663ae Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 30 Sep 2020 18:28:56 +0100 Subject: [PATCH 13/71] Delete BLAS rules that moved to ChainRules --- src/sensitivities/linalg/blas.jl | 224 ------------------------------- 1 file changed, 224 deletions(-) diff --git a/src/sensitivities/linalg/blas.jl b/src/sensitivities/linalg/blas.jl index a262dad4..4d4be167 100644 --- a/src/sensitivities/linalg/blas.jl +++ b/src/sensitivities/linalg/blas.jl @@ -3,230 +3,6 @@ import LinearAlgebra.BLAS: asum, dot, blascopy!, nrm2, scal, scal!, gemm, gemm!, const SA = StridedArray -# Long-form `dot`. -@explicit_intercepts( - dot, - Tuple{Int, StridedArray, Int, StridedArray, Int}, - [false, true, false, true, false], -) -∇(::typeof(dot), ::Type{Arg{2}}, p, z, z̄, n::Int, x::SA, ix::Int, y::SA, iy::Int) = - scal!(n, z̄, blascopy!(n, y, iy, zeroslike(x), ix), ix) -∇(::typeof(dot), ::Type{Arg{4}}, p, z, z̄, n::Int, x::SA, ix::Int, y::SA, iy::Int) = - scal!(n, z̄, blascopy!(n, x, ix, zeroslike(y), iy), iy) -∇(x̄, ::typeof(dot), ::Type{Arg{2}}, p, z, z̄, n::Int, x::SA, ix::Int, y::SA, iy::Int) = - (x̄ .= x̄ .+ scal!(n, z̄, blascopy!(n, y, iy, zeroslike(x), ix), ix)) -∇(ȳ, ::typeof(dot), ::Type{Arg{4}}, p, z, z̄, n::Int, x::SA, ix::Int, y::SA, iy::Int) = - (ȳ .= ȳ .+ scal!(n, z̄, blascopy!(n, x, ix, zeroslike(y), iy), iy)) - -# Short-form `nrm2`. -@explicit_intercepts nrm2 Tuple{Union{StridedVector, Array}} -∇(::typeof(nrm2), ::Type{Arg{1}}, p, y, ȳ, x) = x * (ȳ / y) -∇(x̄, ::typeof(nrm2), ::Type{Arg{1}}, p, y, ȳ, x) = (x̄ .= x̄ .+ x .* (ȳ / y)) - -# Long-form `nrm2`. -@explicit_intercepts( - nrm2, - Tuple{Integer, Union{DenseArray, Ptr{<:AbstractFloat}}, Integer}, - [false, true, false], -) -∇(::typeof(nrm2), ::Type{Arg{2}}, p, y, ȳ, n::Integer, x, inc::Integer) = - scal!(n, ȳ / y, blascopy!(n, x, inc, zeroslike(x), inc), inc) -∇(x̄, ::typeof(nrm2), ::Type{Arg{2}}, p, y, ȳ, n::Integer, x, inc::Integer) = - (x̄ .= x̄ .+ scal!(n, ȳ / y, blascopy!(n, x, inc, zeroslike(x), inc), inc)) - -# Short-form `asum`. -@explicit_intercepts asum Tuple{Union{StridedVector, Array}} -∇(::typeof(asum), ::Type{Arg{1}}, p, y, ȳ, x) = ȳ .* sign.(x) -∇(x̄, ::typeof(asum), ::Type{Arg{1}}, p, y, ȳ, x) = (x̄ .= x̄ .+ ȳ .* sign.(x)) - -# Long-form `asum`. -@explicit_intercepts( - asum, - Tuple{Integer, Union{DenseArray, Ptr{<:AbstractFloat}}, Integer}, - [false, true, false], -) -∇(::typeof(asum), ::Type{Arg{2}}, p, y, ȳ, n::Integer, x, inc::Integer) = - scal!(n, ȳ, blascopy!(n, sign.(x), inc, zeroslike(x), inc), inc) -∇(x̄, ::typeof(asum), ::Type{Arg{2}}, p, y, ȳ, n::Integer, x, inc::Integer) = - (x̄ .= x̄ .+ scal!(n, ȳ, blascopy!(n, sign.(x), inc, zeroslike(x), inc), inc)) - - -# Some weird stuff going on that I haven't figured out yet. -# let f = :(scal{T <: AbstractArray, V <: AbstractFloat}) -# ā = :(blascopy!(n, z̄, inc, zeros(X), inc) .* X) -# X̄ = :(scal!(n, a, z̄, inc)) -# @eva; @primitive $f(n::Int, a::V, X::T, inc::Int) z z̄ false $ā $X̄ false -# end - -# `gemm` sensitivities implementation. -@explicit_intercepts( - gemm, - Tuple{Char, Char, StridedMatrix{T}, StridedMatrix{T}} where T<:∇Scalar, - [false, false, true, true], -) -∇(::typeof(gemm), ::Type{Arg{3}}, p, Y, Ȳ, - tA::Char, - tB::Char, - α::T, - A::StridedMatrix{T}, - B::StridedMatrix{T}, -) where T<:∇Scalar = sum(Ȳ .* Y) / α - -∇(::typeof(gemm), ::Type{Arg{4}}, p, Y, Ȳ, - tA::Char, - tB::Char, - α::T, - A::StridedMatrix{T}, - B::StridedMatrix{T}, -) where T<:∇Scalar = - uppercase(tA) == 'N' ? - uppercase(tB) == 'N' ? - gemm('N', 'T', α, Ȳ, B) : - gemm('N', 'N', α, Ȳ, B) : - uppercase(tB) == 'N' ? - gemm('N', 'T', α, B, Ȳ) : - gemm('T', 'T', α, B, Ȳ) - -∇(Ā::StridedMatrix{T}, ::typeof(gemm), ::Type{Arg{4}}, _, Y, Ȳ, - tA::Char, - tB::Char, - α::T, - A::StridedMatrix{T}, - B::StridedMatrix{T}, -) where T<:∇Scalar = - uppercase(tA) == 'N' ? - uppercase(tB) == 'N' ? - gemm!('N', 'T', α, Ȳ, B, 1.0, Ā) : - gemm!('N', 'N', α, Ȳ, B, 1.0, Ā) : - uppercase(tB) == 'N' ? - gemm!('N', 'T', α, B, Ȳ, 1.0, Ā) : - gemm!('T', 'T', α, B, Ȳ, 1.0, Ā) - -∇(::typeof(gemm), ::Type{Arg{5}}, p, Y, Ȳ, - tA::Char, - tB::Char, - α::T, - A::StridedMatrix{T}, - B::StridedMatrix{T}, -) where T<:∇Scalar = - uppercase(tA) == 'N' ? - uppercase(tB) == 'N' ? - gemm('T', 'N', α, A, Ȳ) : - gemm('T', 'N', α, Ȳ, A) : - uppercase(tB) == 'N' ? - gemm('N', 'N', α, A, Ȳ) : - gemm('T', 'T', α, Ȳ, A) - -∇(B̄::StridedMatrix{T}, ::typeof(gemm), ::Type{Arg{5}}, _, Y, Ȳ, - tA::Char, - tB::Char, - α::T, - A::StridedMatrix{T}, - B::StridedMatrix{T}, -) where T<:∇Scalar = - uppercase(tA) == 'N' ? - uppercase(tB) == 'N' ? - gemm!('T', 'N', α, A, Ȳ, 1.0, B̄) : - gemm!('T', 'N', α, Ȳ, A, 1.0, B̄) : - uppercase(tB) == 'N' ? - gemm!('N', 'N', α, A, Ȳ, 1.0, B̄) : - gemm!('T', 'T', α, Ȳ, A, 1.0, B̄) - -# `gemm` sensitivities implementation for `α = 1`. -@explicit_intercepts( - gemm, - Tuple{Char, Char, T, StridedMatrix{T}, StridedMatrix{T}} where T<:∇Scalar, - [false, false, true, true, true] -) -∇(::typeof(gemm), ::Type{Arg{3}}, p, Y, Ȳ, - tA::Char, - tB::Char, - A::StridedMatrix{T}, - B::StridedMatrix{T} -) where T<:∇Scalar = ∇(gemm, Arg{4}, p, Y, Ȳ, tA, tB, one(T), A, B) -∇(Ā, ::typeof(gemm), ::Type{Arg{3}}, p, Y, Ȳ, - tA::Char, - tB::Char, - A::StridedMatrix{T}, - B::StridedMatrix{T} -) where T<:∇Scalar = ∇(Ā, gemm, Arg{4}, p, Y, Ȳ, tA, tB, one(T), A, B) -∇(::typeof(gemm), ::Type{Arg{4}}, p, Y, Ȳ, - tA::Char, - tB::Char, - A::StridedMatrix{T}, - B::StridedMatrix{T}, -) where T<:∇Scalar = ∇(gemm, Arg{5}, p, Y, Ȳ, tA, tB, one(T), A, B) -∇(B̄, ::typeof(gemm), ::Type{Arg{4}}, p, Y, Ȳ, - tA::Char, - tB::Char, - A::StridedMatrix{T}, - B::StridedMatrix{T} -) where T<:∇Scalar = ∇(B̄, gemm, Arg{5}, p, Y, Ȳ, tA, tB, one(T), A, B) - -# `gemv` sensitivities implementation. -@explicit_intercepts( - gemv, - Tuple{Char, T, StridedMatrix{T}, StridedVector{T}} where T<:∇Scalar, - [false, true, true, true], -) -∇(::typeof(gemv), ::Type{Arg{2}}, p, y, ȳ, - tA::Char, - α::T, - A::StridedMatrix{T}, - x::StridedVector{T}, -) where T<:∇Scalar = dot(ȳ, y) / α -∇(::typeof(gemv), ::Type{Arg{3}}, p, y, ȳ, - tA::Char, - α::T, - A::StridedMatrix{T}, - x::StridedVector{T}, -) where T<:∇Scalar = uppercase(tA) == 'N' ? α * ȳ * x' : α * x * ȳ' -∇(Ā::StridedMatrix{T}, ::typeof(gemv), ::Type{Arg{3}}, _, y, ȳ, - tA::Char, - α::T, - A::StridedMatrix{T}, - x::StridedVector{T}, -) where T<:∇Scalar = uppercase(tA) == 'N' ? ger!(α, ȳ, x, Ā) : ger!(α, x, ȳ, Ā) -∇(::typeof(gemv), ::Type{Arg{4}}, p, y, ȳ, - tA::Char, - α::T, - A::StridedMatrix{T}, - x::StridedVector{T}, -) where T<:∇Scalar = gemv(uppercase(tA) == 'N' ? 'T' : 'N', α, A, ȳ) -∇(x̄::StridedVector{T}, ::typeof(gemv), ::Type{Arg{4}}, _, y, ȳ, - tA::Char, - α::T, - A::StridedMatrix{T}, - x::StridedVector{T}, -) where T<:∇Scalar = gemv!(uppercase(tA) == 'N' ? 'T' : 'N', α, A, ȳ, one(T), x̄) - -# `gemv` sensitivities implementation with `α = 1`. -@explicit_intercepts( - gemv, - Tuple{Char, StridedMatrix{T}, StridedVector{T}} where T<:∇Scalar, - [false, true, true], -) -∇(::typeof(gemv), ::Type{Arg{2}}, p, y, ȳ, - tA::Char, - A::StridedMatrix{T}, - x::StridedVector{T}, -) where T<:∇Scalar = ∇(gemv, Arg{3}, p, y, ȳ, tA, one(T), A, x) -∇(Ā::StridedMatrix{T}, ::typeof(gemv), ::Type{Arg{2}}, p, y, ȳ, - tA::Char, - A::StridedMatrix{T}, - x::StridedVector{T}, -) where T<:∇Scalar = ∇(Ā, gemv, Arg{3}, p, y, ȳ, tA, one(T), A, x) -∇(::typeof(gemv), ::Type{Arg{3}}, p, y, ȳ, - tA::Char, - A::StridedMatrix{T}, - x::StridedVector{T}, -) where T<:∇Scalar = ∇(gemv, Arg{4}, p, y, ȳ, tA, one(T), A, x) -∇(x̄::StridedVector{T}, ::typeof(gemv), ::Type{Arg{3}}, p, y, ȳ, - tA::Char, - A::StridedMatrix{T}, - x::StridedVector{T}, -) where T<:∇Scalar = ∇(x̄, gemv, Arg{4}, p, y, ȳ, tA, one(T), A, x) - # # `syrk` sensitivity implementations. # @explicit_intercepts( # syrk, From 08d9a2e81d730ddd72a9e03b56614b7ee3903360 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 1 Oct 2020 17:00:36 +0100 Subject: [PATCH 14/71] Use ChainRules for Cholesky --- src/sensitivities/chainrules.jl | 23 +- .../linalg/factorization/cholesky.jl | 205 +----------------- .../linalg/factorization/cholesky.jl | 43 +--- 3 files changed, 27 insertions(+), 244 deletions(-) diff --git a/src/sensitivities/chainrules.jl b/src/sensitivities/chainrules.jl index 51983bcc..88ef72aa 100644 --- a/src/sensitivities/chainrules.jl +++ b/src/sensitivities/chainrules.jl @@ -1,13 +1,13 @@ +#using InteractiveUtils + function generate_overload(sig) opT, argTs = Iterators.peel(ExprTools.parameters(sig)) opT <: Core.Builtin && return false # can't do operater overloading for builtins - opT <: Function || return false # not handling non-functions - fieldcount(opT) == 0 || return false # not handling functors + isabstracttype(opT) || fieldcount(opT) == 0 || return false # not handling functors isempty(argTs) && return false # we are an operator overloading AD, need operands - - nameof(opT.name.module) == :NaNMath && return false # Don't care about NaNMath + opT isa DataType && nameof(opT.name.module) == :NaNMath && return false # Don't care about NaNMath # Ignore functions that have complex ranges. This may change when Nabla supports complex # numbers. @@ -15,10 +15,12 @@ function generate_overload(sig) SpecialFunctions.hankelh1, SpecialFunctions.hankelh2, log1p, rem2pi, mod, atan, rem, )) && return false + opT <: Type{<:Complex} && return false # skip complex constructor # Ingore functions because have better Nabla specific version. opT ∈ typeof.(( - isapprox, size, length, + isapprox, size, length, isassigned, + Base.Broadcast.combine_styles, #TODO should i keep this? )) && return false @@ -31,7 +33,7 @@ function generate_overload(sig) @inline $(∇_declaration(signature_def)) $(overload_declarations!(signature_def, original_signature_args)...) end - #@show fdef + #opT <: Type && @show fdef eval(fdef) return true end @@ -68,8 +70,8 @@ function overload_declarations!(signature_def, original_signature_args) signature_def[:kwargs] = [:(kwargs...)] signature_def[:body] = quote - #@show op args = $(_args_tuple(signature_def[:args])) + # @show InteractiveUtils.@which rrule(op, unbox.(args)...) primal_val, pullback = rrule(op, unbox.(args)...; kwargs...) tape = get_tape(args) @@ -112,7 +114,11 @@ function preprocess_declaration(signature_def) preprocess_def = Dict{Symbol, Any}( :name => :preprocess, :args => [op, :($y::Branch), ȳ, args...], - :body => quote $y.pullback($ȳ) end, + :body => quote + pullback = getfield($y, :pullback) # avoid issues with getproperty overloading + @assert(pullback !== nothing, "pullback not set, probably because different code path used for preprocess vs for ∇. Probably need to delete a defination for ∇") + return pullback($ȳ) + end, ) where_params = get(signature_def, :whereparams, nothing) @@ -141,6 +147,7 @@ function ∇_declaration(signature_def) :args => [op, :(::Type{Arg{$N}}), p, y, ȳ, args...], :whereparams => [N; get(signature_def, :whereparams, [])], :body => quote $p[$N+1] end, + :kwargs => [:(kwargs...)], ) return ExprTools.combinedef(∇_def) end diff --git a/src/sensitivities/linalg/factorization/cholesky.jl b/src/sensitivities/linalg/factorization/cholesky.jl index 5cad9ac5..1013fdef 100644 --- a/src/sensitivities/linalg/factorization/cholesky.jl +++ b/src/sensitivities/linalg/factorization/cholesky.jl @@ -4,32 +4,11 @@ import Base: getproperty Base.@deprecate chol(X) cholesky(X).U -#= -See [1] for implementation details: pages 5-9 in particular. The derivations presented in -[1] assume column-major layout, whereas Julia primarily uses row-major. We therefore -implement both the derivations in [1] and their transpose, which is more appropriate to -Julia. - -[1] - "Differentiation of the Cholesky decomposition", Murray 2016 -=# const AM = AbstractMatrix const UT = UpperTriangular -@explicit_intercepts cholesky Tuple{AbstractMatrix{<:∇Scalar}} -∇(::typeof(cholesky), ::Type{Arg{1}}, p, U::Cholesky, Ū::AM{T}, Σ::AM{T}) where T<:∇Scalar = - chol_blocked_rev(Matrix(Ū), Matrix(U.U), 25, true) -@explicit_intercepts getproperty Tuple{Cholesky, Symbol} [true, false] -function ∇(::typeof(getproperty), ::Type{Arg{1}}, p, y, ȳ, C::Cholesky, x::Symbol) - if x === :U - C.uplo === 'U' ? UpperTriangular(ȳ) : LowerTriangular(ȳ') - elseif x === :L - C.uplo === 'L' ? LowerTriangular(ȳ) : UpperTriangular(ȳ') - else - throw(ArgumentError("unrecognized field $x; use U or L")) - end -end @explicit_intercepts( Cholesky, @@ -50,177 +29,15 @@ function ∇( # directly, so just pass through this call and return the sensitivies return X̄ end - -""" - level2partition(A::AbstractMatrix, j::Int, upper::Bool) - -Returns views to various bits of the lower triangle of `A` according to the -`level2partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then -the transposed views are returned from the upper triangle of `A`. -""" -function level2partition(A::AM, j::Int, upper::Bool) - - # Check that A is square and j is a valid index. - M, N = size(A) - (0 >= j || j > M) && throw(ArgumentError("j is out of range.")) - M != N && throw(ArgumentError("A is not square.")) - - if upper - r = view(A, 1:j-1, j) - d = view(A, j, j) - B = view(A, 1:j-1, j+1:N) - c = view(A, j, j+1:N) - else - r = view(A, j, 1:j-1) - d = view(A, j, j) - B = view(A, j+1:N, 1:j-1) - c = view(A, j+1:N, j) - end - return r, d, B, c -end - -""" - level3partition(A::AbstractMatrix, j::Int, k::Int, upper::Bool) - -Returns views to various bits of the lower triangle of `A` according to the -`level3partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then -the transposed views are returned from the upper triangle of `A`. -""" -function level3partition(A::AM, j::Int, k::Int, upper::Bool) - - # Check that A is square and j is a valid index. - M, N = size(A) - (0 >= j || j > M) && throw(ArgumentError("j is out of range.")) - M != N && throw(ArgumentError("A is not square.")) - - # Get views into bits of A. - if upper - R = view(A, 1:j-1, j:k) - D = view(A, j:k, j:k) - B = view(A, 1:j-1, k+1:N) - C = view(A, j:k, k+1:N) - else - R = view(A, j:k, 1:j-1) - D = view(A, j:k, j:k) - B = view(A, k+1:N, 1:j-1) - C = view(A, k+1:N, j:k) - end - return R, D, B, C -end - -""" - chol_unblocked_rev!( - Ā::AbstractMatrix{T}, - L::AbstractMatrix{T}, - upper::Bool - ) where T<:Real - -Compute the reverse-mode sensitivities of the Cholesky factorisation in an unblocked manner. -If `upper` is `false`, then the sensitivites computed from and stored in the lower triangle -of `Ā` and `L` respectively. If `upper` is `true` then they are computed and stored in the -upper triangles. If at input `upper` is `false` and `tril(Ā) = L̄`, at output -`tril(Ā) = tril(Σ̄)`, where `Σ = LLᵀ`. Analogously, if at input `upper` is `true` and -`triu(Ā) = triu(Ū)`, at output `triu(Ā) = triu(Σ̄)` where `Σ = UᵀU`. -""" -function chol_unblocked_rev!(Σ̄::AM{T}, L::AM{T}, upper::Bool) where T<:Real - - # Check that L is square, that Σ̄ is square and that they are the same size. - M, N = size(Σ̄) - M != N && throw(ArgumentError("Σ̄ is not square.")) - - # Compute the reverse-mode diff. - j = N - for ĵ in 1:N - r, d, B, c = level2partition(L, j, upper) - r̄, d̄, B̄, c̄ = level2partition(Σ̄, j, upper) - - # d̄ <- d̄ - c'c̄ / d. - d̄[1] -= dot(c, c̄) / d[1] - - # [d̄ c̄'] <- [d̄ c̄'] / d. - d̄ ./= d - c̄ ./= d - - # r̄ <- r̄ - [d̄ c̄'] [r' B']'. - r̄ = axpy!(-Σ̄[j, j], r, r̄) - r̄ = gemv!(upper ? 'N' : 'T', -one(T), B, c̄, one(T), r̄) - - # B̄ <- B̄ - c̄ r. - B̄ = upper ? ger!(-one(T), r, c̄, B̄) : ger!(-one(T), c̄, r, B̄) - d̄ ./= 2 - j -= 1 - end - return (upper ? triu! : tril!)(Σ̄) -end -chol_unblocked_rev(Σ̄::AM, L::AM, upper::Bool) = chol_unblocked_rev!(copy(Σ̄), L, upper) - -""" - chol_blocked_rev!( - Σ̄::AbstractMatrix{T}, - L::AbstractMatrix{T}, - Nb::Int, - upper::Bool - ) where T<:∇Scalar - -Compute the sensitivities of the Cholesky factorisation using a blocked, cache-friendly -procedure. `Σ̄` are the sensitivities of `L`, and will be transformed into the sensitivities -of `Σ`, where `Σ = LLᵀ`. `Nb` is the block-size to use. If the upper triangle has been used -to represent the factorization, that is `Σ = UᵀU` where `U := Lᵀ`, then this should be -indicated by passing `upper = true`. -""" -function chol_blocked_rev!(Σ̄::AM{T}, L::AM{T}, Nb::Int, upper::Bool) where T<:∇Scalar - - # Check that L is square, that Σ̄ is square and that they are the same size. - M, N = size(Σ̄) - M != N && throw(ArgumentError("Σ̄ is not square.")) - - tmp = Matrix{T}(undef, Nb, Nb) - - # Compute the reverse-mode diff. - k = N - if upper - for k̂ in 1:Nb:N - j = max(1, k - Nb + 1) - R, D, B, C = level3partition(L, j, k, true) - R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, true) - - C̄ = trsm!('L', 'U', 'N', 'N', one(T), D, C̄) - gemm!('N', 'N', -one(T), R, C̄, one(T), B̄) - gemm!('N', 'T', -one(T), C, C̄, one(T), D̄) - chol_unblocked_rev!(D̄, D, true) - gemm!('N', 'T', -one(T), B, C̄, one(T), R̄) - if size(D̄, 1) == Nb - tmp = axpy!(one(T), D̄, transpose!(tmp, D̄)) - gemm!('N', 'N', -one(T), R, tmp, one(T), R̄) - else - gemm!('N', 'N', -one(T), R, D̄ + D̄', one(T), R̄) - end - - k -= Nb - end - return triu!(Σ̄) - else - for k̂ in 1:Nb:N - j = max(1, k - Nb + 1) - R, D, B, C = level3partition(L, j, k, false) - R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, false) - - C̄ = trsm!('R', 'L', 'N', 'N', one(T), D, C̄) - gemm!('N', 'N', -one(T), C̄, R, one(T), B̄) - gemm!('T', 'N', -one(T), C̄, C, one(T), D̄) - chol_unblocked_rev!(D̄, D, false) - gemm!('T', 'N', -one(T), C̄, B, one(T), R̄) - if size(D̄, 1) == Nb - tmp = axpy!(one(T), D̄, transpose!(tmp, D̄)) - gemm!('N', 'N', -one(T), tmp, R, one(T), R̄) - else - gemm!('N', 'N', -one(T), D̄ + D̄', R, one(T), R̄) - end - - k -= Nb - end - return tril!(Σ̄) - end +function ∇( + ::Type{Cholesky}, + ::Type{Arg{1}}, + p, + C::Cholesky, + X̄::Composite{<:Cholesky}, + X::Union{UpperTriangular, LowerTriangular}, + uplo::Union{Char, Symbol}, + info::Integer, +) + return getproperty(X̄, Symbol(uplo)) end -chol_blocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, Nb::Int, upper::Bool) = - chol_blocked_rev!(copy(Σ̄), L, Nb, upper) diff --git a/test/sensitivities/linalg/factorization/cholesky.jl b/test/sensitivities/linalg/factorization/cholesky.jl index 79ca4b3b..73cc0fa0 100644 --- a/test/sensitivities/linalg/factorization/cholesky.jl +++ b/test/sensitivities/linalg/factorization/cholesky.jl @@ -1,47 +1,6 @@ @testset "Cholesky" begin import Nabla: level2partition, level3partition - let rng = MersenneTwister(123456), N = 5 - A = randn(rng, N, N) - r, d, B2, c = level2partition(A, 4, false) - R, D, B3, C = level3partition(A, 4, 4, false) - @test all(r .== R') - @test all(d .== D) - @test B2[1] == B3[1] - @test all(c .== C) - - # Check that level2partition with 'U' is consistent with 'L'. - rᵀ, dᵀ, B2ᵀ, cᵀ = level2partition(transpose(A), 4, true) - @test r == rᵀ - @test d == dᵀ - @test B2' == B2ᵀ - @test c == cᵀ - - # Check that level3partition with 'U' is consistent with 'L'. - R, D, B3, C = level3partition(A, 2, 4, false) - Rᵀ, Dᵀ, B3ᵀ, Cᵀ = level3partition(transpose(A), 2, 4, true) - @test transpose(R) == Rᵀ - @test transpose(D) == Dᵀ - @test transpose(B3) == B3ᵀ - @test transpose(C) == Cᵀ - end - - import Nabla: chol_unblocked_rev, chol_blocked_rev - let rng = MersenneTwister(123456), N = 10 - A, Ā = Matrix.(LowerTriangular.(randn.(Ref(rng), [N, N], [N, N]))) - # NOTE: BLAS gets angry if we don't materialize the Transpose objects first - B, B̄ = Matrix.(transpose.([A, Ā])) - @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 1, false) - @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 3, false) - @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 5, false) - @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 10, false) - @test chol_unblocked_rev(Ā, A, false) ≈ transpose(chol_unblocked_rev(B̄, B, true)) - - @test chol_unblocked_rev(B̄, B, true) ≈ chol_blocked_rev(B̄, B, 1, true) - @test chol_unblocked_rev(B̄, B, true) ≈ chol_blocked_rev(B̄, B, 5, true) - @test chol_unblocked_rev(B̄, B, true) ≈ chol_blocked_rev(B̄, B, 10, true) - end - # Check sensitivities for lower-triangular version. let rng = MersenneTwister(123456), N = 10 for _ in 1:10 @@ -62,7 +21,7 @@ @test getfield(U, :f) == Base.getproperty @test unbox(U) ≈ cholesky(X_).U - @test_throws ArgumentError ∇(X->cholesky(X).info)(X_) + @test_throws Exception ∇(X->cholesky(X).info)(X_) end let From 2aeef688ada0fd51596a4f854c02071a2446b701 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 1 Oct 2020 17:37:11 +0100 Subject: [PATCH 15/71] move SVD to ChainRules delete commented out moved SVD rules undo changes to multiple updating svd --- src/sensitivities/linalg/factorization/svd.jl | 128 ------------------ .../sensitivities/linalg/factorization/svd.jl | 9 -- 2 files changed, 137 deletions(-) diff --git a/src/sensitivities/linalg/factorization/svd.jl b/src/sensitivities/linalg/factorization/svd.jl index 7dc3335f..29eec407 100644 --- a/src/sensitivities/linalg/factorization/svd.jl +++ b/src/sensitivities/linalg/factorization/svd.jl @@ -1,46 +1,6 @@ import LinearAlgebra: svd import Base: getproperty -@explicit_intercepts svd Tuple{AbstractMatrix{<:Real}} - -∇(::typeof(svd), ::Type{Arg{1}}, p, Y::SVD, Ȳ::NamedTuple{(:U,:S,:V)}, A::AbstractMatrix) = - svd_rev(Y, Ȳ.U, Ȳ.S, Ȳ.V) - -@explicit_intercepts getproperty Tuple{SVD, Symbol} [true, false] - -function ∇(::typeof(getproperty), ::Type{Arg{1}}, p, y, ȳ, USV::SVD, x::Symbol) - if x === :S - return (U=zeroslike(USV.U), S=vec(ȳ), V=zeroslike(USV.V)) - elseif x === :U - return (U=reshape(ȳ, size(USV.U)), S=zeroslike(USV.S), V=zeroslike(USV.V)) - elseif x === :V - return (U=zeroslike(USV.U), S=zeroslike(USV.S), V=reshape(ȳ, size(USV.V))) - elseif x === :Vt - throw(ArgumentError("Vt is unsupported; use V and transpose the result")) - else - throw(ArgumentError("unrecognized property $x; expected U, S, or V")) - end -end - -function ∇( - x̄::NamedTuple{(:U,:S,:V)}, - ::typeof(getproperty), - ::Type{Arg{1}}, - p, y, ȳ, - USV::SVD, - x::Symbol, -) - # This call does the validation that `x` is a recognized property - x̄_update = ∇(getproperty, Arg{1}, p, y, ȳ, USV, x) - if x === :S - return (U=x̄.U, S=update!(x̄.S, x̄_update.S), V=x̄.V) - elseif x === :U - return (U=update!(x̄.U, x̄_update.U), S=x̄.S, V=x̄.V) - elseif x === :V - return (U=x̄.U, S=x̄.S, V=update!(x̄.V, x̄_update.V)) - end -end - # Iteration allows destructuring, e.g. U, S, V = svd(A) # These definitions mirror those defined in the LinearAlgebra module, see # https://github.com/JuliaLang/julia/blob/master/stdlib/LinearAlgebra/src/svd.jl#L20-L24 @@ -48,91 +8,3 @@ Base.iterate(usv::Branch{<:SVD}) = (usv.U, Val(:S)) Base.iterate(usv::Branch{<:SVD}, ::Val{:S}) = (usv.S, Val(:V)) Base.iterate(usv::Branch{<:SVD}, ::Val{:V}) = (usv.V, Val(:done)) Base.iterate(usv::Branch{<:SVD}, ::Val{:done}) = nothing - -""" - svd_rev(USV, Ū, S̄, V̄) - -Compute the reverse mode sensitivities of the singular value decomposition (SVD). `USV` is -an `SVD` factorization object produced by a call to `svd`, and `Ū`, `S̄`, and `V̄` are the -respective sensitivities of the `U`, `S`, and `V` factors. -""" -function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix) - # Note: assuming a thin factorization, i.e. svd(A, full=false), which is the default - U = USV.U - s = USV.S - V = USV.V - Vt = USV.Vt - - k = length(s) - T = eltype(s) - F = T[i == j ? 1 : inv(@inbounds s[j]^2 - s[i]^2) for i = 1:k, j = 1:k] - - # We do a lot of matrix operations here, so we'll try to be memory-friendly and do - # as many of the computations in-place as possible. Benchmarking shows that the in- - # place functions here are significantly faster than their out-of-place, naively - # implemented counterparts, and allocate no additional memory. - Ut = U' - FUᵀŪ = mulsubtrans!(Ut*Ū, F) # F .* (UᵀŪ - ŪᵀU) - FVᵀV̄ = mulsubtrans!(Vt*V̄, F) # F .* (VᵀV̄ - V̄ᵀV) - ImUUᵀ = eyesubx!(U*Ut) # I - UUᵀ - ImVVᵀ = eyesubx!(V*Vt) # I - VVᵀ - - S = Diagonal(s) - S̄ = Diagonal(s̄) - - Ā = add!(U*FUᵀŪ*S, ImUUᵀ*(Ū/S))*Vt - add!(Ā, U*S̄*Vt) - add!(Ā, U*add!(S*FVᵀV̄*Vt, (S\V̄')*ImVVᵀ)) - - return Ā -end - -""" - mulsubtrans!(X::AbstractMatrix, F::AbstractMatrix) - -Compute `F .* (X - X')`, overwriting `X` in the process. - -!!! note - This is an internal function that does no argument checking; the matrices passed to - this function are square with matching dimensions by construction. -""" -function mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real - k = size(X, 1) - @inbounds for j = 1:k, i = 1:j # Iterate the upper triangle - if i == j - X[i,i] = zero(T) - else - X[i,j], X[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j]) - end - end - X -end - -""" - eyesubx!(X::AbstractMatrix) - -Compute `I - X`, overwriting `X` in the process. -""" -function eyesubx!(X::AbstractMatrix{T}) where T<:Real - n, m = size(X) - @inbounds for j = 1:m, i = 1:n - X[i,j] = (i == j) - X[i,j] - end - X -end - -""" - add!(X::AbstractMatrix, Y::AbstractMatrix) - -Compute `X + Y`, overwriting X in the process. - -!!! note - This is an internal function that does no argument checking; the matrices passed to - this function are square with matching dimensions by construction. -""" -function add!(X::AbstractMatrix{T}, Y::AbstractMatrix{T}) where T<:Real - @inbounds for i = eachindex(X, Y) - X[i] += Y[i] - end - X -end diff --git a/test/sensitivities/linalg/factorization/svd.jl b/test/sensitivities/linalg/factorization/svd.jl index c3e0323b..21bdb2a4 100644 --- a/test/sensitivities/linalg/factorization/svd.jl +++ b/test/sensitivities/linalg/factorization/svd.jl @@ -35,15 +35,6 @@ @test V isa Branch{<:Adjoint} end - @testset "Helper functions" begin - rng = MersenneTwister(12345) - X = randn(rng, 10, 10) - Y = randn(rng, 10, 10) - @test Nabla.mulsubtrans!(copy(X), Y) ≈ Y .* (X - X') - @test Nabla.eyesubx!(copy(X)) ≈ I - X - @test Nabla.add!(copy(X), Y) ≈ X + Y - end - @testset "Tape updating from multiple components" begin ∇f = ∇() do X U, S, V = svd(X) From b71b6d02e97f4ea8b5c6da0f249f6ffcee7ad84d Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 12 Oct 2020 16:29:11 +0100 Subject: [PATCH 16/71] =?UTF-8?q?Correct=20=E2=88=87=20for=20Symmetric=20c?= =?UTF-8?q?onstructor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sensitivities/linalg/symmetric.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sensitivities/linalg/symmetric.jl b/src/sensitivities/linalg/symmetric.jl index f362452a..ce7476fe 100644 --- a/src/sensitivities/linalg/symmetric.jl +++ b/src/sensitivities/linalg/symmetric.jl @@ -1,4 +1,4 @@ import LinearAlgebra: Symmetric @explicit_intercepts Symmetric Tuple{∇Array} -∇(::typeof(Symmetric), ::Type{Arg{1}}, p, Y::∇Array, Ȳ::∇Array, X::∇Array) = +∇(::Type{Symmetric}, ::Type{Arg{1}}, p, Y::∇Array, Ȳ::∇Array, X::∇Array) = UpperTriangular(Ȳ) + LowerTriangular(Ȳ)' - Diagonal(Ȳ) From 835f66ea8daf219a6d1a624b4f19cddfa4c3ba9c Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 12 Oct 2020 16:29:41 +0100 Subject: [PATCH 17/71] Move strided to ChainRules.jl --- src/sensitivities/linalg/strided.jl | 34 ---------------------------- test/sensitivities/linalg/strided.jl | 17 +++++++++++--- 2 files changed, 14 insertions(+), 37 deletions(-) diff --git a/src/sensitivities/linalg/strided.jl b/src/sensitivities/linalg/strided.jl index 6aa2708f..789205b0 100644 --- a/src/sensitivities/linalg/strided.jl +++ b/src/sensitivities/linalg/strided.jl @@ -1,37 +1,3 @@ -# Use BLAS.gemm for strided matrix-matrix multiplication sensitivites. Don't bother with -# BLAS for matrix-vector stuff yet. Definitely an optimisation that we might want to -# consider at some point in the future though. -const RS = StridedMatrix{<:∇Scalar} -const RST = Transpose{<:∇Scalar, RS} -const RSA = Adjoint{<:∇Scalar, RS} -strided_matmul = [ - (RS, RS, 'N', 'C', :Ȳ, :B, 'C', 'N', :A, :Ȳ), - (RST, RS, 'N', 'T', :B, :Ȳ, 'N', 'N', :A, :Ȳ), - (RS, RST, 'N', 'N', :Ȳ, :B, 'T', 'N', :Ȳ, :A), - (RST, RST, 'T', 'T', :B, :Ȳ, 'T', 'T', :Ȳ, :A), - (RSA, RS, 'N', 'C', :B, :Ȳ, 'N', 'N', :A, :Ȳ), - (RS, RSA, 'N', 'N', :Ȳ, :B, 'C', 'N', :Ȳ, :A), - (RSA, RSA, 'C', 'C', :B, :Ȳ, 'C', 'C', :Ȳ, :A), -] -import Base: * -for (TA, TB, tCA, tDA, CA, DA, tCB, tDB, CB, DB) in strided_matmul - - # Add intercepts and export names. - @eval @explicit_intercepts $(Symbol("*")) Tuple{$TA, $TB} - - # Define allocating and non-allocating sensitivities for each output. - alloc_Ā = :(LinearAlgebra.BLAS.gemm($tCA, $tDA, $CA, $DA)) - alloc_B̄ = :(LinearAlgebra.BLAS.gemm($tCB, $tDB, $CB, $DB)) - no_alloc_Ā = :(LinearAlgebra.BLAS.gemm!($tCA, $tDA, 1., $CA, $DA, 1., Ā)) - no_alloc_B̄ = :(LinearAlgebra.BLAS.gemm!($tCB, $tDB, 1., $CB, $DB, 1., B̄)) - - # Add sensitivity definitions. - @eval ∇(::typeof(*), ::Type{Arg{1}}, p, Y::RS, Ȳ::RS, A::$TA, B::$TB) = $alloc_Ā - @eval ∇(::typeof(*), ::Type{Arg{2}}, p, Y::RS, Ȳ::RS, A::$TA, B::$TB) = $alloc_B̄ - @eval ∇(Ā, ::typeof(*), ::Type{Arg{1}}, p, Y::RS, Ȳ::RS, A::$TA, B::$TB) = $no_alloc_Ā - @eval ∇(B̄, ::typeof(*), ::Type{Arg{2}}, p, Y::RS, Ȳ::RS, A::$TA, B::$TB) = $no_alloc_B̄ -end - # # Not every permutation of transpositions makes sense for matrix-vector multiplication. This # # list just includes those which make sense. # strided_matvecmul = [ diff --git a/test/sensitivities/linalg/strided.jl b/test/sensitivities/linalg/strided.jl index db4a19d4..615c23cc 100644 --- a/test/sensitivities/linalg/strided.jl +++ b/test/sensitivities/linalg/strided.jl @@ -1,9 +1,20 @@ @testset "Strided" begin - + RS = StridedMatrix{<:∇Scalar} + RST = Transpose{<:∇Scalar, RS} + RSA = Adjoint{<:∇Scalar, RS} + strided_matmul_combinations = ( + (RS, RS, 'N', 'C', :Ȳ, :B, 'C', 'N', :A, :Ȳ), + (RST, RS, 'N', 'T', :B, :Ȳ, 'N', 'N', :A, :Ȳ), + (RS, RST, 'N', 'N', :Ȳ, :B, 'T', 'N', :Ȳ, :A), + (RST, RST, 'T', 'T', :B, :Ȳ, 'T', 'T', :Ȳ, :A), + (RSA, RS, 'N', 'C', :B, :Ȳ, 'N', 'N', :A, :Ȳ), + (RS, RSA, 'N', 'N', :Ȳ, :B, 'C', 'N', :Ȳ, :A), + (RSA, RSA, 'C', 'C', :B, :Ȳ, 'C', 'C', :Ȳ, :A), + ) + # TODO: This test seems like it doesn't actually test the combinations. let rng = MersenneTwister(123456), N = 100 - # Test strided matrix-matrix multiplication sensitivities. - for (TA, TB, tCA, tDA, CA, DA, tCB, tDB, CB, DB) in Nabla.strided_matmul + for (TA, TB, tCA, tDA, CA, DA, tCB, tDB, CB, DB) in strided_matmul_combinations A, B, VA, VB = randn.(Ref(rng), [N, N, N, N], [N, N, N, N]) @test check_errs(*, A * B, (A, B), (VA, VB)) end From bd28593a8fcf7d3059dbf084054c1b43e0092947 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 12 Oct 2020 16:32:34 +0100 Subject: [PATCH 18/71] Delete moved rule for diagonal Add note to diagonal about overloading Pair --- src/sensitivities/linalg/diagonal.jl | 28 ++++------------------------ 1 file changed, 4 insertions(+), 24 deletions(-) diff --git a/src/sensitivities/linalg/diagonal.jl b/src/sensitivities/linalg/diagonal.jl index a01f4f01..055a4ce8 100644 --- a/src/sensitivities/linalg/diagonal.jl +++ b/src/sensitivities/linalg/diagonal.jl @@ -2,30 +2,6 @@ import LinearAlgebra: det, logdet, diagm, Diagonal, diag const ∇ScalarDiag = Diagonal{<:∇Scalar} - -@explicit_intercepts Diagonal Tuple{∇AbstractVector} -function ∇( - ::Type{Diagonal}, - ::Type{Arg{1}}, - p, - Y::∇ScalarDiag, - Ȳ::∇AbstractMatrix, - x::∇AbstractVector, -) - return copyto!(similar(x), diag(Ȳ)) -end -function ∇( - x̄::∇AbstractVector, - ::Type{Diagonal}, - ::Type{Arg{1}}, - p, - Y::∇ScalarDiag, - Ȳ::∇AbstractMatrix, - x::∇AbstractVector, -) - return broadcast!(+, x̄, x̄, diag(Ȳ)) -end - @explicit_intercepts Diagonal Tuple{∇AbstractMatrix} function ∇( ::Type{Diagonal}, @@ -62,9 +38,13 @@ end # _diagm when it receives arguments that are nodes. _diagm can go through the intercepts # machinery, so it knows how to deal. +# TODO: Possibly we should overload `Pair` so that it constructs a `Node{Pair}` then this +# would hit sentitivities that we have defined via ChainRules. + _diagm(x::∇AbstractVector, k::Integer=0) = diagm(k => x) LinearAlgebra.diagm(x::Pair{<:Integer, <:Node{<:∇AbstractVector}}) = _diagm(last(x), first(x)) + @explicit_intercepts _diagm Tuple{∇AbstractVector} function ∇( ::typeof(_diagm), From 1b0df5b06d7dd0433c8424d89e3abd38f221490c Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 13 Oct 2020 12:14:19 +0100 Subject: [PATCH 19/71] remove structured constructors that moved to ChainRules --- src/sensitivities/linalg/triangular.jl | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/sensitivities/linalg/triangular.jl b/src/sensitivities/linalg/triangular.jl index 6fcb4c83..2ba04d65 100644 --- a/src/sensitivities/linalg/triangular.jl +++ b/src/sensitivities/linalg/triangular.jl @@ -4,23 +4,13 @@ const ∇ScalarLT = LowerTriangular{<:∇Scalar} const ∇ScalarUT = UpperTriangular{<:∇Scalar} for (ctor, T) in zip([:LowerTriangular, :UpperTriangular], [:∇ScalarLT, :∇ScalarUT]) - - @eval @explicit_intercepts $ctor Tuple{∇AbstractMatrix} - @eval ∇(::Type{$ctor}, ::Type{Arg{1}}, p, Y::$T, Ȳ::$T, X::∇AbstractMatrix) = Matrix(Ȳ) - @eval ∇( - X̄::∇AbstractMatrix, - ::Type{$ctor}, - ::Type{Arg{1}}, - p, - Y::$T, - Ȳ::$T, - X::∇AbstractMatrix, - ) = broadcast!(+, X̄, X̄, Ȳ) - + #== TODO: a lot of this need to move to ChainRules to make sure the types are right. @eval @explicit_intercepts det Tuple{$T} @eval ∇(::typeof(det), ::Type{Arg{1}}, p, y::∇Scalar, ȳ::∇Scalar, X::$T) = Diagonal(ȳ .* y ./ view(X, diagind(X))) + + # Optimisation for in-place updates. @eval function ∇( X̄::$T, @@ -83,4 +73,5 @@ for (ctor, T) in zip([:LowerTriangular, :UpperTriangular], [:∇ScalarLT, :∇Sc X̄.diag .+= ȳ ./ view(X, diagind(X)) return X̄ end + ==# end From 78c4c41d36436b08bb1fb0ae0a7a7073b082401d Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 13 Oct 2020 15:12:25 +0100 Subject: [PATCH 20/71] delete triangular rules that moved to ChainRules --- src/Nabla.jl | 1 - src/sensitivities/linalg/triangular.jl | 77 -------------------------- 2 files changed, 78 deletions(-) delete mode 100644 src/sensitivities/linalg/triangular.jl diff --git a/src/Nabla.jl b/src/Nabla.jl index 0deb92a5..67c72f43 100644 --- a/src/Nabla.jl +++ b/src/Nabla.jl @@ -64,7 +64,6 @@ module Nabla include("sensitivities/linalg/strided.jl") include("sensitivities/linalg/blas.jl") include("sensitivities/linalg/diagonal.jl") - include("sensitivities/linalg/triangular.jl") include("sensitivities/linalg/factorization/cholesky.jl") include("sensitivities/linalg/factorization/svd.jl") diff --git a/src/sensitivities/linalg/triangular.jl b/src/sensitivities/linalg/triangular.jl deleted file mode 100644 index 2ba04d65..00000000 --- a/src/sensitivities/linalg/triangular.jl +++ /dev/null @@ -1,77 +0,0 @@ -import LinearAlgebra: det, logdet, LowerTriangular, UpperTriangular - -const ∇ScalarLT = LowerTriangular{<:∇Scalar} -const ∇ScalarUT = UpperTriangular{<:∇Scalar} - -for (ctor, T) in zip([:LowerTriangular, :UpperTriangular], [:∇ScalarLT, :∇ScalarUT]) - #== TODO: a lot of this need to move to ChainRules to make sure the types are right. - @eval @explicit_intercepts det Tuple{$T} - @eval ∇(::typeof(det), ::Type{Arg{1}}, p, y::∇Scalar, ȳ::∇Scalar, X::$T) = - Diagonal(ȳ .* y ./ view(X, diagind(X))) - - - - # Optimisation for in-place updates. - @eval function ∇( - X̄::$T, - ::typeof(det), - ::Type{Arg{1}}, - p, - y::∇Scalar, - ȳ::∇Scalar, - X::$T, - ) - X̄_diag = view(X̄, diagind(X̄)) - broadcast!((x̄, x, y, ȳ)->x̄ + ȳ * y / x, - X̄_diag, X̄_diag, view(X, diagind(X)), y, ȳ) - return X̄ - end - - # Optimisation for in-place updates to `Diagonal` sensitivity cache. - @eval function ∇( - X̄::Diagonal, - ::typeof(det), - ::Type{Arg{1}}, - p, - y::∇Scalar, - ȳ::∇Scalar, - X::$T, - ) - X̄.diag .+= ȳ .* y ./ view(X, diagind(X)) - return X̄ - end - - @eval @explicit_intercepts logdet Tuple{$T} - @eval ∇(::typeof(logdet), ::Type{Arg{1}}, p, y::∇Scalar, ȳ::∇Scalar, X::$T) = - Diagonal(ȳ ./ view(X, diagind(X))) - - # Optimisation for in-place updates. - @eval function ∇( - X̄::∇Array, - ::typeof(logdet), - ::Type{Arg{1}}, - p, - y::∇Scalar, - ȳ::∇Scalar, - X::$T - ) - X̄_diag = view(X̄, diagind(X̄)) - broadcast!((x̄, x, ȳ)->x̄ + ȳ / x, X̄_diag, X̄_diag, view(X, diagind(X)), ȳ) - return X̄ - end - - # Optimisation for in-place updates to `Diagonal` sensitivity cache. - @eval function ∇( - X̄::Diagonal, - ::typeof(logdet), - ::Type{Arg{1}}, - p, - y::∇Scalar, - ȳ::∇Scalar, - X::$T, - ) - X̄.diag .+= ȳ ./ view(X, diagind(X)) - return X̄ - end - ==# -end From 7aad58a387e4b9317a56dc480db7e7ac88d95e28 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 13 Oct 2020 18:29:07 +0100 Subject: [PATCH 21/71] put stuff in testsets --- test/core.jl | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/test/core.jl b/test/core.jl index 29f49cfe..510b31a2 100644 --- a/test/core.jl +++ b/test/core.jl @@ -100,8 +100,8 @@ end # testset Tape -# Check that functions involving `isapprox` can be differentiated -let + +@testset "Check that functions involving `isapprox` can be differentiated" begin f(x) = x ≈ 5.0 ? 1.0 : 3.0 * x g(x) = 5.0 * x h(x) = g(x) ≈ 25.0 ? x : f(x) + g(x) @@ -121,8 +121,7 @@ let @test ∇f(6.0, 5.0) == (3.0, 0.0) end -# Check that functions with extra, unused variables can be differentiated -let +@testset "Check that functions with extra, unused variables can be differentiated" begin f(a,b,c,d) = a*c ∇f = ∇(f) g(a,b) = 12 @@ -133,8 +132,7 @@ let @test ∇g(1,2) == (0,0) end -# Check that functions with `zero` and `one` can be differentiated -let +@testset "Check that functions with `zero` and `one` can be differentiated" begin f(a) = zero(a) g(a) = one(a) h(a) = zero(3 * a) + one(4 * a) @@ -148,8 +146,7 @@ let @test ∇h(8) == (0,) end -# Check that the convenience implementation of ∇ works as intended. -let +@testset "Check that the convenience implementation of ∇ works as intended." begin f(x, y) = 2x + y ∇f = ∇(f) ∇f_out = ∇(f; get_output=true) @@ -170,8 +167,7 @@ end @test ∇(unbox, get_output=true)(2) == (2, (0,)) end -# Tests for zero'd and one'd containers. -let +@testset "Tests for zero'd and one'd containers." begin import Nabla: zerod_container, oned_container @test zerod_container(1.0) == 0.0 @test zerod_container(1) == 0 From ae99591a4331e3e5d1d8fc314157784920a092bb Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 14 Oct 2020 13:55:07 +0100 Subject: [PATCH 22/71] move list of linalg optimizations to the tests --- test/sensitivities/linalg/generic.jl | 31 +++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/test/sensitivities/linalg/generic.jl b/test/sensitivities/linalg/generic.jl index 6d8f7711..92794842 100644 --- a/test/sensitivities/linalg/generic.jl +++ b/test/sensitivities/linalg/generic.jl @@ -19,23 +19,44 @@ trand(rng::AbstractRNG, ::Type{<:Adjoint}) = Adjoint(rand(rng, N, N)) @testset "Unary sensitivities" begin + _ϵ, lb, ub = 3e-2, -3.0, 3.0 + unary_linalg_optimisations = [ + (-, ∇Array, (lb, ub)), + (tr, ∇Array, (lb, ub)), + (inv, ∇Array, (lb, ub)), + (det, ∇Array, (_ϵ, ub)), + (logdet, ∇Array, (_ϵ, ub)), + (transpose, ∇Array, (lb, ub)), + (adjoint, ∇Scalar, (_ϵ, ub)), + (adjoint, ∇Array, (lb, ub)), + (norm, ∇Array, (lb, ub)), + (norm, ∇Scalar, (lb, ub)), + ] + rng = MersenneTwister(123) - @testset "$f" for (f, T_In, T_Out, X̄, bounds) in Nabla.unary_linalg_optimisations + @testset "$f" for (f, T_In, bounds) in unary_linalg_optimisations for _ in 1:5 Z = trand(rng, T_In) .* (bounds[2] .- bounds[1]) .+ bounds[1] X = Z'Z + 1e-6 * one(Z) - Ȳ, V = eval(f)(X), trandn(rng, T_In) - @test check_errs(eval(f), Ȳ, X, 1e-1 .* V) + Ȳ, V = f(X), trandn(rng, T_In) + @test check_errs(f, Ȳ, X, 1e-1 .* V) end end end @testset "Binary sensitivities" begin rng = MersenneTwister(2) - @testset "$f" for (f, T_A, T_B, T_Y, Ā, B̄) in Nabla.binary_linalg_optimisations + binary_linalg_optimisations = [ + (*, ∇Array, ∇Array,), + (/, ∇Array, ∇Array,), + (\, ∇Array, ∇Array,), + (norm, ∇Array, ∇Scalar,), + (norm, ∇Scalar, ∇Scalar,), + ] + @testset "$f" for (f, T_A, T_B) in binary_linalg_optimisations for _ in 1:5 A, B, VA, VB = trandn.(Ref(rng), (T_A, T_B, T_A, T_B)) - @test check_errs(eval(f), eval(f)(A, B), (A, B), (VA, VB)) + @test check_errs(f, eval(f)(A, B), (A, B), (VA, VB)) end end end From a6ede28e70f0fe6408e08fcf3586afcef8bb8a95 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 14 Oct 2020 14:56:21 +0100 Subject: [PATCH 23/71] delete never used uniform scaling file --- src/sensitivities/linalg/uniformscaling.jl | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/sensitivities/linalg/uniformscaling.jl diff --git a/src/sensitivities/linalg/uniformscaling.jl b/src/sensitivities/linalg/uniformscaling.jl deleted file mode 100644 index e69de29b..00000000 From e05bbba25381f7f1c8273903e752d38fc8184ee7 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 16 Oct 2020 20:40:35 +0100 Subject: [PATCH 24/71] Move indexing over to ChainRules --- src/sensitivities/chainrules.jl | 7 ++++--- src/sensitivities/indexing.jl | 12 ++---------- test/sensitivities/indexing.jl | 10 +++++++++- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/sensitivities/chainrules.jl b/src/sensitivities/chainrules.jl index 88ef72aa..7a54492e 100644 --- a/src/sensitivities/chainrules.jl +++ b/src/sensitivities/chainrules.jl @@ -17,7 +17,7 @@ function generate_overload(sig) )) && return false opT <: Type{<:Complex} && return false # skip complex constructor - # Ingore functions because have better Nabla specific version. + # Ignore these functions because they have better Nabla specific versions. opT ∈ typeof.(( isapprox, size, length, isassigned, Base.Broadcast.combine_styles, #TODO should i keep this? @@ -33,7 +33,7 @@ function generate_overload(sig) @inline $(∇_declaration(signature_def)) $(overload_declarations!(signature_def, original_signature_args)...) end - #opT <: Type && @show fdef + opT <: typeof(svd) && @show fdef eval(fdef) return true end @@ -137,6 +137,7 @@ function ∇_declaration(signature_def) # For readability lets name all the parts, NB: this is being a bit too cute. op = signature_def[:name] args = signature_def[:args] + N = gensym(:N) p = gensym(:p) y = :(::Any) @@ -146,7 +147,7 @@ function ∇_declaration(signature_def) :name => :∇, :args => [op, :(::Type{Arg{$N}}), p, y, ȳ, args...], :whereparams => [N; get(signature_def, :whereparams, [])], - :body => quote $p[$N+1] end, + :body => quote $p[$N+1] end, # skip dself :kwargs => [:(kwargs...)], ) return ExprTools.combinedef(∇_def) diff --git a/src/sensitivities/indexing.jl b/src/sensitivities/indexing.jl index 6e4791cb..ca03dbb8 100644 --- a/src/sensitivities/indexing.jl +++ b/src/sensitivities/indexing.jl @@ -6,16 +6,8 @@ for i = 1:7 @eval @explicit_intercepts getindex $T $is_node end -function ∇(Ā, ::typeof(getindex), ::Type{Arg{1}}, p, y, ȳ, A, inds...) - Ā[inds...] += ȳ - return Ā -end -function ∇(Ā, ::typeof(getindex), ::Type{Arg{1}}, p, y::AbstractArray, ȳ::AbstractArray, A, inds...) - @views Ā[inds...] .+= reshape(ȳ, size(y)...) - return Ā -end -function ∇(::typeof(getindex), ::Type{Arg{1}}, p, y, ȳ, A, inds...) - return ∇(zerod_container(A), getindex, Arg{1}, p, y, ȳ, A, inds...) +function ∇(::typeof(getindex), ::Type{Arg{1}}, p, y, ȳ, A::Ref) + return Ref(ȳ) end # # Implementation of reverse-mode sensitivities for `view`. Not currently in use because diff --git a/test/sensitivities/indexing.jl b/test/sensitivities/indexing.jl index 3f1a01f3..a8a156c3 100644 --- a/test/sensitivities/indexing.jl +++ b/test/sensitivities/indexing.jl @@ -20,7 +20,15 @@ @test ∇(y, oneslike(unbox(y)))[x] == [0, 1, 2] end - @testset "Ref" begin + @testset "Ref indexed by []" begin @test ref_equal(∇(getindex)(Ref(4))[1], Ref(1)) end + + @testset "Tuple indexed by Int" begin + leaf = Leaf(Tape(), tuple(5, 6, 7)) + y = getindex(leaf, 1) + @test unbox(y) == 5 + # Nabla has never supported getindex(::Tuple, ::Int) + @test_broken ∇(y, one(unbox(y)))[leaf] == (1, 0, 0) + end end From 1549b757424448ce7e2fca3baf30a8ecc54092d9 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 19 Oct 2020 17:47:09 +0100 Subject: [PATCH 25/71] Remove DiffRules entirely --- Project.toml | 4 +- src/sensitivities/scalar.jl | 36 ------------ test/runtests.jl | 25 +++++++++ test/sensitivities/functional/functional.jl | 62 ++++++++------------- test/sensitivities/scalar.jl | 49 ++++------------ 5 files changed, 61 insertions(+), 115 deletions(-) diff --git a/Project.toml b/Project.toml index 44abe880..4921fb92 100644 --- a/Project.toml +++ b/Project.toml @@ -5,9 +5,10 @@ version = "0.12.3" [deps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -15,7 +16,6 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -DiffRules = "^0.0" FDM = "^0.6" ForwardDiff = "0.10.12" SpecialFunctions = ">=0.5.0" diff --git a/src/sensitivities/scalar.jl b/src/sensitivities/scalar.jl index cd75b702..57a0102f 100644 --- a/src/sensitivities/scalar.jl +++ b/src/sensitivities/scalar.jl @@ -1,6 +1,3 @@ -using SpecialFunctions -using DiffRules: DiffRules, @define_diffrule, diffrule, diffrules, hasdiffrule - # Hand code the identity because it's really fundamental. It doesn't need to generate a new # node on the computational graph since it does nothing, but it is useful to have it's # gradient implemented for use in higher-order functions. @@ -9,39 +6,6 @@ import Base.identity @inline ∇(::typeof(identity), ::Type{Arg{1}}, p, y, ȳ, x) = ȳ @inline ∇(::typeof(identity), ::Type{Arg{1}}, x::Real) = one(x) -# Ignore functions that have complex ranges. This may change when Nabla supports complex -# numbers. -ignored_fs = [(:SpecialFunctions, :hankelh1), - (:SpecialFunctions, :hankelh2), - (:Base, :log1p), - (:Base, :rem2pi), - (:Base, :mod), - (:Base, :atan), - (:Base, :rem)] - -unary_sensitivities, binary_sensitivities = [], [] - -for (package, f, arity) in diffrules() - (package == :NaNMath || (package, f) in ignored_fs) && continue - - @eval import $package: $f - if arity == 1 - push!(unary_sensitivities, (package, f)) - ∂f∂x = diffrule(package, f, :x) - #@eval @explicit_intercepts $f Tuple{∇Scalar} - #@eval @inline ∇(::typeof($f), ::Type{Arg{1}}, p, y, ȳ, x::∇Scalar) = ȳ * $∂f∂x - #@eval @inline ∇(::typeof($f), ::Type{Arg{1}}, x::∇Scalar) = $∂f∂x - elseif arity == 2 - push!(binary_sensitivities, (package, f)) - ∂f∂x, ∂f∂y = diffrule(package, f, :x, :y) - #@eval @explicit_intercepts $f Tuple{∇Scalar, ∇Scalar} - #@eval ∇(::typeof($f), ::Type{Arg{1}}, p, z, z̄, x::∇Scalar, y::∇Scalar) = z̄ * $∂f∂x - #@eval ∇(::typeof($f), ::Type{Arg{2}}, p, z, z̄, x::∇Scalar, y::∇Scalar) = z̄ * $∂f∂y - else - error("Cannot implement sensitivity for $package.$f: arity $arity not supported.") - end -end - # Add method to resolve exponentiation ambiguity. ^(n::Node{<:Real}, p::Integer) = invoke(^, Tuple{Node{<:Real}, Real}, n, p) diff --git a/test/runtests.jl b/test/runtests.jl index c79880ce..7aaecc50 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,31 @@ derivative_via_frule(f, x) = last(Nabla.frule((Nabla.NO_FIELDS, 1.0), f, x)) @test derivative_via_frule(sin, 0) == 1 @test derivative_via_frule(sin, 1.2) == derivative_via_frule(sin, 2π + 1.2) +# These are the core scalar sensitives Nabla expects to have defined +# we test against them both for sensitives/scalar.jl and in sensitivities/functional.jl + +const UNARY_SCALAR_SENSITIVITIES = [ + # Base: + +, -, abs, abs2, acos, acosd, acosh, acot, acotd, acoth, acsc, acscd, acsch, asec, + asecd, asech, asin, asind, asinh, atand, atanh, cbrt, cos, cosd, cosh, cospi, cot,cotd, + coth, csc, cscd, csch, deg2rad, exp, exp10, exp2, expm1, inv, log, log10, log2, + rad2deg, sec, secd, sech, sin, sind, sinh, sinpi, sqrt, tan, tand, tanh, transpose, + # SpecialFunctions.jl: + airyai, airyaiprime, airybi, airybiprime, besselj0, besselj1, bessely0, bessely1, + dawson, digamma, erf, erfc, erfcinv, erfcx, erfi, erfinv, gamma, invdigamma, lgamma, + trigamma, +] + +const BINARY_SCALAR_SENSITIVITIES = [ + # Base: + *, +, -, /, \, ^, hypot, max, min, + # SpecialFunctions.jl: + besseli, besselj, besselk, bessely, beta, lbeta, polygamma, +] + +const ONLY_DIFF_IN_SECOND_ARG_SENSITIVITIES = [ + besseli, besselj, besselk, bessely, polygamma +] @testset "Nabla.jl" begin diff --git a/test/sensitivities/functional/functional.jl b/test/sensitivities/functional/functional.jl index 6fac7646..2803e44f 100644 --- a/test/sensitivities/functional/functional.jl +++ b/test/sensitivities/functional/functional.jl @@ -1,6 +1,3 @@ -using SpecialFunctions -using DiffRules: diffrule, hasdiffrule - @testset "Functional" begin let rng = MersenneTwister(123456) import Nabla.fmad @@ -20,12 +17,12 @@ using DiffRules: diffrule, hasdiffrule s = broadcast(f, x_) return ∇(s, oneslike(unbox(s)))[x_] ≈ derivative_via_frule.(f, x) end - @testset "$package.$f" for (package, f) in Nabla.unary_sensitivities - domain = domain1(eval(f)) + @testset "$f" for f in UNARY_SCALAR_SENSITIVITIES + domain = domain1(f) domain === nothing && error("Could not determine domain for $f.") x_dist = Uniform(domain...) x = rand(rng, x_dist, 100) - @test check_unary_broadcast(eval(f), x) + @test check_unary_broadcast(f, x) end # Check that `broadcast` returns the correct gradient under each implemented binary @@ -66,26 +63,18 @@ using DiffRules: diffrule, hasdiffrule # @test ∇s[x_] ≈ ∇x # @test ∇s[y_] ≈ ∇y end - @testset "$package.$f" for (package, f) in Nabla.binary_sensitivities - - # TODO: More care needs to be taken to test the following. - if hasdiffrule(package, f, 2) - ∂f∂x, ∂f∂y = diffrule(package, f, :x, :y) - else - ∂f∂x, ∂f∂y = :∂f∂x, :∂f∂y - end - + @testset "$f" for f in BINARY_SCALAR_SENSITIVITIES # TODO: Implement the edge cases for functions differentiable in only either # argument. - (∂f∂x == :NaN || ∂f∂y == :NaN) && continue - domain = domain2(eval(f)) + f in ONLY_DIFF_IN_SECOND_ARG_SENSITIVITIES && continue + domain = domain2(f) domain === nothing && error("Could not determine domain for $f.") (x_lb, x_ub), (y_lb, y_ub) = domain x_dist, y_dist = Uniform(x_lb, x_ub), Uniform(y_lb, y_ub) x, y = rand(rng, x_dist, 100), rand(rng, y_dist, 100) - check_binary_broadcast(eval(f), x, y) - check_binary_broadcast(eval(f), rand(rng, x_dist), y) - check_binary_broadcast(eval(f), x, rand(rng, y_dist)) + check_binary_broadcast(f, x, y) + check_binary_broadcast(f, rand(rng, x_dist), y) + check_binary_broadcast(f, x, rand(rng, y_dist)) end # let # Ternary functions (because it's useful to check I guess.) @@ -161,12 +150,12 @@ using DiffRules: diffrule, hasdiffrule @test unbox(z_) == f.(x) @test ∇(z_)[x_] == ∇(broadcast(f, x_))[x_] end - for (package, f) in Nabla.unary_sensitivities - domain = domain1(eval(f)) + for f in UNARY_SCALAR_SENSITIVITIES + domain = domain1(f) domain === nothing && error("Could not determine domain for $f.") x_dist = Uniform(domain...) - check_unary_dot(eval(f), rand(rng, x_dist)) - check_unary_dot(eval(f), rand(rng, x_dist, 100)) + check_unary_dot(f, rand(rng, x_dist)) + check_unary_dot(f, rand(rng, x_dist, 100)) end # Check that the dot notation works as expected for all of the binary functions in @@ -185,30 +174,25 @@ using DiffRules: diffrule, hasdiffrule @test ∇(z_)[x_] == ∇(broadcast(f, x_, y_))[x_] @test ∇(z_)[y_] == ∇(broadcast(f, x_, y_))[y_] end - for (package, f) in Nabla.binary_sensitivities + for f in BINARY_SCALAR_SENSITIVITIES # TODO: More care needs to be taken to test the following. - f in [:atan, :mod, :rem] && continue - if hasdiffrule(package, f, 2) - ∂f∂x, ∂f∂y = diffrule(package, f, :x, :y) - else - ∂f∂x, ∂f∂y = :∂f∂x, :∂f∂y - end + f in [atan, mod, rem] && continue # TODO: Implement the edge cases for functions differentiable in only either # argument. - (∂f∂x == :NaN || ∂f∂y == :NaN) && continue - domain = domain2(eval(f)) + f in ONLY_DIFF_IN_SECOND_ARG_SENSITIVITIES && continue + domain = domain2(f) domain === nothing && error("Could not determine domain for $f.") (x_lb, x_ub), (y_lb, y_ub) = domain x_distr = Uniform(x_lb, x_ub) y_distr = Uniform(y_lb, y_ub) x = rand(rng, x_distr, 100) y = rand(rng, y_distr, 100) - check_binary_dot(eval(f), x, y) - check_binary_dot(eval(f), rand(rng, x_distr), y) - check_binary_dot(eval(f), x, rand(rng, y_distr)) - check_binary_dot(eval(f), Ref(rand(rng, x_distr)), y) - check_binary_dot(eval(f), x, Ref(rand(rng, y_distr))) - check_binary_dot(eval(f), rand(rng, x_distr), rand(rng, y_distr)) + check_binary_dot(f, x, y) + check_binary_dot(f, rand(rng, x_distr), y) + check_binary_dot(f, x, rand(rng, y_distr)) + check_binary_dot(f, Ref(rand(rng, x_distr)), y) + check_binary_dot(f, x, Ref(rand(rng, y_distr))) + check_binary_dot(f, rand(rng, x_distr), rand(rng, y_distr)) end # test with other broadcast styles diff --git a/test/sensitivities/scalar.jl b/test/sensitivities/scalar.jl index af34f3d0..7ab9e58f 100644 --- a/test/sensitivities/scalar.jl +++ b/test/sensitivities/scalar.jl @@ -1,5 +1,3 @@ -using DiffRules: diffrule, hasdiffrule - @testset "Scalar domains" begin @test in_domain(sin, 10.) @test in_domain(cos, 10.) @@ -25,9 +23,9 @@ end @test ∇(identity, Arg{1}, 5.0) == 1.0 end - unary_check(f, x) = check_errs(eval(f), ȳ, x, v) - @testset "$package.$f" for (package, f) in Nabla.unary_sensitivities - domain = domain1(eval(f)) + unary_check(f, x) = check_errs(f, ȳ, x, v) + @testset "$f" for f in UNARY_SCALAR_SENSITIVITIES + domain = domain1(f) domain === nothing && error("Could not determine domain for $f.") lb, ub = domain randx = () -> rand(rng) * (ub - lb) + lb @@ -37,21 +35,10 @@ end end end - @testset "$package.$f" for (package, f) in Nabla.binary_sensitivities - - # This is a hack. Sensitivities added in Nabla don't persist upon reloading the - # package, so we can't query them here. It happens to be the case that all such - # sensitivities are differentiable in both arguments, so we can just set them - # to "not-NaN" in such cases. - if hasdiffrule(package, f, 2) - ∂f∂x, ∂f∂y = diffrule(package, f, :x, :y) - else - ∂f∂x, ∂f∂y = :∂f∂x, :∂f∂y - end - - if ∂f∂x == :NaN && ∂f∂y != :NaN - # Assume that the first argument is integer-valued. - domain = domain1(y -> eval(f)(0, y)) + @testset "$f" for f in BINARY_SCALAR_SENSITIVITIES + if f in ONLY_DIFF_IN_SECOND_ARG_SENSITIVITIES + # First argument is not differentiable, it is integer-valued. + domain = domain1(y -> f(0, y)) domain === nothing && error("Could not determine domain for $f.") lb, ub = domain randx = () -> rand(rng, 0:5) @@ -59,32 +46,18 @@ end for _ in 1:10 x = randx() - @test check_errs(y -> eval(f)(x, y), ȳ, randy(), v) - end - elseif ∂f∂x != :NaN && ∂f∂y == :NaN - # Assume that the second argument is integer-valued. - domain = domain1(x -> eval(f)(x, 0)) - domain === nothing && error("Could not determine domain for $f.") - lb, ub = domain - randx = () -> rand(rng) * (ub - lb) + lb - randy = () -> rand(rng, 0:5) - - for _ in 1:10 - y = randy() - @test check_errs(x -> eval(f)(x, y), randx(), ȳ, v) + @test check_errs(y -> f(x, y), ȳ, randy(), v) end - elseif ∂f∂x != :NaN && ∂f∂y != :NaN - domain = domain2(eval(f)) + else # Both arguments are differentiable + domain = domain2(f) domain === nothing && error("Could not determine domain for $f.") (x_lb, x_ub), (y_lb, y_ub) = domain randx = () -> rand(rng) * (x_ub - x_lb) + x_lb randy = () -> rand(rng) * (y_ub - y_lb) + y_lb for _ in 1:10 - @test check_errs(eval(f), z̄, (randx(), randy()), (v, v)) + @test check_errs(f, z̄, (randx(), randy()), (v, v)) end - else - error("Cannot test $f: $f is not differentiable in either argument.") end end From c0a7d8128c5713651d07a03acc657c50a620e8f0 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 21 Oct 2020 13:58:33 +0100 Subject: [PATCH 26/71] remove testing scratch file --- test/scratch.jl | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 test/scratch.jl diff --git a/test/scratch.jl b/test/scratch.jl deleted file mode 100644 index 64944c48..00000000 --- a/test/scratch.jl +++ /dev/null @@ -1,3 +0,0 @@ -using Nabla, Test, Random, ChainRulesCore, SpecialFunctions, LinearAlgebra -using DiffRules: diffrule, hasdiffrule -using Nabla: unbox From a16a64b0db49eb6e57bef0e088f072f242fdae4c Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 21 Oct 2020 14:08:07 +0100 Subject: [PATCH 27/71] stop tracking ExprTools as a submodule --- dev/ExprTools | 1 - 1 file changed, 1 deletion(-) delete mode 160000 dev/ExprTools diff --git a/dev/ExprTools b/dev/ExprTools deleted file mode 160000 index cba2e159..00000000 --- a/dev/ExprTools +++ /dev/null @@ -1 +0,0 @@ -Subproject commit cba2e15975636c8502e40c346ad2597266403128 From fd3c4975cc3beb79f487e01ab8a2efe0143ffef4 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 21 Oct 2020 16:14:00 +0100 Subject: [PATCH 28/71] make reduce tests use the list of UNITARY sensitivities --- test/sensitivities/functional/reduce.jl | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/test/sensitivities/functional/reduce.jl b/test/sensitivities/functional/reduce.jl index 291ea8ee..1cffb377 100644 --- a/test/sensitivities/functional/reduce.jl +++ b/test/sensitivities/functional/reduce.jl @@ -1,17 +1,14 @@ @testset "Reduce" begin let rng = MersenneTwister(123456) - import Nabla.fmad - # Check that `mapreduce`, `mapfoldl`and `mapfoldr` work as expected with all unary # functions, some composite functions which use FMAD under both `+` and `*`. let N = 3 for functional in (mapreduce, mapfoldl, mapfoldr) # Sensitivities implemented in Base. - for (package, f) in Nabla.unary_sensitivities + for f in UNARY_SCALAR_SENSITIVITIES # Generate some data and get the function to be mapped. - f = eval(f) domain = domain1(f) domain === nothing && error("Could not determine domain for $f.") lb, ub = domain @@ -24,7 +21,7 @@ end # Some composite sensitivities. - composite_functions = (x->5x, x->1 / (1 + x), x->10+x) + composite_functions = (x->5x, x->1 / (1 + x), x->10+x) for f in composite_functions # Generate some data. @@ -47,7 +44,6 @@ # Check that `reduce`, `foldl` and `foldr` work as expected for `+` and `*`. let for functional in (reduce, foldl, foldr) - # Test `+`. x = randn(rng, 100) x_ = Leaf(Tape(), x) @@ -61,10 +57,8 @@ # and some composite functions which use FMAD. let N = 5 # Sensitivities implemented in Base. - for (package, f) in Nabla.unary_sensitivities - + for f in UNARY_SCALAR_SENSITIVITIES # Generate some data and get the function to be mapped. - f = eval(f) domain = domain1(f) domain === nothing && error("Could not determine domain for $f.") lb, ub = domain From c6b4d36dbf5f3b656f10de2317109ea606d83262 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 21 Oct 2020 16:14:11 +0100 Subject: [PATCH 29/71] sortout Project.toml --- Project.toml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 4921fb92..2d24fe0c 100644 --- a/Project.toml +++ b/Project.toml @@ -5,10 +5,8 @@ version = "0.12.3" [deps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -16,9 +14,12 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -FDM = "^0.6" +ChainRules = "0.7.28" +ChainRulesCore = "0.9.17" +ExprTools = "0.1.3" +FDM = "0.6.1" ForwardDiff = "0.10.12" -SpecialFunctions = ">=0.5.0" +SpecialFunctions = "0.9, 0.10" julia = "^1.0" [extras] From f2c309cdd3182612bd6fb87902473aabe6ebf65c Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 21 Oct 2020 17:06:36 +0100 Subject: [PATCH 30/71] remove import of removed partition functions from tests --- test/sensitivities/linalg/factorization/cholesky.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/sensitivities/linalg/factorization/cholesky.jl b/test/sensitivities/linalg/factorization/cholesky.jl index 73cc0fa0..d3f95993 100644 --- a/test/sensitivities/linalg/factorization/cholesky.jl +++ b/test/sensitivities/linalg/factorization/cholesky.jl @@ -1,6 +1,4 @@ @testset "Cholesky" begin - - import Nabla: level2partition, level3partition # Check sensitivities for lower-triangular version. let rng = MersenneTwister(123456), N = 10 for _ in 1:10 From 773908fe2a04748b075cccc5d93c8b15f9a69e95 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 21 Oct 2020 17:37:42 +0100 Subject: [PATCH 31/71] make tests not overwrite functions --- test/core.jl | 39 +++++++++++++++++++++---------------- test/sensitivities/array.jl | 5 +++-- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/test/core.jl b/test/core.jl index 510b31a2..a8533f25 100644 --- a/test/core.jl +++ b/test/core.jl @@ -102,23 +102,28 @@ end # testset Tape @testset "Check that functions involving `isapprox` can be differentiated" begin - f(x) = x ≈ 5.0 ? 1.0 : 3.0 * x - g(x) = 5.0 * x - h(x) = g(x) ≈ 25.0 ? x : f(x) + g(x) - ∇f = ∇(f) - ∇h = ∇(h) - @test ∇f(5.0) == (0.0,) - @test ∇f(6.0) == (3.0,) - @test ∇h(5.0) == (1.0,) - @test ∇h(6.0) == (8.0,) - f(x) = x ≈ [5.0] ? 1.0 : 3.0 * sum(x) - ∇f = ∇(f) - @test ∇f([5.0]) == ([0.0],) - @test ∇f([6.0]) == ([3.0],) - f(x, y) = x ≈ y ? 2y : 3x - ∇f = ∇(f) - @test ∇f(5.0, 5.0) == (0.0, 2.0) - @test ∇f(6.0, 5.0) == (3.0, 0.0) + @testset "First" begin + f(x) = x ≈ 5.0 ? 1.0 : 3.0 * x + g(x) = 5.0 * x + h(x) = g(x) ≈ 25.0 ? x : f(x) + g(x) + ∇f = ∇(f) + ∇h = ∇(h) + @test ∇f(5.0) == (0.0,) + @test ∇f(6.0) == (3.0,) + @test ∇h(5.0) == (1.0,) + @test ∇h(6.0) == (8.0,) + end + + @testset "Second" begin + f(x) = x ≈ [5.0] ? 1.0 : 3.0 * sum(x) + ∇f = ∇(f) + @test ∇f([5.0]) == ([0.0],) + @test ∇f([6.0]) == ([3.0],) + f(x, y) = x ≈ y ? 2y : 3x + ∇f = ∇(f) + @test ∇f(5.0, 5.0) == (0.0, 2.0) + @test ∇f(6.0, 5.0) == (3.0, 0.0) + end end @testset "Check that functions with extra, unused variables can be differentiated" begin diff --git a/test/sensitivities/array.jl b/test/sensitivities/array.jl index a9971782..8bc529f8 100644 --- a/test/sensitivities/array.jl +++ b/test/sensitivities/array.jl @@ -13,9 +13,10 @@ a = rand(3, 2); b = rand(3); c = rand(3, 3); f(a, b, c) = sum(hcat(2*a, 3*b, 4*c)) @test ∇(f)(a,b,c) == (2*ones(3, 2), 3*ones(3), 4*ones(3, 3)) + a = rand(2, 4); b = rand(1, 4); c = rand(3, 4); - f(a, b, c) = sum(vcat(2*a, 3*b, 4*c)) - @test ∇(f)(a,b,c) == (2*ones(2, 4), 3*ones(1, 4), 4*ones(3, 4)) + g(a, b, c) = sum(vcat(2*a, 3*b, 4*c)) + @test ∇(g)(a,b,c) == (2*ones(2, 4), 3*ones(1, 4), 4*ones(3, 4)) @test check_errs(x->fill(x, 4, 4), randn(4, 4), randn(), randn()) @test check_errs(x->fill(x, (4, 4)), randn(4, 4), randn(), randn()) From 2d84b3e2c125f3df8c8ae9b718d51badd63eb52c Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 21 Oct 2020 17:38:00 +0100 Subject: [PATCH 32/71] drop support for Special Functions 0.9 --- Project.toml | 2 +- test/runtests.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 2d24fe0c..76ad871c 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,7 @@ ChainRulesCore = "0.9.17" ExprTools = "0.1.3" FDM = "0.6.1" ForwardDiff = "0.10.12" -SpecialFunctions = "0.9, 0.10" +SpecialFunctions = "0.10" julia = "^1.0" [extras] diff --git a/test/runtests.jl b/test/runtests.jl index 7aaecc50..d67f13f1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,7 +26,7 @@ const UNARY_SCALAR_SENSITIVITIES = [ rad2deg, sec, secd, sech, sin, sind, sinh, sinpi, sqrt, tan, tand, tanh, transpose, # SpecialFunctions.jl: airyai, airyaiprime, airybi, airybiprime, besselj0, besselj1, bessely0, bessely1, - dawson, digamma, erf, erfc, erfcinv, erfcx, erfi, erfinv, gamma, invdigamma, lgamma, + dawson, digamma, erf, erfc, erfcinv, erfcx, erfi, erfinv, gamma, invdigamma, loggamma, trigamma, ] @@ -34,7 +34,7 @@ const BINARY_SCALAR_SENSITIVITIES = [ # Base: *, +, -, /, \, ^, hypot, max, min, # SpecialFunctions.jl: - besseli, besselj, besselk, bessely, beta, lbeta, polygamma, + besseli, besselj, besselk, bessely, beta, polygamma, ] const ONLY_DIFF_IN_SECOND_ARG_SENSITIVITIES = [ From e7554356512eb295d1416e3b7a4a9572e7f06633 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 21 Oct 2020 19:52:12 +0100 Subject: [PATCH 33/71] Add docstrings and seperate original_sig from unonized_sig --- src/sensitivities/chainrules.jl | 164 ++++++++++++++++++++++++++------ 1 file changed, 136 insertions(+), 28 deletions(-) diff --git a/src/sensitivities/chainrules.jl b/src/sensitivities/chainrules.jl index 7a54492e..a75daabd 100644 --- a/src/sensitivities/chainrules.jl +++ b/src/sensitivities/chainrules.jl @@ -1,5 +1,45 @@ #using InteractiveUtils +""" + generate_overload(sig) + +Takes a signature tuple type, for a primal function that has an `rrule` and generates +appropriate overloads for Nabla's `Node` types to allow performing AD. +This is the hook function for `ChainRulesCore.on_new_rule(hook, rrule)`. + +For example, if `generate_overload` is called on `Tuple{typeof{identity}, Any}` +then approximately the following code is `@eval`ed: + +```julia +function Base.identity(x1::Node{<:Any}; kwargs...) + args = (x1,) + (primal_val, pullback) = rrule(op, unbox.(args)...; kwargs...) + tape = get_tape(args) + branch = Branch(primal_val, op, args, kwargs.data, tape, length(tape) + 1, pullback) + push!(tape, branch) + return branch +end + +@inline function preprocess( + op::typeof(identity), y::Branch, ȳ, x1::Union{Any, Node{<:Any}} +) + pullback = getfield(y, :pullback) + @assert pullback !== nothing "pullback not set, ..." + return pullback(ȳ) +end + + +@inline function ∇( + op::typeof(identity), ::Type{Arg{N}}, p, ::Any, ::Any, x1::Union{Any, Node{<:Any}}; + kwargs... +) where N + return p[N + 1] # skip dself (N==1) and we don't support functors +end +``` + +The real code evaluated is a little more complex with macro-hygine and handling for +various complicated type-signatures, including multiple arguments. +""" function generate_overload(sig) opT, argTs = Iterators.peel(ExprTools.parameters(sig)) opT <: Core.Builtin && return false # can't do operater overloading for builtins @@ -24,21 +64,43 @@ function generate_overload(sig) )) && return false - signature_def = build_def(sig) - original_signature_args = signature_def[:args] - signature_def[:args] = unionise_sig.(original_signature_args) + original_signature_def = build_def(sig) + unionized_signature_def = copy(original_signature_def) + unionized_signature_def[:args] = unionise_sig.(original_signature_def[:args]) fdef = quote - @inline $(preprocess_declaration(signature_def)) - @inline $(∇_declaration(signature_def)) - $(overload_declarations!(signature_def, original_signature_args)...) + @inline $(preprocess_declaration(unionized_signature_def)) + @inline $(∇_declaration(unionized_signature_def)) + $(overload_declarations!(original_signature_def)...) end - opT <: typeof(svd) && @show fdef + # for debugging uncomment and edit the below to look at the generated code + # opT <: typeof(identity) && @show fdef eval(fdef) return true end -"like `ExprTools.signature` but on a signature type-tuple, not a Method" +""" + build_def(sig) + +Like `ExprTools.signature` but on a signature type-tuple, not a Method. +For `sig` being a tuple-type representing a methods type signature, this generates a +dictionary that can be passes to `ExprTools.combinedef` to define that function, +Provided that you assign the `:body` key on the dictionary first. + +For example: +```julia +julia> Nabla.build_def(Tuple{typeof(identity), Any}) +Dict{Symbol, Any} with 2 entries: + :name => :(op::typeof(identity)) + :args => Expr[:(x1::Any)] + +julia> Nabla.build_def(Tuple{typeof(+), Vector{T}, Vector{T}} where T<:Number) +Dict{Symbol, Any} with 3 entries: + :name => :(op::typeof(+)) + :args => Expr[:(x1::Array{var"##T#5492", 1}), :(x2::Array{var"##T#5492", 1})] + :whereparams => Any[:(var"##T#5492" <: Number)] +``` +""" function build_def(orig_sig) sig = _truely_rename_unionall(orig_sig) # TODO ExprTools possibly should do this for `signature(::Method)`` also def = Dict{Symbol, Any}() @@ -53,25 +115,35 @@ function build_def(orig_sig) def[:whereparams] = ExprTools.where_parameters(sig) def = Dict{Symbol, Any}(k => v for (k, v) in def if v !== nothing) # filter out nonfields. - return def end -"this overwrites and ruins `signature_def` for others" -function overload_declarations!(signature_def, original_signature_args) +""" + overload_declarations!(original_signature_def) + +Given a `signature_def` dictionary as returned by [`build_def`](@ref) this returns +the ASTs for the overloads of the primal functions to accept `Nabla.Node`s. +The `signature_def` should *not* have been unionized, as this function will instead generate +1 method for each position a node could be in. - # Our macro-hygine is not complete here. +Note: this mutate `signature_def` and so should not be called if others functions also need +to use it after. +""" +function overload_declarations!(signature_def) + # Our manual macro-hygine is not complete here. # the argument names and `op`, `tape` `args`, `kwargs` etc could conflict with # where-params. but for sake of outputting readable code we are not gensyming everything # chance of conflict seems low as where-params are normally upper-case. @assert(signature_def[:name].head == :(::)) @assert(signature_def[:name].args[1] == :op) + original_signature_args = signature_def[:args] signature_def[:kwargs] = [:(kwargs...)] signature_def[:body] = quote - args = $(_args_tuple(signature_def[:args])) - # @show InteractiveUtils.@which rrule(op, unbox.(args)...) + args = $(_args_tuple(original_signature_args)) + # uncommenting the below to is useful for debugging what rrule is being hit. + # @show InteractiveUtils.@which rrule(op, unbox.(args)...) primal_val, pullback = rrule(op, unbox.(args)...; kwargs...) tape = get_tape(args) @@ -99,18 +171,27 @@ function overload_declarations!(signature_def, original_signature_args) return definitions end -function preprocess_declaration(signature_def) - # basically want to generate things like: - # `preprocess(f::$opT, y::Branch, ȳ, $((arg_sig)...)) = y.pullback(ȳ)` - # We need the pullback value to use to compute the sensitivies of the inputs +""" + preprocess_declaration(unionized_signature_def) + +Generates AST for overloads for [`Nabla.preprocess`](@ref) that will call the pullback +stored on the `Branch`. +Roughly speaking generated code like: +`preprocess(f::opT, y::Branch, ȳ, xs...)) = y.pullback(ȳ)` +We need the pullback value to use to compute the sensitivies of the primal inputs, that will +be queries by `∇(::opT, ::Type{Arg{N}}, p, y, ȳ, xs...)` where `p` is that pullback value +return by the `preprocess` function. + +Note that the `unionised_signature_def` must already have been unionised to accept `Node`s. +""" +function preprocess_declaration(signature_def) op = signature_def[:name] args = signature_def[:args] y = gensym(:y) ȳ = gensym(:ȳ) - # preprocess has a broadly similar definition, signature-wise, to the overload. - # so we copy it to get whereparams etc + # preprocess has a similar definition, signature-wise, to what is in signature_def preprocess_def = Dict{Symbol, Any}( :name => :preprocess, :args => [op, :($y::Branch), ȳ, args...], @@ -128,12 +209,18 @@ function preprocess_declaration(signature_def) return ExprTools.combinedef(preprocess_def) end +""" + ∇_declaration(unionised_signature_def) -function ∇_declaration(signature_def) - # basically want to generate things like: - # `∇(::$opT, ::Type{Arg{N}}, p, y, ȳ, xs...) where N = p[N+1] # Skip dself` - # We need the pullback value to use to compute the sensitivies of the inputs +Generates that AST for the overload of the `∇` function which returns the gradient for +specified arguments. +Basically this generates things like: +`∇(::opT, ::Type{Arg{N}}, p, y, ȳ, xs...) where N = p[N+1] # Skip dself` +where `p` is the pullback computed by [`preprocess`](@ref) +Note that the `unionised_signature_def` must already have been unionised to accept `Node`s. +""" +function ∇_declaration(signature_def) # For readability lets name all the parts, NB: this is being a bit too cute. op = signature_def[:name] args = signature_def[:args] @@ -159,15 +246,22 @@ end For `arg_exprs` being a list of arguments expressions from a signature, of a form such as `[:(x::Int), :(y::Float64), :(z::Vararg)]`, returns a tuple expresion containing all -of them by name; while correctly handling splatting, -e.g for prior example `:((x, y, z...))` +of them by name; while correctly handling splatting, for things that are `Vararg` typed. +e.g for the prior example `:((x, y, z...))` """ function _args_tuple(arg_exprs) ret = Expr(:tuple) ret.args = map(arg_exprs) do arg @assert Meta.isexpr(arg, :(::), 2) arg_name, Texpr = arg.args - if Texpr == :Vararg || (Meta.isexpr(Texpr, :curly) && Texpr.args[1] == :Vararg) + if Meta.isexpr(Texpr, :where) # remove where from `Vararg{T, N} where {T, N}` + Texpr = Texpr.args[1] + end + # Needs to be after removing `where` + if Meta.isexpr(Texpr, :curly) # remove `{T, N}` from `Vararg{T,N``` + Texpr = Texpr.args[1] + end + if Texpr == :Vararg return :($arg_name...) else return arg_name @@ -176,7 +270,21 @@ function _args_tuple(arg_exprs) return ret end -"like `Base.rename_unionall`, but actually gensyms the name also, not just a new instance" +""" + _truely_rename_unionall(@nospecialize(u)) + +For `u` being a `UnionAll` this replaces every `TypeVar` with a new one with a `gensym`ed +names. This is useful for manual macro-hygine. + +Example: +``` +julia> Nabla._truely_rename_unionall(Array{T, N} where {T<:Number, N}) +Array{var"##T#2881", var"##N#2880"} where var"##N#2880" where var"##T#2881"<:Number +``` + +Note that the similar `Base.rename_unionall`, does not `gensym` the names just replaces the +instances with new one with identical names. +""" function _truely_rename_unionall(@nospecialize(u)) isa(u,UnionAll) || return u body = _truely_rename_unionall(u.body) From 0fce85be567f9398884ee2feb7b78af2110be625 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 22 Oct 2020 15:05:00 +0100 Subject: [PATCH 34/71] Apply suggestions from code review Co-authored-by: mattBrzezinski --- src/core.jl | 2 +- src/sensitivities/chainrules.jl | 9 ++------- test/code_transformation/util.jl | 2 ++ 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/core.jl b/src/core.jl index e0724d33..a0abab53 100644 --- a/src/core.jl +++ b/src/core.jl @@ -79,7 +79,7 @@ struct Branch{T, B} <: Node{T} kwargs::NamedTuple tape::Tape pos::Int - pullback::B # if we have a rrule pullback for this it is stored here + pullback::B end function Branch(f, args::Tuple, tape::Tape; kwargs...) unboxed = unbox.(args) diff --git a/src/sensitivities/chainrules.jl b/src/sensitivities/chainrules.jl index a75daabd..f21e96bc 100644 --- a/src/sensitivities/chainrules.jl +++ b/src/sensitivities/chainrules.jl @@ -1,4 +1,3 @@ -#using InteractiveUtils """ generate_overload(sig) @@ -42,9 +41,9 @@ various complicated type-signatures, including multiple arguments. """ function generate_overload(sig) opT, argTs = Iterators.peel(ExprTools.parameters(sig)) - opT <: Core.Builtin && return false # can't do operater overloading for builtins + opT <: Core.Builtin && return false # can't do operator overloading for builtins - isabstracttype(opT) || fieldcount(opT) == 0 || return false # not handling functors + isabstracttype(opT) || fieldcount(opT) == 0 || return false # not handling functors isempty(argTs) && return false # we are an operator overloading AD, need operands opT isa DataType && nameof(opT.name.module) == :NaNMath && return false # Don't care about NaNMath @@ -63,7 +62,6 @@ function generate_overload(sig) Base.Broadcast.combine_styles, #TODO should i keep this? )) && return false - original_signature_def = build_def(sig) unionized_signature_def = copy(original_signature_def) unionized_signature_def[:args] = unionise_sig.(original_signature_def[:args]) @@ -125,7 +123,6 @@ Given a `signature_def` dictionary as returned by [`build_def`](@ref) this retur the ASTs for the overloads of the primal functions to accept `Nabla.Node`s. The `signature_def` should *not* have been unionized, as this function will instead generate 1 method for each position a node could be in. - Note: this mutate `signature_def` and so should not be called if others functions also need to use it after. """ @@ -138,7 +135,6 @@ function overload_declarations!(signature_def) @assert(signature_def[:name].args[1] == :op) original_signature_args = signature_def[:args] - signature_def[:kwargs] = [:(kwargs...)] signature_def[:body] = quote args = $(_args_tuple(original_signature_args)) @@ -299,7 +295,6 @@ function _truely_rename_unionall(@nospecialize(u)) end - # Find a tape, ds might be Nodes or might be something else. # All nodes should have the same tape, so the first one will do get_tape(ds) = first(tape(d) for d in ds if d isa Node) diff --git a/test/code_transformation/util.jl b/test/code_transformation/util.jl index 5d54b06e..210ac0d6 100644 --- a/test/code_transformation/util.jl +++ b/test/code_transformation/util.jl @@ -77,7 +77,9 @@ end @testset "node_type" begin + # special case for a redudant local where N in a Vararg @test Nabla.node_type(:(Vararg{Int64, N} where N)) == :(Vararg{Node{<:Int64}}) + @test Nabla.node_type(:Float32) == :(Node{<:Float32}) end end From 6c1c4df7682837622e22b9dea15bd766df184568 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 22 Oct 2020 15:06:02 +0100 Subject: [PATCH 35/71] Apply suggestions from code review Co-authored-by: mattBrzezinski --- src/sensitivities/chainrules.jl | 10 ++++++++-- test/sensitivities/scalar.jl | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/sensitivities/chainrules.jl b/src/sensitivities/chainrules.jl index f21e96bc..efdab434 100644 --- a/src/sensitivities/chainrules.jl +++ b/src/sensitivities/chainrules.jl @@ -38,6 +38,11 @@ end The real code evaluated is a little more complex with macro-hygine and handling for various complicated type-signatures, including multiple arguments. + +It does not generate any code for `rrules` for primal functions that Nabla does not support. +These include: builtin functions, functors, functions without any positional arguments, and functions for working with complex numbers. It also includes a short list of non-differentiable functions that Nabla has special cases for outside of AD such as `size` + +This function returns true or false as to wether or not code was generated. While this has no actual effect in itself, it can be useful for checking how many rules Nabla supports. """ function generate_overload(sig) opT, argTs = Iterators.peel(ExprTools.parameters(sig)) @@ -46,7 +51,7 @@ function generate_overload(sig) isabstracttype(opT) || fieldcount(opT) == 0 || return false # not handling functors isempty(argTs) && return false # we are an operator overloading AD, need operands - opT isa DataType && nameof(opT.name.module) == :NaNMath && return false # Don't care about NaNMath + opT isa DataType && nameof(opT.name.module) == :NaNMath && return false # Don't care about NaNMath # Ignore functions that have complex ranges. This may change when Nabla supports complex # numbers. @@ -74,6 +79,7 @@ function generate_overload(sig) # for debugging uncomment and edit the below to look at the generated code # opT <: typeof(identity) && @show fdef eval(fdef) + return true end @@ -151,7 +157,7 @@ function overload_declarations!(signature_def) # we need to generate a version of this for each place that an arg could be n_args = length(original_signature_args) definitions = Expr[] - for swap_mask in Iterators.product(ntuple(_->(true,false), n_args)...) + for swap_mask in Iterators.product(ntuple(_->(true, false), n_args)...) any(swap_mask) || continue # don't generate if not swapping anything. signature_def[:args] = map(swap_mask, original_signature_args) do swap, orig_arg if swap diff --git a/test/sensitivities/scalar.jl b/test/sensitivities/scalar.jl index 7ab9e58f..9d9ef6d6 100644 --- a/test/sensitivities/scalar.jl +++ b/test/sensitivities/scalar.jl @@ -48,7 +48,7 @@ end x = randx() @test check_errs(y -> f(x, y), ȳ, randy(), v) end - else # Both arguments are differentiable + else # Both arguments are differentiable domain = domain2(f) domain === nothing && error("Could not determine domain for $f.") (x_lb, x_ub), (y_lb, y_ub) = domain From bdd3e4399d5c95091e28111e8e6f5e9d57426d8f Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 22 Oct 2020 17:53:11 +0100 Subject: [PATCH 36/71] move deciding what rules to use into its own function and docstring --- src/sensitivities/chainrules.jl | 56 +++++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/src/sensitivities/chainrules.jl b/src/sensitivities/chainrules.jl index efdab434..5282a60a 100644 --- a/src/sensitivities/chainrules.jl +++ b/src/sensitivities/chainrules.jl @@ -40,18 +40,53 @@ The real code evaluated is a little more complex with macro-hygine and handling various complicated type-signatures, including multiple arguments. It does not generate any code for `rrules` for primal functions that Nabla does not support. -These include: builtin functions, functors, functions without any positional arguments, and functions for working with complex numbers. It also includes a short list of non-differentiable functions that Nabla has special cases for outside of AD such as `size` +See [`should_use_rrule`](@ref) for more details on what rules we do not use. -This function returns true or false as to wether or not code was generated. While this has no actual effect in itself, it can be useful for checking how many rules Nabla supports. +This function returns true or false as to wether or not code was generated. While this has +no actual effect in itself, it can be useful for checking how many rules Nabla supports. """ function generate_overload(sig) + should_use_rrule(sig) || return false + + original_signature_def = build_def(sig) + unionized_signature_def = copy(original_signature_def) + unionized_signature_def[:args] = unionise_sig.(original_signature_def[:args]) + + fdef = quote + @inline $(preprocess_declaration(unionized_signature_def)) + @inline $(∇_declaration(unionized_signature_def)) + $(overload_declarations!(original_signature_def)...) + end + # for debugging uncomment and edit the below to look at the generated code + # opT <: typeof(identity) && @show fdef + eval(fdef) + + return true +end + +""" + should_use_rrule(sig) + +Should we make use of the chainrules `rrule` for the primal function with the given +signature tuple type (`sig`). + +We do not use rules for: + - builtin functions + - functors / closures + - functions without any positional arguments + - functions from the `NaNMath` module + - functions for working with complex numbers. + - Nondifferentiable functions that we define directly on `Node`s better (like `size`) +""" +function should_use_rrule(sig) opT, argTs = Iterators.peel(ExprTools.parameters(sig)) opT <: Core.Builtin && return false # can't do operator overloading for builtins isabstracttype(opT) || fieldcount(opT) == 0 || return false # not handling functors isempty(argTs) && return false # we are an operator overloading AD, need operands - opT isa DataType && nameof(opT.name.module) == :NaNMath && return false # Don't care about NaNMath + # Don't care about NaNMath + opT isa DataType && nameof(opT.name.module) == :NaNMath && return false # Ignore functions that have complex ranges. This may change when Nabla supports complex # numbers. @@ -67,20 +102,7 @@ function generate_overload(sig) Base.Broadcast.combine_styles, #TODO should i keep this? )) && return false - original_signature_def = build_def(sig) - unionized_signature_def = copy(original_signature_def) - unionized_signature_def[:args] = unionise_sig.(original_signature_def[:args]) - - fdef = quote - @inline $(preprocess_declaration(unionized_signature_def)) - @inline $(∇_declaration(unionized_signature_def)) - $(overload_declarations!(original_signature_def)...) - end - # for debugging uncomment and edit the below to look at the generated code - # opT <: typeof(identity) && @show fdef - eval(fdef) - - return true + return true # no exclusion applies end """ From b17190cd05b60f2a78bfb9e7118278ee1149793a Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 23 Oct 2020 16:00:12 +0100 Subject: [PATCH 37/71] Specific error --- test/sensitivities/linalg/factorization/cholesky.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/sensitivities/linalg/factorization/cholesky.jl b/test/sensitivities/linalg/factorization/cholesky.jl index d3f95993..48ec76dc 100644 --- a/test/sensitivities/linalg/factorization/cholesky.jl +++ b/test/sensitivities/linalg/factorization/cholesky.jl @@ -19,7 +19,7 @@ @test getfield(U, :f) == Base.getproperty @test unbox(U) ≈ cholesky(X_).U - @test_throws Exception ∇(X->cholesky(X).info)(X_) + @test_throws MethodError ∇(X->cholesky(X).info)(X_) end let From ba2670ca5e3ccf6ac9ac7c9e9adc384d69eaff87 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 23 Oct 2020 16:36:23 +0100 Subject: [PATCH 38/71] Apply suggestions from code review Co-authored-by: willtebbutt Co-authored-by: mattBrzezinski --- src/code_transformation/util.jl | 2 +- src/core.jl | 6 +++--- src/sensitivities/chainrules.jl | 2 +- src/sensitivities/linalg/diagonal.jl | 2 +- test/code_transformation/differentiable.jl | 1 - test/core.jl | 4 ++-- test/runtests.jl | 12 +++++++++--- test/sensitivities/functional/functional.jl | 2 +- 8 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/code_transformation/util.jl b/src/code_transformation/util.jl index 0b9e6472..55024452 100644 --- a/src/code_transformation/util.jl +++ b/src/code_transformation/util.jl @@ -48,7 +48,7 @@ Returns an expression for the `Node{<:tp}`. e.g. Correctly `Varargs{Real}` becomes `:(Varargs{Node{<:Real}})` -This is a lot like [`unionize_type`](ref) but it doesn't permit the original type anymore. +This is a lot like [`unionise_type`](ref) but it doesn't permit the original type anymore. """ function node_type(tp::Union{Symbol, Expr}) (_tp, _info) = remove_vararg(tp) diff --git a/src/core.jl b/src/core.jl index a0abab53..3219a51a 100644 --- a/src/core.jl +++ b/src/core.jl @@ -67,10 +67,10 @@ f - the function used to generate this Node. args - Values indicating which elements in the tape will require updating by this node. tape - The Tape to which this Branch is assigned. pos - the location of this Branch in the tape to which it is assigned. -pullback::B - if there is a custom primate rule (a `ChainRulesCore.rrule`) then this holds - the pullback to propagates gradients back through the operation, if there is not a rule +pullback::B - if there is a custom primative rule (a `ChainRulesCore.rrule`) then this holds + the pullback to propagate gradients back through the operation. If there is not a rule then this is set to `nothing`. - It also maybe set to `nothing` by legacy Nabla rules that have not moved to ChainRules. + It may also be set to `nothing` by legacy Nabla rules that have not moved to ChainRules. """ struct Branch{T, B} <: Node{T} val::T diff --git a/src/sensitivities/chainrules.jl b/src/sensitivities/chainrules.jl index 5282a60a..caba0a5c 100644 --- a/src/sensitivities/chainrules.jl +++ b/src/sensitivities/chainrules.jl @@ -42,7 +42,7 @@ various complicated type-signatures, including multiple arguments. It does not generate any code for `rrules` for primal functions that Nabla does not support. See [`should_use_rrule`](@ref) for more details on what rules we do not use. -This function returns true or false as to wether or not code was generated. While this has +This function returns true or false as to whether or not code was generated. While this has no actual effect in itself, it can be useful for checking how many rules Nabla supports. """ function generate_overload(sig) diff --git a/src/sensitivities/linalg/diagonal.jl b/src/sensitivities/linalg/diagonal.jl index 055a4ce8..9e0427c0 100644 --- a/src/sensitivities/linalg/diagonal.jl +++ b/src/sensitivities/linalg/diagonal.jl @@ -39,7 +39,7 @@ end # machinery, so it knows how to deal. # TODO: Possibly we should overload `Pair` so that it constructs a `Node{Pair}` then this -# would hit sentitivities that we have defined via ChainRules. +# would hit sensitivities that we have defined via ChainRules. _diagm(x::∇AbstractVector, k::Integer=0) = diagm(k => x) LinearAlgebra.diagm(x::Pair{<:Integer, <:Node{<:∇AbstractVector}}) = _diagm(last(x), first(x)) diff --git a/test/code_transformation/differentiable.jl b/test/code_transformation/differentiable.jl index 7f298dcb..7758e0d6 100644 --- a/test/code_transformation/differentiable.jl +++ b/test/code_transformation/differentiable.jl @@ -68,7 +68,6 @@ skip_line_info(ex) = ex :(x2::Vararg{Union{Int64, Node{<:Int64}}}), ) - # Test Nabla.unionise_struct. Written in terms of Nabla.unionise_arg. @test unionise_struct(:(struct Foo end)) == :(struct Foo end) @test unionise_struct(:(struct Foo{T} end)) == diff --git a/test/core.jl b/test/core.jl index a8533f25..2a093daf 100644 --- a/test/core.jl +++ b/test/core.jl @@ -102,7 +102,7 @@ end # testset Tape @testset "Check that functions involving `isapprox` can be differentiated" begin - @testset "First" begin + @testset "Test Case 1" begin f(x) = x ≈ 5.0 ? 1.0 : 3.0 * x g(x) = 5.0 * x h(x) = g(x) ≈ 25.0 ? x : f(x) + g(x) @@ -114,7 +114,7 @@ end # testset Tape @test ∇h(6.0) == (8.0,) end - @testset "Second" begin + @testset "Test Case 2" begin f(x) = x ≈ [5.0] ? 1.0 : 3.0 * sum(x) ∇f = ∇(f) @test ∇f([5.0]) == ([0.0],) diff --git a/test/runtests.jl b/test/runtests.jl index d67f13f1..993b6b2a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,14 @@ using Nabla -using Test, LinearAlgebra, Statistics, Random, ForwardDiff -using Distributions, BenchmarkTools, SpecialFunctions - using Nabla: unbox, pos, tape, oneslike, zeroslike +using BenchmarkTools +using Distributions +using ForwardDiff +using LinearAlgebra +using Random +using SpecialFunctions +using Statistics +using Test + # Helper function for comparing `Ref`s, since they don't compare equal under `==` ref_equal(a::Ref{T}, b::Ref{T}) where {T} = a[] == b[] diff --git a/test/sensitivities/functional/functional.jl b/test/sensitivities/functional/functional.jl index 2803e44f..c23563e5 100644 --- a/test/sensitivities/functional/functional.jl +++ b/test/sensitivities/functional/functional.jl @@ -66,7 +66,7 @@ @testset "$f" for f in BINARY_SCALAR_SENSITIVITIES # TODO: Implement the edge cases for functions differentiable in only either # argument. - f in ONLY_DIFF_IN_SECOND_ARG_SENSITIVITIES && continue + f in ONLY_DIFF_IN_SECOND_ARG_SENSITIVITIES && continue domain = domain2(f) domain === nothing && error("Could not determine domain for $f.") (x_lb, x_ub), (y_lb, y_ub) = domain From 8e897b9a9d609a47bceebd90fd9f518cdb317bcf Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 2 Nov 2020 16:38:18 +0000 Subject: [PATCH 39/71] Support Special Function 0.9 --- Project.toml | 4 +++- test/runtests.jl | 3 +-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 76ad871c..1fc5df3c 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,8 @@ FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SnoopCompile = "aa65fe97-06da-5843-b5b1-d5d13cad87d2" +SnoopCompileCore = "e2b509da-e806-4183-be48-004708413034" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -19,7 +21,7 @@ ChainRulesCore = "0.9.17" ExprTools = "0.1.3" FDM = "0.6.1" ForwardDiff = "0.10.12" -SpecialFunctions = "0.10" +SpecialFunctions = "0.9, 0.10" julia = "^1.0" [extras] diff --git a/test/runtests.jl b/test/runtests.jl index 993b6b2a..d24f0c6c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -32,8 +32,7 @@ const UNARY_SCALAR_SENSITIVITIES = [ rad2deg, sec, secd, sech, sin, sind, sinh, sinpi, sqrt, tan, tand, tanh, transpose, # SpecialFunctions.jl: airyai, airyaiprime, airybi, airybiprime, besselj0, besselj1, bessely0, bessely1, - dawson, digamma, erf, erfc, erfcinv, erfcx, erfi, erfinv, gamma, invdigamma, loggamma, - trigamma, + dawson, digamma, erf, erfc, erfcinv, erfcx, erfi, erfinv, gamma, invdigamma, trigamma, ] const BINARY_SCALAR_SENSITIVITIES = [ From ff44f9c0984a4425c3b7245d97342343a347ce9b Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 2 Nov 2020 16:41:05 +0000 Subject: [PATCH 40/71] Apply suggestions from code review --- test/code_transformation/differentiable.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/code_transformation/differentiable.jl b/test/code_transformation/differentiable.jl index 7758e0d6..2d841020 100644 --- a/test/code_transformation/differentiable.jl +++ b/test/code_transformation/differentiable.jl @@ -63,7 +63,7 @@ skip_line_info(ex) = ex @test unionise_sig(:(foo(x::T))) == :(foo($(unionise_arg(:(x::T))))) @test unionise_sig(:(foo(x::T) where T)) == :(foo($(unionise_arg(:(x::T)))) where T) - @test isequal( # special case for a redudant where N in a Vararg + @test isequal( # special case for a redudant `where N` in a Vararg Nabla.unionise_sig(:(x2::(Vararg{Int64, N} where N))), :(x2::Vararg{Union{Int64, Node{<:Int64}}}), ) From 20db25530fd280c05c27523d4de8aaac4bb76808 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 2 Nov 2020 16:42:42 +0000 Subject: [PATCH 41/71] Remove directly applied (commented out) broadcasting tests Co-authored-by: mattBrzezinski --- test/sensitivities/functional/functional.jl | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/test/sensitivities/functional/functional.jl b/test/sensitivities/functional/functional.jl index c23563e5..3c6dce1b 100644 --- a/test/sensitivities/functional/functional.jl +++ b/test/sensitivities/functional/functional.jl @@ -33,11 +33,7 @@ s = broadcast(f, x_, y_) o = oneslike(unbox(s)) ∇s = ∇(s, o) -# ∇x = sum(broadcast((z, z̄, x, y)->∇(f, Arg{1}, nothing, z, z̄, x, y), unbox(s), o, x, y)) -# ∇y = broadcast((z, z̄, x, y)->∇(f, Arg{2}, nothing, z, z̄, x, y), unbox(s), o, x, y) @test broadcast(f, x, y) == unbox(s) -# @test ∇s[x_] ≈ ∇x -# @test ∇s[y_] ≈ ∇y end function check_binary_broadcast(f, x::Real, y) tape = Tape() @@ -45,11 +41,7 @@ s = broadcast(f, x_, y_) o = oneslike(unbox(s)) ∇s = ∇(s, o) -# ∇x = sum(broadcast((z, z̄, x, y)->∇(f, Arg{1}, nothing, z, z̄, x, y), unbox(s), o, x, y)) -# ∇y = broadcast((z, z̄, x, y)->∇(f, Arg{2}, nothing, z, z̄, x, y), unbox(s), o, x, y) @test broadcast(f, x, y) == unbox(s) -# @test ∇s[x_] ≈ ∇x -# @test ∇s[y_] ≈ ∇y end function check_binary_broadcast(f, x, y::Real) tape = Tape() @@ -57,11 +49,7 @@ s = broadcast(f, x_, y_) o = oneslike(unbox(s)) ∇s = ∇(s, o) -# ∇x = sum(broadcast((z, z̄, x, y)->∇(f, Arg{1}, nothing, z, z̄, x, y), unbox(s), o, x, y)) -# ∇y = broadcast((z, z̄, x, y)->∇(f, Arg{2}, nothing, z, z̄, x, y), unbox(s), o, x, y) @test broadcast(f, x, y) == unbox(s) -# @test ∇s[x_] ≈ ∇x -# @test ∇s[y_] ≈ ∇y end @testset "$f" for f in BINARY_SCALAR_SENSITIVITIES # TODO: Implement the edge cases for functions differentiable in only either From 67455fe978acea41a34e3f3c795fc1651ef91bd8 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 13 Nov 2020 13:46:59 +0000 Subject: [PATCH 42/71] Allow SpecialFunction 0.8 for julia 1.0 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1fc5df3c..b11fa8f0 100644 --- a/Project.toml +++ b/Project.toml @@ -21,7 +21,7 @@ ChainRulesCore = "0.9.17" ExprTools = "0.1.3" FDM = "0.6.1" ForwardDiff = "0.10.12" -SpecialFunctions = "0.9, 0.10" +SpecialFunctions = "0.8, 0.9, 0.10" julia = "^1.0" [extras] From 5151ad89f2298dc4374e6ca3e8ef917759b30d08 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 13 Nov 2020 14:27:50 +0000 Subject: [PATCH 43/71] Delete redundant rule for identity --- src/sensitivities/scalar.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/sensitivities/scalar.jl b/src/sensitivities/scalar.jl index 57a0102f..26927526 100644 --- a/src/sensitivities/scalar.jl +++ b/src/sensitivities/scalar.jl @@ -1,11 +1,3 @@ -# Hand code the identity because it's really fundamental. It doesn't need to generate a new -# node on the computational graph since it does nothing, but it is useful to have it's -# gradient implemented for use in higher-order functions. -import Base.identity -@explicit_intercepts identity Tuple{Any} -@inline ∇(::typeof(identity), ::Type{Arg{1}}, p, y, ȳ, x) = ȳ -@inline ∇(::typeof(identity), ::Type{Arg{1}}, x::Real) = one(x) - # Add method to resolve exponentiation ambiguity. ^(n::Node{<:Real}, p::Integer) = invoke(^, Tuple{Node{<:Real}, Real}, n, p) From 04f9b5875d623185dc91b7ef40ebe854433f0888 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 13 Nov 2020 14:28:09 +0000 Subject: [PATCH 44/71] remove mistakenly added SnoopCompile dependency --- Project.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/Project.toml b/Project.toml index b11fa8f0..4301a20b 100644 --- a/Project.toml +++ b/Project.toml @@ -10,8 +10,6 @@ FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SnoopCompile = "aa65fe97-06da-5843-b5b1-d5d13cad87d2" -SnoopCompileCore = "e2b509da-e806-4183-be48-004708413034" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" From 859e0a24c9c0183851f2f02c41e5f26ba0577af0 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 4 Dec 2020 15:54:04 +0000 Subject: [PATCH 45/71] Fix comments Co-authored-by: Eric Davies --- src/code_transformation/util.jl | 2 +- src/core.jl | 5 +++-- src/sensitivities/chainrules.jl | 2 +- test/sensitivities/functional/reduce.jl | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/code_transformation/util.jl b/src/code_transformation/util.jl index 55024452..3dc07bf9 100644 --- a/src/code_transformation/util.jl +++ b/src/code_transformation/util.jl @@ -46,7 +46,7 @@ end Returns an expression for the `Node{<:tp}`. e.g. `node_type(:Real)` returns `:(Node{<:Real}})`. -Correctly `Varargs{Real}` becomes `:(Varargs{Node{<:Real}})` +Correctly `:(Vararg{Real})` becomes `:(Vararg{Node{<:Real}})` This is a lot like [`unionise_type`](ref) but it doesn't permit the original type anymore. """ diff --git a/src/core.jl b/src/core.jl index 3219a51a..2b77bddb 100644 --- a/src/core.jl +++ b/src/core.jl @@ -192,8 +192,9 @@ To implement a new reverse-mode sensitivity for the `N^{th}` argument of functio is the output of `preprocess`. `x1`, `x2`,... are the inputs to the function, `y` is its output and `ȳ` the reverse-mode sensitivity of `y`. -∇(x̄, f::Function, ::Type{Arg{N}}, p, y, ȳ, x...) -This is the optionally inplace version of `∇` that should, if implemented, mutate + ∇(x̄, f::Function, ::Type{Arg{N}}, p, y, ȳ, x...) + +This is the optional in-place version of `∇` that should, if implemented, mutate x̄ to have the gradient added to it. """ ∇(y::Node, ȳ) = propagate(tape(y), reverse_tape(y, ȳ)) diff --git a/src/sensitivities/chainrules.jl b/src/sensitivities/chainrules.jl index caba0a5c..7f194bbb 100644 --- a/src/sensitivities/chainrules.jl +++ b/src/sensitivities/chainrules.jl @@ -282,7 +282,7 @@ function _args_tuple(arg_exprs) Texpr = Texpr.args[1] end # Needs to be after removing `where` - if Meta.isexpr(Texpr, :curly) # remove `{T, N}` from `Vararg{T,N``` + if Meta.isexpr(Texpr, :curly) # remove `{T, N}` from `Vararg{T,N}` Texpr = Texpr.args[1] end if Texpr == :Vararg diff --git a/test/sensitivities/functional/reduce.jl b/test/sensitivities/functional/reduce.jl index 1cffb377..74f5bfcb 100644 --- a/test/sensitivities/functional/reduce.jl +++ b/test/sensitivities/functional/reduce.jl @@ -21,7 +21,7 @@ end # Some composite sensitivities. - composite_functions = (x->5x, x->1 / (1 + x), x->10+x) + composite_functions = (x->5x, x->1 / (1 + x), x->10+x) for f in composite_functions # Generate some data. From fc7fb5a317cbd4a7f00ccc0779bed68a1d85b834 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 13 Nov 2020 15:00:25 +0000 Subject: [PATCH 46/71] Don't generate rules for a bunch of nondifferentiable things that cause bulk invalidations --- src/sensitivities/chainrules.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/sensitivities/chainrules.jl b/src/sensitivities/chainrules.jl index 7f194bbb..36501b88 100644 --- a/src/sensitivities/chainrules.jl +++ b/src/sensitivities/chainrules.jl @@ -77,6 +77,8 @@ We do not use rules for: - functions from the `NaNMath` module - functions for working with complex numbers. - Nondifferentiable functions that we define directly on `Node`s better (like `size`) + - Nondifferentiable functions that are never used in practice and that cause a lot of + compiler invalidations and so cause a large increase in loading time. """ function should_use_rrule(sig) opT, argTs = Iterators.peel(ExprTools.parameters(sig)) @@ -102,6 +104,17 @@ function should_use_rrule(sig) Base.Broadcast.combine_styles, #TODO should i keep this? )) && return false + # Ignore these functions because in practice they are never used and defining them cause + # a ton of compiler invalidations, making loading slow. + opT ∈ typeof.(( + string, repr, print, println, write, readlines, eachline, Core.print, Core.println, + isequal, ==, in, haskey, + isnothing, ismissing, isfile, + isbitstype, isbits, isabstracttype, isconcretetype, + startswith, endswith, join, joinpath, normpath, chomp, + schedule, # this one is huge, causes over 2500 invalidations + )) && return false + return true # no exclusion applies end From c27a8fd209191541286837d56b80700591c8f23c Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 4 Dec 2020 16:21:26 +0000 Subject: [PATCH 47/71] correct _truly_rename_unionall spelling --- src/sensitivities/chainrules.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/sensitivities/chainrules.jl b/src/sensitivities/chainrules.jl index 36501b88..4795e9d8 100644 --- a/src/sensitivities/chainrules.jl +++ b/src/sensitivities/chainrules.jl @@ -141,7 +141,7 @@ Dict{Symbol, Any} with 3 entries: ``` """ function build_def(orig_sig) - sig = _truely_rename_unionall(orig_sig) # TODO ExprTools possibly should do this for `signature(::Method)`` also + sig = _truly_rename_unionall(orig_sig) # TODO ExprTools possibly should do this for `signature(::Method)`` also def = Dict{Symbol, Any}() opT = ExprTools.parameters(sig)[1] @@ -308,23 +308,23 @@ function _args_tuple(arg_exprs) end """ - _truely_rename_unionall(@nospecialize(u)) + _truly_rename_unionall(@nospecialize(u)) For `u` being a `UnionAll` this replaces every `TypeVar` with a new one with a `gensym`ed names. This is useful for manual macro-hygine. Example: ``` -julia> Nabla._truely_rename_unionall(Array{T, N} where {T<:Number, N}) +julia> Nabla._truly_rename_unionall(Array{T, N} where {T<:Number, N}) Array{var"##T#2881", var"##N#2880"} where var"##N#2880" where var"##T#2881"<:Number ``` -Note that the similar `Base.rename_unionall`, does not `gensym` the names just replaces the -instances with new one with identical names. +Note that the similar `Base.rename_unionall`, though `Base.rename_unionall` does not +`gensym` the names just replaces the instances with new instances with identical names. """ -function _truely_rename_unionall(@nospecialize(u)) +function _truly_rename_unionall(@nospecialize(u)) isa(u,UnionAll) || return u - body = _truely_rename_unionall(u.body) + body = _truly_rename_unionall(u.body) if body === u.body body = u else From 6d0559b02844e9ffc732fbaa77e15652a89223a7 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 4 Dec 2020 19:30:59 +0000 Subject: [PATCH 48/71] handle remove varargs with redundant N and other typevars --- src/code_transformation/util.jl | 7 +++++-- test/code_transformation/util.jl | 9 ++++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/code_transformation/util.jl b/src/code_transformation/util.jl index 3dc07bf9..3d4da219 100644 --- a/src/code_transformation/util.jl +++ b/src/code_transformation/util.jl @@ -113,13 +113,16 @@ function remove_vararg(typ::Expr) # handle interally `where N` from `typ = :(Vararg{FOO, N} where N)` which results in # `body = :(Vararg{FOO, N})` and `new_type = Foo where N`, we don't need to keep it # at all, the `where N` wasn't doing anything to begin with, so we just strip it out - if Meta.isexpr(new_typ, :where, 2) && Meta.isexpr(body, :curly, 3) + if Meta.isexpr(new_typ, :where) && Meta.isexpr(body, :curly, 3) @assert body.args[1] == :Vararg T = body.args[2] N = body.args[3] - if new_typ.args == [T, N] + if new_typ.args == [T, N] # ($T where $N) body = :(Vararg{T}) new_typ = T + elseif T == new_typ.args[1] && N ∈ new_typ.args[2:end] # ($T where {?, $N, ?}) + body = :(Vararg{T}) + filter!(!isequal(N), new_typ.args) end end diff --git a/test/code_transformation/util.jl b/test/code_transformation/util.jl index 210ac0d6..f1867e06 100644 --- a/test/code_transformation/util.jl +++ b/test/code_transformation/util.jl @@ -48,6 +48,13 @@ @test Nabla.remove_vararg(:Real) == (:Real, :nothing) @test Nabla.remove_vararg(:(Vararg{T} where T)) == (:(T where T), :Vararg) @test Nabla.remove_vararg(:(Vararg{T, N} where T<:Real)) == (:(T where T<:Real), :N) + # Redundant local `where N` (rather than leaving the `N` to be outside in the sig) + @test Nabla.remove_vararg(:(Vararg{Real, N} where N)) == (:Real, :Vararg) + @test Nabla.remove_vararg(:(Vararg{T, N} where {N,T})) == (:(T where T), :Vararg) + @test Nabla.remove_vararg(:(Vararg{T, N} where {T<:Real, N})) == (:(T where T<:Real), :Vararg) + # This case doesn't work but never occurs in practice + @test_broken Nabla.remove_vararg(:(Vararg{T, N} where T where N)) == (:(T where T), :Vararg) + # Test Nabla.replace_vararg. @test Nabla.replace_vararg(:(U{T, N{T}}), (:V, :nothing)) == :(U{T, N{T}}) @@ -79,7 +86,7 @@ @testset "node_type" begin # special case for a redudant local where N in a Vararg @test Nabla.node_type(:(Vararg{Int64, N} where N)) == :(Vararg{Node{<:Int64}}) - + @test Nabla.node_type(:Float32) == :(Node{<:Float32}) end end From bd4271f7cd0f8b0688bbff8684a8b80b26271d3c Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 4 Dec 2020 19:36:14 +0000 Subject: [PATCH 49/71] remove iunneded variable Co-authored-by: Eric Davies --- src/sensitivities/chainrules.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sensitivities/chainrules.jl b/src/sensitivities/chainrules.jl index 4795e9d8..92e992a4 100644 --- a/src/sensitivities/chainrules.jl +++ b/src/sensitivities/chainrules.jl @@ -147,7 +147,6 @@ function build_def(orig_sig) opT = ExprTools.parameters(sig)[1] def[:name] = :(op::$opT) - explicit_tvars = Core.TypeName[]#ExprTools.extract_tvars(sig) arg_types = ExprTools.name_of_type.(ExprTools.argument_types(sig)) arg_names = [Symbol(:x, ii) for ii in eachindex(arg_types)] #TODO: should we pass the arg_names in? def[:args] = Expr.(:(::), arg_names, arg_types) From 251ab3b1944012ab251c1c22b5c297307bf5e12d Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 4 Dec 2020 19:40:01 +0000 Subject: [PATCH 50/71] filter out nonfields inplace --- src/sensitivities/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sensitivities/chainrules.jl b/src/sensitivities/chainrules.jl index 92e992a4..a5077a65 100644 --- a/src/sensitivities/chainrules.jl +++ b/src/sensitivities/chainrules.jl @@ -152,7 +152,7 @@ function build_def(orig_sig) def[:args] = Expr.(:(::), arg_names, arg_types) def[:whereparams] = ExprTools.where_parameters(sig) - def = Dict{Symbol, Any}(k => v for (k, v) in def if v !== nothing) # filter out nonfields. + filter!(kv->last(kv)!==nothing, def) # filter out nonfields. return def end From c4708712d1407dcdf4f6408a9e58ef3d71f17ffd Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 8 Dec 2020 20:17:31 +0000 Subject: [PATCH 51/71] split up BINARY_SCALAR_SENSITIVITIES --- test/runtests.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index d24f0c6c..9036d746 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,15 +35,19 @@ const UNARY_SCALAR_SENSITIVITIES = [ dawson, digamma, erf, erfc, erfcinv, erfcx, erfi, erfinv, gamma, invdigamma, trigamma, ] -const BINARY_SCALAR_SENSITIVITIES = [ +const DIFF_IN_FIRST_ANND_SECOND_ARG_SENSITIVITIES = [ #Base # Base: *, +, -, /, \, ^, hypot, max, min, # SpecialFunctions.jl: - besseli, besselj, besselk, bessely, beta, polygamma, + beta, ] - const ONLY_DIFF_IN_SECOND_ARG_SENSITIVITIES = [ - besseli, besselj, besselk, bessely, polygamma + # SpecialFunctions.jl: + besseli, besselj, besselk, bessely, polygamma, +] +const BINARY_SCALAR_SENSITIVITIES = [ + DIFF_IN_FIRST_ANND_SECOND_ARG_SENSITIVITIES; + ONLY_DIFF_IN_SECOND_ARG_SENSITIVITIES; ] @testset "Nabla.jl" begin From 9d7e87b2c8decac239a879c25539af8806e02d32 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 8 Dec 2020 20:25:13 +0000 Subject: [PATCH 52/71] =?UTF-8?q?Remove=20last=20of=20the=20varient=20y?= =?UTF-8?q?=CC=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sensitivities/chainrules.jl | 4 ++-- src/sensitivities/functional/reducedim.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sensitivities/chainrules.jl b/src/sensitivities/chainrules.jl index a5077a65..e1ffbd7a 100644 --- a/src/sensitivities/chainrules.jl +++ b/src/sensitivities/chainrules.jl @@ -20,11 +20,11 @@ function Base.identity(x1::Node{<:Any}; kwargs...) end @inline function preprocess( - op::typeof(identity), y::Branch, ȳ, x1::Union{Any, Node{<:Any}} + op::typeof(identity), y::Branch, ȳ, x1::Union{Any, Node{<:Any}} ) pullback = getfield(y, :pullback) @assert pullback !== nothing "pullback not set, ..." - return pullback(ȳ) + return pullback(ȳ) end diff --git a/src/sensitivities/functional/reducedim.jl b/src/sensitivities/functional/reducedim.jl index aae932d7..156a4ec7 100644 --- a/src/sensitivities/functional/reducedim.jl +++ b/src/sensitivities/functional/reducedim.jl @@ -42,12 +42,12 @@ end function ∇( ::typeof(sum), ::Type{Arg{2}}, - p, y, ȳ, + p, y, ȳ, ::typeof(abs2), A::AbstractArray{<:Real}; dims=:, ) - return 2ȳ .* A + return 2ȳ .* A end @explicit_intercepts( From f11e7ebf3714ddf7237605aab953bb75a3719dbc Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 1 Mar 2021 18:58:49 +0000 Subject: [PATCH 53/71] fix comment typo Co-authored-by: Curtis Vogt --- src/sensitivities/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sensitivities/chainrules.jl b/src/sensitivities/chainrules.jl index e1ffbd7a..65958b83 100644 --- a/src/sensitivities/chainrules.jl +++ b/src/sensitivities/chainrules.jl @@ -167,7 +167,7 @@ Note: this mutate `signature_def` and so should not be called if others function to use it after. """ function overload_declarations!(signature_def) - # Our manual macro-hygine is not complete here. + # Our manual macro-hygiene is not complete here. # the argument names and `op`, `tape` `args`, `kwargs` etc could conflict with # where-params. but for sake of outputting readable code we are not gensyming everything # chance of conflict seems low as where-params are normally upper-case. From a5f758a43076830fd0d41c11338f095671c10f07 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 1 Mar 2021 19:01:38 +0000 Subject: [PATCH 54/71] typos in comments/docs Co-authored-by: Curtis Vogt --- src/sensitivities/chainrules.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sensitivities/chainrules.jl b/src/sensitivities/chainrules.jl index 65958b83..f30306f3 100644 --- a/src/sensitivities/chainrules.jl +++ b/src/sensitivities/chainrules.jl @@ -32,11 +32,11 @@ end op::typeof(identity), ::Type{Arg{N}}, p, ::Any, ::Any, x1::Union{Any, Node{<:Any}}; kwargs... ) where N - return p[N + 1] # skip dself (N==1) and we don't support functors + return p[N + 1] # skip dself (N==1) as we don't support functors end ``` -The real code evaluated is a little more complex with macro-hygine and handling for +The real code evaluated is a little more complex with macro-hygiene and handling for various complicated type-signatures, including multiple arguments. It does not generate any code for `rrules` for primal functions that Nabla does not support. @@ -76,8 +76,8 @@ We do not use rules for: - functions without any positional arguments - functions from the `NaNMath` module - functions for working with complex numbers. - - Nondifferentiable functions that we define directly on `Node`s better (like `size`) - - Nondifferentiable functions that are never used in practice and that cause a lot of + - Non-differentiable functions that we define directly on `Node`s better (like `size`) + - Non-differentiable functions that are never used in practice and that cause a lot of compiler invalidations and so cause a large increase in loading time. """ function should_use_rrule(sig) From ff8cf0da59061127d7bffb28a5b05ec2cb64dec1 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 24 Jun 2021 15:39:18 +0100 Subject: [PATCH 55/71] Update for new ChainRulesCore --- Project.toml | 10 ++++++---- src/Nabla.jl | 1 + src/sensitivities/linalg/factorization/cholesky.jl | 2 +- test/runtests.jl | 3 ++- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 4301a20b..21785a7e 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.12.3" [deps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -14,13 +15,14 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -ChainRules = "0.7.28" -ChainRulesCore = "0.9.17" +ChainRules = "0.8" +ChainRulesCore = "0.10.9" +ChainRulesOverloadGeneration = "0.1.2" ExprTools = "0.1.3" FDM = "0.6.1" ForwardDiff = "0.10.12" -SpecialFunctions = "0.8, 0.9, 0.10" -julia = "^1.0" +SpecialFunctions = "1.5.1" +julia = "^1.3" [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" diff --git a/src/Nabla.jl b/src/Nabla.jl index 67c72f43..d1684fae 100644 --- a/src/Nabla.jl +++ b/src/Nabla.jl @@ -3,6 +3,7 @@ __precompile__() module Nabla using ChainRules using ChainRulesCore + using ChainRulesOverloadGeneration using ExprTools: ExprTools using ForwardDiff: ForwardDiff using LinearAlgebra diff --git a/src/sensitivities/linalg/factorization/cholesky.jl b/src/sensitivities/linalg/factorization/cholesky.jl index 1013fdef..6ae5aff2 100644 --- a/src/sensitivities/linalg/factorization/cholesky.jl +++ b/src/sensitivities/linalg/factorization/cholesky.jl @@ -34,7 +34,7 @@ function ∇( ::Type{Arg{1}}, p, C::Cholesky, - X̄::Composite{<:Cholesky}, + X̄::Tangent{<:Cholesky}, X::Union{UpperTriangular, LowerTriangular}, uplo::Union{Char, Symbol}, info::Integer, diff --git a/test/runtests.jl b/test/runtests.jl index 9036d746..4f601183 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using Nabla using Nabla: unbox, pos, tape, oneslike, zeroslike using BenchmarkTools +using ChainRulesCore using Distributions using ForwardDiff using LinearAlgebra @@ -15,7 +16,7 @@ ref_equal(a::Ref{T}, b::Ref{T}) where {T} = a[] == b[] ref_equal(a::Ref, b::Ref) = false # for comparing against scalar rules -derivative_via_frule(f, x) = last(Nabla.frule((Nabla.NO_FIELDS, 1.0), f, x)) +derivative_via_frule(f, x) = last(frule((NoTangent(), 1.0), f, x)) # Sensiblity checkes that his is defined right @test derivative_via_frule(cos, 0) == 0 @test derivative_via_frule(sin, 0) == 1 From 24f52bd11e6323ea977698258eda79a3584e907c Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 24 Jun 2021 15:39:56 +0100 Subject: [PATCH 56/71] =?UTF-8?q?Make=20tests=20of=20identity=20not=20try?= =?UTF-8?q?=20and=20use=20the=20version=20of=20=E2=88=87=20that=20requires?= =?UTF-8?q?=20passing=20in=20a=20preprocess=20output?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/sensitivities/scalar.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/sensitivities/scalar.jl b/test/sensitivities/scalar.jl index 9d9ef6d6..448e0205 100644 --- a/test/sensitivities/scalar.jl +++ b/test/sensitivities/scalar.jl @@ -18,9 +18,8 @@ end @testset "Scalar" begin let v = 1.0, ȳ = 5.0, z̄ = 4.0, rng = MersenneTwister(123456) let - @test ∇(identity, Arg{1}, 5.0, 4.0, 3.0, 2.0) == 3.0 - @test ∇(identity, Arg{1}, 5) == 1 - @test ∇(identity, Arg{1}, 5.0) == 1.0 + @test ∇(identity)(5) === (1,) + @test ∇(identity)(5.0) === (1.0,) end unary_check(f, x) = check_errs(f, ȳ, x, v) From f0c05b43596763da1f4a2f88697ee5187e53016d Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 24 Jun 2021 15:40:18 +0100 Subject: [PATCH 57/71] Stop testing equality when approximate equality is better --- test/sensitivities/functional/functional.jl | 28 ++++++++++----------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/test/sensitivities/functional/functional.jl b/test/sensitivities/functional/functional.jl index 3c6dce1b..cfea750c 100644 --- a/test/sensitivities/functional/functional.jl +++ b/test/sensitivities/functional/functional.jl @@ -83,8 +83,8 @@ z_ = x_ .+ y_ z2_ = broadcast(+, x_, y_) @test unbox(z_) == x .+ y - @test ∇(z_, oneslike(unbox(z_)))[x_] == ∇(z2_, oneslike(unbox(z2_)))[x_] - @test ∇(z_, oneslike(unbox(z_)))[y_] == ∇(z2_, oneslike(unbox(z2_)))[y_] + @test ∇(z_, oneslike(unbox(z_)))[x_] ≈ ∇(z2_, oneslike(unbox(z2_)))[x_] + @test ∇(z_, oneslike(unbox(z_)))[y_] ≈ ∇(z2_, oneslike(unbox(z2_)))[y_] end let x, y, tape = randn(rng, 5), 5.0, Tape() @@ -92,8 +92,8 @@ z_ = x_ * y_ z2_ = broadcast(*, x_, y_) @test unbox(z_) == x .* y - @test ∇(z_, oneslike(unbox(z_)))[x_] == ∇(z2_, oneslike(unbox(z2_)))[x_] - @test ∇(z_, oneslike(unbox(z_)))[y_] == ∇(z2_, oneslike(unbox(z2_)))[y_] + @test ∇(z_, oneslike(unbox(z_)))[x_] ≈ ∇(z2_, oneslike(unbox(z2_)))[x_] + @test ∇(z_, oneslike(unbox(z_)))[y_] ≈ ∇(z2_, oneslike(unbox(z2_)))[y_] end let x, y, tape = randn(rng, 5), 5.0, Tape() @@ -101,8 +101,8 @@ z_ = x_ .- y_ z2_ = broadcast(-, x_, y_) @test unbox(z_) == x .- y - @test ∇(z_, oneslike(unbox(z_)))[x_] == ∇(z2_, oneslike(unbox(z2_)))[x_] - @test ∇(z_, oneslike(unbox(z_)))[y_] == ∇(z2_, oneslike(unbox(z2_)))[y_] + @test ∇(z_, oneslike(unbox(z_)))[x_] ≈ ∇(z2_, oneslike(unbox(z2_)))[x_] + @test ∇(z_, oneslike(unbox(z_)))[y_] ≈ ∇(z2_, oneslike(unbox(z2_)))[y_] end let x, y, tape = randn(rng, 5), 5.0, Tape() @@ -110,7 +110,7 @@ z_ = x_ / y_ z2_ = broadcast(/, x_, y_) @test unbox(z_) == x ./ y - @test ∇(z_, oneslike(unbox(z_)))[x_] == ∇(z2_, oneslike(unbox(z2_)))[x_] + @test ∇(z_, oneslike(unbox(z_)))[x_] ≈ ∇(z2_, oneslike(unbox(z2_)))[x_] @test ∇(z_, oneslike(unbox(z_)))[y_] ≈ ∇(z2_, oneslike(unbox(z2_)))[y_] end let @@ -120,7 +120,7 @@ z2_ = broadcast(\, x_, y_) @test unbox(z_) == x .\ y @test ∇(z_, oneslike(unbox(z_)))[x_] ≈ ∇(z2_, oneslike(unbox(z2_)))[x_] - @test ∇(z_, oneslike(unbox(z_)))[y_] == ∇(z2_, oneslike(unbox(z2_)))[y_] + @test ∇(z_, oneslike(unbox(z_)))[y_] ≈ ∇(z2_, oneslike(unbox(z2_)))[y_] end # Check that dot notation works as expected for all unary function in Nabla for both @@ -130,13 +130,13 @@ z_ = f.(x_) z2_ = broadcast(f, x_) @test unbox(z_) == f.(x) - @test ∇(z_, oneslike(unbox(z_)))[x_] == ∇(z2_, oneslike(unbox(z2_)))[x_] + @test ∇(z_, oneslike(unbox(z_)))[x_] ≈ ∇(z2_, oneslike(unbox(z2_)))[x_] end function check_unary_dot(f, x::∇Scalar) x_ = Leaf(Tape(), x) z_ = f.(x_) @test unbox(z_) == f.(x) - @test ∇(z_)[x_] == ∇(broadcast(f, x_))[x_] + @test ∇(z_)[x_] ≈ ∇(broadcast(f, x_))[x_] end for f in UNARY_SCALAR_SENSITIVITIES domain = domain1(f) @@ -153,14 +153,14 @@ z_ = f.(x_, y_) z2_ = broadcast(f, x_, y_) @test unbox(z_) == f.(x, y) - @test ∇(z_, oneslike(unbox(z_)))[x_] == ∇(z2_, oneslike(unbox(z2_)))[x_] - @test ∇(z_, oneslike(unbox(z_)))[y_] == ∇(z2_, oneslike(unbox(z2_)))[y_] + @test ∇(z_, oneslike(unbox(z_)))[x_] ≈ ∇(z2_, oneslike(unbox(z2_)))[x_] + @test ∇(z_, oneslike(unbox(z_)))[y_] ≈ ∇(z2_, oneslike(unbox(z2_)))[y_] end function check_binary_dot(f, x::∇Scalar, y::∇Scalar) x_, y_ = Leaf.(Tape(), (x, y)) z_ = f.(x_, y_) - @test ∇(z_)[x_] == ∇(broadcast(f, x_, y_))[x_] - @test ∇(z_)[y_] == ∇(broadcast(f, x_, y_))[y_] + @test ∇(z_)[x_] ≈ ∇(broadcast(f, x_, y_))[x_] + @test ∇(z_)[y_] ≈ ∇(broadcast(f, x_, y_))[y_] end for f in BINARY_SCALAR_SENSITIVITIES # TODO: More care needs to be taken to test the following. From 4f5241c75935ae41d8fcc385d92dcc2812e80b70 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 24 Jun 2021 16:45:00 +0100 Subject: [PATCH 58/71] remove matrix exp which is now in ChainRules --- src/sensitivities/linalg/generic.jl | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/src/sensitivities/linalg/generic.jl b/src/sensitivities/linalg/generic.jl index 433c06a3..faedb92b 100644 --- a/src/sensitivities/linalg/generic.jl +++ b/src/sensitivities/linalg/generic.jl @@ -49,24 +49,3 @@ end import Base: copy @explicit_intercepts copy Tuple{Any} ∇(::typeof(copy), ::Type{Arg{1}}, p, Y, Ȳ, A) = copy(Ȳ) - -# Matrix exponential -# Ported from Theano, see https://github.com/Theano/Theano/blob/3b8a5b342b30c7ffd2f89f0... -# e9efef601b7492411/theano/tensor/slinalg.py#L518-L553 -# Implementation there is based on Kalbfleisch and Lawless, 1985, The Analysis of Panel -# Data Under a Markov Assumption. -import Base: exp -@explicit_intercepts exp Tuple{AbstractMatrix{<:∇Scalar}} -function ∇(::typeof(exp), ::Type{Arg{1}}, p, Y, Ȳ, X::AbstractMatrix) - # TODO: Make this work for asymmetric matrices - issymmetric(X) || throw(ArgumentError("input is not symmetric; eigenvalues are complex")) - n = LinearAlgebra.checksquare(X) - λ, U = eigen(X) - eλ = exp.(λ) - Z = @inbounds begin - eltype(eλ)[i == j ? eλ[i] : (eλ[i] - eλ[j]) / (λ[i] - λ[j]) for i = 1:n, j = 1:n] - end - Uᵀ = transpose(U) - F = factorize(Uᵀ) - return real(F \ (Uᵀ * Ȳ / F .* Z) * Uᵀ) -end From 64502d19b3025f6055921d5b065c074e8e0711e0 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 25 Jun 2021 11:44:14 +0100 Subject: [PATCH 59/71] Change test to reflect that asymmetric matrix expodential now works --- test/sensitivities/linalg/generic.jl | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/test/sensitivities/linalg/generic.jl b/test/sensitivities/linalg/generic.jl index 92794842..f401bb03 100644 --- a/test/sensitivities/linalg/generic.jl +++ b/test/sensitivities/linalg/generic.jl @@ -111,12 +111,16 @@ @testset "exp" begin rng = MersenneTwister(12345) n = 10 - symm!(X) = (X .= (X .+ X') ./ 2; X) - X = symm!(randn(rng, n, n)) - VX = symm!(randn(rng, n, n)) - @test check_errs(exp, randn(rng, n, n), X, VX) - A = randn(rng, n, n) - VA = randn(rng, n, n) - @test_throws ArgumentError check_errs(exp, randn(rng, n, n), A, VA) + @testset "Symmetric" begin + symm!(X) = (X .= (X .+ X') ./ 2; X) + X = symm!(randn(rng, n, n)) + VX = symm!(randn(rng, n, n)) + @test check_errs(exp, randn(rng, n, n), X, VX) + end + @testset "Asymmetric" begin + A = randn(rng, n, n) + VA = randn(rng, n, n) + @test check_errs(exp, randn(rng, n, n), A, VA) + end end end From 4afefc0ced8fbbb2ccb805cccb8d595b35fb2aa8 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 25 Jun 2021 11:47:17 +0100 Subject: [PATCH 60/71] update now the SVD.Vt now works --- test/sensitivities/linalg/factorization/svd.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/sensitivities/linalg/factorization/svd.jl b/test/sensitivities/linalg/factorization/svd.jl index 21bdb2a4..3bcd202d 100644 --- a/test/sensitivities/linalg/factorization/svd.jl +++ b/test/sensitivities/linalg/factorization/svd.jl @@ -8,6 +8,7 @@ @test check_errs(X->svd(X).U, randn(rng, n, k), A, VA) @test check_errs(X->svd(X).S, randn(rng, k), A, VA) @test check_errs(X->svd(X).V, randn(rng, m, k), A, VA) + @test check_errs(X->svd(X).Vt, randn(rng, k, m), A, VA) end end @@ -15,7 +16,6 @@ rng = MersenneTwister(12345) A = randn(rng, 5, 3) V̄t = randn(rng, 3, 3) - @test_throws ArgumentError check_errs(X->svd(X).Vt, V̄t, A, A) @test_throws ErrorException check_errs(X->svd(X).whoops, V̄t, A, A) end From 1da2489610d2ec9f24e18be6149956eb795db671 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 25 Jun 2021 13:41:01 +0100 Subject: [PATCH 61/71] Link to chainrules hooks outside of __init__ for MUCH faster loading --- src/Nabla.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/Nabla.jl b/src/Nabla.jl index d1684fae..e50dacb4 100644 --- a/src/Nabla.jl +++ b/src/Nabla.jl @@ -30,9 +30,6 @@ module Nabla end end - # Link up to ChainRulesCore so rules are generated when new rrules are declared. - __init__() = on_new_rule(generate_overload, rrule) - # Meta-programming utilities specific to Nabla. include("code_transformation/util.jl") include("code_transformation/differentiable.jl") @@ -71,4 +68,10 @@ module Nabla # Checkpointing include("checkpointing.jl") + + # Link up to ChainRulesCore so rules are generated when new rrules are declared. + # NB: I originally thought I should be putting this in the `__init__` + # But it seems fine if I don't, and it loads like 10x faster. + on_new_rule(generate_overload, rrule) + end # module Nabla From ce65be8024239c5a8920fe3ccdfba6b99cdc5679 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 2 Jul 2021 15:57:56 +0100 Subject: [PATCH 62/71] broadcast_axes is now axes and we want to ignore ChainRules def for it --- src/sensitivities/chainrules.jl | 2 +- src/sensitivities/functional/functional.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sensitivities/chainrules.jl b/src/sensitivities/chainrules.jl index f30306f3..f6f49f9d 100644 --- a/src/sensitivities/chainrules.jl +++ b/src/sensitivities/chainrules.jl @@ -100,7 +100,7 @@ function should_use_rrule(sig) # Ignore these functions because they have better Nabla specific versions. opT ∈ typeof.(( - isapprox, size, length, isassigned, + isapprox, axes, size, length, isassigned, Base.Broadcast.combine_styles, #TODO should i keep this? )) && return false diff --git a/src/sensitivities/functional/functional.jl b/src/sensitivities/functional/functional.jl index 950cc8e5..a7c5dc64 100644 --- a/src/sensitivities/functional/functional.jl +++ b/src/sensitivities/functional/functional.jl @@ -41,7 +41,7 @@ function Base.BroadcastStyle(::NodeStyle{S}, B::BroadcastStyle) where {S} promoted isa Broadcast.Unknown ? promoted : NodeStyle{promoted}() end -Broadcast.broadcast_axes(x::Node) = broadcast_axes(unbox(x)) +Base.axes(x::Node) = axes(unbox(x)) Broadcast.broadcastable(x::Node) = x # eagerly construct a Branch when encountering a Node in broadcasting From 55877bc004fe8a9f2a07900320cecc7c77fa1c77 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 2 Jul 2021 16:07:52 +0100 Subject: [PATCH 63/71] Move helpers to ExprTools 1.4 --- Project.toml | 2 +- src/sensitivities/chainrules.jl | 106 ++------------------------------ 2 files changed, 5 insertions(+), 103 deletions(-) diff --git a/Project.toml b/Project.toml index 21785a7e..5ef7f7ce 100644 --- a/Project.toml +++ b/Project.toml @@ -18,7 +18,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" ChainRules = "0.8" ChainRulesCore = "0.10.9" ChainRulesOverloadGeneration = "0.1.2" -ExprTools = "0.1.3" +ExprTools = "0.1.4" FDM = "0.6.1" ForwardDiff = "0.10.12" SpecialFunctions = "1.5.1" diff --git a/src/sensitivities/chainrules.jl b/src/sensitivities/chainrules.jl index f6f49f9d..f6caa882 100644 --- a/src/sensitivities/chainrules.jl +++ b/src/sensitivities/chainrules.jl @@ -48,7 +48,7 @@ no actual effect in itself, it can be useful for checking how many rules Nabla s function generate_overload(sig) should_use_rrule(sig) || return false - original_signature_def = build_def(sig) + original_signature_def = ExprTools.signature(sig; extra_hygiene=true) unionized_signature_def = copy(original_signature_def) unionized_signature_def[:args] = unionise_sig.(original_signature_def[:args]) @@ -57,8 +57,6 @@ function generate_overload(sig) @inline $(∇_declaration(unionized_signature_def)) $(overload_declarations!(original_signature_def)...) end - # for debugging uncomment and edit the below to look at the generated code - # opT <: typeof(identity) && @show fdef eval(fdef) return true @@ -118,48 +116,10 @@ function should_use_rrule(sig) return true # no exclusion applies end -""" - build_def(sig) - -Like `ExprTools.signature` but on a signature type-tuple, not a Method. -For `sig` being a tuple-type representing a methods type signature, this generates a -dictionary that can be passes to `ExprTools.combinedef` to define that function, -Provided that you assign the `:body` key on the dictionary first. - -For example: -```julia -julia> Nabla.build_def(Tuple{typeof(identity), Any}) -Dict{Symbol, Any} with 2 entries: - :name => :(op::typeof(identity)) - :args => Expr[:(x1::Any)] - -julia> Nabla.build_def(Tuple{typeof(+), Vector{T}, Vector{T}} where T<:Number) -Dict{Symbol, Any} with 3 entries: - :name => :(op::typeof(+)) - :args => Expr[:(x1::Array{var"##T#5492", 1}), :(x2::Array{var"##T#5492", 1})] - :whereparams => Any[:(var"##T#5492" <: Number)] -``` -""" -function build_def(orig_sig) - sig = _truly_rename_unionall(orig_sig) # TODO ExprTools possibly should do this for `signature(::Method)`` also - def = Dict{Symbol, Any}() - - opT = ExprTools.parameters(sig)[1] - def[:name] = :(op::$opT) - - arg_types = ExprTools.name_of_type.(ExprTools.argument_types(sig)) - arg_names = [Symbol(:x, ii) for ii in eachindex(arg_types)] #TODO: should we pass the arg_names in? - def[:args] = Expr.(:(::), arg_names, arg_types) - def[:whereparams] = ExprTools.where_parameters(sig) - - filter!(kv->last(kv)!==nothing, def) # filter out nonfields. - return def -end - """ overload_declarations!(original_signature_def) -Given a `signature_def` dictionary as returned by [`build_def`](@ref) this returns +Given a `signature_def` dictionary as returned by `ExprTools.signature` this returns the ASTs for the overloads of the primal functions to accept `Nabla.Node`s. The `signature_def` should *not* have been unionized, as this function will instead generate 1 method for each position a node could be in. @@ -177,7 +137,7 @@ function overload_declarations!(signature_def) original_signature_args = signature_def[:args] signature_def[:kwargs] = [:(kwargs...)] signature_def[:body] = quote - args = $(_args_tuple(original_signature_args)) + args = $(ExprTools.args_tuple_expr(original_signature_args)) # uncommenting the below to is useful for debugging what rrule is being hit. # @show InteractiveUtils.@which rrule(op, unbox.(args)...) primal_val, pullback = rrule(op, unbox.(args)...; kwargs...) @@ -188,7 +148,7 @@ function overload_declarations!(signature_def) return branch end - # we need to generate a version of this for each place that an arg could be + # we need to generate a version of this for each place that an arg could be a Node n_args = length(original_signature_args) definitions = Expr[] for swap_mask in Iterators.product(ntuple(_->(true, false), n_args)...) @@ -277,64 +237,6 @@ function ∇_declaration(signature_def) end -""" - _args_tuple(arg_exprs) - -For `arg_exprs` being a list of arguments expressions from a signature, of a form -such as `[:(x::Int), :(y::Float64), :(z::Vararg)]`, returns a tuple expresion containing all -of them by name; while correctly handling splatting, for things that are `Vararg` typed. -e.g for the prior example `:((x, y, z...))` -""" -function _args_tuple(arg_exprs) - ret = Expr(:tuple) - ret.args = map(arg_exprs) do arg - @assert Meta.isexpr(arg, :(::), 2) - arg_name, Texpr = arg.args - if Meta.isexpr(Texpr, :where) # remove where from `Vararg{T, N} where {T, N}` - Texpr = Texpr.args[1] - end - # Needs to be after removing `where` - if Meta.isexpr(Texpr, :curly) # remove `{T, N}` from `Vararg{T,N}` - Texpr = Texpr.args[1] - end - if Texpr == :Vararg - return :($arg_name...) - else - return arg_name - end - end - return ret -end - -""" - _truly_rename_unionall(@nospecialize(u)) - -For `u` being a `UnionAll` this replaces every `TypeVar` with a new one with a `gensym`ed -names. This is useful for manual macro-hygine. - -Example: -``` -julia> Nabla._truly_rename_unionall(Array{T, N} where {T<:Number, N}) -Array{var"##T#2881", var"##N#2880"} where var"##N#2880" where var"##T#2881"<:Number -``` - -Note that the similar `Base.rename_unionall`, though `Base.rename_unionall` does not -`gensym` the names just replaces the instances with new instances with identical names. -""" -function _truly_rename_unionall(@nospecialize(u)) - isa(u,UnionAll) || return u - body = _truly_rename_unionall(u.body) - if body === u.body - body = u - else - body = UnionAll(u.var, body) - end - var = u.var::TypeVar - nv = TypeVar(gensym(var.name), var.lb, var.ub) - return UnionAll(nv, body{nv}) -end - - # Find a tape, ds might be Nodes or might be something else. # All nodes should have the same tape, so the first one will do get_tape(ds) = first(tape(d) for d in ds if d isa Node) From 5134063f2c5f3b33e70f4d9d5b59c732fddb9aff Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 2 Jul 2021 16:12:11 +0100 Subject: [PATCH 64/71] use collect not float. Co-authored-by: Eric Davies --- test/sensitivities/functional/reducedim.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/sensitivities/functional/reducedim.jl b/test/sensitivities/functional/reducedim.jl index a0c1f919..4a4d945e 100644 --- a/test/sensitivities/functional/reducedim.jl +++ b/test/sensitivities/functional/reducedim.jl @@ -54,7 +54,7 @@ randn(rng, 10, 10, 10), randn(rng, 10, 10, 10)) # Issue #123 - x6_ = float.(1:10) + x6_ = collect(1.0:10.0) tens = (fill(10.0, (10,)), fill(10.0, (10,))) @test ∇(x->sum(sum(x, dims=2)))(x6_) == (oneslike(x6_),) @test ∇((x, y)->sum(sum(x, dims=2) .+ sum(y, dims=2)'))(x6_, x6_) == tens From 09fe3e47475e5a3eafcb6143c677377bb0418b41 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 2 Jul 2021 16:14:46 +0100 Subject: [PATCH 65/71] delete strided and its tests which are wrong and it is tested in ChainRules.jl correctly --- src/Nabla.jl | 1 - src/sensitivities/linalg/strided.jl | 35 ---------------------------- test/runtests.jl | 1 - test/sensitivities/linalg/strided.jl | 22 ----------------- 4 files changed, 59 deletions(-) delete mode 100644 src/sensitivities/linalg/strided.jl delete mode 100644 test/sensitivities/linalg/strided.jl diff --git a/src/Nabla.jl b/src/Nabla.jl index e50dacb4..dc8a882f 100644 --- a/src/Nabla.jl +++ b/src/Nabla.jl @@ -59,7 +59,6 @@ module Nabla # Linear algebra optimisations. include("sensitivities/linalg/generic.jl") include("sensitivities/linalg/symmetric.jl") - include("sensitivities/linalg/strided.jl") include("sensitivities/linalg/blas.jl") include("sensitivities/linalg/diagonal.jl") include("sensitivities/linalg/factorization/cholesky.jl") diff --git a/src/sensitivities/linalg/strided.jl b/src/sensitivities/linalg/strided.jl deleted file mode 100644 index 789205b0..00000000 --- a/src/sensitivities/linalg/strided.jl +++ /dev/null @@ -1,35 +0,0 @@ -# # Not every permutation of transpositions makes sense for matrix-vector multiplication. This -# # list just includes those which make sense. -# strided_matvecmul = [ -# (:*, 'C', :ȳ, :b, 'C'), -# (:At_mul_B, 'T', :b, :ȳ, 'N'), -# (:Ac_mul_B, 'C', :b, :ȳ, 'N'), -# ] -# for (f, tdA, CA, dA, tCb) in strided_matvecmul -# n_Ā, u_Ā = tdA == 'C' ? :(Ā = $CA * $dA') : :(Ā = $CA * $dA'), :(ger!(1., $CA, $dA, Ā)) -# n_b̄, u_b̄ = :(b̄ = gemv($tCb, A, ȳ)), :(b̄ = gemv!($tCb, 1., A, ȳ, 1., b̄)) -# generate_primitive(f, [:(T <: StridedMatrix), :(V <: StridedVector)], -# [:A, :b], [:Ā, :b̄], [:T, :V], [true, true], :y, :ȳ, [n_Ā, n_b̄], [u_Ā, u_b̄]) -# end - -# # Operations of the for Y = A \ B -# strided_ldiv = [ -# (:\, :(C = At_ldiv_B(A, Ȳ)), 'N', 'T', :C, :Y), -# (:At_ldiv_B, :(C = A \ Ȳ), 'N', 'T', :Y, :C), -# (:A_ldiv_Bt, :(C = At_rdiv_B(Ȳ, A)), 'T', 'T', :C, :Y), -# (:At_ldiv_Bt, :(C = At_rdiv_Bt(Ȳ, A)), 'N', 'N', :Y, :C), -# (:Ac_ldiv_B, :(C = A \ Ȳ), 'N', 'C', :Y, :C), -# (:A_ldiv_Bc, :(C = Ac_rdiv_B(Ȳ, A)), 'C', 'C', :C, :Y), -# (:Ac_ldiv_Bc, :(C = Ac_rdiv_Bc(Ȳ, A)), 'N', 'N', :Y, :C), -# ] - -# # Iterate through primitive definitions and add methods for each. -# for (f, C, tA, tB, arg1, arg2) in strided_ldiv -# new_Ā = :(Ā = gemm($tA, $tB, -1.0, $arg1, $arg2)) -# update_Ā = :(gemm!($tA, $tB, -1.0, $arg1, $arg2, 1.0, Ā)) -# new_B̄ = :(B̄ = C) -# update_B̄ = :(broadcast!((b̄, c)->b̄ + c, B̄, B̄, C)) -# generate_primitive(f, [:(T <: Any), :(V <: Any)], -# [:A, :B], [:Ā, :B̄], [:T, :V], [true, true], :Y, :Ȳ, -# [new_Ā, new_B̄], [update_Ā, update_B̄], C) -# end diff --git a/test/runtests.jl b/test/runtests.jl index 4f601183..c8b068d3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -82,7 +82,6 @@ end include("sensitivities/linalg/uniformscaling.jl") include("sensitivities/linalg/diagonal.jl") include("sensitivities/linalg/triangular.jl") - include("sensitivities/linalg/strided.jl") include("sensitivities/linalg/blas.jl") @testset "Factorisations" begin diff --git a/test/sensitivities/linalg/strided.jl b/test/sensitivities/linalg/strided.jl deleted file mode 100644 index 615c23cc..00000000 --- a/test/sensitivities/linalg/strided.jl +++ /dev/null @@ -1,22 +0,0 @@ -@testset "Strided" begin - RS = StridedMatrix{<:∇Scalar} - RST = Transpose{<:∇Scalar, RS} - RSA = Adjoint{<:∇Scalar, RS} - strided_matmul_combinations = ( - (RS, RS, 'N', 'C', :Ȳ, :B, 'C', 'N', :A, :Ȳ), - (RST, RS, 'N', 'T', :B, :Ȳ, 'N', 'N', :A, :Ȳ), - (RS, RST, 'N', 'N', :Ȳ, :B, 'T', 'N', :Ȳ, :A), - (RST, RST, 'T', 'T', :B, :Ȳ, 'T', 'T', :Ȳ, :A), - (RSA, RS, 'N', 'C', :B, :Ȳ, 'N', 'N', :A, :Ȳ), - (RS, RSA, 'N', 'N', :Ȳ, :B, 'C', 'N', :Ȳ, :A), - (RSA, RSA, 'C', 'C', :B, :Ȳ, 'C', 'C', :Ȳ, :A), - ) - # TODO: This test seems like it doesn't actually test the combinations. - let rng = MersenneTwister(123456), N = 100 - # Test strided matrix-matrix multiplication sensitivities. - for (TA, TB, tCA, tDA, CA, DA, tCB, tDB, CB, DB) in strided_matmul_combinations - A, B, VA, VB = randn.(Ref(rng), [N, N, N, N], [N, N, N, N]) - @test check_errs(*, A * B, (A, B), (VA, VB)) - end - end -end From f2fe4cbbfa90ac46abeb12659478927e0b45860a Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 2 Jul 2021 16:19:27 +0100 Subject: [PATCH 66/71] tag as a breaking change --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5ef7f7ce..b9245db4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Nabla" uuid = "49c96f43-aa6d-5a04-a506-44c7070ebe78" -version = "0.12.3" +version = "0.13.0" [deps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" From c2b9edc4a251a528676afa340a07b26a4b1e5399 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 2 Jul 2021 17:59:05 +0100 Subject: [PATCH 67/71] delete old comment --- src/Nabla.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Nabla.jl b/src/Nabla.jl index dc8a882f..da425e43 100644 --- a/src/Nabla.jl +++ b/src/Nabla.jl @@ -69,8 +69,6 @@ module Nabla # Link up to ChainRulesCore so rules are generated when new rrules are declared. - # NB: I originally thought I should be putting this in the `__init__` - # But it seems fine if I don't, and it loads like 10x faster. on_new_rule(generate_overload, rrule) end # module Nabla From d6328b4a1717bb1e572ff5bc21afa69e1db2610a Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 2 Jul 2021 18:07:19 +0100 Subject: [PATCH 68/71] Block using ChainRules for rules remaining in Nabla --- src/sensitivities/chainrules.jl | 35 ++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/src/sensitivities/chainrules.jl b/src/sensitivities/chainrules.jl index f6caa882..ec1526cf 100644 --- a/src/sensitivities/chainrules.jl +++ b/src/sensitivities/chainrules.jl @@ -77,6 +77,12 @@ We do not use rules for: - Non-differentiable functions that we define directly on `Node`s better (like `size`) - Non-differentiable functions that are never used in practice and that cause a lot of compiler invalidations and so cause a large increase in loading time. + +Finally this excludes function that at time of last update Nabla had its own rules for +because ChainRules didn't support them. +Generally, for this category once they are added to ChainRules, we should change to using +them from there. This requires also deleting the code from Nabla that provides those rules +currently, so that there is no clash. """ function should_use_rrule(sig) opT, argTs = Iterators.peel(ExprTools.parameters(sig)) @@ -113,9 +119,36 @@ function should_use_rrule(sig) schedule, # this one is huge, causes over 2500 invalidations )) && return false + # Rules currently implemented directly in Nabla, but that could use ChainRules in future + sig <: Union{ + Tuple{typeof(+),AbstractArray,LinearAlgebra.UniformScaling}, + Tuple{typeof(+),LinearAlgebra.UniformScaling,AbstractArray}, + Tuple{typeof(/),Number,AbstractArray}, + Tuple{typeof(LinearAlgebra.BLAS.symm),Char,Char,AbstractArray,AbstractArray}, + Tuple{typeof(LinearAlgebra.BLAS.symm),Char,Char,Number,AbstractArray,AbstractArray}, + Tuple{typeof(LinearAlgebra.BLAS.symv),Char,AbstractArray,AbstractArray}, + Tuple{typeof(LinearAlgebra.BLAS.symv),Char,Number,AbstractArray,AbstractArray}, + Tuple{typeof(LinearAlgebra.BLAS.trmm),Char,Char,Char,Char,Number,AbstractArray,AbstractArray}, + Tuple{typeof(LinearAlgebra.BLAS.trmv),Char,Char,Char,AbstractArray,AbstractArray}, + Tuple{typeof(LinearAlgebra.BLAS.trsm),Char,Char,Char,Char,Number,AbstractArray,AbstractArray}, + Tuple{typeof(LinearAlgebra.BLAS.trsv),Char,Char,Char,AbstractArray,AbstractArray}, + Tuple{typeof(Statistics.mean),Function,AbstractArray}, + Tuple{typeof(\),AbstractArray,Number}, + Tuple{typeof(broadcast),Any,Vararg}, + Tuple{typeof(copy),Any}, + Tuple{typeof(float),Any}, + Tuple{typeof(getindex),Ref}, + Tuple{typeof(kron),AbstractArray,AbstractArray}, + Tuple{typeof(map),Function,Vararg}, + Tuple{typeof(mapfoldl),Any,Union{typeof(+), typeof(Base.add_sum)},Union{Number,AbstractArray}}, + Tuple{typeof(mapfoldr),Any,Union{typeof(+), typeof(Base.add_sum)},Union{Number,AbstractArray}}, + Tuple{typeof(mapreduce),Any,Union{typeof(+), typeof(Base.add_sum)},AbstractArray}, + Tuple{typeof(sum),Function,AbstractArray}, + Tuple{typeof(sum),typeof(abs2),AbstractArray}, + } && return false + return true # no exclusion applies end - """ overload_declarations!(original_signature_def) From d37b487dd2f64394f58eb2ca16def087f3c5473d Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 2 Jul 2021 18:08:33 +0100 Subject: [PATCH 69/71] Update CI to match current prod min julia version --- .github/workflows/CI.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index e346dfdd..64594ee7 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -30,9 +30,9 @@ jobs: - os: windows-latest arch: x86 include: - # Add a 1.5 job because that's what Invenia actually uses + # Add a 1.6 job because that's what Invenia actually uses - os: ubuntu-latest - version: 1.5 + version: 1.6 arch: x64 steps: - uses: actions/checkout@v2 From 00ef4acdc02dabd688d841a26c5de6939adda7b7 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 2 Jul 2021 18:17:04 +0100 Subject: [PATCH 70/71] block one and zero --- src/sensitivities/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sensitivities/chainrules.jl b/src/sensitivities/chainrules.jl index ec1526cf..29885f2c 100644 --- a/src/sensitivities/chainrules.jl +++ b/src/sensitivities/chainrules.jl @@ -104,7 +104,7 @@ function should_use_rrule(sig) # Ignore these functions because they have better Nabla specific versions. opT ∈ typeof.(( - isapprox, axes, size, length, isassigned, + isapprox, axes, size, length, isassigned, one, zero, Base.Broadcast.combine_styles, #TODO should i keep this? )) && return false From 970893a0eb4afbe26f9cc8d22d58e831b9ae07f3 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 2 Jul 2021 18:39:42 +0100 Subject: [PATCH 71/71] Update docs to be compatible with ChainRules and document about using ChainRulesCore --- .github/workflows/CI.yml | 2 +- .gitignore | 10 +++-- docs/Project.toml | 4 +- docs/make.jl | 25 ++++++------ docs/src/assets/invenia.css | 75 ------------------------------------ docs/src/assets/logo.png | Bin 0 -> 7274 bytes docs/src/pages/custom.md | 17 +++++++- 7 files changed, 40 insertions(+), 93 deletions(-) delete mode 100644 docs/src/assets/invenia.css create mode 100644 docs/src/assets/logo.png diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 64594ee7..9ace3bbf 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -81,7 +81,7 @@ jobs: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 with: - version: '1' + version: '1.6' - run: | julia --project=docs -e ' using Pkg diff --git a/.gitignore b/.gitignore index b6d4700e..0560d362 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,11 @@ +*.pdf +*.DS_Store *.jl.cov *.jl.*.cov *.jl.mem -*.pdf -*.DS_Store Manifest.toml -docs/build/ +docs/build +docs/site +docs/src/assets/chainrules.css +docs/src/assets/indigo.css +.vscode/settings.json diff --git a/docs/Project.toml b/docs/Project.toml index 53bc6f84..d931a66c 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,5 +1,7 @@ [deps] +DocThemeIndigo = "8bac0ac5-51bf-41f9-885e-2bf1ac2bec5f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Nabla = "49c96f43-aa6d-5a04-a506-44c7070ebe78" [compat] -Documenter = "~0.19" +Documenter = "0.27" diff --git a/docs/make.jl b/docs/make.jl index 5b0b36da..8ff20539 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,25 +1,26 @@ -using Documenter, Nabla +using Documenter +using DocThemeIndigo +using Nabla +const indigo = DocThemeIndigo.install(Nabla) makedocs( modules=[Nabla], - format=:html, + format=Documenter.HTML( + prettyurls=false, + assets=[indigo], + ), + sitename="Nabla.jl", + authors="Invenia Labs", pages=[ "Home" => "index.md", "API" => "pages/api.md", "Custom Sensitivities" => "pages/custom.md", "Details" => "pages/autodiff.md", ], - sitename="Nabla.jl", - authors="Invenia Labs", - assets=[ - "assets/invenia.css", - ], ) + deploydocs( repo = "github.com/invenia/Nabla.jl.git", - julia = "1.0", - target = "build", - deps = nothing, - make = nothing, -) + push_preview=true, +) \ No newline at end of file diff --git a/docs/src/assets/invenia.css b/docs/src/assets/invenia.css deleted file mode 100644 index 343c6f22..00000000 --- a/docs/src/assets/invenia.css +++ /dev/null @@ -1,75 +0,0 @@ -/* Links */ - -a { - color: #4595D1; -} - -a:hover, a:focus { - color: #194E82; -} - -/* Navigation */ - -nav.toc ul a:hover, -nav.toc ul.internal a:hover { - color: #FFFFFF; - background-color: #4595D1; -} - -nav.toc ul .toctext { - color: #FFFFFF; -} - -nav.toc { - box-shadow: none; - color: #FFFFFF; - background-color: #194E82; -} - -nav.toc li.current > .toctext { - color: #FFFFFF; - background-color: #4595D1; - border-top-width: 0px; - border-bottom-width: 0px; -} - -nav.toc ul.internal a { - color: #194E82; - background-color: #FFFFFF; -} - -/* Text */ - -article#docs a.nav-anchor { - color: #194E82; -} - -article#docs blockquote { - font-style: italic; -} - -/* Code */ - -code .hljs-meta { - color: #4595D1; -} - -code .hljs-keyword { - color: #194E82; -} - -pre, code { - font-family: "Liberation Mono", "Consolas", "DejaVu Sans Mono", "Ubuntu Mono", "Courier New", "andale mono", "lucida console", monospace; -} - -/* mkdocs (old) */ - -/*.navbar-default { - background-color: #194E82; -} - -.navbar-default .navbar-nav > .active > a, -.navbar-default .navbar-nav > .active > a:hover, -.navbar-default .navbar-nav > .active > a:focus { - background-color: #4595D1; -}*/ diff --git a/docs/src/assets/logo.png b/docs/src/assets/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..214b5c9a49c2bb5a0db1de74ea83b25b56d2051e GIT binary patch literal 7274 zcmYj$cRbc@^#2F9qC$8gTisR}MfNV6kd!^sZFbv)%y^JpA=x8JX2wmKN%l<0CNq0v zC48^z_WXX|??2`Bajxr}^FHr$u5*rOcQlm9&d{Gh5QI!cS^h495M#m5lO(6Wza9bd zRNy~SM`e8%1fc~#(^3&wrF{cZ2%W2fo~yQlrR!sJXAE-lp~C|Vs|wcK3Ue1@{?OB@ z6(flt47)1wH}822FON;Rjp%$kHJP+FQ#Ts$Jzy-L_<|l&0+Sw7&@WlT&)=9@F7S*6 z^bzz27=8AwRP&q`&6L+dJgru@k^{$HHpcrCaMfHiIp??E{jTTzwQ1q%O*_##?12KQ z$wMjoS0;nDHDBhPsx}2www_GxIvIBTm0BD;?p&i9^x&K?I$GVYV0nh>=G}ajHYPPg zUB^&yw6UIOWnkL8g3X-%$WRfU&$)DLRD`f?)HYg&5+R$U^&gC5jmSn8y+_9}Qs#*` zVwfI&>%vwr;IZ(+( zfK@$7%!+7l-RK+lr*-;h{Fn1B36L4;bfxgFDBV69gwRvQK2QKJ_?o+!;T3R5O7AKG zRLu0t23YQ9?~&M406$mq9M-g-L(0-LPhk*C=doW-cEg0S*Dad~kOKWH1lM13H!~#x z!J2o`r+8xtpoufK)0_Ey7MeFM-vqi-zxU%Dz%Q&@?n)DIxXvBF_6~3N_0Jdo4rNr zG(g+yxA`a^{$izlRDypF_Jue=yC*LjaCMV3CbXBt?~DxI1e5-L+=;s}LHBWf<)PJE z0Kwmna#4uCHzC^{nT|VsJ4NHK>RC450e#(k;r%@DH;q<;4XdAm-hgjwI;sfL!YS*h zGsKY>4*iw5_SuSw0AgJ(PB>X3AbE8;6a~!s(G~S1>l{3edKiao_ycUJcNc0e$HRR- zCg;CTI{!bS$`OhlCM}etz@*~)GB$3xCx;)zVZVt3Qq-qOM~(b@aylsIzi+0zaR`JO ziBSAuT8TfquFKMdy@Mwdy;U6crx!ds@t<6iG56+?l%fAl*EA#U6>sxfuGb>f=H`kT@cpL4xoH2!@D zAykb0HB{i&E_`9VBt_dk@(nq&Dc{S8{pWocE0PBMa*15c3Q(+Rvhxv&358K2yld$j zBuH`ENsJvI=={*M=>0=(3taV;ju=VDLzmoARdRU!mdV;dXpkjD;V2#-y!vN6v1WJG zqZ508f44M(D~SSyc8RyZk!z)Q39XwX*a5uhS)&$$4P}YUcn_~> z2Yx%rW6{e+8(oX

PEcbNeX(+oT6DusO4z36R_WRbHNcVceza1GeU=Q%r!)&Nt}I z$J0{OqeelR`hNj5?=1OI%W-C-hTf=Ru@mk6Aw>}gm_asq3*e$sRRLIwinUEfdWClg|yPwHaMEL zPqU`^SNBD2e>qk%##UL|SOFn3dTw@-D{6iD6-k@nCC28T}77>I^ptvWW(ZCt6IileBQwQ}6QubVptmgwFSaOAI@oqz6+7&!=FB5ifMcR!{Bk16 zHdN@l{)0nhz$RxEfP1~bOZ%aa8qX42Gry~I?V?la$5jwgWu8szptVKSkLaAF(=pJZ!u4uhrJY~P<&{-vbQT8vwQ)1@5Q3V z(pk!5PtBx8fu<@nazRe-t4jx*Uwpr}8?aMRm@~F*`P*PzG8tq`hea;R{E=#hW`$2j z%wHqp*cG&0Z$b4k;Lndse@@4&984Iv4>Wi}_Ii4A9d=yoWNmM2-fiuE*-L^Hb#t!8 z63nEaXBQ5f4wfYu6#=eurT3m**LmtoBKOqdIz54CIRm@Dr2<3bor0H^EWjP*UAkWttq{yvGb=OWzvaZK0JyAl3;XW#n+DWjXcF2<5uNE)n z035wFU%}S=!b!d%X7Zwri6tN+TdRZH_OfNSOclA8hKBfw&c~}wToF9nmr8}O6H?Vt zmv7J$fKcDw#~wbFVc6e71CUuiE|vJqJi`G()T+DJkQMR*#HgycHWbnM;;pwpfMZ67 zzC==DZ(RQTq0HHW2qczFCMHELl{IX*tU!j7_U&F1%rVKjxRt7fvd+*T-1iG+uRceA zwnhOgQcFm1bxTCY8sEQCa<=wENh39OYm@k#m6#dPLL+UEz-;bXxU2%N7GbR$T1>`H zzqJJLVk~Vta&guYL6#Sv-(i`hr(#qJw$Ss`C)gw{mTMGw_rz{Ebw(0asEskMo5QS& z9QFKvJN{Q`n(4;lL4jXKFxmDNy}NGGCFaL`ar@h)w4S}#ezkEe0pZfw*^ zlr+BB#_xQw4yGiXXWt2Ff2en*defP3hMHAtKwmSaG^I!iCWp=U$KPq&dIa;ZD+hLT zk)T4?YiHJsA6=~I7$V;l73j$hqsGI-6^(-eeWv`T=SCKJoE0*wz#e z;{s(Rsk10Sm^|7tryp4qt}sIv%et8DBdq(xTb<+|`ICRPvlk=;>Z+=_NA9oG%>7g! z5)MF%4{JZF^GZSLTMt2@B*DgJij{Bd6rq~=Q6m$>52eqVY{k8bw)Bn(E=v;&x%Ii<|0#^6oh4B>oTfz5$ynNOxr#^ri*u~=k++-dPJzecmjYwtx@T~a z{={U|pNzU!)9(DcBE3MB`zY7Vmx9>nrOT6%9o+*a!hqz<#W5k;Gs{DzLF*l%G z^6f^@;l%Ds$6R-|PxQGN1TWle_sclf$$xoDbUb#Ce2ltE3jqvd+W; z@YJVE0s3mhQ<4TACdKuVz&8Lx>6qpa=yzB=uMIMN*k!fpU3bNNHxky^rGJpGcNg#N}$1 z9Zpprr2x?(BTliXQu*c=3P5kc_;vZeW>GA7;>k9i+lk$NUfX?j-P5>Z`(I#SOOCL5 zEt^>UL|aFia@);8UDx?Y6Wu^N@N|H(eD$YwkVWk!tb7B^{ub+Bdgr;)YMiV)8_Ee< zIHQPPl(g;GBt4LCm;k87Xpvn^d&$@yQsVO{S@`1f>(SyY8`w!NddPX%H_PSG7FJ1a zlUr2@)8U|th`FERX(ys#HAvu2{wlkG;>Aa2b!79IH=zm$vSXNc=^5`RD}2KvtCQq( z5Idu18;EsOts*OSp<52OBZLMf7&nT%U3&NQ0PWXZH0@5iC#iXqGg`#nef@Vtt;V`g zvT3^2Ap5Gshzw{GB!hnhuc4{NIV%GyvDT)U0{Hs&NX*JDbFGSRWP@9)d+SZ)$Z#1` zweMXQ-+%$RmTORZ5mBYp?S|8OZwnRPm{+;^B zR*2Do+xu(Ql7z@me^RD{N)+aot1wzFYc;>@gShRK0eI*3Q7^(O+i+1_dtRAqJ`8`Y zrr&(T4-Jd9@jNVloB$VcmlmT@%^$4ISc2=f@g;bM`bW0!f?XqLsYVKRAfS8yjk*_j zP`rw-eh3*pUifnf`PSriVa)O8nx7UfdO$JT958W{beYP{^4OqO{KqaO=}YjCrPJ!` zlqR!_$1ey7F57^3SFRBu3X@4U?A+FwTrpeD`|Cq{*Js5KKosN%V*`yI6cX)p205 zfF9Iur1X4#$o2;zq|%UphTUs-rak(Ehxja#dY`9X%Pt1mQH^VIL~B1pXlW_T2gFXd zyw;4@;7qBDM;nOXWb!5lKmRN@g<}$Pf3t6ipu|M{GfNJqV)bV-m(5EkZvW7pxsOvs zBb%*Te6!TzZW}S(Szx{`id^X%3^@#2?3W<|skMNgB>&CL2hC8-?xI+QUcQ# zjO`AX19BO{*^FSei7#TOW?s4*wR5^Yz+2;>Qv~u?b}c^YU(!dae)o|uv3G3&1Tc+x0A=RzeMaFzor5>aTHrB{qz67)k1&Z9IA?K94mk^luB51cUCx4s_?FfUjacu}Ren$UeSj_~vY z;|Y|(>5=(`v*<1pubRJ$+r)@p>D7^@f}}GhE@VFEe9yTF)vpAZa=$Nc|9GX2n*h>1 zayC2Eo@i~id2+Dbc0wi}3>nde5hHbR0YpBgr04gWyxMYB&R3G_4C64lD45BDmqFw0 zHhCB|T2`Q0T61OOEoUkQRIp`olxb{$Zs__Bpr+@=yU%aYpbA~g&*7e91e2+HQp}hIa ze1Nntt%gPjc2#}C(>t2rvb@OyQ>)%u&#w~|Z8V@mh{&`|t_n$GdEPJ5FSF}-2xYh8 z*^lal4`iah@zBEUS|`S$;z9f}!NbHyWLJ@HLGYeQ%EDUe3;b*S zuwA6T+37&@>58e8krIeqahjL5>)UgTW*5AVLx_&Y6*dT%K!-5nl8s`V7KbC};rD(X zHOwT^tHk!b`8c@C-Wn7?FjvGmbj1dcN=o3H?VXCNI;D;Y%t~w*Bf&!=E37l*T(+bF zP%^R1KrwpAjrEbK8Vfnn{njfe0V!hk{x<6L@$vw;|CN692rXvv-^^&`ccnJhOxEW^ z>YKrU44|}#P&BiBrHO%2@%Wfc<>;TP_{71LQ@nCVh?WDa$gCWm6dH zsm1rz446q``@jihiYKGrtI53W;gD~oGbB)<^xwGsJtshE|NgxOo84_4u!z~A1sjk4fa8K~7uf3R zmym7z#o`nEwnYZ2T6TtGD(gl?n|s`fe*H8*zBMow#EK1V!~6U6Am(P8y}It?&+KjU zVMr$V7155524#){L6N;}m5W-wG94OpY5gNy7Wr(FcjWpmtbN_ZqxmbpdF7=*@!OWp zhN-uHg;Sf@`e`sELJ|TUzue$5q|e|g!IQAjGO2@Kik6P#a4K%JMrkO3kXe?{RZPSM zNXfH2zb;9`N0dtBZ?f4c6GuSaCgyni73LrZ7W$6KoJRX5C&}kY!7{?O=t6nm#X{#@ z>3~3Ia>kK6$4cO)99@Q+B5n1|mqoC3es`YU$AMryyun5AwiZr4X3XtQRobUgDFIaX z`MYUYUoh;WQr)z04jNx_tMiVtfSFtRz|2tT%a?wt8fH=Yn~b${uIJw*75yMS2>JlT zvW&h$CN{uI-X5$BJPYT*{cHkU1iVHeAbGAA*|AU3fy9t65QD9-Z>~idAfvPG5TZOqTrJ+JDBJf+| z@v8&{Fj1xrqDE2#iUa#`TVO{EygL6ch(+>T6A5#cO&-J-0Zup)=9T!|WAL?I(7EXtt$cwvcZqJ2~5cm)jSI zd90*IC!P{V1u7}h3;nR+jq_Ks9)8abOj%yqbvpREQ?bJGT=PA&YGCft!_UzX$_c?> ztFTEbs|-!^5}Dpnr`vmWA+xCBE+GJ)HPaNVJT+F2rkWN+vL{cM+ZDA)pv*)XT>g0o6Lo*6kUVHj5Zg}u7aDAuyjH;*gIkCor5&dJsRUhA*Z_5KJd5NGBP+D**;zP(co1 zq=d0eD~~ag0nLZ~XQ1v-sC&tBwDt%BO40ax?I|y(IstbEtmqM^f~J*l0oNtFUr&VKMhqx2fQm-$l71T4QtgZDqX1Yw%m$(# z3c!Y6%FxYw`9FN1@!n_I7j%lyREzu$$-Xe)+#ukg4mSXyGYl)D=-?+>@TtXI%zh#^ zLyf!n95fvMF<_>cu7UxxaMuA%waR=V$qOa3;QssDTf^ovP%;EcI^kVRi?bjG$hN?> ztaEOj{%UN7L`{S;D8G;daJHTSt2@3iarmV8GeQv3`YV03g890uEZ|Z~8Lmy84X)mN zToV2Njs?SGWqdlNzZV8RfCES)