From dee87fb2281a6a69088e34460d2113c9e1dcb702 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 20 Oct 2024 07:15:48 -0700 Subject: [PATCH] small helper method --- ema_pytorch/ema_pytorch.py | 8 ++++++++ setup.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/ema_pytorch/ema_pytorch.py b/ema_pytorch/ema_pytorch.py index 7d31e56..fde3ffe 100644 --- a/ema_pytorch/ema_pytorch.py +++ b/ema_pytorch/ema_pytorch.py @@ -200,6 +200,14 @@ def model(self): def eval(self): return self.ema_model.eval() + @torch.no_grad() + def forward_eval(self, *args, **kwargs): + # handy function for invoking ema model with no grad + eval + training = self.ema_model.training + out = self.ema_model(*args, **kwargs) + self.ema_model.train(training) + return out + def restore_ema_model_device(self): device = self.initted.device self.ema_model.to(device) diff --git a/setup.py b/setup.py index adb1b9f..61b30a0 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ema-pytorch', packages = find_packages(exclude=[]), - version = '0.7.2', + version = '0.7.3', license='MIT', description = 'Easy way to keep track of exponential moving average version of your pytorch module', author = 'Phil Wang',