Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/nod-ai/sharktank into users…
Browse files Browse the repository at this point in the history
…/Groverkss/dep-manage
  • Loading branch information
Groverkss committed Aug 22, 2024
2 parents 0230c0f + 66cfc17 commit 3553365
Show file tree
Hide file tree
Showing 31 changed files with 1,934 additions and 129 deletions.
12 changes: 10 additions & 2 deletions libshortfin/bindings/python/_shortfin/asyncio_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,16 @@ def call_later(
return handle

def call_exception_handler(self, context) -> None:
# TODO: Should route this to the central exception handler.
raise RuntimeError(f"Async exception on {self._worker}: {context}")
# TODO: Should route this to the central exception handler. Should
# also play with ergonomics of how the errors get reported in
# various contexts and optimize.
source_exception = context.get("exception")
if isinstance(source_exception, BaseException):
raise RuntimeError(
f"Async exception on {self._worker}): {source_exception}"
).with_traceback(source_exception.__traceback__)
else:
raise RuntimeError(f"Async exception on {self._worker}: {context}")

def _timer_handle_cancelled(self, handle):
# We don't do anything special: just skip it if it comes up.
Expand Down
310 changes: 241 additions & 69 deletions libshortfin/bindings/python/lib_ext.cc

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions libshortfin/bindings/python/lib_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <nanobind/nanobind.h>
#include <nanobind/operators.h>
#include <nanobind/stl/filesystem.h>
#include <nanobind/stl/function.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
Expand Down
16 changes: 14 additions & 2 deletions libshortfin/bindings/python/shortfin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@

# Most classes from the native "local" namespace are aliased to the top
# level of the public API.
CompletionEvent = _sfl.local.CompletionEvent
Device = _sfl.local.Device
Message = _sfl.local.Message
Node = _sfl.local.Node
Process = _sfl.local.Process
Program = _sfl.local.Program
ProgramModule = _sfl.local.ProgramModule
Queue = _sfl.local.Queue
QueueReader = _sfl.local.QueueReader
QueueWriter = _sfl.local.QueueWriter
Scope = _sfl.local.Scope
ScopedDevice = _sfl.local.ScopedDevice
CompletionEvent = _sfl.local.CompletionEvent
System = _sfl.local.System
SystemBuilder = _sfl.local.SystemBuilder
Worker = _sfl.local.Worker
Expand All @@ -26,11 +32,17 @@
from . import host

__all__ = [
"CompletionEvent",
"Device",
"Message",
"Node",
"Program",
"ProgramModule",
"Queue",
"QueueReader",
"QueueWriter",
"Scope",
"ScopedDevice",
"CompletionEvent",
"System",
"SystemBuilder",
"Worker",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def do_something(i, delay):
total_delay = 0.0
max_delay = 0.0
for i in range(20):
delay = random.random() * 2
delay = random.random() * 0.25
total_delay += delay
max_delay = max(max_delay, delay)
print("SCHEDULE", i)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@

lsys = sf.host.CPUSystemBuilder().create_system()

total_processes = 0
lock = threading.Lock()


def tick_total():
global total_processes
with lock:
total_processes += 1


class MyProcess(sf.Process):
def __init__(self, arg, **kwargs):
Expand All @@ -22,10 +31,11 @@ async def run(self):
print(f"[pid:{self.pid}] Hello async:", self.arg, self)
processes = []
if self.arg < 10:
await asyncio.sleep(0.3)
await asyncio.sleep(0.1)
processes.append(MyProcess(self.arg + 1, scope=self.scope).launch())
await asyncio.gather(*processes)
print(f"[pid:{self.pid}] Goodbye async:", self.arg, self)
tick_total()


async def main():
Expand All @@ -47,4 +57,6 @@ def create_worker(i):
return i


print("RESULT:", lsys.run(main()))
result = lsys.run(main())
assert result == 9, f"{result}"
assert total_processes == 105, f"{total_processes}"
83 changes: 83 additions & 0 deletions libshortfin/examples/python/async/queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import asyncio

import shortfin as sf

lsys = sf.host.CPUSystemBuilder().create_system()

received_payloads = []


class Message(sf.Message):
def __init__(self, payload):
super().__init__()
self.payload = payload

def __repr__(self):
return f"Message(payload='{self.payload}')"


class WriterProcess(sf.Process):
def __init__(self, queue, **kwargs):
super().__init__(**kwargs)
self.writer = queue.writer()

async def run(self):
print("Start writer")
counter = 0
while (counter := counter + 1) <= 500:
msg = Message(f"Msg#{counter:03}")
await self.writer(msg)
print(f"Wrote message: {counter}")
self.writer.close()


class ReaderProcess(sf.Process):
def __init__(self, queue, **kwargs):
super().__init__(**kwargs)
self.reader = queue.reader()

async def run(self):
count = 0
while message := await self.reader():
print(f"[pid={self.pid}] Received message:", message)
received_payloads.append(message.payload)
count += 1
# After 100 messages, let the writer get ahead of the readers.
# Ensures that backlog and async close with a backlog works.
if count == 100:
await asyncio.sleep(0.25)


async def main():
queue = lsys.create_queue("infeed")
main_scope = lsys.create_scope()
w1 = lsys.create_worker("w1")
w1_scope = lsys.create_scope(w1)
await asyncio.gather(
WriterProcess(queue, scope=main_scope).launch(),
# By having a reader on the main worker and a separate worker,
# we test both intra and inter worker future resolution, which
# take different paths internally.
ReaderProcess(queue, scope=main_scope).launch(),
ReaderProcess(queue, scope=w1_scope).launch(),
)


lsys.run(main())


# Validate.
# May have come in slightly out of order so sort.
received_payloads.sort()
expected_payloads = [f"Msg#{i:03}" for i in range(1, 501)]
expected_payloads.sort()

assert (
received_payloads == expected_payloads
), f"EXPECTED: {repr(expected_payloads)}\nACTUAL:{received_payloads}"
180 changes: 180 additions & 0 deletions libshortfin/examples/python/http/http_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import argparse
import asyncio
from contextlib import asynccontextmanager
import threading
import sys

from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse
import shortfin as sf
import uvicorn


class FastAPIResponder(sf.Message):
"""Bridge between FastAPI and shortfin that can be put on a queue and used to
send a response back at an arbitrary point.
This object is constructed in a FastAPI handler, capturing the current event loop
used by the web server. Then it can be put on a shortfin Queue and once within
a shortfin worker, an arbitrary worker can call `send_response` to send a simple
FastAPI response back to the webserver loop and onto the client.
"""

def __init__(self, request: Request):
super().__init__()
self.request = request
# Capture the running loop so that we can send responses back.
self._loop = asyncio.get_running_loop()
self.response = asyncio.Future(loop=self._loop)
self._responded = False
self._streaming_queue: asyncio.Queue | None = None
self.is_disconnected = False

def send_response(self, response: Response):
"""Sends a response back for this transaction.
This is intended for sending single part responses back. See
start_response() for sending back a streaming, multi-part response.
"""
assert not self._responded, "Response already sent"
if self._loop.is_closed():
raise IOError("Web server is shut down")
self._responded = True
self._loop.call_soon_threadsafe(self.response.set_result, response)

def start_response(self, **kwargs):
"""Starts a streaming response, passing the given kwargs to the
fastapi.responses.StreamingResponse constructor.
This is appropriate to use for generating a sparse response stream as is
typical of chat apps. As it will hop threads for each part, other means should
be used for bulk transfer (i.e. by scheduling on the webserver loop
directly).
"""
assert not self._responded, "Response already sent"
if self._loop.is_closed():
raise IOError("Web server is shut down")
self._responded = True
self._streaming_queue = asyncio.Queue()

async def gen():
while True:
if await self.request.is_disconnected():
self.is_disconnected = True
part = await self._streaming_queue.get()
if part is None:
break
yield part

def start():
response = StreamingResponse(gen(), **kwargs)
self.response.set_result(response)

self._loop.call_soon_threadsafe(start)

def stream_part(self, content: bytes | None):
"""Streams content to a response started with start_response().
Streaming must be ended by sending None.
"""
assert self._streaming_queue is not None, "start_response() not called"
if self._loop.is_closed():
raise IOError("Web server is shut down")
self._loop.call_soon_threadsafe(self._streaming_queue.put_nowait, content)


class System:
def __init__(self):
self.ls = sf.host.CPUSystemBuilder().create_system()
# TODO: Come up with an easier bootstrap thing than manually
# running a thread.
self.t = threading.Thread(target=lambda: self.ls.run(self.run()))
self.request_queue = self.ls.create_queue("request")
self.request_writer = self.request_queue.writer()

def start(self):
self.t.start()

def shutdown(self):
self.request_queue.close()

async def run(self):
print("*** Sytem Running ***")
request_reader = self.request_queue.reader()
while responder := await request_reader():
print("Got request:", responder)
# Can send a single response:
# request.send_response(JSONResponse({"answer": 42}))
# Or stream:
responder.start_response()
for i in range(20):
if responder.is_disconnected:
print("Cancelled!")
break
responder.stream_part(f"Iteration {i}\n".encode())
await asyncio.sleep(0.2)
else:
responder.stream_part(None)


@asynccontextmanager
async def lifespan(app: FastAPI):
system.start()
yield
print("Shutting down shortfin")
system.shutdown()


system = System()
app = FastAPI(lifespan=lifespan)


@app.get("/predict")
async def predict(request: Request):
transaction = FastAPIResponder(request)
system.request_writer(transaction)
return await transaction.response


def main(argv):
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument(
"--root-path",
type=str,
default=None,
help="Root path to use for installing behind path based proxy.",
)
parser.add_argument(
"--timeout-keep-alive", type=int, default=5, help="Keep alive timeout"
)
parser.add_argument(
"--testing-mock-service",
action="store_true",
help="Enable the mock testing service",
)
parser.add_argument(
"--device-uri", type=str, default="local-task", help="Device URI to serve on"
)

args = parser.parse_args(argv)

uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="debug",
timeout_keep_alive=args.timeout_keep_alive,
)


if __name__ == "__main__":
main(sys.argv[1:])
3 changes: 3 additions & 0 deletions libshortfin/examples/python/mobilenet_server/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.mlir
*.onnx
*.vmfb
Loading

0 comments on commit 3553365

Please sign in to comment.