-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Efficient Inplace Accumulating Strided MatMul #273
Conversation
I think this can all be cone much easier now with 5 arg I think we might even be able to just replace the generic code for the rrule/frule of matmul with that, |
|
This is basically done, but can't be merged until we are testing accumulation properly, as it could just be wrong |
Another advantage of this over the Nabla code is that we don't have to worry about the tests not hitting the right one, because of the fact that this is the most general one. |
@@ -4,6 +4,7 @@ version = "0.7.23" | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Version number needs to be bumped
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good. Just a couple of comments.
src/rulesets/Base/arraymath.jl
Outdated
@@ -19,23 +19,47 @@ end | |||
##### `*` | |||
##### | |||
|
|||
function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real}) | |||
function rrule(::typeof(*), A::AbstractMatrix{<:Number}, B::AbstractMatrix{<:Number}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sethaxen pointed out that we probably want to restrict these to Real
and Complex
numbers.
These come out from Nabla,
but i think are wrong as they don't currently work, and were not actually being hit.
See invenia/Nabla.jl#192
I think we need to unwrap the
Transpose
andAdjoint
objects before passing to GEMMand/or maybe consider the double transpose case carefully.
I may split this PR in two and just get the simple case that has no
Transposes
orAdjoints
in first, since Nabla actually uses that.