Skip to content

Commit

Permalink
Convert tensors to scalars for plotting compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
utsavrai committed Apr 26, 2024
1 parent 0fa794b commit e642eb2
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions 04_pytorch_custom_datasets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2498,10 +2498,11 @@
" )\n",
"\n",
" # 5. Update results dictionary\n",
" results[\"train_loss\"].append(train_loss)\n",
" results[\"train_acc\"].append(train_acc)\n",
" results[\"test_loss\"].append(test_loss)\n",
" results[\"test_acc\"].append(test_acc)\n",
" # Ensure all data is moved to CPU and converted to float for storage\n",
" results[\"train_loss\"].append(train_loss.item() if isinstance(train_loss, torch.Tensor) else train_loss)\n",
" results[\"train_acc\"].append(train_acc.item() if isinstance(train_acc, torch.Tensor) else train_acc)\n",
" results[\"test_loss\"].append(test_loss.item() if isinstance(test_loss, torch.Tensor) else test_loss)\n",
" results[\"test_acc\"].append(test_acc.item() if isinstance(test_acc, torch.Tensor) else test_acc)\n",
"\n",
" # 6. Return the filled results at the end of the epochs\n",
" return results"
Expand Down

0 comments on commit e642eb2

Please sign in to comment.