-
Notifications
You must be signed in to change notification settings - Fork 5
Cholesky passing CI #217
base: master
Are you sure you want to change the base?
Cholesky passing CI #217
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,12 +9,12 @@ const AM = AbstractMatrix | |
const UT = UpperTriangular | ||
|
||
|
||
|
||
@explicit_intercepts( | ||
Cholesky, | ||
Tuple{AbstractMatrix{<:∇Scalar}, Union{Char, Symbol}, Integer}, | ||
[true, false, false], | ||
) | ||
|
||
function ∇( | ||
::Type{Cholesky}, | ||
::Type{Arg{1}}, | ||
|
@@ -41,3 +41,19 @@ function ∇( | |
) | ||
return getproperty(X̄, Symbol(uplo)) | ||
end | ||
|
||
# Yar, some work arounds for breaking changes in ChainRules.jl | ||
# https://github.com/JuliaDiff/ChainRules.jl/pull/630 | ||
|
||
# Single arg function was dropped | ||
function ChainRules.rrule(::typeof(cholesky), A::AbstractMatrix{<:Real}) | ||
return ChainRules.rrule(cholesky, A, Val(false)) | ||
end | ||
|
||
# U and L properties were replaced with factors | ||
# This should probably be moved to ChainRules to support both options. | ||
function Base.getproperty(tangent::Tangent{P, T}, sym::Symbol) where {P <: Cholesky, T <: NamedTuple} | ||
idx = hasfield(T, :factors) && sym in (:U, :L) ? :factors : sym | ||
hasfield(T, idx) || return ZeroTangent() | ||
return unthunk(getfield(ChainRulesCore.backing(tangent), idx)) | ||
Comment on lines
+56
to
+58
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this return an upper/lower triangular matrix if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIRC in the CR PR it was also discussed if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yeah, this was just to see what was needed to make tests pass. AFAIK,
Ideally, this is something that should be added to CR, but if folks disagree it can live here to keep things working. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's worth at least opening an issue on ChainRules to discuss adding this there. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this necessary? In LinearAlgebra, the single-arg method calls the 2-arg method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still trying to wrap my head around how all the overloading works here, but I think it has to do with how Nabla.jl generates overloads for
Nabla.Node
types based of existingrrule
signatures. I could probably dig a bit deeper into how to define an explicit overloads in Nabla, but therrule
solution seems easier :)https://github.com/invenia/Nabla.jl/blob/f12de3ea148f1b348615b1ee24ab2a63e68d92d5/src/sensitivities/chainrules.jl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps, but this issue would likely then crop up elsewhere, since ChainRules tries to define
rrule
s for methods that are called by other methods with fewer arguments. e.g. therrule
forlu
would probably also be missed by Nabla: https://github.com/JuliaDiff/ChainRules.jl/blob/6ff4c319f8fd25f27636d28144d78c92f81d8753/src/rulesets/LinearAlgebra/factorization.jl#L134-L136There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the "missing" single-arg method is only breaking some very specific tests such as
Nabla.jl/test/sensitivities/linalg/factorization/cholesky.jl
Lines 12 to 15 in f5adedb
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nabla.jls
Node
type doesn't subtypeAbstractMatrix
(orNumber
).This means it had problems going through things that have that kind of type restriction.