Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] OpenXLA PJRT plugin #33

Merged
merged 3 commits into from
Feb 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 221 additions & 0 deletions rfcs/20230123-pjrt-plugin.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
## RFC: OpenXLA PJRT Plugin

| Status | Proposed |
| :------------ | :------------------------------------------------------ |
| **RFC #** | [33](https://github.com/openxla/community/pull/33) |
| **Author(s)** | Skye Wanderman-Milne (skyewm@google.com), Jieying Luo (jieying@google.com), Jacques Pienaar (jpienaar@google.com) |
| **Sponsor** | Stella Laurenzo (laurenzo@google.com), James Rubin (jamesrubin@google.com) |
| **Updated** | 2023-01-23 |

## Objective

* Framework integration of a packaged compiler and runtime solution;

## Proposal

* Adopt PJRT as the supported device plugin mechanism for OpenXLA;
* Create new repo openxla/openxla-pjrt-plugin for the OpenXLA PJRT plugin;

## Background: PJRT

[PJRT](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h)
is a uniform Device API that we want to add to the OpenXLA ecosystem. The long
term vision for PJRT is that: (1) frameworks (TensorFlow, PyTorch, JAX, etc.)
will call PJRT, which has device-specific implementations that are opaque to the
frameworks; (2) each device focuses on implementing PJRT APIs, and can be opaque
to the frameworks.

![PJRT plugin integrating OpenXLA into ML frameworks](20230123-pjrt-plugin/frameworks.png)
<p align = "center"> PJRT provides a platform independent interfacing for
compilers and corresponding runtimes. </p>

PJRT API will provide an easy interface with which frameworks can integrate a
packaged compiler and runtime solution. It will be the supported interface that
will be used by TensorFlow and JAX for all compiler and runtime integration. And
as such it will be easy for other compilers and runtimes that implement the PJRT
interface to integrate with these systems.

## PJRT plugin mechanism goal

The PJRT plugin mechanism should support the following features:

* Different devices (e.g. TPU, GPU) can have different implementations
(through
[PJRT C API interface](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h)).
* Registration of multiple PJRT plugins.
* Loading multiple PJRT plugins (e.g. both CPU and TPU) in the same process.
* Passing configuration values from client (e.g. JAX lib) or from json files
provided by the plugin (default configs).
* Plugin discovery and choosing which plugins to load.

## High level plugin structure

![High-level plugin structure](20230123-pjrt-plugin/plugin-structure.png)

## Current Status

As of Dec 14, 2022

