diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py index 412cca9e5..6f0de8330 100644 --- a/ot/bregman/_convolutional.py +++ b/ot/bregman/_convolutional.py @@ -246,11 +246,6 @@ def _convolutional_barycenter2d_log( A = list_to_array(A) nx = get_backend(A) - if nx.__name__ in ("jax", "tf"): - raise NotImplementedError( - "Log-domain functions are not yet implemented" - " for Jax and TF. Use numpy or torch arrays instead." - ) n_hists, width, height = A.shape @@ -483,11 +478,7 @@ def _convolutional_barycenter2d_debiased_log( A = list_to_array(A) n_hists, width, height = A.shape nx = get_backend(A) - if nx.__name__ in ("jax", "tf"): - raise NotImplementedError( - "Log-domain functions are not yet implemented" - " for Jax and TF. Use numpy or torch arrays instead." - ) + if weights is None: weights = nx.ones((n_hists,), type_as=A) / n_hists else: