Skip to content

Commit

Permalink
refactor: change function _get_convol_img_fn for more clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
framunoz authored Jan 2, 2025
1 parent dadb470 commit 1030815
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions ot/bregman/_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,7 @@ def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False):
t2 = nx.linspace(0, 1, height, type_as=type_as)
Y2, X2 = nx.meshgrid(t2, t2)
M2 = -((X2 - Y2) ** 2) / reg

# As M1 and M2 are computed first, we can use them to compute the convolution in log-domain
def convol_imgs(log_imgs):
log_imgs = nx.logsumexp(M1[:, :, None] + log_imgs[None], axis=1)
log_imgs = nx.logsumexp(M2[:, :, None] + log_imgs.T[None], axis=1).T
return log_imgs


# If normal domain is selected, we can use M1 and M2 to compute the convolution
if not log_domain:
K1, K2 = nx.exp(M1), nx.exp(M2)
Expand All @@ -47,6 +41,13 @@ def convol_imgs(imgs):
kxy = nx.einsum("...ij,klj->kli", K2, kx)
return kxy

# Else, we can use M1 and M2 to compute the convolution in log-domain
else:
def convol_imgs(log_imgs):
log_imgs = nx.logsumexp(M1[:, :, None] + log_imgs[None], axis=1)
log_imgs = nx.logsumexp(M2[:, :, None] + log_imgs.T[None], axis=1).T
return log_imgs

return convol_imgs


Expand Down

0 comments on commit 1030815

Please sign in to comment.