* [LoadPjrtPlugin(plugin_name, library_path)](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/stream_executor/tpu/pjrt_api.cc#L73)
can be used to load a PJRT plugin. We also provide a Python method
[load_pjrt_plugin](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla.cc#L329)
binded to it.
* [GetCApiClient(plugin_name)](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc#L1150)
can be used to create a PJRT client. We also provide a Python method
[get_c_api_client](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla.cc#L391)
binded to it.

## Design Considerations

### <a name="heading=h.782ksg6rl5bj"></a> Loading plugins

We will provide a low level C++ method to load a PJRT plugin, which takes
`library_path` and `config_values` as inputs. A python method `load_pjrt_plugin`
binded to it will be provided as well.

```c++
Status LoadPjrtPlugin(string library_path, map<string, string> config_values) {
library_handle = dlopen(library_path);
function_prt = dlsym(library_handle, "GetPjrtApi");
if (function_prt != nullptr) {
PJRT_Api* api = function_prt();
plugin_name = parse(library_path);
PluginInfo plugin_info(api, config_values);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to check if plugin_name is already loaded and if so, do an early return.
I wouldn't return an error. Allowing to load the same pluging many times isn't wrong.

This will also allows function_prt() to do all the initialization that it needs. (related to some open questions bellow)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! Edited.

global_pjrt_plugin_map[plugin_name] = plugin_info;
}
}
```

* `GetPjrtApi` returns a `PJRT_Api*` pointer that contains the implementation
of the C APIs.
* `global_pjrt_plugin_map` is a global `map<plugin_name, PluginInfo>` with the
same lifetime as the program
* `plugin_name` has the same meaning as `platform_name` in JAX.
* To be able to store the config values, and be future-proofed for other
information such as version, we propose to use a class PluginInfo. This
class is immutable after constructed, and contains getter method for
`PJRT_Api*` and config values.

### Discovering and automatic loading plugins

To allow framework users to pip-install plugins without requiring further code
changes, we'll implement a Python discovery mechanism to automatically find and
load plugins. The plugin discovery will be based on the
[naming convention of the Python module](https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-naming-convention),
and full paths set in the environment variable `PJRT_PLUGIN_LIBRARY_PATH` which
is added to allow users to manually specify directories and/or .so files). For
modules found, it will be imported.

* Tentative naming convention for .so/json files:
"pjrt-plugin-<plugin_name>.so" or "pjrt-plugin-<plugin_name>.json".

There are two options to automatically load plugins (decision has not been made
yet):

Option 1: Plugin module updates PJRT_PLUGIN_LIBRARY_PATH on import. A python
method `load_pjrt_plugins` will be added to discover .so/json files related to
PJRT plugins and load them.

```python
def load_pjrt_plugins():
for directory in env[PJRT_PLUGIN_LIBRARY_PATH]:
if json file:
Library_path, config_values = parse json file
load_pjrt_plugin(library_path, config_values)
elif .so file:
load_pjrt_plugin(.so file path)
```

Option 2: Plugin module is responsible for calling
[load_pjrt_plugin](#heading=h.782ksg6rl5bj) with its default options on import.

Open questions:

* Are there requirements about what file should not be loaded?

### <a name="heading=h.396bmv8gkskz"></a> Create PJRT client(s)

Frameworks decide which PJRT clients to create and use. For example, a framework
can create PJRT clients for all the plugins loaded. It can also choose to only
create a subset of PJRT clients based on some priorities rules. It can also come
from the user configuration.

Two python methods binding to C++ methods will be provided to facilitate
creating PJRT clients:

1. `get_loaded_plugin_names()` which gets all loaded plugin names from
`global_pjrt_plugin_map`.
2. `create_pjrt_client(plugin_name)` which creates a PJRT C API client (similar
to
[GetCApiClient](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc#L1150)).
It will retrieve the PluginInfo stored in `global_pjrt_plugin_map` and run
`PJRT_Client_Create` (see
[Config values for creating PJRT client(s)](#heading=h.bjuf0soco0sj)
section).

Open questions:

* What about plugin initialization that are not related to creating a PJRT
client? For example, these two functions in
[tpu_initializer_helper.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.cc#L269-L270)
should be run when initializing TPU. Currently they are called every time a
PJRT TPU client is created. Shall we add another method InitializePlugin and
only run it once? Alternatively, the plugin can implement it in
`PJRT_Client_Create` and run it once in the first time a client was created.
* Do we want to create PJRT clients for every plugin that is found? Will that
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't do that automatically as this will increase some resource utilization even if the end user doesn't want to use it.
I would let frameworks decide which behavior they want. I wouldn't impose that decision at the PJRT level.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Updated the text accordingly.

involve some initialization for the device which should not run multiple
times?
* We may need to store loaded PluginInfo in a nested map `map<device_type,
<plugin_name, PluginInfo>>` if we want to know what plugins are
available for a device (e.g. current PJRT GPU and IREE GPU) and only
create one PJRT client per device.

### <a name="heading=h.bjuf0soco0sj"></a> Config values for creating PJRT client(s)

GetCApiClient will be changed to take a map of config values. This map can be
passed to `PJRT_Client_Create` through
[PJRT_NamedValue](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h#L210).
The default config values can (1) come from the json file which will be stored
in PluginInfo, or (2) plugin implementation during import.

### Framework integration

All the methods mentioned above can be reused across frameworks. For example,
JAX lib can integrate as follows:

* Call [load_pjrt_plugins](#heading=h.782ksg6rl5bj) when
[initializing backends](https://github.com/google/jax/blob/a66b3dcdd378b723e275a19258de826260b4c83e/jax/_src/lib/xla_bridge.py#L381).
* Call [get_loaded_plugin_names](#heading=h.396bmv8gkskz) to get loaded PJRT
`plugin_name`, have some framework specific logics to decide whether to call
[create_pjrt_client](#heading=h.396bmv8gkskz) to create the PJRT client.

```python
def create_pjrt_clients():
loaded_plugin_names = get_loaded_plugin_names()
for plugin_name in loaded_plugin_names:
# Framework specific logics to decide whether to create
if should_create_pjrt_client(plugin_name):
pjrt_client = create_pjrt_client(plugin_name)
```

For TensorFlow, discovery will be added to
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if 2 frameworks create to clients for the same device. Is this supported? If so, it would be great to specify it.
If this isn't supported, it will be harder to have in the same python script different frameworks. So supporting multiple clients would be great.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is up to the specific hardware and plugin implementation. If the hardware only allows exclusive access, then the software will abide by that constraint. Otherwise, some plugins may use the hardware in an exclusive way (ie. Allocate all memory). The openxla plugin that we envision now will default to allowing multiple clients and supporting on demand memory allocation (with a possible option for more greedy access as an optimization for cases that need it).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sound good. Thanks.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is more to this than greedy access as an optimisation. A Graphcore IPU can only be owned by a single context on the host. So two processes, or indeed two clients in a single process, can't share an IPU.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that becomes a restriction to use of a Graphcore IPU then -- the PJRT API layer isn't going to do any kind of virtualization or remoting. If a software mechanism is needed to arbitrate multi-party access to a device, then that would be up to the device implementation.

Side note: this is currently an issue when using Jax by default with the XLA GPU backend as it allocates all memory, effectively making it impossible to share. There are environment variable workarounds to cause it to dynamically allocate.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that becomes a platform specific restriction, and I don't want a virtualisation layer. My concern is this becomes an implicit assumption in all users of this API.

The openxla plugin that we envision now will default to allowing multiple clients

My only point being that greedy or exclusive access isn't necessarily an optimisation, it can be a requirement (there's a reason in silicon for it).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see us doing anything in the software stack that makes multi-tenancy either harder or easier, but we will probably seek to make the default openxla implementation more user friendly on this front by default as it is a frequent pain point.

There really should be one Client per process, and if there can only be one Client per system then that would limit to only launching one process for that category of devices.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a paragraph to summarize this discussion thread.

[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/api_template.__init__.py#L142),
and loading PJRT plugins will be added to
[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc#L124).
Depending on the plugin, the client can be created implicitly in the first time
it is used, or created in a plugin custom op kernel. Created PJRT clients will
be saved in the
[global TensorFlow ResourceManager](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/tfrt/common/pjrt_util.h#L26).

For more information about PJRT, please consult the PJRT
[C header](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h),
[C++ header](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/pjrt_client.h),
[plugin mechanism design doc](https://docs.google.com/document/d/1FieBwiflD6DT9Z38wFIrnNxrUto1AOP3UuibQ4vpvEQ/edit)
and
[integration guide](https://docs.google.com/document/d/1EL62DkASMFZbk89K6MIPNGUFBTbkC21v5-uL5fB8VsM/edit).

## Initial contribution

We will use the plugin in
[IREE samples](https://github.com/iree-org/iree-samples/tree/main/pjrt-plugin)
to seed the repo and build full PJRT support as part of OpenXLA project.
Binary file added rfcs/20230123-pjrt-plugin/frameworks.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added rfcs/20230123-pjrt-plugin/plugin-structure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.