From 0c71f5a52bb92f74484b7403b50886c1242d8110 Mon Sep 17 00:00:00 2001 From: bsh98 Date: Mon, 28 Jun 2021 22:06:51 -0700 Subject: [PATCH] updated torch.load params --- model_coverage.py | 2 +- model_generator.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/model_coverage.py b/model_coverage.py index f5ba1a3..98c9a4a 100644 --- a/model_coverage.py +++ b/model_coverage.py @@ -91,7 +91,7 @@ def mask_text(self, text_tokenized): return masked def reload_model(self, model_file): - print(self.model.load_state_dict(torch.load(model_file), strict=False)) + print(self.model.load_state_dict(torch.load(model_file, map_location=torch.device(self.device)), strict=False)) def save_model(self, model_file): torch.save(self.model.state_dict(), model_file) diff --git a/model_generator.py b/model_generator.py index a8a0efc..b1eca97 100644 --- a/model_generator.py +++ b/model_generator.py @@ -36,7 +36,7 @@ def __init__(self, max_output_length=25, max_input_length=300, device='cpu', tok self.mode = "train" def reload(self, from_file): - print(self.model.load_state_dict(torch.load(from_file), strict=False)) + print(self.model.load_state_dict(torch.load(from_file, map_location=torch.device(self.device)), strict=False)) def save(self, to_file): torch.save(self.model.state_dict(), to_file)