From d82a1e973ce822d05834ae597694f0f038095b7d Mon Sep 17 00:00:00 2001 From: Jesper Stemann Andersen Date: Sun, 15 Dec 2024 23:54:41 +0100 Subject: [PATCH 1/2] Added mlx_device_get_index Also exposed mlx_device_equal. --- mlx/c/device.cpp | 13 +++++++++++++ mlx/c/device.h | 8 ++++++++ 2 files changed, 21 insertions(+) diff --git a/mlx/c/device.cpp b/mlx/c/device.cpp index f1922a0..77bc1d3 100644 --- a/mlx/c/device.cpp +++ b/mlx/c/device.cpp @@ -43,6 +43,15 @@ extern "C" int mlx_device_set(mlx_device* dev, const mlx_device src) { return 0; } +extern "C" int mlx_device_get_index(mlx_device dev) { + try { + return mlx_device_get_(dev).index; + } catch (std::exception& e) { + mlx_error(e.what()); + return 0; // DEBUG: could have a specific value + } +} + extern "C" mlx_device_type mlx_device_get_type(mlx_device dev) { try { return mlx_device_type_to_c(mlx_device_get_(dev).type); @@ -51,9 +60,11 @@ extern "C" mlx_device_type mlx_device_get_type(mlx_device dev) { return MLX_CPU; // DEBUG: could have a specific value } } + extern "C" bool mlx_device_equal(mlx_device lhs, mlx_device rhs) { return mlx_device_get_(lhs) == mlx_device_get_(rhs); } + extern "C" int mlx_get_default_device(mlx_device* dev) { try { mlx_device_set_(*dev, mlx::core::default_device()); @@ -63,6 +74,7 @@ extern "C" int mlx_get_default_device(mlx_device* dev) { return 1; } } + extern "C" int mlx_set_default_device(mlx_device dev) { try { mlx::core::set_default_device(mlx_device_get_(dev)); @@ -72,6 +84,7 @@ extern "C" int mlx_set_default_device(mlx_device dev) { } return 0; } + extern "C" int mlx_device_free(mlx_device dev) { try { mlx_device_free_(dev); diff --git a/mlx/c/device.h b/mlx/c/device.h index f97a0ec..c94af90 100644 --- a/mlx/c/device.h +++ b/mlx/c/device.h @@ -48,6 +48,14 @@ int mlx_device_set(mlx_device* dev, const mlx_device src); * Get device description. */ int mlx_device_tostring(mlx_string* str, mlx_device dev); +/** + * Check if devices are the same. + */ +bool mlx_device_equal(mlx_device lhs, mlx_device rhs); +/** + * Returns the index of the device. + */ +int mlx_device_get_index(mlx_device dev); /** * Returns the type of the device. */ From 028977ea4dbec3b1c412d30ffbd8ffc1c5f4da80 Mon Sep 17 00:00:00 2001 From: Jesper Stemann Andersen Date: Mon, 16 Dec 2024 10:18:34 +0100 Subject: [PATCH 2/2] Also added mlx_stream_get_index --- mlx/c/stream.cpp | 8 ++++++++ mlx/c/stream.h | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/mlx/c/stream.cpp b/mlx/c/stream.cpp index 1749a19..e7ed3cc 100644 --- a/mlx/c/stream.cpp +++ b/mlx/c/stream.cpp @@ -62,6 +62,14 @@ extern "C" int mlx_stream_get_device(mlx_device* dev, mlx_stream stream) { return 1; } } +extern "C" int mlx_stream_get_index(mlx_stream stream) { + try { + return mlx_stream_get_(stream).index; + } catch (std::exception& e) { + mlx_error(e.what()); + return 0; // DEBUG: could have a specific value + } +} extern "C" int mlx_synchronize(mlx_stream stream) { try { mlx::core::synchronize(mlx_stream_get_(stream)); diff --git a/mlx/c/stream.h b/mlx/c/stream.h index 6af4f58..6aab0c1 100644 --- a/mlx/c/stream.h +++ b/mlx/c/stream.h @@ -53,6 +53,10 @@ bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs); * Return the device of the stream. */ int mlx_stream_get_device(mlx_device* dev, mlx_stream stream); +/** + * Return the index of the stream. + */ +int mlx_stream_get_index(mlx_stream stream); /** * Synchronize with the provided stream. */