Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Cholesky passing CI #217

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Nabla"
uuid = "49c96f43-aa6d-5a04-a506-44c7070ebe78"
version = "0.13.4"
version = "0.13.5"

[deps]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
Expand All @@ -15,7 +15,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRules = "0.8, 1"
ChainRules = "0.8, 1.35.3"
ChainRulesCore = "0.10.9, 1"
ChainRulesOverloadGeneration = "0.1.2"
ExprTools = "0.1.4"
Expand Down
2 changes: 1 addition & 1 deletion src/sensitivities/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ function overload_declarations!(signature_def)
definitions = Expr[]
for swap_mask in Iterators.product(ntuple(_->(true, false), n_args)...)
any(swap_mask) || continue # don't generate if not swapping anything.

# Also don't generate if swapping only final varadic argument
# as this could be a emptry varadic argument and thus result in type-pirating
# original function.
Expand Down
18 changes: 17 additions & 1 deletion src/sensitivities/linalg/factorization/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ const AM = AbstractMatrix
const UT = UpperTriangular



@explicit_intercepts(
Cholesky,
Tuple{AbstractMatrix{<:∇Scalar}, Union{Char, Symbol}, Integer},
[true, false, false],
)

function ∇(
::Type{Cholesky},
::Type{Arg{1}},
Expand All @@ -41,3 +41,19 @@ function ∇(
)
return getproperty(X̄, Symbol(uplo))
end

# Yar, some work arounds for breaking changes in ChainRules.jl
# https://github.com/JuliaDiff/ChainRules.jl/pull/630

# Single arg function was dropped
function ChainRules.rrule(::typeof(cholesky), A::AbstractMatrix{<:Real})

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary? In LinearAlgebra, the single-arg method calls the 2-arg method.

Copy link
Member Author

@rofinn rofinn Jun 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still trying to wrap my head around how all the overloading works here, but I think it has to do with how Nabla.jl generates overloads for Nabla.Node types based of existing rrule signatures. I could probably dig a bit deeper into how to define an explicit overloads in Nabla, but the rrule solution seems easier :)

https://github.com/invenia/Nabla.jl/blob/f12de3ea148f1b348615b1ee24ab2a63e68d92d5/src/sensitivities/chainrules.jl

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps, but this issue would likely then crop up elsewhere, since ChainRules tries to define rrules for methods that are called by other methods with fewer arguments. e.g. the rrule for lu would probably also be missed by Nabla: https://github.com/JuliaDiff/ChainRules.jl/blob/6ff4c319f8fd25f27636d28144d78c92f81d8753/src/rulesets/LinearAlgebra/factorization.jl#L134-L136

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the "missing" single-arg method is only breaking some very specific tests such as

X_ = Matrix{Float64}(I, 5, 5)
X = Leaf(Tape(), X_)
C = cholesky(X)
@test C isa Branch{<:Cholesky}
but AD is still working? That is, maybe one just has to update the tests?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nabla.jls Node type doesn't subtype AbstractMatrix (or Number).
This means it had problems going through things that have that kind of type restriction.

return ChainRules.rrule(cholesky, A, Val(false))
end

# U and L properties were replaced with factors
# This should probably be moved to ChainRules to support both options.
function Base.getproperty(tangent::Tangent{P, T}, sym::Symbol) where {P <: Cholesky, T <: NamedTuple}
idx = hasfield(T, :factors) && sym in (:U, :L) ? :factors : sym
hasfield(T, idx) || return ZeroTangent()
return unthunk(getfield(ChainRulesCore.backing(tangent), idx))
Comment on lines +56 to +58

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this return an upper/lower triangular matrix if tangent.U/tangent.L is requested? It seems this implementation would just return tangent.factors in both cases.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC in the CR PR it was also discussed if getproperty for these tangents should be added to CR.

Copy link
Member Author

@rofinn rofinn Jun 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this return an upper/lower triangular matrix if tangent.U/tangent.L is requested?

Yeah, this was just to see what was needed to make tests pass. AFAIK, tangent.U and tangent.L have just been renamed to factors shouldn't this just work as is in most cases? I guess the concern is that factors would be the incorrect type?

IIRC in the CR PR it was also discussed if getproperty for these tangents should be added to CR.

Ideally, this is something that should be added to CR, but if folks disagree it can live here to keep things working.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth at least opening an issue on ChainRules to discuss adding this there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end
3 changes: 2 additions & 1 deletion test/sensitivities/linalg/factorization/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
@test getfield(U, :f) == Base.getproperty
@test unbox(U) ≈ cholesky(X_).U

@test_throws MethodError ∇(X->cholesky(X).info)(X_)
# @test_throws MethodError ∇(X->cholesky(X).info)(X_)
@test_throws ErrorException ∇(X->cholesky(X).info)(X_)
end

let
Expand Down