Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: jnp.cumsum(np.arange(2**14)) gives segment fault. #19

Open
AlexanderMath opened this issue Sep 10, 2023 · 2 comments
Open

BUG: jnp.cumsum(np.arange(2**14)) gives segment fault. #19

AlexanderMath opened this issue Sep 10, 2023 · 2 comments
Labels
bug Something isn't working

Comments

@AlexanderMath
Copy link

AlexanderMath commented Sep 10, 2023

Description

Reproducer

import numpy as np 
import jax 
import jax.numpy as jnp 

# works 
#range = np.arange(2**13)
#print(np.cumsum(range))
#print(jax.jit(jnp.cumsum, backend="ipu")(range)) 

# gives segment fault
range = np.arange(2**14)
print(np.cumsum(range))
print(jax.jit(jnp.cumsum, backend="ipu")(range)) 

Output

$ python  reproduce.py 
[        0         1         3 ... 134176771 134193153 134209536]
Segmentation fault (core dumped)

Note:

  • initially suspected int32 overflow but 134M << 2**31

Meta comment. The reproducer took 2 hours to make because jnp.cumsum was used inside ~400 lines of code, and I wrongly assumed jnp.cumsum was unlikely to cause segment fault compared to: tesseleate-ipu, C code, usage of uint in C code, poplar simulation of uint64 in C code, passing from python to C code, index computations, ... . Would it be a lot of work to add automated testing on these basic (np, jnp) functions?

What jax/jaxlib version are you using?

0.3.16

Which accelerator(s) are you using?

IPU MK2

Additional System Info

No response

@AlexanderMath AlexanderMath added the bug Something isn't working label Sep 10, 2023
@AlexanderMath
Copy link
Author

Found same issue with jax.cumprod when trying to use log-tricks as temporary solution.

@AlexanderMath
Copy link
Author

AlexanderMath commented Sep 10, 2023

Here's a hacky temporary solution. It uses matrix multiplication to compute jnp.cumsum of 2**7 chunks, and then subsequently adds the correct offsets. Use with caution. >90% of time is spent adding the subsequent offsets.

import jax 
import jax.numpy as jnp 

def matmul_cumsum_jax(arr):
    return jnp.tril(jnp.ones((len(arr), len(arr)))) @ arr 

def cumsum_jax(arr):
    chunk_size = 2**7 
    original_shape = arr.shape 
    padding = chunk_size - (len(arr) % chunk_size) if len(arr) % chunk_size != 0 else 0
    arr = jnp.pad(arr, (0, padding))  
    num_chunks = -(-len(arr) // chunk_size) 
    chunks = arr.T.reshape(num_chunks, chunk_size) 
    chunks = jax.vmap(matmul_cumsum_jax)(chunks)
    offset = 0
    offsets = [offset]
    for i, chunk in enumerate(chunks):
        offset += chunk[-1]
        offsets.append(offset)
    chunks = jax.vmap(jax.lax.add, in_axes=(0,0))(chunks, jnp.array(offsets[:-1]))
    return jnp.concatenate(chunks).reshape(-1)[:original_shape[0]]

arange = np.arange(2**14)
arange = np.concatenate((np.zeros(1), np.diff(arange))).astype(np.int32)
true_indxs = np.cumsum(arange)
us_indxs = np.asarray(jax.jit(cumsum_jax, backend="ipu")(arange)).astype(np.int32)
print(true_indxs[::127])
print(us_indxs[::127])
print(np.max(np.abs(true_indxs - us_indxs)))
print(np.all(true_indxs==us_indxs))

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant