Skip to content

Commit

Permalink
Expose configuration of exchange peer info timeout (#1103)
Browse files Browse the repository at this point in the history
Add support to configure the timeout of messages during exchange peer info via constructor argument and environment variable.

This feature is important in Distributed (see #994 for details), more so on larger scale clusters. A new configuration variable will be added to allow controlling this feature via Dask config and environment variables.

Authors:
  - Peter Andreas Entschev (https://github.com/pentschev)

Approvers:
  - Mads R. B. Kristensen (https://github.com/madsbk)

URL: #1103
  • Loading branch information
pentschev authored Jan 21, 2025
1 parent f43a7b8 commit 120ded5
Showing 1 changed file with 34 additions and 10 deletions.
44 changes: 34 additions & 10 deletions ucp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def _get_ctx():
return _ctx


async def exchange_peer_info(endpoint, msg_tag, ctrl_tag, listener, stream_timeout=5.0):
async def exchange_peer_info(
endpoint, msg_tag, ctrl_tag, listener, connect_timeout=5.0
):
"""Help function that exchange endpoint information"""

# Pack peer information incl. a checksum
Expand All @@ -50,20 +52,20 @@ async def exchange_peer_info(endpoint, msg_tag, ctrl_tag, listener, stream_timeo
if listener is True:
await asyncio.wait_for(
comm.stream_send(endpoint, my_info_arr, my_info_arr.nbytes),
timeout=stream_timeout,
timeout=connect_timeout,
)
await asyncio.wait_for(
comm.stream_recv(endpoint, peer_info_arr, peer_info_arr.nbytes),
timeout=stream_timeout,
timeout=connect_timeout,
)
else:
await asyncio.wait_for(
comm.stream_recv(endpoint, peer_info_arr, peer_info_arr.nbytes),
timeout=stream_timeout,
timeout=connect_timeout,
)
await asyncio.wait_for(
comm.stream_send(endpoint, my_info_arr, my_info_arr.nbytes),
timeout=stream_timeout,
timeout=connect_timeout,
)

# Unpacking and sanity check of the peer information
Expand All @@ -74,7 +76,7 @@ async def exchange_peer_info(endpoint, msg_tag, ctrl_tag, listener, stream_timeo

if expected_checksum != ret["checksum"]:
raise RuntimeError(
f'Checksum invalid! {hex(expected_checksum)} != {hex(ret["checksum"])}'
f"Checksum invalid! {hex(expected_checksum)} != {hex(ret['checksum'])}"
)

return ret
Expand Down Expand Up @@ -157,6 +159,7 @@ async def _listener_handler_coroutine(conn_request, ctx, func, endpoint_error_ha
msg_tag=msg_tag,
ctrl_tag=ctrl_tag,
listener=True,
connect_timeout=ctx.connect_timeout,
)
tags = {
"msg_send": peer_info["msg_tag"],
Expand Down Expand Up @@ -216,7 +219,9 @@ class ApplicationContext:
The context of the Asyncio interface of UCX.
"""

def __init__(self, config_dict={}, blocking_progress_mode=None):
def __init__(
self, config_dict={}, blocking_progress_mode=None, connect_timeout=None
):
self.progress_tasks = []

# For now, a application context only has one worker
Expand All @@ -230,6 +235,10 @@ def __init__(self, config_dict={}, blocking_progress_mode=None):
else:
self.blocking_progress_mode = True

if connect_timeout is None:
self.connect_timeout = float(os.environ.get("UCXPY_CONNECT_TIMEOUT", 5))
else:
self.connect_timeout = connect_timeout
if self.blocking_progress_mode:
self.epoll_fd = self.worker.init_blocking_progress_mode()
weakref.finalize(
Expand Down Expand Up @@ -330,6 +339,7 @@ async def create_endpoint(self, ip_address, port, endpoint_error_handling=True):
msg_tag=msg_tag,
ctrl_tag=ctrl_tag,
listener=False,
connect_timeout=self.connect_timeout,
)
tags = {
"msg_send": peer_info["msg_tag"],
Expand Down Expand Up @@ -548,6 +558,7 @@ class Endpoint:

def __init__(self, endpoint, ctx, tags=None):
self._ep = endpoint
self._uid = self._ep.handle
self._ctx = ctx
self._send_count = 0 # Number of calls to self.send()
self._recv_count = 0 # Number of calls to self.recv()
Expand All @@ -559,7 +570,7 @@ def __init__(self, endpoint, ctx, tags=None):
@property
def uid(self):
"""The unique ID of the underlying UCX endpoint"""
return self._ep.handle
return self._uid

def closed(self):
"""Is this endpoint closed?"""
Expand Down Expand Up @@ -896,7 +907,12 @@ def set_close_callback(self, callback_func):
# The following functions initialize and use a single ApplicationContext instance


def init(options={}, env_takes_precedence=False, blocking_progress_mode=None):
def init(
options={},
env_takes_precedence=False,
blocking_progress_mode=None,
connect_timeout=None,
):
"""Initiate UCX.
Usually this is done automatically at the first API call
Expand All @@ -914,6 +930,10 @@ def init(options={}, env_takes_precedence=False, blocking_progress_mode=None):
If None, blocking UCX progress mode is used unless the environment variable
`UCXPY_NON_BLOCKING_MODE` is defined.
Otherwise, if True blocking mode is used and if False non-blocking mode is used.
connect_timeout: float, optional
The timeout in seconds for exchanging endpoint information upon endpoint
establishment. If None, use the value from `UCXPY_CONNECT_TIMEOUT` if defined,
otherwise fallback to the default of 5 seconds.
"""
global _ctx
if _ctx is not None:
Expand All @@ -935,7 +955,11 @@ def init(options={}, env_takes_precedence=False, blocking_progress_mode=None):
logger.debug(
f"Ignoring environment {env_k}={env_v}; using option {k}={v}"
)
_ctx = ApplicationContext(options, blocking_progress_mode=blocking_progress_mode)
_ctx = ApplicationContext(
options,
blocking_progress_mode=blocking_progress_mode,
connect_timeout=connect_timeout,
)


def reset():
Expand Down

0 comments on commit 120ded5

Please sign in to comment.