Skip to content

Commit

Permalink
[iree.build] Wire up out of process concurrency. (iree-org#19291)
Browse files Browse the repository at this point in the history
* Introduces an explicit thunk creation stage which gives a way to
create a fully remotable object.
* Reworks process concurrency to occupy a host thread in addition to a
sub-process, which keeps the task concurrency accounting simple and
makes errors propagate more easily.
* Adds a test action for invoking a thunk out of process.
* This is the boilerplate required while implementing a turbine AOT
export action.

Signed-off-by: Stella Laurenzo <stellaraccident@gmail.com>
  • Loading branch information
stellaraccident authored Nov 26, 2024
1 parent 53e9601 commit ef4ecf3
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 13 deletions.
1 change: 1 addition & 0 deletions compiler/bindings/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ SOURCES
net_actions.py
onnx_actions.py
target_machine.py
test_actions.py
)

add_mlir_python_modules(IREECompilerBuildPythonModules
Expand Down
59 changes: 46 additions & 13 deletions compiler/bindings/python/iree/build/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,17 @@ def __str__(self) -> str:
return self.value


class BuildAction(BuildDependency, abc.ABC):
"""An action that must be carried out."""
class BuildAction(BuildDependency):
"""An action that must be carried out.
This class is designed to be subclassed by concrete actions. In-process
only actions should override `_invoke`, whereas those that can be executed
out-of-process must override `_remotable_thunk`.
Note that even actions that are marked for `PROCESS` concurrency will
run on a dedicated thread within the host process. Only the `_remotable_thunk`
result will be scheduled out of process.
"""

def __init__(
self,
Expand All @@ -289,20 +298,43 @@ def __init__(
):
super().__init__(executor=executor, deps=deps)
self.desc = desc
self.concurrnecy = concurrency
self.concurrency = concurrency

def __str__(self):
return self.desc

def __repr__(self):
return f"Action[{type(self).__name__}]('{self.desc}')"

def invoke(self):
self._invoke()
def invoke(self, scheduler: "Scheduler"):
# Invoke is run within whatever in-process execution context was requested:
# - On the scheduler thread for NONE
# - On a worker thread for THREAD or PROCESS
# For PROCESS concurrency, we have to create a compatible invocation
# thunk, schedule that on the process pool and wait for it.
if self.concurrency == ActionConcurrency.PROCESS:
thunk = self._remotable_thunk()
fut = scheduler.process_pool_executor.submit(thunk)
fut.result()
else:
self._invoke()

@abc.abstractmethod
def _invoke(self):
...
self._remotable_thunk()()

def _remotable_thunk(self) -> Callable[[], None]:
"""Creates a remotable no-arg thunk that will execute this out of process.
This must return a no arg/result callable that can be pickled. While there
are various ways to ensure this, here are a few guidelines:
* Must be a type/function defined at a module level.
* Cannot be decorated.
* Must only contain attributes with the same constraints.
"""
raise NotImplementedError(
f"Action '{self}' does not implement remotable invocation"
)


class BuildContext(BuildDependency):
Expand Down Expand Up @@ -513,19 +545,20 @@ def _schedule_action(self, dep: BuildDependency):
if isinstance(dep, BuildAction):

def invoke():
dep.invoke()
dep.invoke(self)
return dep

print(f"Scheduling action: {dep}", file=self.stderr)
if dep.concurrnecy == ActionConcurrency.NONE:
if dep.concurrency == ActionConcurrency.NONE:
invoke()
elif dep.concurrnecy == ActionConcurrency.THREAD:
elif (
dep.concurrency == ActionConcurrency.THREAD
or dep.concurrency == ActionConcurrency.PROCESS
):
dep.start(self.thread_pool_executor.submit(invoke))
elif dep.concurrnecy == ActionConcurrency.PROCESS:
dep.start(self.process_pool_executor.submit(invoke))
else:
raise AssertionError(
f"Unhandled ActionConcurrency value: {dep.concurrnecy}"
f"Unhandled ActionConcurrency value: {dep.concurrency}"
)
else:
# Not schedulable. Just mark it as done.
Expand Down
31 changes: 31 additions & 0 deletions compiler/bindings/python/iree/build/test_actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Callable
from iree.build.executor import ActionConcurrency, BuildAction


class _ThunkTrampoline:
def __init__(self, thunk, args):
self.thunk = thunk
self.args = args

def __call__(self):
self.thunk(*self.args)


class ExecuteOutOfProcessThunkAction(BuildAction):
"""Executes a callback thunk with arguments.
Both the thunk and args must be pickleable.
"""

def __init__(self, thunk, args, concurrency=ActionConcurrency.PROCESS, **kwargs):
super().__init__(concurrency=concurrency, **kwargs)
self.trampoline = _ThunkTrampoline(thunk, args)

def _remotable_thunk(self) -> Callable[[], None]:
return self.trampoline
7 changes: 7 additions & 0 deletions compiler/bindings/python/test/build_api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,10 @@ if(IREE_INPUT_TORCH)
"mnist_builder_test.py"
)
endif()

iree_py_test(
NAME
concurrency_test
SRCS
"concurrency_test.py"
)
61 changes: 61 additions & 0 deletions compiler/bindings/python/test/build_api/concurrency_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import os
from pathlib import Path
import tempfile
import unittest

from iree.build import *
from iree.build.executor import BuildContext
from iree.build.test_actions import ExecuteOutOfProcessThunkAction


@entrypoint
def write_out_of_process_pid():
context = BuildContext.current()
output_file = context.allocate_file("pid.txt")
action = ExecuteOutOfProcessThunkAction(
_write_pid_file,
args=[output_file.get_fs_path()],
desc="Writing pid file",
executor=context.executor,
)
output_file.deps.add(action)
return output_file


def _write_pid_file(output_path: Path):
pid = os.getpid()
print(f"Running action out of process: pid={pid}")
output_path.write_text(str(pid))


class ConcurrencyTest(unittest.TestCase):
def setUp(self):
self._temp_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
self._temp_dir.__enter__()
self.output_path = Path(self._temp_dir.name)

def tearDown(self) -> None:
self._temp_dir.__exit__(None, None, None)

def testProcessConcurrency(self):
parent_pid = os.getpid()
print(f"Testing out of process concurrency: pid={parent_pid}")
iree_build_main(
args=["write_out_of_process_pid", "--output-dir", str(self.output_path)]
)
pid_file = (
self.output_path / "genfiles" / "write_out_of_process_pid" / "pid.txt"
)
child_pid = int(pid_file.read_text())
print(f"Got child pid={child_pid}")
self.assertNotEqual(parent_pid, child_pid)


if __name__ == "__main__":
unittest.main()

0 comments on commit ef4ecf3

Please sign in to comment.