diff --git a/sxt/base/device/property.cc b/sxt/base/device/property.cc index 63d7ff7f..e12eab6d 100644 --- a/sxt/base/device/property.cc +++ b/sxt/base/device/property.cc @@ -68,4 +68,26 @@ int get_cuda_version() noexcept { }(); return version; } + +//-------------------------------------------------------------------------------------------------- +// get_device_mem_info +//-------------------------------------------------------------------------------------------------- +void get_device_mem_info(size_t& bytes_free, size_t& bytes_total) noexcept { + auto rcode = cudaMemGetInfo(&bytes_free, &bytes_total); + if (rcode != cudaSuccess) { + baser::panic("cudaMemGetInfo failed: {}", cudaGetErrorString(rcode)); + } +} + +//-------------------------------------------------------------------------------------------------- +// get_total_device_memory +//-------------------------------------------------------------------------------------------------- +size_t get_total_device_memory() noexcept { + static size_t res = []() noexcept { + size_t bytes_free, bytes_total; + get_device_mem_info(bytes_free, bytes_total); + return bytes_total; + }(); + return res; +} } // namespace sxt::basdv diff --git a/sxt/base/device/property.h b/sxt/base/device/property.h index 0710f21e..647d653a 100644 --- a/sxt/base/device/property.h +++ b/sxt/base/device/property.h @@ -16,6 +16,8 @@ */ #pragma once +#include + #include "sxt/base/type/raw_stream.h" namespace sxt::basdv { @@ -33,4 +35,17 @@ int get_latest_cuda_version_supported_by_driver() noexcept; // get_cuda_version //-------------------------------------------------------------------------------------------------- int get_cuda_version() noexcept; + +//-------------------------------------------------------------------------------------------------- +// get_device_mem_info +//-------------------------------------------------------------------------------------------------- +void get_device_mem_info(size_t& bytes_free, size_t& bytes_total) noexcept; + +//-------------------------------------------------------------------------------------------------- +// get_total_device_memory +//-------------------------------------------------------------------------------------------------- +// Get the total amount of memory available for a single GPU device. +// +// Note: assumes each device has the same amount of memory +size_t get_total_device_memory() noexcept; } // namespace sxt::basdv diff --git a/sxt/base/device/property.t.cc b/sxt/base/device/property.t.cc index dccf2c39..17dec34e 100644 --- a/sxt/base/device/property.t.cc +++ b/sxt/base/device/property.t.cc @@ -32,3 +32,11 @@ TEST_CASE("we can get the version of the driver running") { auto v2 = get_cuda_version(); REQUIRE(v1 >= v2); } + +TEST_CASE("we can query info about device memory") { + size_t bytes_free, bytes_total; + get_device_mem_info(bytes_free, bytes_total); + REQUIRE(0 < bytes_free); + REQUIRE(bytes_free <= bytes_total); + REQUIRE(bytes_total == get_total_device_memory()); +}