Skip to content

Commit

Permalink
make sure model on target device
Browse files Browse the repository at this point in the history
  • Loading branch information
mrdbourke authored Jul 27, 2022
1 parent 6237b93 commit defa4fa
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions going_modular/going_modular/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def train(model: torch.nn.Module,
"test_loss": [],
"test_acc": []
}

# Make sure model on target device
model.to(device)

# Loop through training and testing steps for a number of epochs
for epoch in tqdm(range(epochs)):
Expand Down

0 comments on commit defa4fa

Please sign in to comment.