Skip to content

Commit

Permalink
Merge pull request #909 from utsavrai/main
Browse files Browse the repository at this point in the history
Bug Fix: Convert tensors to scalars for plotting compatibility
  • Loading branch information
mrdbourke authored Aug 22, 2024
2 parents 65ed32c + e642eb2 commit 9f1084e
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions 04_pytorch_custom_datasets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2498,10 +2498,11 @@
" )\n",
"\n",
" # 5. Update results dictionary\n",
" results[\"train_loss\"].append(train_loss)\n",
" results[\"train_acc\"].append(train_acc)\n",
" results[\"test_loss\"].append(test_loss)\n",
" results[\"test_acc\"].append(test_acc)\n",
" # Ensure all data is moved to CPU and converted to float for storage\n",
" results[\"train_loss\"].append(train_loss.item() if isinstance(train_loss, torch.Tensor) else train_loss)\n",
" results[\"train_acc\"].append(train_acc.item() if isinstance(train_acc, torch.Tensor) else train_acc)\n",
" results[\"test_loss\"].append(test_loss.item() if isinstance(test_loss, torch.Tensor) else test_loss)\n",
" results[\"test_acc\"].append(test_acc.item() if isinstance(test_acc, torch.Tensor) else test_acc)\n",
"\n",
" # 6. Return the filled results at the end of the epochs\n",
" return results"
Expand Down

0 comments on commit 9f1084e

Please sign in to comment.