Skip to content

Commit

Permalink
[runtime][hip] Propagate errors through semaphores (iree-org#18021)
Browse files Browse the repository at this point in the history
Adds proper handling of errors that occur when executing operations on
the device or when a semaphore fails in the wait list. These errors will
propagate to downstream semaphores that are in the signal list of the
operation.

This change includes some refactor the pending queue actions:
* Make the context hold a sticky status instead of just an status code.
* Remove the worker threads' "error" state. This can be handled by the
context's status.
* Remove "exit committed" thread state in favor of standard thread
joining.
* Make the "exit requested" thread state a separate boolean variable and
guard against submitting more work after an exit is requested. Also wait
on all work to complete before exiting worker threads, not just on the
currently ran actions.
* Make pending work items increment immediately when an action is
enqueued instead of when scheduled on the HIP stream. This is required
to properly count outstanding work.
* Remove and merge some of the redundant state for the worker and
completion threads.
* Remove reference counting from the pending queue actions context. It
has a clear owner, which is the device.
* Rework when the threads exit, which is pretty much only when exit is
requested and there is no more queued or executing actions. Errors don't
cause the threads to exit.

Here is not included moving the destruction and cleanup of actions from
the worker thread to the completion thread. This is an optimization and
code simplification that is now possible since we are not using HIP
stream callbacks, so we could do that right after an action completes.
Technically, actions get destroyed on the completion thread as well when
not on the happy path and actions fail.
  • Loading branch information
sogartar authored Aug 2, 2024
1 parent 242d69e commit 3fd336f
Show file tree
Hide file tree
Showing 14 changed files with 724 additions and 515 deletions.
3 changes: 3 additions & 0 deletions runtime/src/iree/base/internal/threading.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ void iree_thread_request_affinity(iree_thread_t* thread,
// This has no effect if the thread is not suspended.
void iree_thread_resume(iree_thread_t* thread);

// Blocks the current thread until |thread| has finished its execution.
void iree_thread_join(iree_thread_t* thread);

void iree_thread_yield(void);

#ifdef __cplusplus
Expand Down
6 changes: 6 additions & 0 deletions runtime/src/iree/base/internal/threading_darwin.c
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,12 @@ void iree_thread_resume(iree_thread_t* thread) {
IREE_TRACE_ZONE_END(z0);
}

void iree_thread_join(iree_thread_t* thread) {
IREE_TRACE_ZONE_BEGIN(z0);
pthread_join(thread->handle, NULL);
IREE_TRACE_ZONE_END(z0);
}

void iree_thread_yield(void) { sched_yield(); }

#endif // IREE_PLATFORM_APPLE
6 changes: 6 additions & 0 deletions runtime/src/iree/base/internal/threading_pthreads.c
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,12 @@ void iree_thread_resume(iree_thread_t* thread) {
IREE_TRACE_ZONE_END(z0);
}

void iree_thread_join(iree_thread_t* thread) {
IREE_TRACE_ZONE_BEGIN(z0);
pthread_join(thread->handle, NULL);
IREE_TRACE_ZONE_END(z0);
}

void iree_thread_yield(void) { sched_yield(); }

#endif // IREE_PLATFORM_*
6 changes: 6 additions & 0 deletions runtime/src/iree/base/internal/threading_win32.c
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,12 @@ void iree_thread_resume(iree_thread_t* thread) {
IREE_TRACE_ZONE_END(z0);
}

void iree_thread_join(iree_thread_t* thread) {
IREE_TRACE_ZONE_BEGIN(z0);
WaitForSingleObject(thread->handle, INFINITE);
IREE_TRACE_ZONE_END(z0);
}

void iree_thread_yield(void) { YieldProcessor(); }

#endif // IREE_PLATFORM_WINDOWS
19 changes: 19 additions & 0 deletions runtime/src/iree/hal/cts/cts_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <set>
#include <string>
#include <string_view>

#include "iree/base/api.h"
#include "iree/base/string_view.h"
Expand Down Expand Up @@ -276,6 +277,24 @@ class CTSTestBase : public BaseType, public CTSTestResources {
IREE_EXPECT_OK(iree_hal_semaphore_query(semaphore, &value));
EXPECT_EQ(expected_value, value);
}

// Check that a contains b.
// That is the codes of a and b are equal and the message of b is contained
// in the message of a.
void CheckStatusContains(iree_status_t a, iree_status_t b) {
EXPECT_EQ(iree_status_code(a), iree_status_code(b));
iree_allocator_t allocator = iree_allocator_system();
char* a_str = NULL;
iree_host_size_t a_str_length = 0;
EXPECT_TRUE(iree_status_to_string(a, &allocator, &a_str, &a_str_length));
char* b_str = NULL;
iree_host_size_t b_str_length = 0;
EXPECT_TRUE(iree_status_to_string(b, &allocator, &b_str, &b_str_length));
EXPECT_TRUE(std::string_view(a_str).find(std::string_view(b_str)) !=
std::string_view::npos);
iree_allocator_free(allocator, a_str);
iree_allocator_free(allocator, b_str);
}
};

} // namespace cts
Expand Down
72 changes: 72 additions & 0 deletions runtime/src/iree/hal/cts/semaphore_submission_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,78 @@ TEST_F(SemaphoreSubmissionTest, BatchWaitingOnSmallerValueBeforeSignaled) {
iree_hal_command_buffer_release(command_buffer);
}

