Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TransformTrace, TransformTracer and TransformInterpreter classes for transforming PLxPR #6389

Open
wants to merge 222 commits into
base: master
Choose a base branch
from

Conversation

mudit2812
Copy link
Contributor

@mudit2812 mudit2812 commented Oct 11, 2024

This PR adds 3 new classes to qml.capture to facilitate transforming PLxPR natively without the need to first create QuantumScripts.

  • TransformInterpreter will be used to evaluate PLxPR for the purposes of applying transforms.
  • TransformInterpreter will use TransformTrace and TransformTracer to transform primitives that are being evaluated.
  • Scaffolding needed to create transforms for PLxPR has been added to TransformDispatcher, TransformContainer, and TransformProgram.
  • A markdown file has been added to qml.capture to give a detailed explanation of how the framework for applying transforms natively to PLxPR works.
  • Update qml.capture.enable() and qml.capture.disable() to dispatch to jax when using pennylane classes with capture enabled, and to autograd when capture is disabled. This assumes that users won't be silly and try to use pennylane.numpy with capture enabled.

Future work: Storing transforms as primitives in PLxPR (Or as metadata belonging to a "transform program" HOP), and better integration with the execution pipeline.

[sc-75560]

Copy link
Contributor

@PietropaoloFrisoni PietropaoloFrisoni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before I forget, I realized that the line:

trace_stack.append(MainTrace(0, EvalTrace, None))

is necessary within the markdown explanation file (after having defined the EvalTrace class), otherwise the example does not work

Edit: I added a suggestion in the file directly : )

tests/capture/test_capture_transforms.py Outdated Show resolved Hide resolved
idx = max(t.idx for t in tracers if isinstance(t, TransformTracer))
is_qml_primitive = (
primitive.__class__.__module__.split(".")[0] == "pennylane"
or primitive.name.split("_")[0] in _mp_return_types
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So measurements don't have the module name of "pennylane"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, what happens if someone is using a custom operator? Should we have a backlog task to figure that out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So measurements don't have the module name of "pennylane"?

No, because measurement primitives are jax.core.Primitives so their namespace is jax, but the rest of the primitives in pennylane are pennylane.capture.NonInterpPrimitives. We could try to make the measurement primitives NonInterpPrimitives as well, but I feel like that is not soemthing that should be done in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, what happens if someone is using a custom operator? Should we have a backlog task to figure that out?

Good point, it will not work out of the box. I considered standardizing the primitive names by adding qml. to the front of the primitive names, but eventually reverted that change, because that also seems like it should be done in a different PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @dwierichs in hindsight, why aren't measurement primitives not NonInterpPrimitives?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dwierichs For context, the NonInterpPrimitive is the new name for NonJVPPrimitive (or a similar name) you created some time ago. I re-named it after intercepting BatchTracers along with JVPTracers as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the context, @PietropaoloFrisoni

@mudit2812 I just added it where it was needed from the standpoint of the epic requirements, basically.
I suppose if we want to differentiate through measurements, that actually will currently break due to them not being NonInterpPrimitives (?)
There was no inherent reason to not make them a custom Primitive.

mudit2812 added a commit that referenced this pull request Dec 2, 2024
[sc-72804] [sc-72803]

This PR adds a primitive for capturing transforms.

**Description of change**
* Add a primitive to capture transforms: every time a transform is
created, it automatically gets registered as a primitive
* This primitive is currently only for capturing transforms and does not
actually transform the input function.
* This primitive gets bound (binded?) when calling the transform on the
QNode or a function.
* When applied to a function, the transform is captured as a
higher-order primitive. Evaluating the subsequent JAXPR doesn't do
anything (i.e, the inner JAXPR is evaluated without any transformations)
* When applied to a QNode, the transform is appended to the QNode's
`TransformProgram` as well so that executing QNodes applies the
transform(s) to the constructed tape, and also captured as a primitive.
* Add `plxpr_transform` attribute to `TransformDispatcher` and
`TransformContainer`. This is not currently used and is being added here
to set up scaffolding in `pennylane.transfoms.core` for #6389 .

**Benefits**
Transforms can be captured into PLxPR and no extra work is required when
creating new transforms

---------

Co-authored-by: Pietropaolo Frisoni <pietropaolo.frisoni@xanadu.ai>
Co-authored-by: Christina Lee <christina@xanadu.ai>
Comment on lines 299 to 313
for inval in invals:
# The following branch is added because we want observables to get transformed.
# However, due to the logic used in `interpret_operator_eqn`, we only transform
# operators that do not get consumed by other primitives. So, we do the transforming
# (binding) here instead. The global state is updated with the "op_is_observable" key
# because transforms may want special handling for observables of measurements.
if isinstance(inval, qml.operation.Operator):
# pylint: disable=protected-access
op_tracers, op_params = self._env[id(inval)]
self._state[-1]["is_measurement_obs"] = True
new_inval = inval._primitive.bind(*op_tracers, **op_params)
self._state[-1].pop("is_measurement_obs")
traced_invals.append(self.read_with_trace(new_inval))
continue
traced_invals.append(self.read_with_trace(inval))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is new. I hope that the comment I've left to explain why it was added is enough. Please let me know if you need more info.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.