Skip to content

Commit

Permalink
bugfix for latest jax
Browse files Browse the repository at this point in the history
  • Loading branch information
sokrypton authored Oct 9, 2022
1 parent 9494856 commit c90c4d3
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions colabdesign/shared/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ def clear_mem():
for obj_name in dir(module):
obj = getattr(module, obj_name)
if hasattr(obj, "cache_clear"):
obj.cache_clear()
try:
obj.cache_clear()
except:
pass
gc.collect()

def update_dict(D, *args, **kwargs):
Expand Down Expand Up @@ -105,4 +108,4 @@ def softmax(x, axis=-1):
return x / x.sum(axis,keepdims=True)

def categorical(p):
return (p.cumsum(-1) >= np.random.uniform(size=p.shape[:-1])[..., None]).argmax(-1)
return (p.cumsum(-1) >= np.random.uniform(size=p.shape[:-1])[..., None]).argmax(-1)

0 comments on commit c90c4d3

Please sign in to comment.