// Submit an batch and check that the wait semaphore fails when the signal
// semaphore fails.
TEST_F(SemaphoreSubmissionTest, PropagateFailSignal) {
// signal-wait relation:
//
// semaphore1
//
// command_buffer
//
// semaphore2

iree_hal_command_buffer_t* command_buffer = CreateEmptyCommandBuffer();
iree_hal_semaphore_t* semaphore1 = CreateSemaphore();
iree_hal_semaphore_t* semaphore2 = CreateSemaphore();

// Submit the command buffer.
uint64_t semaphore1_wait_value = 1;
iree_hal_semaphore_list_t command_buffer_wait_list = {
/*count=*/1, &semaphore1, &semaphore1_wait_value};
uint64_t semaphore2_signal_value = 1;
iree_hal_semaphore_list_t command_buffer_signal_list = {
/*count=*/1, &semaphore2, &semaphore2_signal_value};
IREE_ASSERT_OK(iree_hal_device_queue_execute(
device_, IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/command_buffer_wait_list,
/*signal_semaphore_list=*/command_buffer_signal_list, 1, &command_buffer,
/*binding_tables=*/NULL));

iree_status_t status =
iree_make_status(IREE_STATUS_CANCELLED, "PropagateFailSignal test.");
std::thread signal_thread([&]() {
iree_hal_semaphore_fail(semaphore1, iree_status_clone(status));
});

iree_status_t wait_status =
iree_hal_semaphore_wait(semaphore2, semaphore2_signal_value,
iree_make_deadline(IREE_TIME_INFINITE_FUTURE));
EXPECT_EQ(iree_status_code(wait_status), IREE_STATUS_ABORTED);
uint64_t value = 1234;
iree_status_t query_status = iree_hal_semaphore_query(semaphore2, &value);
EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE);
CheckStatusContains(query_status, status);

signal_thread.join();
iree_hal_semaphore_release(semaphore1);
iree_hal_semaphore_release(semaphore2);
iree_hal_command_buffer_release(command_buffer);
iree_status_ignore(status);
iree_status_ignore(wait_status);
iree_status_ignore(query_status);
}

// Submit an invalid dispatch and check that the wait semaphore fails.
TEST_F(SemaphoreSubmissionTest, PropagateDispatchFailure) {
// signal-wait relation:
//
// semaphore1
//
// command_buffer
//
// semaphore2

// TODO (sogartar):
// I tried to add a kernel that stores into a null pointer or
// traps(aborts), but with HIP that causes the whole executable to abort,
// which is not what we want.
// We want a failure of the kernel launch or when waiting on the stream for
// the kernel to complete.
// This needs to be "soft" failure that result in a returned error from the
// underlying API call.
}

} // namespace iree::hal::cts

