Skip to content

Commit

Permalink
feat(client): support custom client session factory (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
Chaoses-Ib committed Apr 9, 2024
1 parent fff2d8c commit 6ee4dd9
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 44 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "comfy-script"
version = "0.4.3"
version = "0.4.4"
description = "A Python front end and library for ComfyUI"
readme = "README.md"
# ComfyUI: >=3.8
Expand All @@ -23,6 +23,8 @@ classifiers = [
client = [
# Already required by ComfyUI
"aiohttp",
# Used by aiohttp
"yarl",

# 1.5.9: https://github.com/erdewit/nest_asyncio/issues/87
"nest_asyncio ~= 1.0, >= 1.5.9",
Expand Down
62 changes: 46 additions & 16 deletions src/comfy_script/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,54 @@
from __future__ import annotations
import json
import os
from pathlib import PurePath
import sys
import traceback
from typing import Callable

import asyncio
import nest_asyncio
import aiohttp
from yarl import URL

nest_asyncio.apply()

endpoint = 'http://127.0.0.1:8188/'
client: Client | None = None
'''The global client object.'''

def set_endpoint(api_endpoint: str):
global endpoint
if not api_endpoint.startswith('http://'):
api_endpoint = 'http://' + api_endpoint
if not api_endpoint.endswith('/'):
api_endpoint += '/'
endpoint = api_endpoint
class Client:
def __init__(
self,
base_url: str | URL = 'http://127.0.0.1:8188/',
*,
session_factory: Callable[[], aiohttp.ClientSession] = aiohttp.ClientSession
):
'''
- `base_url`: The base URL of the ComfyUI server API.
e.g. `'http://127.0.0.1:8188/'`
- `session_factory`: A callable factory that returns a new [`aiohttp.ClientSession`](https://docs.aiohttp.org/en/latest/client_reference.html#aiohttp.ClientSession) object.
e.g. `lambda: aiohttp.ClientSession(auth=aiohttp.BasicAuth('Aladdin', 'open sesame'))`
'''
if base_url is None:
base_url = 'http://127.0.0.1:8188/'
elif not isinstance(base_url, str):
base_url = str(base_url)

if not base_url.startswith('http://'):
base_url = 'http://' + base_url
if not base_url.endswith('/'):
base_url += '/'
self.base_url = base_url

# Do not pass base_url to ClientSession, as it only supports absolute URLs without path part
self._session_factory = session_factory

def session(self) -> aiohttp.ClientSession:
'''Because `aiohttp.ClientSession` is not event-loop-safe (thread-safe), a new session should be created for each request to avoid potential issues. Also, `aiohttp.ClientSession` cannot be closed in a sync manner.'''
return self._session_factory()

async def response_to_str(response: aiohttp.ClientResponse) -> str:
try:
Expand Down Expand Up @@ -64,9 +94,9 @@ def node_info(node_class):
traceback.print_exc()
return out

async with aiohttp.ClientSession() as session:
# http://127.0.0.1:8188/object_info
async with session.get(f'{endpoint}object_info') as response:
async with client.session() as session:
# http://127.0.0.1:8188/object_info
async with session.get(f'{client.base_url}object_info') as response:
if response.status == 200:
return await response.json()
else:
Expand All @@ -81,9 +111,9 @@ async def _get_embeddings() -> list[str]:
embeddings = folder_paths.get_filename_list("embeddings")
return list(map(lambda a: os.path.splitext(a)[0], embeddings))

async with aiohttp.ClientSession() as session:
# http://127.0.0.1:8188/embeddings
async with session.get(f'{endpoint}embeddings') as response:
async with client.session() as session:
# http://127.0.0.1:8188/embeddings
async with session.get(f'{client.base_url}embeddings') as response:
if response.status == 200:
return await response.json()
else:
Expand All @@ -99,8 +129,8 @@ def default(self, o):
return super().default(o)

__all__ = [
'endpoint'
'set_endpoint',
'client',
'Client',
'_get_nodes_info',
'get_nodes_info',
'_get_embeddings',
Expand Down
52 changes: 28 additions & 24 deletions src/comfy_script/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
_client_id = str(uuid.uuid4())
_save_script_source = True

def load(comfyui: str | Path = None, args: ComfyUIArgs | None = None, vars: dict | None = None, watch: bool = True, save_script_source: bool = True):
def load(comfyui: str | Client | Path = None, args: ComfyUIArgs | None = None, vars: dict | None = None, watch: bool = True, save_script_source: bool = True):
'''
- `comfyui`: A URL of the ComfyUI server API, or a path to the ComfyUI directory, or `'comfyui'` to use the [`comfyui` package](https://github.com/comfyanonymous/ComfyUI/pull/298).
- `comfyui`: The base URL of the ComfyUI server API, or a `Client` object, or a path to the ComfyUI directory, or `'comfyui'` to use the [`comfyui` package](https://github.com/comfyanonymous/ComfyUI/pull/298).
If not specified, the following ones will be tried in order:
1. Local server API: http://127.0.0.1:8188/
Expand All @@ -30,7 +30,7 @@ def load(comfyui: str | Path = None, args: ComfyUIArgs | None = None, vars: dict
'''
asyncio.run(_load(comfyui, args, vars, watch, save_script_source))

async def _load(comfyui: str | Path = None, args: ComfyUIArgs | None = None, vars: dict | None = None, watch: bool = True, save_script_source: bool = True):
async def _load(comfyui: str | Client | Path = None, args: ComfyUIArgs | None = None, vars: dict | None = None, watch: bool = True, save_script_source: bool = True):
global _save_script_source, queue

_save_script_source = save_script_source
Expand All @@ -39,15 +39,17 @@ async def _load(comfyui: str | Path = None, args: ComfyUIArgs | None = None, var
if comfyui is None:
try:
nodes_info = await client._get_nodes_info()
if comfyui_server != client.endpoint:
print(f'ComfyScript: Using ComfyUI from {client.endpoint}')
if comfyui_base_url != client.client.base_url:
print(f'ComfyScript: Using ComfyUI from {client.client.base_url}')
except Exception as e:
# To avoid "During handling of the above exception, another exception occurred"
pass
if nodes_info is None:
start_comfyui(comfyui, args)
elif isinstance(comfyui, str) and (comfyui.startswith('http://') or comfyui.startswith('https://')):
client.set_endpoint(comfyui)
client.client = client.Client(comfyui)
elif isinstance(comfyui, client.Client):
client.client = comfyui
else:
start_comfyui(comfyui, args)

Expand Down Expand Up @@ -155,7 +157,7 @@ def to_argv(self) -> list[str]:
return self.argv

comfyui_started = False
comfyui_server = None
comfyui_base_url = None

def start_comfyui(comfyui: Path | str = None, args: ComfyUIArgs | None = None, *, no_server: bool = False, join_at_exit: bool = True, autonomy: bool = False):
'''
Expand All @@ -175,11 +177,11 @@ def start_comfyui(comfyui: Path | str = None, args: ComfyUIArgs | None = None, *
- `autonomy`: If enabled, currently, the server will not be started even if `no_server=False`.
'''
global comfyui_started, comfyui_server
if comfyui_started and (comfyui_server is not None or no_server):
global comfyui_started, comfyui_base_url
if comfyui_started and (comfyui_base_url is not None or no_server):
return
comfyui_started = False
comfyui_server = None
comfyui_base_url = None

if comfyui is None:
default_comfyui = Path(__file__).resolve().parents[5]
Expand Down Expand Up @@ -345,8 +347,8 @@ def enable_args_parsing_hook():
if not no_server:
threading.Thread(target=main.server.loop.run_until_complete, args=(main.server.publish_loop(),), daemon=True).start()

comfyui_server = f'http://127.0.0.1:{main.args.port}/'
client.set_endpoint(comfyui_server)
comfyui_base_url = f'http://127.0.0.1:{main.args.port}/'
client.client = client.Client(comfyui_base_url)
else:
if comfyui != 'comfyui':
print(f'ComfyScript: Importing ComfyUI from {comfyui}')
Expand Down Expand Up @@ -443,8 +445,8 @@ def __init__(self):
self.queue_remaining = 0

async def _get_history(self, prompt_id: str) -> dict | None:
async with aiohttp.ClientSession() as session:
async with session.get(f'{client.endpoint}history/{prompt_id}') as response:
async with client.client.session() as session:
async with session.get(f'{client.client.base_url}history/{prompt_id}') as response:
if response.status == 200:
json = await response.json()
# print(json)
Expand All @@ -455,8 +457,8 @@ async def _get_history(self, prompt_id: str) -> dict | None:
async def _watch(self):
while True:
try:
async with aiohttp.ClientSession() as session:
async with session.ws_connect(f'{client.endpoint}ws', params={'clientId': _client_id}) as ws:
async with client.client.session() as session:
async with session.ws_connect(f'{client.client.base_url}ws', params={'clientId': _client_id}) as ws:
self.queue_remaining = 0
executing = False
async for msg in ws:
Expand Down Expand Up @@ -582,15 +584,15 @@ async def _put(self, workflow: data.NodeOutput | Iterable[data.NodeOutput] | Wor
raise TypeError(f'ComfyScript: Invalid workflow type: {workflow}')
# print(prompt)

async with aiohttp.ClientSession() as session:
async with client.client.session() as session:
extra_data = {}
if _save_script_source:
extra_data = {
'extra_pnginfo': {
'ComfyScriptSource': source
}
}
async with session.post(f'{client.endpoint}prompt', json={
async with session.post(f'{client.client.base_url}prompt', json={
'prompt': prompt,
'extra_data': extra_data,
'client_id': _client_id,
Expand Down Expand Up @@ -648,8 +650,8 @@ def cancel_current(self):
'''Interrupt the current task'''
return asyncio.run(self._cancel_current())
async def _cancel_current(self):
async with aiohttp.ClientSession() as session:
async with session.post(f'{client.endpoint}interrupt', json={
async with client.client.session() as session:
async with session.post(f'{client.client.base_url}interrupt', json={
'client_id': _client_id,
}) as response:
if response.status != 200:
Expand All @@ -659,8 +661,8 @@ def cancel_remaining(self):
'''Clear the queue'''
return asyncio.run(self._cancel_remaining())
async def _cancel_remaining(self):
async with aiohttp.ClientSession() as session:
async with session.post(f'{client.endpoint}queue', json={
async with client.client.session() as session:
async with session.post(f'{client.client.base_url}queue', json={
'clear': True,
'client_id': _client_id,
}) as response:
Expand Down Expand Up @@ -754,8 +756,8 @@ def wait_result(self, output: data.NodeOutput) -> data.Result | None:
# def wait(self):
# return asyncio.run(self._wait())
# async def _wait(self):
# async with aiohttp.ClientSession() as session:
# async with session.ws_connect(f'{client.endpoint}ws?clientId={_client_id}') as ws:
# async with client.client.session() as session:
# async with session.ws_connect(f'{client.client.base_url}ws?clientId={_client_id}') as ws:
# async for msg in ws:
# if msg.type == aiohttp.WSMsgType.TEXT:
# msg = msg.json()
Expand Down Expand Up @@ -896,12 +898,14 @@ def __exit__(self, exc_type, exc_value, traceback):
queue = TaskQueue()

from .. import client
from ..client import Client
from . import nodes
from . import data
from .data import *

__all__ = [
'load',
'Client',
'ComfyUIArgs',
'start_comfyui',
'TaskQueue',
Expand Down
4 changes: 2 additions & 2 deletions src/comfy_script/runtime/data/Images.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
class ImageBatchResult(Result):
# TODO: Lazy cell
async def _get_image(self, image: dict) -> Image.Image | None:
async with aiohttp.ClientSession() as session:
async with session.get(f'{client.endpoint}view', params=image) as response:
async with client.client.session() as session:
async with session.get(f'{client.client.base_url}view', params=image) as response:
if response.status == 200:
return Image.open(io.BytesIO(await response.read()))
else:
Expand Down
2 changes: 1 addition & 1 deletion src/comfy_script/transpile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, workflow: str | dict, api_endpoint: str = None):
- `workflow`: Can be in web UI format or API format.
'''
if api_endpoint is not None:
client.set_endpoint(api_endpoint)
client.client = client.Client(api_endpoint)
self.nodes_info = client.get_nodes_info()

if isinstance(workflow, str):
Expand Down

0 comments on commit 6ee4dd9

Please sign in to comment.