From c84c71dac9d43a1a335fa7d150b6aa166091e1a5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Jun 2024 20:44:51 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/autoembedder/evaluator.py | 8 +++++--- src/autoembedder/learner.py | 19 ++++++++++++------- src/autoembedder/model.py | 8 +++++--- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/autoembedder/evaluator.py b/src/autoembedder/evaluator.py index fde110d..2f2b7ca 100644 --- a/src/autoembedder/evaluator.py +++ b/src/autoembedder/evaluator.py @@ -37,9 +37,11 @@ def _predict( device = torch.device( "cuda" if torch.cuda.is_available() - else "mps" - if torch.backends.mps.is_available() and parameters.get("use_mps", False) - else "cpu" + else ( + "mps" + if torch.backends.mps.is_available() and parameters.get("use_mps", False) + else "cpu" + ) ) with torch.no_grad(): diff --git a/src/autoembedder/learner.py b/src/autoembedder/learner.py index 3bc5610..cf2533b 100644 --- a/src/autoembedder/learner.py +++ b/src/autoembedder/learner.py @@ -362,9 +362,12 @@ def fit( torch.device( "cuda" if torch.cuda.is_available() - else "mps" - if torch.backends.mps.is_available() and parameters.get("use_mps", False) - else "cpu" + else ( + "mps" + if torch.backends.mps.is_available() + and parameters.get("use_mps", False) + else "cpu" + ) ) ) if ( @@ -437,10 +440,12 @@ def fit( map_location=torch.device( "cuda" if torch.cuda.is_available() - else "mps" - if torch.backends.mps.is_available() - and parameters.get("use_mps", False) - else "cpu" + else ( + "mps" + if torch.backends.mps.is_available() + and parameters.get("use_mps", False) + else "cpu" + ) ), ) Checkpoint.load_objects( diff --git a/src/autoembedder/model.py b/src/autoembedder/model.py index 5a1296a..700ea2f 100644 --- a/src/autoembedder/model.py +++ b/src/autoembedder/model.py @@ -58,9 +58,11 @@ def model_input( device = torch.device( "cuda" if torch.cuda.is_available() - else "mps" - if torch.backends.mps.is_available() and parameters.get("use_mps", False) - else "cpu" + else ( + "mps" + if torch.backends.mps.is_available() and parameters.get("use_mps", False) + else "cpu" + ) ) cat = [] cont = []