Skip to content

Commit

Permalink
updated torch.load params
Browse files Browse the repository at this point in the history
  • Loading branch information
bsh98 committed Jun 29, 2021
1 parent 9615326 commit 0c71f5a
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion model_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def mask_text(self, text_tokenized):
return masked

def reload_model(self, model_file):
print(self.model.load_state_dict(torch.load(model_file), strict=False))
print(self.model.load_state_dict(torch.load(model_file, map_location=torch.device(self.device)), strict=False))

def save_model(self, model_file):
torch.save(self.model.state_dict(), model_file)
Expand Down
2 changes: 1 addition & 1 deletion model_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, max_output_length=25, max_input_length=300, device='cpu', tok
self.mode = "train"

def reload(self, from_file):
print(self.model.load_state_dict(torch.load(from_file), strict=False))
print(self.model.load_state_dict(torch.load(from_file, map_location=torch.device(self.device)), strict=False))

def save(self, to_file):
torch.save(self.model.state_dict(), to_file)
Expand Down

0 comments on commit 0c71f5a

Please sign in to comment.