From e642eb23b59b16673ab8431555ce3def0a2836d9 Mon Sep 17 00:00:00 2001 From: utsavrai Date: Fri, 26 Apr 2024 17:51:36 +0100 Subject: [PATCH] Convert tensors to scalars for plotting compatibility --- 04_pytorch_custom_datasets.ipynb | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/04_pytorch_custom_datasets.ipynb b/04_pytorch_custom_datasets.ipynb index cba13564..e870495a 100644 --- a/04_pytorch_custom_datasets.ipynb +++ b/04_pytorch_custom_datasets.ipynb @@ -2498,10 +2498,11 @@ " )\n", "\n", " # 5. Update results dictionary\n", - " results[\"train_loss\"].append(train_loss)\n", - " results[\"train_acc\"].append(train_acc)\n", - " results[\"test_loss\"].append(test_loss)\n", - " results[\"test_acc\"].append(test_acc)\n", + " # Ensure all data is moved to CPU and converted to float for storage\n", + " results[\"train_loss\"].append(train_loss.item() if isinstance(train_loss, torch.Tensor) else train_loss)\n", + " results[\"train_acc\"].append(train_acc.item() if isinstance(train_acc, torch.Tensor) else train_acc)\n", + " results[\"test_loss\"].append(test_loss.item() if isinstance(test_loss, torch.Tensor) else test_loss)\n", + " results[\"test_acc\"].append(test_acc.item() if isinstance(test_acc, torch.Tensor) else test_acc)\n", "\n", " # 6. Return the filled results at the end of the epochs\n", " return results"