diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py index 11a1c5903..d4ad92466 100644 --- a/ot/bregman/_convolutional.py +++ b/ot/bregman/_convolutional.py @@ -21,7 +21,9 @@ def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False): - """Return the convolution operator for 2D images. The function constructed is equivalent to blurring on horizontal then vertical directions.""" + """Return the convolution operator for 2D images. + + The function constructed is equivalent to blurring on horizontal then vertical directions.""" t1 = nx.linspace(0, 1, width, type_as=type_as) Y1, X1 = nx.meshgrid(t1, t1) M1 = -((X1 - Y1) ** 2) / reg @@ -30,11 +32,13 @@ def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False): Y2, X2 = nx.meshgrid(t2, t2) M2 = -((X2 - Y2) ** 2) / reg + # As M1 and M2 are computed first, we can use them to compute the convolution in log-domain def convol_imgs(log_imgs): log_imgs = nx.logsumexp(M1[:, :, None] + log_imgs[None], axis=1) log_imgs = nx.logsumexp(M2[:, :, None] + log_imgs.T[None], axis=1).T return log_imgs + # If normal domain is selected, we can use M1 and M2 to compute the convolution if not log_domain: K1, K2 = nx.exp(M1), nx.exp(M2)