Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TEST/GTEST: Added cuda gpu switching testing. #10388

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/gtest/common/mem_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ bool mem_buffer::is_mem_type_supported(ucs_memory_type_t mem_type)
mem_types.end();
}

void mem_buffer::set_device_context()
void mem_buffer::set_device_context(int device)
{
static __thread bool device_set = false;

Expand All @@ -179,7 +179,7 @@ void mem_buffer::set_device_context()

#if HAVE_CUDA
if (is_cuda_supported()) {
cudaSetDevice(0);
cudaSetDevice(device);
/* need to call free as context maybe lazily initialized when calling
* cudaSetDevice(0) but calling cudaFree(0) should guarantee context
* creation upon return */
Expand All @@ -189,7 +189,7 @@ void mem_buffer::set_device_context()

#if HAVE_ROCM
if (is_rocm_supported()) {
hipSetDevice(0);
hipSetDevice(device);
}
#endif

Expand Down
2 changes: 1 addition & 1 deletion test/gtest/common/mem_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class mem_buffer {
static bool is_gpu_supported();

/* set device context if compiled with GPU support */
static void set_device_context();
static void set_device_context(int device = 0);

/* returns whether ROCM device supports managed memory */
static bool is_rocm_managed_supported();
Expand Down
52 changes: 52 additions & 0 deletions test/gtest/ucp/test_ucp_mmap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ extern "C" {
#include <ucs/type/float8.h>
}

#if HAVE_CUDA
#include <cuda_runtime.h>
#endif

#include <cmath>
#include <list>

Expand Down Expand Up @@ -1248,3 +1252,51 @@ UCS_TEST_P(test_ucp_mmap_export, export_import) {
}

UCP_INSTANTIATE_TEST_CASE_GPU_AWARE(test_ucp_mmap_export)

#if HAVE_CUDA
class test_ucp_mmap_mgpu : public ucs::test {
};

UCS_TEST_F(test_ucp_mmap_mgpu, switch_gpu) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to add tests for transfer cuda_copy or cuda_ipc, mem type cuda or vmm/mallocasync:

  • buf1 on device1, buf2 on device2, copy happening under progress with context of unrelated device3

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, we will need more test cases.
The test in this PR is the simplest one. And it fails. Once it is fixed, we will add more scenarios.

if (!mem_buffer::is_mem_type_supported(UCS_MEMORY_TYPE_CUDA)) {
UCS_TEST_SKIP_R("cuda is not supported");
}

int num_devices;
ASSERT_EQ(cudaGetDeviceCount(&num_devices), cudaSuccess);

if (num_devices < 2) {
UCS_TEST_SKIP_R("less than two cuda devices available");
}

ucs::handle<ucp_config_t*> config;
UCS_TEST_CREATE_HANDLE(ucp_config_t*, config, ucp_config_release,
ucp_config_read, NULL, NULL);

ucs::handle<ucp_context_h> context;
ucp_params_t params;
params.field_mask = UCP_PARAM_FIELD_FEATURES;
params.features = UCP_FEATURE_TAG;
UCS_TEST_CREATE_HANDLE(ucp_context_h, context, ucp_cleanup, ucp_init,
&params, config.get());

int device;
ASSERT_EQ(cudaGetDevice(&device), cudaSuccess);
ASSERT_EQ(cudaSetDevice((device + 1) % num_devices), cudaSuccess);

const size_t size = 16;
mem_buffer buffer(size, UCS_MEMORY_TYPE_CUDA);

ASSERT_EQ(cudaSetDevice(device), cudaSuccess);

ucp_mem_map_params_t mem_map_params;
mem_map_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
UCP_MEM_MAP_PARAM_FIELD_LENGTH;
mem_map_params.address = buffer.ptr();
mem_map_params.length = size;

ucp_mem_h ucp_mem;
ASSERT_EQ(ucp_mem_map(context.get(), &mem_map_params, &ucp_mem), UCS_OK);
EXPECT_EQ(ucp_mem_unmap(context.get(), ucp_mem), UCS_OK);
}
#endif
Loading