Skip to content

Commit

Permalink
[MRG] Avoid changing precision in the backend (#572)
Browse files Browse the repository at this point in the history
* Avoid changing precision in the backend

* Update RELEASES.md
  • Loading branch information
kachayev authored Nov 10, 2023
1 parent a56e1b2 commit 91c67fb
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
- Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520)
- Handle documentation and warnings when integers are provided to (f)gw solvers based on cg (Issue #530, PR #559)
- Correct independence of `fgw_barycenters` to `init_C` and `init_X` (Issue #547, PR #566)
- Avoid precision change when computing norm using PyTorch backend (Discussion #570, PR #572)

## 0.9.1
*August 2023*
Expand Down
2 changes: 1 addition & 1 deletion ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1941,7 +1941,7 @@ def power(self, a, exponents):
return torch.pow(a, exponents)

def norm(self, a, axis=None, keepdims=False):
return torch.linalg.norm(a.double(), dim=axis, keepdims=keepdims)
return torch.linalg.norm(a, dim=axis, keepdims=keepdims)

def any(self, a):
return torch.any(a)
Expand Down
2 changes: 1 addition & 1 deletion test/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_empirical_bures_wasserstein_mapping_numerical_error_warning():


def test_bures_wasserstein_distance(nx):
ms, mt = np.array([0]), np.array([10])
ms, mt = np.array([0]).astype(np.float32), np.array([10]).astype(np.float32)
Cs, Ct = np.array([[1]]).astype(np.float32), np.array([[1]]).astype(np.float32)
msb, mtb, Csb, Ctb = nx.from_numpy(ms, mt, Cs, Ct)
Wb_log, log = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=True)
Expand Down

0 comments on commit 91c67fb

Please sign in to comment.