Skip to content

Commit

Permalink
feat: add functions to query device memory (PROOF-923) (#210)
Browse files Browse the repository at this point in the history
* support query device memory

* undo memory_utility change

* document function
  • Loading branch information
rnburn authored Jan 6, 2025
1 parent af4bf60 commit 3bc8138
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 0 deletions.
22 changes: 22 additions & 0 deletions sxt/base/device/property.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,26 @@ int get_cuda_version() noexcept {
}();
return version;
}

//--------------------------------------------------------------------------------------------------
// get_device_mem_info
//--------------------------------------------------------------------------------------------------
void get_device_mem_info(size_t& bytes_free, size_t& bytes_total) noexcept {
auto rcode = cudaMemGetInfo(&bytes_free, &bytes_total);
if (rcode != cudaSuccess) {
baser::panic("cudaMemGetInfo failed: {}", cudaGetErrorString(rcode));
}
}

//--------------------------------------------------------------------------------------------------
// get_total_device_memory
//--------------------------------------------------------------------------------------------------
size_t get_total_device_memory() noexcept {
static size_t res = []() noexcept {
size_t bytes_free, bytes_total;
get_device_mem_info(bytes_free, bytes_total);
return bytes_total;
}();
return res;
}
} // namespace sxt::basdv
15 changes: 15 additions & 0 deletions sxt/base/device/property.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
#pragma once

#include <cstddef>

#include "sxt/base/type/raw_stream.h"

namespace sxt::basdv {
Expand All @@ -33,4 +35,17 @@ int get_latest_cuda_version_supported_by_driver() noexcept;
// get_cuda_version
//--------------------------------------------------------------------------------------------------
int get_cuda_version() noexcept;

//--------------------------------------------------------------------------------------------------
// get_device_mem_info
//--------------------------------------------------------------------------------------------------
void get_device_mem_info(size_t& bytes_free, size_t& bytes_total) noexcept;

//--------------------------------------------------------------------------------------------------
// get_total_device_memory
//--------------------------------------------------------------------------------------------------
// Get the total amount of memory available for a single GPU device.
//
// Note: assumes each device has the same amount of memory
size_t get_total_device_memory() noexcept;
} // namespace sxt::basdv
8 changes: 8 additions & 0 deletions sxt/base/device/property.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,11 @@ TEST_CASE("we can get the version of the driver running") {
auto v2 = get_cuda_version();
REQUIRE(v1 >= v2);
}

TEST_CASE("we can query info about device memory") {
size_t bytes_free, bytes_total;
get_device_mem_info(bytes_free, bytes_total);
REQUIRE(0 < bytes_free);
REQUIRE(bytes_free <= bytes_total);
REQUIRE(bytes_total == get_total_device_memory());
}

0 comments on commit 3bc8138

Please sign in to comment.