#endif // IREE_HAL_CTS_SEMAPHORE_SUBMISSION_TEST_H_
134 changes: 134 additions & 0 deletions runtime/src/iree/hal/cts/semaphore_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,140 @@ TEST_F(SemaphoreTest, SimultaneousMultiWaitAll) {
iree_hal_semaphore_release(semaphore2);
}

// Wait on a semaphore that is then failed.
TEST_F(SemaphoreTest, FailThenWait) {
iree_hal_semaphore_t* semaphore = this->CreateSemaphore();

iree_status_t status =
iree_make_status(IREE_STATUS_CANCELLED, "FailThenWait test.");
iree_hal_semaphore_fail(semaphore, iree_status_clone(status));

iree_status_t wait_status = iree_hal_semaphore_wait(
semaphore, 1, iree_make_deadline(IREE_TIME_INFINITE_FUTURE));
EXPECT_EQ(iree_status_code(wait_status), IREE_STATUS_ABORTED);
uint64_t value = 1234;
iree_status_t query_status = iree_hal_semaphore_query(semaphore, &value);
EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE);
CheckStatusContains(query_status, status);

iree_hal_semaphore_release(semaphore);
iree_status_ignore(status);
iree_status_ignore(wait_status);
iree_status_ignore(query_status);
}

// Wait on a semaphore that is then failed.
TEST_F(SemaphoreTest, WaitThenFail) {
iree_hal_semaphore_t* semaphore = this->CreateSemaphore();

// It is possible that the order becomes fail than wait.
// We assume that it is less likely since starting the thread takes time.
iree_status_t status =
iree_make_status(IREE_STATUS_CANCELLED, "WaitThenFail test.");
std::thread signal_thread(
[&]() { iree_hal_semaphore_fail(semaphore, iree_status_clone(status)); });

iree_status_t wait_status = iree_hal_semaphore_wait(
semaphore, 1, iree_make_deadline(IREE_TIME_INFINITE_FUTURE));
EXPECT_EQ(iree_status_code(wait_status), IREE_STATUS_ABORTED);
uint64_t value = 1234;
iree_status_t query_status = iree_hal_semaphore_query(semaphore, &value);
EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE);
CheckStatusContains(query_status, status);

signal_thread.join();
iree_hal_semaphore_release(semaphore);
iree_status_ignore(status);
iree_status_ignore(wait_status);
iree_status_ignore(query_status);
}

// Wait 2 semaphores then fail one of them.
TEST_F(SemaphoreTest, MultiWaitThenFail) {
iree_hal_semaphore_t* semaphore1 = this->CreateSemaphore();
iree_hal_semaphore_t* semaphore2 = this->CreateSemaphore();

// It is possible that the order becomes fail than wait.
// We assume that it is less likely since starting the thread takes time.
iree_status_t status =
iree_make_status(IREE_STATUS_CANCELLED, "MultiWaitThenFail test.");
std::thread signal_thread([&]() {
iree_hal_semaphore_fail(semaphore1, iree_status_clone(status));
});

iree_hal_semaphore_t* semaphore_array[] = {semaphore1, semaphore2};
uint64_t payload_array[] = {1, 1};
iree_hal_semaphore_list_t semaphore_list = {
IREE_ARRAYSIZE(semaphore_array),
semaphore_array,
payload_array,
};
iree_status_t wait_status = iree_hal_semaphore_list_wait(
semaphore_list, iree_make_deadline(IREE_TIME_INFINITE_FUTURE));
EXPECT_EQ(iree_status_code(wait_status), IREE_STATUS_ABORTED);
uint64_t value = 1234;
iree_status_t semaphore1_query_status =
iree_hal_semaphore_query(semaphore1, &value);
EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE);
CheckStatusContains(semaphore1_query_status, status);

// semaphore2 must not have changed.
uint64_t semaphore2_value = 1234;
IREE_ASSERT_OK(iree_hal_semaphore_query(semaphore2, &semaphore2_value));
EXPECT_EQ(semaphore2_value, 0);

signal_thread.join();
iree_hal_semaphore_release(semaphore1);
iree_hal_semaphore_release(semaphore2);
iree_status_ignore(status);
iree_status_ignore(wait_status);
iree_status_ignore(semaphore1_query_status);
}

// Wait 2 semaphores using iree_hal_device_wait_semaphores then fail
// one of them.
TEST_F(SemaphoreTest, DeviceMultiWaitThenFail) {
iree_hal_semaphore_t* semaphore1 = this->CreateSemaphore();
iree_hal_semaphore_t* semaphore2 = this->CreateSemaphore();

// It is possible that the order becomes fail than wait.
// We assume that it is less likely since starting the thread takes time.
iree_status_t status =
iree_make_status(IREE_STATUS_CANCELLED, "DeviceMultiWaitThenFail test.");
std::thread signal_thread([&]() {
iree_hal_semaphore_fail(semaphore1, iree_status_clone(status));
});

iree_hal_semaphore_t* semaphore_array[] = {semaphore1, semaphore2};
uint64_t payload_array[] = {1, 1};
iree_hal_semaphore_list_t semaphore_list = {
IREE_ARRAYSIZE(semaphore_array),
semaphore_array,
payload_array,
};
iree_status_t wait_status = iree_hal_device_wait_semaphores(
device_, IREE_HAL_WAIT_MODE_ANY, semaphore_list,
iree_make_deadline(IREE_TIME_INFINITE_FUTURE));
EXPECT_EQ(iree_status_code(wait_status), IREE_STATUS_ABORTED);
uint64_t value = 1234;
iree_status_t semaphore1_query_status =
iree_hal_semaphore_query(semaphore1, &value);
EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE);
CheckStatusContains(semaphore1_query_status, status);

// semaphore2 must not have changed.
uint64_t semaphore2_value = 1234;
IREE_ASSERT_OK(iree_hal_semaphore_query(semaphore2, &semaphore2_value));
EXPECT_EQ(semaphore2_value, 0);

signal_thread.join();
iree_hal_semaphore_release(semaphore1);
iree_hal_semaphore_release(semaphore2);
iree_status_ignore(status);
iree_status_ignore(wait_status);
iree_status_ignore(semaphore1_query_status);
}

} // namespace iree::hal::cts

#endif // IREE_HAL_CTS_SEMAPHORE_TEST_H_
12 changes: 12 additions & 0 deletions runtime/src/iree/hal/drivers/cuda/cts/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

unset(FILTER_TESTS)
string(APPEND FILTER_TESTS "SemaphoreTest.WaitThenFail:")
string(APPEND FILTER_TESTS "SemaphoreTest.FailThenWait:")
string(APPEND FILTER_TESTS "SemaphoreTest.MultiWaitThenFail:")
string(APPEND FILTER_TESTS "SemaphoreTest.DeviceMultiWaitThenFail:")
string(APPEND FILTER_TESTS "SemaphoreSubmissionTest.PropagateFailSignal:")
set(FILTER_TESTS_ARGS
"--gtest_filter=-${FILTER_TESTS}"
)

iree_hal_cts_test_suite(
DRIVER_NAME
cuda
Expand All @@ -19,6 +29,7 @@ iree_hal_cts_test_suite(
"\"PTXE\""
ARGS
"--cuda_use_streams=false"
${FILTER_TESTS_ARGS}
DEPS
iree::hal::drivers::cuda::registration
EXCLUDED_TESTS
Expand All @@ -44,6 +55,7 @@ iree_hal_cts_test_suite(
"\"PTXE\""
ARGS
"--cuda_use_streams=true"
${FILTER_TESTS_ARGS}
DEPS
iree::hal::drivers::cuda::registration
EXCLUDED_TESTS
Expand Down
Loading

0 comments on commit 3fd336f

Please sign in to comment.