diff --git a/04_pytorch_custom_datasets.ipynb b/04_pytorch_custom_datasets.ipynb index b09dde55..006659b9 100644 --- a/04_pytorch_custom_datasets.ipynb +++ b/04_pytorch_custom_datasets.ipynb @@ -2498,10 +2498,11 @@ " )\n", "\n", " # 5. Update results dictionary\n", - " results[\"train_loss\"].append(train_loss)\n", - " results[\"train_acc\"].append(train_acc)\n", - " results[\"test_loss\"].append(test_loss)\n", - " results[\"test_acc\"].append(test_acc)\n", + " # Ensure all data is moved to CPU and converted to float for storage\n", + " results[\"train_loss\"].append(train_loss.item() if isinstance(train_loss, torch.Tensor) else train_loss)\n", + " results[\"train_acc\"].append(train_acc.item() if isinstance(train_acc, torch.Tensor) else train_acc)\n", + " results[\"test_loss\"].append(test_loss.item() if isinstance(test_loss, torch.Tensor) else test_loss)\n", + " results[\"test_acc\"].append(test_acc.item() if isinstance(test_acc, torch.Tensor) else test_acc)\n", "\n", " # 6. Return the filled results at the end of the epochs\n", " return results"