-
Notifications
You must be signed in to change notification settings - Fork 577
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: Tensor Indexed Updates #433
base: master
Are you sure you want to change the base?
Conversation
This is to allow numpy-like (or rather JAX-like) indexed tensor modifications like slice assignment.
a02ff39
to
77aa8fd
Compare
creating a simplified API around existing ones. Performance for existing ops | ||
will be unaffected. | ||
|
||
Such a copy-and-update method can lead to poor practices: users may overuse and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there the risk that users start to update single elements in a loop with this new syntax?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, which would lead to excessive copying of tensors.
As mentioned in this paragraph, we can try to emphasize in the docs that this shouldn't be done, and it's not something we've observed with scatter_nd_update
, which could also be misused in the same way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I saw but I think it's a bit less natural "to abuse" scatter_nd_update
in a loop instead of a numpy like syntax.
Of course it will be slow also on numpy as we suppose to prefer a "vectorized" use of the op so generally this kind of "individual" access could probably rely on ufunc (See google/mlir-npcomp#1). But in our case we need also to add the extra copy penalty.
Also I don't know if this misuse could be detected on autograph tracing at least to emit a warning (like the one we have for excessive retracing).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I don't know if this misuse could be detected on autograph tracing at least to emit a warning
Potentially. I'm sure it's worth it until if/when we see that this kind of abuse is common. For those who do care about performance, benchmarking/profiling the model will likely expose these repeated calls as the issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have any feedback instead about the mentioned "ufunc like" needs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I understand what the "ufunc like" needs are. I agree if a user is modifying a tensor in a loop one element at a time, they could use a ufunc. The tensor.at[index_expression].apply(ufunc)
should support this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure but I meant how is this mapped efficiently "down to the stack"? I meant if we suppose to use ufunc
when we need to access to the Tensor without the compositional TF/*HLO ops can we have something like e.g. numba @vectorize decorators or sompething probably similar jax.numpy.vectorize
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like tf.map_fn? I agree an efficient vectorized element-wise operation would be useful. This is not what we're proposing here though. Well, I suppose .at[...].apply(ufunc)
should probably do something smart like that. The focus here is on the update API though, so we'll just re-use whatever already exists in TF for the implementation details to start. These can be improved later via other additions like a tf.vectorize
API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I suppose .at[...].apply(ufunc) should probably do something smart like that.
Yes, thank you for the clarification. I supposed that this was the right occasion but if we'll just re-use whatever already exists it is still not the case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be still nice if we could put this in a notice as a future work or just to disambiguate the perimeter of the RFC.
v = v.at[:, :, 2].set(0) | ||
``` | ||
Which would create a modified copy of the original variable. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A possible naive confusion is the user thinking that at[]
by itself returns something usable, a mutable Tensor constructed from the sliced values. In reality, the user only gets something usable (outside of the indexed update operation) by "finishing the sentence" and calling set()
.
This could perhaps be mitigated with clear documentation and informative error messages.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Hopefully through usage examples and the tutorial, it would become clear. I'm not sure if there's a python equivalent to ABSL_MUST_USE_RESULT
to force the user to actually assign the result of something like set(...)
to something, which may help mitigate this.
## Questions and Discussion Topics | ||
|
||
- Does the `.at[]` syntax seem appealing to the TF community? | ||
- Should we consider deprecating existing gather/scatter methods in favor of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The existing methods may save users some Python overhead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Will this only be a concern for eager? I don't expect the overhead to be significant, but we can do some basic benchmarking and make sure that the most common use-cases are as fast as possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's also a concern for graph-building overhead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Relative to the rest of graph-building, I again suspect the cost of this will be negligible, but maybe you have found otherwise with the tf-numpy project? We can do some benchmarking to confirm for models that currently contain tensor_scatter_nd_update
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We didn't find any real-world problems (but tf-numpy wasn't used much anyway). Agree that benchmarking will help. We can punt on benchmarking until we see real problems.
Added basic type annotations to the implementation details.
This is great, thanks for the proposal. LGTM. Note that this is currently doable in TF in a performant way via a XLA op: from tensorflow.compiler.tf2xla.python import xla
output = xla.dynamic_update_slice(input, update, indices) |
This is a great initiative! |
Thanks for the proposal! |
Please. |
This is to allow numpy-like (or rather JAX-like) indexed tensor modifications like slice assignment.