Skip to content

Commit

Permalink
Unit tests to evaluate AtomicRef and fetch_add
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Dec 28, 2023
1 parent bd17740 commit 6497c26
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-FileCopyrightText: 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import dpnp
import pytest

import numba_dpex as dpex
import numba_dpex.experimental as dpex_exp
from numba_dpex.experimental.kernel_iface import AtomicRef
from numba_dpex.tests._helper import get_all_dtypes

list_of_supported_dtypes = get_all_dtypes(
no_bool=True, no_float16=True, no_none=True, no_complex=True
)


@pytest.fixture(params=list_of_supported_dtypes)
def input_arrays(request):
# The size of input and out arrays to be used
N = 10
a = dpnp.ones(N, dtype=request.param)
b = dpnp.zeros(N, dtype=request.param)
return a, b


def test_fetch_add(input_arrays):
@dpex_exp.kernel
def atomic_ref_kernel(a, b, ref_index):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=ref_index)
v.fetch_add(a[i])

a, b = input_arrays
ref_index = 0

dpex_exp.call_kernel(atomic_ref_kernel, dpex.Range(10), a, b, ref_index)

# Verify that `a` was accumulated at b[ref_index]
assert b[ref_index] == 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-FileCopyrightText: 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import dpnp
import pytest
from numba.core.errors import TypingError

import numba_dpex as dpex
import numba_dpex.experimental as dpex_exp
from numba_dpex.experimental.kernel_iface import AddressSpace, AtomicRef


def test_atomic_ref_compilation():
@dpex_exp.kernel
def atomic_ref_kernel(a, b):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=0, address_space=AddressSpace.GLOBAL)
v.fetch_add(a[i])

a = dpnp.ones(10)
b = dpnp.zeros(10)
try:
dpex_exp.call_kernel(atomic_ref_kernel, dpex.Range(10), a, b)
except Exception:
pytest.fail("Unexpected execution failure")


def test_atomic_ref_compilation_failure():
"""A negative test that verifies that a TypingError is raised if we try to
create an AtomicRef in the local address space from a global address space
ref.
"""

@dpex_exp.kernel
def atomic_ref_kernel(a, b):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=0, address_space=AddressSpace.LOCAL)
v.fetch_add(a[i])

a = dpnp.ones(10)
b = dpnp.zeros(10)

with pytest.raises(TypingError):
dpex_exp.call_kernel(atomic_ref_kernel, dpex.Range(10), a, b)

0 comments on commit 6497c26

Please sign in to comment.