You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Perhaps it's a bit impolite, but this code is truly amazing! It still outperforms the argpartition specifically designed for top-k problems in the latest numpy tests!
For anyone finding a batched top-k code snippet, I slightly modify it to:
@nb.njit(cache=True)deffast_arg_top_k_batch(array, N, K):
""" Gets the indices of the top k values in an (1-D) array. * NOTE: The returned indices are not sorted based on the top values. """batch_sorted_indices=np.zeros((N, K,), dtype=FLOAT_TYPE)
minimum_index=0minimum_index_value=0foriinrange(N):
arr=array[i]
sorted_indices=batch_sorted_indices[i]
forvalueinarr:
ifvalue>minimum_index_value:
sorted_indices[minimum_index] =valueminimum_index=sorted_indices.argmin()
minimum_index_value=sorted_indices[minimum_index]
# FLOAT_BUFFER = np.finfo(FLOAT_TYPE).resolution# In some situations, because of different resolution you get k-1 results - this is to avoid that!minimum_index_value-=FLOAT_BUFFERreturn (array>=minimum_index_value).nonzero()
Indices of 2d array is a little bit troublesome, so I simply return nonzero() result(or just the boolean mask)
it's also much faster:
Thanks for your work :) 👍
The text was updated successfully, but these errors were encountered:
Perhaps it's a bit impolite, but this code is truly amazing! It still outperforms the argpartition specifically designed for top-k problems in the latest numpy tests!
For anyone finding a batched top-k code snippet, I slightly modify it to:
Indices of 2d array is a little bit troublesome, so I simply return nonzero() result(or just the boolean mask)
it's also much faster:
Thanks for your work :) 👍
The text was updated successfully, but these errors were encountered: