-
Notifications
You must be signed in to change notification settings - Fork 5
Cholesky passing CI #217
base: master
Are you sure you want to change the base?
Cholesky passing CI #217
Conversation
idx = hasfield(T, :factors) && sym in (:U, :L) ? :factors : sym | ||
hasfield(T, idx) || return ZeroTangent() | ||
return unthunk(getfield(ChainRulesCore.backing(tangent), idx)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this return an upper/lower triangular matrix if tangent.U
/tangent.L
is requested? It seems this implementation would just return tangent.factors
in both cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC in the CR PR it was also discussed if getproperty
for these tangents should be added to CR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this return an upper/lower triangular matrix if tangent.U/tangent.L is requested?
Yeah, this was just to see what was needed to make tests pass. AFAIK, tangent.U
and tangent.L
have just been renamed to factors
shouldn't this just work as is in most cases? I guess the concern is that factors
would be the incorrect type?
IIRC in the CR PR it was also discussed if getproperty for these tangents should be added to CR.
Ideally, this is something that should be added to CR, but if folks disagree it can live here to keep things working.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's worth at least opening an issue on ChainRules to discuss adding this there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# https://github.com/JuliaDiff/ChainRules.jl/pull/630 | ||
|
||
# Single arg function was dropped | ||
function ChainRules.rrule(::typeof(cholesky), A::AbstractMatrix{<:Real}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this necessary? In LinearAlgebra, the single-arg method calls the 2-arg method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still trying to wrap my head around how all the overloading works here, but I think it has to do with how Nabla.jl generates overloads for Nabla.Node
types based of existing rrule
signatures. I could probably dig a bit deeper into how to define an explicit overloads in Nabla, but the rrule
solution seems easier :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps, but this issue would likely then crop up elsewhere, since ChainRules tries to define rrule
s for methods that are called by other methods with fewer arguments. e.g. the rrule
for lu
would probably also be missed by Nabla: https://github.com/JuliaDiff/ChainRules.jl/blob/6ff4c319f8fd25f27636d28144d78c92f81d8753/src/rulesets/LinearAlgebra/factorization.jl#L134-L136
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the "missing" single-arg method is only breaking some very specific tests such as
Nabla.jl/test/sensitivities/linalg/factorization/cholesky.jl
Lines 12 to 15 in f5adedb
X_ = Matrix{Float64}(I, 5, 5) | |
X = Leaf(Tape(), X_) | |
C = cholesky(X) | |
@test C isa Branch{<:Cholesky} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nabla.jls Node
type doesn't subtype AbstractMatrix
(or Number
).
This means it had problems going through things that have that kind of type restriction.
Well done. I wonder why we didn't spot them in the reverse dependency checks? (I am still away sick) |
I don't think type piracy is the right solution here, but it narrows down the specific changes that broke our codebase. Perhaps we should re-add these methods to ChainRules?
Closes #216