-
Notifications
You must be signed in to change notification settings - Fork 615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TransformTrace
, TransformTracer
and TransformInterpreter
classes for transforming PLxPR
#6389
base: master
Are you sure you want to change the base?
Conversation
…AI/pennylane into plxpr-interpreter-base
Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before I forget, I realized that the line:
trace_stack.append(MainTrace(0, EvalTrace, None))
is necessary within the markdown explanation file (after having defined the EvalTrace
class), otherwise the example does not work
Edit: I added a suggestion in the file directly : )
idx = max(t.idx for t in tracers if isinstance(t, TransformTracer)) | ||
is_qml_primitive = ( | ||
primitive.__class__.__module__.split(".")[0] == "pennylane" | ||
or primitive.name.split("_")[0] in _mp_return_types |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So measurements don't have the module name of "pennylane"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, what happens if someone is using a custom operator? Should we have a backlog task to figure that out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So measurements don't have the module name of "pennylane"?
No, because measurement primitives are jax.core.Primitive
s so their namespace is jax
, but the rest of the primitives in pennylane are pennylane.capture.NonInterpPrimitive
s. We could try to make the measurement primitives NonInterpPrimitive
s as well, but I feel like that is not soemthing that should be done in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, what happens if someone is using a custom operator? Should we have a backlog task to figure that out?
Good point, it will not work out of the box. I considered standardizing the primitive names by adding qml.
to the front of the primitive names, but eventually reverted that change, because that also seems like it should be done in a different PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @dwierichs in hindsight, why aren't measurement primitives not NonInterpPrimitive
s?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dwierichs For context, the NonInterpPrimitive
is the new name for NonJVPPrimitive
(or a similar name) you created some time ago. I re-named it after intercepting BatchTracers
along with JVPTracers
as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the context, @PietropaoloFrisoni
@mudit2812 I just added it where it was needed from the standpoint of the epic requirements, basically.
I suppose if we want to differentiate through measurements, that actually will currently break due to them not being NonInterpPrimitive
s (?)
There was no inherent reason to not make them a custom Primitive.
[sc-72804] [sc-72803] This PR adds a primitive for capturing transforms. **Description of change** * Add a primitive to capture transforms: every time a transform is created, it automatically gets registered as a primitive * This primitive is currently only for capturing transforms and does not actually transform the input function. * This primitive gets bound (binded?) when calling the transform on the QNode or a function. * When applied to a function, the transform is captured as a higher-order primitive. Evaluating the subsequent JAXPR doesn't do anything (i.e, the inner JAXPR is evaluated without any transformations) * When applied to a QNode, the transform is appended to the QNode's `TransformProgram` as well so that executing QNodes applies the transform(s) to the constructed tape, and also captured as a primitive. * Add `plxpr_transform` attribute to `TransformDispatcher` and `TransformContainer`. This is not currently used and is being added here to set up scaffolding in `pennylane.transfoms.core` for #6389 . **Benefits** Transforms can be captured into PLxPR and no extra work is required when creating new transforms --------- Co-authored-by: Pietropaolo Frisoni <pietropaolo.frisoni@xanadu.ai> Co-authored-by: Christina Lee <christina@xanadu.ai>
…ts and storing state
for inval in invals: | ||
# The following branch is added because we want observables to get transformed. | ||
# However, due to the logic used in `interpret_operator_eqn`, we only transform | ||
# operators that do not get consumed by other primitives. So, we do the transforming | ||
# (binding) here instead. The global state is updated with the "op_is_observable" key | ||
# because transforms may want special handling for observables of measurements. | ||
if isinstance(inval, qml.operation.Operator): | ||
# pylint: disable=protected-access | ||
op_tracers, op_params = self._env[id(inval)] | ||
self._state[-1]["is_measurement_obs"] = True | ||
new_inval = inval._primitive.bind(*op_tracers, **op_params) | ||
self._state[-1].pop("is_measurement_obs") | ||
traced_invals.append(self.read_with_trace(new_inval)) | ||
continue | ||
traced_invals.append(self.read_with_trace(inval)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is new. I hope that the comment I've left to explain why it was added is enough. Please let me know if you need more info.
This PR adds 3 new classes to
qml.capture
to facilitate transforming PLxPR natively without the need to first createQuantumScript
s.TransformInterpreter
will be used to evaluate PLxPR for the purposes of applying transforms.TransformInterpreter
will useTransformTrace
andTransformTracer
to transform primitives that are being evaluated.TransformDispatcher
,TransformContainer
, andTransformProgram
.qml.capture
to give a detailed explanation of how the framework for applying transforms natively to PLxPR works.qml.capture.enable()
andqml.capture.disable()
to dispatch tojax
when using pennylane classes with capture enabled, and toautograd
when capture is disabled. This assumes that users won't be silly and try to usepennylane.numpy
with capture enabled.Future work: Storing transforms as primitives in PLxPR (Or as metadata belonging to a "transform program" HOP), and better integration with the execution pipeline.
[sc-75560]