From e2f9998be92c237753654f720880d87308a6e847 Mon Sep 17 00:00:00 2001 From: Andy-wyx Date: Thu, 14 Dec 2023 14:51:51 -0500 Subject: [PATCH] upload Feedback Alignment folder --- feedback_alignment/FeedbackAlignment.ipynb | 2116 ++++++++++++++++++++ 1 file changed, 2116 insertions(+) create mode 100644 feedback_alignment/FeedbackAlignment.ipynb diff --git a/feedback_alignment/FeedbackAlignment.ipynb b/feedback_alignment/FeedbackAlignment.ipynb new file mode 100644 index 0000000..2c4f2eb --- /dev/null +++ b/feedback_alignment/FeedbackAlignment.ipynb @@ -0,0 +1,2116 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "3HTi6mdDfuOU" + }, + "source": [ + "\n", + "# **CLPS 1291 Final Project**\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ArERhSYMDEVx" + }, + "outputs": [], + "source": [ + "#@title Enter your details - {display-mode: \"form\"}\n", + "\n", + "Name = 'Yunxi Liang' #@param {type: \"string\"}\n", + "Collaborators = '' #@param {type: \"string\"}\n", + "\n" + ] + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "from torch.utils.data import TensorDataset, DataLoader, Dataset\n", + "from torch.optim import Adam" + ], + "metadata": { + "id": "96js0CwIEHBi" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### FA Module Setup" + ], + "metadata": { + "id": "7MigEk06unDt" + } + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "import torch.nn as nn\n", + "from torch import autograd\n", + "from torch.autograd import Variable\n", + "\n", + "\n", + "class LinearFANetwork(nn.Module):\n", + " \"\"\"\n", + " Linear feed-forward networks with feedback alignment learning\n", + " Does NOT perform non-linear activation after each layer\n", + " \"\"\"\n", + " def __init__(self, in_features, num_layers, num_hidden_list):\n", + " \"\"\"\n", + " :param in_features: dimension of input features (784 for MNIST)\n", + " :param num_layers: number of layers for feed-forward net\n", + " :param num_hidden_list: list of integers indicating hidden nodes of each layer = output features?\n", + " \"\"\"\n", + " super(LinearFANetwork, self).__init__()\n", + " self.in_features = in_features\n", + " self.num_layers = num_layers\n", + " self.num_hidden_list = num_hidden_list\n", + "\n", + " # create list of linear layers\n", + " # first hidden layer\n", + " self.linear = [LinearFAModule(self.in_features, self.num_hidden_list[0])] # calls FA module to build hidden layers\n", + " # append additional hidden layers to list\n", + " for idx in range(self.num_layers - 1):\n", + " self.linear.append(LinearFAModule(self.num_hidden_list[idx], self.num_hidden_list[idx+1]))\n", + "\n", + " # create ModuleList to make list of layers work\n", + " self.linear = nn.ModuleList(self.linear)\n", + "\n", + " def forward(self, inputs):\n", + " \"\"\"\n", + " forward pass, which is same for conventional feed-forward net\n", + " :param inputs: inputs with shape [batch_size, in_features]\n", + " :return: logit outputs from the network\n", + " \"\"\"\n", + "\n", + " # first layer\n", + " linear1 = self.linear[0](inputs)\n", + "\n", + " # second layer\n", + " linear2 = self.linear[1](linear1)\n", + "\n", + " return linear2\n", + "\n", + "class LinearFAModule(nn.Module):\n", + "\n", + " def __init__(self, input_features, output_features, bias=True):\n", + " super(LinearFAModule, self).__init__()\n", + " self.input_features = input_features\n", + " self.output_features = output_features\n", + "\n", + " # weight and bias for forward pass\n", + " # weight has transposed form; more efficient (so i heard) (transposed at forward pass)\n", + " self.weight = nn.Parameter(torch.Tensor(output_features, input_features))\n", + " if bias:\n", + " self.bias = nn.Parameter(torch.Tensor(output_features))\n", + " else:\n", + " self.register_parameter('bias', None)\n", + "\n", + " # fixed random weight and bias for FA backward pass\n", + " # does not need gradient\n", + " self.weight_fa = Variable(torch.FloatTensor(output_features, input_features), requires_grad=False)\n", + "\n", + " # weight initialization\n", + " torch.nn.init.kaiming_uniform(self.weight)\n", + " torch.nn.init.kaiming_uniform(self.weight_fa)\n", + " torch.nn.init.constant(self.bias, 1)\n", + "\n", + " def forward(self, input):\n", + " return LinearFAFunction.apply(input, self.weight, self.weight_fa, self.bias)\n", + "\n", + "\n", + "class LinearFAFunction(autograd.Function): # is this like the relu activation function?\n", + "\n", + " @staticmethod\n", + " # same as reference linear function, but with additional fa tensor for backward\n", + " def forward(context, input, weight, weight_fa, bias=None):\n", + " context.save_for_backward(input, weight, weight_fa, bias)\n", + " output = input.mm(weight.t())\n", + " if bias is not None:\n", + " output += bias.unsqueeze(0).expand_as(output)\n", + " return output\n", + "\n", + " @staticmethod\n", + " def backward(context, grad_output):\n", + " input, weight, weight_fa, bias = context.saved_variables\n", + " grad_input = grad_weight = grad_weight_fa = grad_bias = None\n", + "\n", + " if context.needs_input_grad[0]:\n", + " # all of the logic of FA resides in this one line\n", + " # calculate the gradient of input with fixed fa tensor, rather than the \"correct\" model weight\n", + " grad_input = grad_output.mm(weight_fa.to(grad_output.device))\n", + " if context.needs_input_grad[1]:\n", + " # grad for weight with FA'ed grad_output from downstream layer\n", + " # it is same with original linear function\n", + " grad_weight = grad_output.t().mm(input)\n", + " if bias is not None and context.needs_input_grad[3]:\n", + " grad_bias = grad_output.sum(0).squeeze(0)\n", + "\n", + " return grad_input, grad_weight, grad_weight_fa, grad_bias" + ], + "metadata": { + "id": "yl5Y76uwElGo" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Processing data and training" + ], + "metadata": { + "id": "oirh8wKYEuFP" + } + }, + { + "cell_type": "code", + "source": [ + "# SETUP and load the datasets\n", + "import os\n", + "import torch\n", + "import tensorflow_datasets as datasets\n", + "from torchvision import transforms\n", + "import torchvision\n", + "\n", + "# check whether CUDA is available\n", + "cuda_available = torch.cuda.is_available()\n", + "if cuda_available:\n", + " print(\"CUDA is available\")\n", + "else:\n", + " print(\"CUDA is not available\")\n", + "\n", + "# CIFAR-10 consists of 60,000 32x32 color images in 10 different classes, with 6,000 images per class.\n", + "from torchvision import datasets\n", + "from torchvision.transforms import ToTensor\n", + "import matplotlib.pyplot as plt\n", + "\n", + "transform = transforms.Compose([\n", + " transforms.ToTensor(), # Convert images to PyTorch tensors\n", + " transforms.Normalize((0.5,), (0.5,)) # Normalize pixel values to the range [-1, 1]\n", + "])\n", + "\n", + "\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5w79vSxlDjgz", + "outputId": "66d8866f-635a-48ec-98bb-e7ae0d080c0a" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "CUDA is available\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "'''\n", + "def load_cifar10(batch_size):\n", + " DOWNLOAD_CIFAR10 = False\n", + " if not(os.path.exists('./cifar10/')) or not os.listdir('./cifar10/'):\n", + " DOWNLOAD_CIFAR10 = True\n", + "\n", + " transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #for each RGB channel.\n", + " ])\n", + "\n", + " train_data = torchvision.datasets.CIFAR10(\n", + " root='./cifar10/', train=True,\n", + " download=DOWNLOAD_CIFAR10, transform=transform\n", + " )\n", + " train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)\n", + "\n", + " test_data = torchvision.datasets.CIFAR10(\n", + " root='./cifar10/', train=False,\n", + " download=DOWNLOAD_CIFAR10, transform=transform\n", + " )\n", + " test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)\n", + "\n", + " return train_loader, test_loader\n", + "\n", + "train_loader,test_loader = load_cifar10(64)\n", + "'''" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "e8nuWyPZr0-b", + "outputId": "fa626c46-40d5-4238-ab78-90303023f392" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar10/cifar-10-python.tar.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 170498071/170498071 [00:05<00:00, 30406874.84it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Extracting ./cifar10/cifar-10-python.tar.gz to ./cifar10/\n", + "Files already downloaded and verified\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# Download and load the MNIST training dataset\n", + "train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n", + "\n", + "# Create a DataLoader for training data\n", + "train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n", + "\n", + "# Download and load the MNIST test dataset\n", + "test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n", + "\n", + "# Create a DataLoader for test data\n", + "test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)" + ], + "metadata": { + "id": "sJOj7aXYoQsF", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "92d62d16-b444-47ec-baed-8e6c4dbcbba0" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 9912422/9912422 [00:00<00:00, 61505351.94it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw\n", + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 28881/28881 [00:00<00:00, 108060387.00it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw\n", + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 1648877/1648877 [00:00<00:00, 21288181.11it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw\n", + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 4542/4542 [00:00<00:00, 19439315.07it/s]" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw\n", + "\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Baseline1: Build the Network: simple Con\n" + ], + "metadata": { + "id": "7uL7DDGKEn2Z" + } + }, + { + "cell_type": "code", + "source": [ + "class SimpleConvNet(nn.Module):\n", + " def __init__(self, out_features):\n", + " super().__init__()\n", + " # reshape the data first to pass into conv1\n", + " self.conv1 = nn.Conv2d(1, 10, kernel_size=5) # first conv layer (1->10)\n", + " # maxpool2d (kernel=2), and then apply relu\n", + " self.conv2 = nn.Conv2d(10, 20, kernel_size=5)# second conv layer\n", + " self.conv2_drop = nn.Dropout2d() # dropout layer\n", + " # apply maxpool2d(kernel=2) after dropout, then apply relu\n", + " # flatten tensor using view, prepare for fc\n", + " self.fc1 = nn.Linear(320, 50) # fc1\n", + " # apply relu again, and then dropout for regularization\n", + " self.fc2 = nn.Linear(50, out_features) # fc2\n", + " # pass this through softmax\n", + "\n", + " def forward(self, x):\n", + " x = x.view(len(x), 1, 28, 28)\n", + " x = F.relu(F.max_pool2d(self.conv1(x), 2))\n", + " x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n", + " # flatten the tensor to prepare for fc\n", + " x = x.view(-1, 320)\n", + " x = F.relu(self.fc1(x))\n", + " x = F.dropout(x, training=self.training)\n", + " x = self.fc2(x)\n", + " return F.log_softmax(x, dim=-1)\n" + ], + "metadata": { + "id": "Cm7s_tl6C_6x" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Train Simple CNN\n", + "import torch.optim as optim\n", + "SimpleConvNet_model = SimpleConvNet(10)\n", + "\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.SGD(SimpleConvNet_model.parameters(), lr=0.01) # used sgd as optimizer" + ], + "metadata": { + "id": "iePTveEJDYVm" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "epochs = 5\n", + "for epoch in range(epochs):\n", + " for batch_id, (data, target) in enumerate(train_loader):\n", + " optimizer.zero_grad()\n", + " output = SimpleConvNet_model(data)\n", + " loss = criterion(output, target) # criterion ouptut a loss item\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " if batch_id % 100 ==0:\n", + " print('Epoch {}, Batch {}, Loss: {:.4f}'.format(epoch, batch_id, loss.item()))\n", + "\n", + "SimpleConvNet_model.eval()\n", + "test_loss = 0\n", + "correct = 0\n", + "\n", + "with torch.no_grad():\n", + " for data, target in test_loader:\n", + " output = SimpleConvNet_model(data)\n", + " test_loss += criterion(output, target).item()\n", + " pred = output.argmax(dim=1, keepdim=True)\n", + " correct += pred.eq(target.view_as(pred)).sum().item()\n", + "\n", + "test_loss /= len(test_loader.dataset) # returns average test loss\n", + "accuracy = correct / len(test_loader.dataset)\n", + "\n", + "print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(\n", + " test_loss, correct, len(test_loader.dataset), 100. * accuracy))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 391 + }, + "id": "OLATIP1IfbDE", + "outputId": "2b53b5d6-82d1-42ee-e8b5-217e4bd4f1ff" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "error", + "ename": "RuntimeError", + "evalue": "ignored", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbatch_id\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSimpleConvNet_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# criterion ouptut a loss item\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1517\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1518\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1519\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1520\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1525\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1526\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1528\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1529\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m28\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m28\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_pool2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_pool2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv2_drop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: shape '[64, 1, 28, 28]' is invalid for input of size 196608" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Baseline2: Test if loaded weights are working" + ], + "metadata": { + "id": "Yf7L9DuZSi-c" + } + }, + { + "cell_type": "code", + "source": [ + "# Weight extraction\n", + "trained_model = SimpleConvNet_model\n", + "state_dict = trained_model.state_dict()\n", + "conv1_weights = state_dict['conv1.weight']\n", + "conv2_weights = state_dict['conv2.weight']" + ], + "metadata": { + "id": "sAalhmGWFFeO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from torch.nn import Parameter\n", + "class SimpleCNN(nn.Module):\n", + " def __init__(self, out_features, conv1_weights, conv2_weights):\n", + " super().__init__()\n", + " self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n", + " self.conv1.weight = Parameter(conv1_weights)\n", + " self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n", + " self.conv2.weight = Parameter(conv2_weights)\n", + " self.conv2_drop = nn.Dropout2d()\n", + " # freeze the above 3 layers\n", + " self.conv1.requires_grad = False\n", + " self.conv2.requires_grad = False\n", + " self.conv2_drop.requires_grad = False\n", + "\n", + " self.fc1 = nn.Linear(320, 50) # fc1\n", + " # apply relu again, and then dropout for regularization\n", + " self.fc2 = nn.Linear(50, out_features) # fc2\n", + " # pass this through softmax\n", + "\n", + " def forward(self, x):\n", + " x = x.view(len(x), 1, 28, 28)\n", + " x = F.relu(F.max_pool2d(self.conv1(x), 2))\n", + " x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n", + " # flatten the tensor to prepare for fc\n", + " x = x.view(-1, 320)\n", + " x = F.relu(self.fc1(x))\n", + " x = F.dropout(x, training=self.training)\n", + " x = self.fc2(x)\n", + " return F.log_softmax(x, dim=-1)\n", + "\n", + "just_conv = SimpleCNN(10, conv1_weights, conv2_weights)\n", + "optimizer_cnn = torch.optim.SGD(just_conv.parameters(),\n", + " lr=0.01)" + ], + "metadata": { + "id": "fU8RUmjv7-Wj" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "epochs = 5\n", + "for epoch in range(epochs):\n", + " for batch_id, (inputs, targets) in enumerate(train_loader):\n", + " optimizer_cnn.zero_grad()\n", + " outputs_cnn = just_conv(inputs)\n", + " #print(inputs.shape, targets.shape, outputs_fa.shape)\n", + " # calculate loss\n", + " loss_cnn = criterion(outputs_cnn, targets)\n", + " loss_cnn.backward()\n", + " optimizer_cnn.step()\n", + "\n", + " if batch_id % 100 ==0:\n", + " print('Epoch {}, Batch {}, Loss: {:.4f}'.format(epoch, batch_id, loss_cnn.item()))\n", + "\n", + "just_conv.eval()\n", + "test_loss = 0\n", + "correct = 0\n", + "\n", + "with torch.no_grad():\n", + " for data, target in test_loader:\n", + " output = just_conv(data)\n", + " test_loss += criterion(output, target).item()\n", + " pred = output.argmax(dim=1, keepdim=True)\n", + " correct += pred.eq(target.view_as(pred)).sum().item()\n", + "\n", + "test_loss /= len(test_loader.dataset) # returns average test loss\n", + "accuracy = correct / len(test_loader.dataset)\n", + "\n", + "print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(\n", + " test_loss, correct, len(test_loader.dataset), 100. * accuracy))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "qQ9SftKJ8sHS", + "outputId": "d4455fa6-9741-41fa-826b-676982470713" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch 0, Batch 0, Loss: 2.5834\n", + "Epoch 0, Batch 100, Loss: 1.1622\n", + "Epoch 0, Batch 200, Loss: 0.9284\n", + "Epoch 0, Batch 300, Loss: 0.8592\n", + "Epoch 0, Batch 400, Loss: 0.5966\n", + "Epoch 0, Batch 500, Loss: 0.7824\n", + "Epoch 0, Batch 600, Loss: 0.6595\n", + "Epoch 0, Batch 700, Loss: 0.3930\n", + "Epoch 0, Batch 800, Loss: 0.3622\n", + "Epoch 0, Batch 900, Loss: 0.3888\n", + "Epoch 1, Batch 0, Loss: 0.4752\n", + "Epoch 1, Batch 100, Loss: 0.2869\n", + "Epoch 1, Batch 200, Loss: 0.4019\n", + "Epoch 1, Batch 300, Loss: 0.2408\n", + "Epoch 1, Batch 400, Loss: 0.1853\n", + "Epoch 1, Batch 500, Loss: 0.2997\n", + "Epoch 1, Batch 600, Loss: 0.4690\n", + "Epoch 1, Batch 700, Loss: 0.3131\n", + "Epoch 1, Batch 800, Loss: 0.3871\n", + "Epoch 1, Batch 900, Loss: 0.1463\n", + "Epoch 2, Batch 0, Loss: 0.2966\n", + "Epoch 2, Batch 100, Loss: 0.3476\n", + "Epoch 2, Batch 200, Loss: 0.2582\n", + "Epoch 2, Batch 300, Loss: 0.2038\n", + "Epoch 2, Batch 400, Loss: 0.2793\n", + "Epoch 2, Batch 500, Loss: 0.3335\n", + "Epoch 2, Batch 600, Loss: 0.1498\n", + "Epoch 2, Batch 700, Loss: 0.4266\n", + "Epoch 2, Batch 800, Loss: 0.3741\n", + "Epoch 2, Batch 900, Loss: 0.2491\n", + "Epoch 3, Batch 0, Loss: 0.3343\n", + "Epoch 3, Batch 100, Loss: 0.2170\n", + "Epoch 3, Batch 200, Loss: 0.2580\n", + "Epoch 3, Batch 300, Loss: 0.3171\n", + "Epoch 3, Batch 400, Loss: 0.1433\n", + "Epoch 3, Batch 500, Loss: 0.2407\n", + "Epoch 3, Batch 600, Loss: 0.2798\n", + "Epoch 3, Batch 700, Loss: 0.1814\n", + "Epoch 3, Batch 800, Loss: 0.2447\n", + "Epoch 3, Batch 900, Loss: 0.2266\n", + "Epoch 4, Batch 0, Loss: 0.2229\n", + "Epoch 4, Batch 100, Loss: 0.0959\n", + "Epoch 4, Batch 200, Loss: 0.2022\n", + "Epoch 4, Batch 300, Loss: 0.3450\n", + "Epoch 4, Batch 400, Loss: 0.1550\n", + "Epoch 4, Batch 500, Loss: 0.1553\n", + "Epoch 4, Batch 600, Loss: 0.2527\n", + "Epoch 4, Batch 700, Loss: 0.1740\n", + "Epoch 4, Batch 800, Loss: 0.1469\n", + "Epoch 4, Batch 900, Loss: 0.2044\n", + "Test set: Average loss: 0.0013, Accuracy: 9740/10000 (97.40%)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Baseline3: Weight Extraction and build FA" + ], + "metadata": { + "id": "ShWHnzBOSDZa" + } + }, + { + "cell_type": "code", + "source": [ + "from torch.nn import Parameter\n", + "class SimpleConvNet_withFA(nn.Module):\n", + " def __init__(self, out_features, conv1_weights, conv2_weights):\n", + " super().__init__()\n", + " self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n", + " self.conv1.weight = Parameter(conv1_weights)\n", + " self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n", + " self.conv2.weight = Parameter(conv2_weights)\n", + " self.conv2_drop = nn.Dropout2d()\n", + " # freeze the above 3 layers\n", + " self.conv1.requires_grad = False\n", + " self.conv2.requires_grad = False\n", + " self.conv2_drop.requires_grad = False\n", + "\n", + " self.fa1 = LinearFAModule(320, 50) # just specify size?\n", + " self.fa2 = LinearFAModule(50, out_features)\n", + "\n", + " def forward(self, x):\n", + " #print(\"Input size:\", x.size())\n", + " x = F.relu(F.max_pool2d(self.conv1(x), 2))\n", + " #print(\"After conv1 size:\", x.size())\n", + " x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n", + " #print(\"After conv2 size:\", x.size())\n", + " # Flatten the tensor to prepare for fc\n", + " x = x.view(-1, 320)\n", + " #print(\"After view size:\", x.size())\n", + " x = F.relu(self.fa1(x))\n", + " #print(\"After fa1 size:\", x.size())\n", + " x = F.dropout(x, training=self.training)\n", + " x = self.fa2(x)\n", + " #print(\"Final output size:\", x.size())\n", + " return F.log_softmax(x, dim=-1)\n", + "\n", + " # # x = x.view(-1, 1, 28, 28)\n", + " # x = F.relu(F.max_pool2d(self.conv1(x), 2))\n", + " # x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n", + " # # flatten the tensor to prepare for fc\n", + " # x = x.view(-1, self.conv2.out_channels)\n", + " # x = F.relu(self.fa1(x))\n", + " # x = F.dropout(x, training=self.training)\n", + " # x = self.fa2(x)\n", + " # return F.log_softmax(x, dim=-1)\n", + "\n", + "conv_withfa = SimpleConvNet_withFA(10, conv1_weights, conv2_weights)\n", + "optimizer_fa = torch.optim.SGD(conv_withfa.parameters(),\n", + " lr=0.0001)" + ], + "metadata": { + "id": "vVQ-0J_yDdJk", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "026613a7-2de1-45d7-ec9d-0ba099639f1d" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + ":69: UserWarning: nn.init.kaiming_uniform is now deprecated in favor of nn.init.kaiming_uniform_.\n", + " torch.nn.init.kaiming_uniform(self.weight)\n", + ":70: UserWarning: nn.init.kaiming_uniform is now deprecated in favor of nn.init.kaiming_uniform_.\n", + " torch.nn.init.kaiming_uniform(self.weight_fa)\n", + ":71: UserWarning: nn.init.constant is now deprecated in favor of nn.init.constant_.\n", + " torch.nn.init.constant(self.bias, 1)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "epochs = 5\n", + "for epoch in range(epochs):\n", + " for batch_id, (inputs, targets) in enumerate(train_loader):\n", + " optimizer_fa.zero_grad()\n", + " outputs_fa = conv_withfa(inputs)\n", + " #print(inputs.shape, targets.shape, outputs_fa.shape)\n", + " # calculate loss\n", + " loss_fa = criterion(outputs_fa, targets)\n", + " loss_fa.backward()\n", + " optimizer_fa.step()\n", + "\n", + " if batch_id % 100 ==0:\n", + " print('Epoch {}, Batch {}, Loss: {:.4f}'.format(epoch, batch_id, loss_fa.item()))\n", + "\n", + "conv_withfa.eval()\n", + "test_loss = 0\n", + "correct = 0\n", + "\n", + "with torch.no_grad():\n", + " for data, target in test_loader:\n", + " output = conv_withfa(data)\n", + " test_loss += criterion(output, target).item()\n", + " pred = output.argmax(dim=1, keepdim=True)\n", + " correct += pred.eq(target.view_as(pred)).sum().item()\n", + "\n", + "test_loss /= len(test_loader.dataset) # returns average test loss\n", + "accuracy = correct / len(test_loader.dataset)\n", + "\n", + "print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(\n", + " test_loss, correct, len(test_loader.dataset), 100. * accuracy))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "D--dfI7TmE1B", + "outputId": "ceb7df1f-6e92-4f77-94c7-de11751a7f97" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch 0, Batch 0, Loss: 16.0889\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + ":90: DeprecationWarning: 'saved_variables' is deprecated; use 'saved_tensors'\n", + " input, weight, weight_fa, bias = context.saved_variables\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch 0, Batch 100, Loss: 12.2939\n", + "Epoch 0, Batch 200, Loss: 15.1310\n", + "Epoch 0, Batch 300, Loss: 14.4827\n", + "Epoch 0, Batch 400, Loss: 17.6298\n", + "Epoch 0, Batch 500, Loss: 15.8998\n", + "Epoch 0, Batch 600, Loss: 13.9132\n", + "Epoch 0, Batch 700, Loss: 15.9931\n", + "Epoch 0, Batch 800, Loss: 16.1339\n", + "Epoch 0, Batch 900, Loss: 15.8832\n", + "Epoch 1, Batch 0, Loss: 19.5445\n", + "Epoch 1, Batch 100, Loss: 17.5985\n", + "Epoch 1, Batch 200, Loss: 14.8734\n", + "Epoch 1, Batch 300, Loss: 18.7231\n", + "Epoch 1, Batch 400, Loss: 14.8385\n", + "Epoch 1, Batch 500, Loss: 12.2053\n", + "Epoch 1, Batch 600, Loss: 15.9272\n", + "Epoch 1, Batch 700, Loss: 15.5226\n", + "Epoch 1, Batch 800, Loss: 11.7830\n", + "Epoch 1, Batch 900, Loss: 14.3114\n", + "Epoch 2, Batch 0, Loss: 14.4862\n", + "Epoch 2, Batch 100, Loss: 11.8447\n", + "Epoch 2, Batch 200, Loss: 11.6348\n", + "Epoch 2, Batch 300, Loss: 12.3725\n", + "Epoch 2, Batch 400, Loss: 10.2506\n", + "Epoch 2, Batch 500, Loss: 10.0633\n", + "Epoch 2, Batch 600, Loss: 13.8340\n", + "Epoch 2, Batch 700, Loss: 13.6265\n", + "Epoch 2, Batch 800, Loss: 8.9138\n", + "Epoch 2, Batch 900, Loss: 7.4068\n", + "Epoch 3, Batch 0, Loss: 8.7327\n", + "Epoch 3, Batch 100, Loss: 9.6699\n", + "Epoch 3, Batch 200, Loss: 10.0646\n", + "Epoch 3, Batch 300, Loss: 9.1900\n", + "Epoch 3, Batch 400, Loss: 8.6116\n", + "Epoch 3, Batch 500, Loss: 7.3793\n", + "Epoch 3, Batch 600, Loss: 7.9442\n", + "Epoch 3, Batch 700, Loss: 6.7128\n", + "Epoch 3, Batch 800, Loss: 5.9199\n", + "Epoch 3, Batch 900, Loss: 7.2058\n", + "Epoch 4, Batch 0, Loss: 6.7629\n", + "Epoch 4, Batch 100, Loss: 5.9891\n", + "Epoch 4, Batch 200, Loss: 4.0608\n", + "Epoch 4, Batch 300, Loss: 4.2900\n", + "Epoch 4, Batch 400, Loss: 7.5622\n", + "Epoch 4, Batch 500, Loss: 4.8813\n", + "Epoch 4, Batch 600, Loss: 5.8307\n", + "Epoch 4, Batch 700, Loss: 6.8185\n", + "Epoch 4, Batch 800, Loss: 6.0583\n", + "Epoch 4, Batch 900, Loss: 6.9000\n", + "Test set: Average loss: 0.0079, Accuracy: 8488/10000 (84.88%)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Past Drafts" + ], + "metadata": { + "id": "gDpwu5vnSwWF" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JhNXaMGAwvKd", + "outputId": "02e623e6-4626-49f9-dff8-d08e5a3967b2" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading...\n", + "From: https://drive.google.com/uc?id=1-0RKObNKSfKmxPfBkZGZnQ66Y-vQR9DY\n", + "To: /content/colab_pdf.py\n", + "\r 0% 0.00/1.83k [00:0010)\n", + " # maxpool2d (kernel=2), and then apply relu\n", + " self.conv2 = nn.Conv2d(10, 20, kernel_size=5)# second conv layer\n", + " self.conv2_drop = nn.Dropout2d() # dropout layer\n", + " # apply maxpool2d(kernel=2) after dropout, then apply relu\n", + " # flatten tensor using view, prepare for fc\n", + " self.fc1 = nn.Linear(320, 50) # fc1\n", + " # apply relu again, and then dropout for regularization\n", + " self.fc2 = nn.Linear(50, out_features) # fc2\n", + " # pass this through softmax\n", + "\n", + " def forward(self, x):\n", + " x = x.view(len(x), 1, 28, 28)\n", + " x = F.relu(F.max_pool2d(self.conv1(x), 2))\n", + " x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n", + " # flatten the tensor to prepare for fc\n", + " x = x.view(-1, 320)\n", + " x = F.relu(self.fc1(x))\n", + " x = F.dropout(x, training=self.training)\n", + " x = self.fc2(x)\n", + " return F.log_softmax(x, dim=-1)\n", + "\n", + "class SimpleConvNet_withoutFA(nn.Module):\n", + " def __init__(self, out_features):\n", + " super().__init__()\n", + " self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n", + " self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n", + " self.conv2_drop = nn.Dropout2d()\n", + " # freeze the above 3 layers\n", + " self.conv1.requires_grad = False\n", + " self.conv2.requires_grad = False\n", + " self.conv2_drop.requires_grad = False\n", + "\n", + " self.fa1 = nn.Linear(self.conv2_drop.size(0), 50, bias=False)\n", + " self.fa2 = nn.Linear(50, out_features, bias=False)\n", + "\n", + " def forward(self, x):\n", + " x = x.view(len(x), 1, 28, 28)\n", + " x = F.relu(F.max_pool2d(self.conv1(x), 2))\n", + " x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n", + " # flatten the tensor to prepare for fc\n", + " x = x.view(-1, self.conv2_drop.size(0))\n", + " x = F.relu(self.fa1(x))\n", + " x = F.dropout(x, training=self.training)\n", + " x = self.fa2(x)\n", + " return F.log_softmax(x, dim=-1)" + ], + "metadata": { + "id": "UhqhXavdADV2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# CIFAR-10 consists of 60,000 32x32 color images in 10 different classes, with 6,000 images per class.\n", + "(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()\n", + "\n", + "# Normalize pixel values to be between 0 and 1\n", + "train_images, test_images = train_images / 255.0, test_images / 255.0\n", + "\n", + "# train" + ], + "metadata": { + "id": "h_ukzvdrEQBK" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',\n", + " 'dog', 'frog', 'horse', 'ship', 'truck']\n", + "\n", + "plt.figure(figsize=(10,10))\n", + "for i in range(25):\n", + " plt.subplot(5,5,i+1)\n", + " plt.xticks([])\n", + " plt.yticks([])\n", + " plt.grid(False)\n", + " plt.imshow(train_images[i])\n", + " # The CIFAR labels happen to be arrays,\n", + " # which is why you need the extra index\n", + " plt.xlabel(class_names[train_labels[i][0]])\n", + "plt.show()" + ], + "metadata": { + "id": "KQZAGcRhEbdo" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "#### STEP2: Construct the CNN+FC model as BASELINE" + ], + "metadata": { + "id": "W9nOgx7r3M7y" + } + }, + { + "cell_type": "code", + "source": [ + "model = models.Sequential()\n", + "model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))\n", + "model.add(layers.MaxPooling2D((2, 2)))\n", + "model.add(layers.Conv2D(64, (3, 3), activation='relu'))\n", + "# model.add(layers.MaxPooling2D((2, 2)))\n", + "# model.add(layers.Conv2D(32, (3, 3), activation='relu'))\n", + "\n", + "model.add(layers.Flatten())\n", + "model.add(layers.Dense(65, activation='relu'))\n", + "model.add(layers.Dense(10))\n", + "\n", + "model.summary()" + ], + "metadata": { + "id": "IHDf4-2fEh59", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "f1e11e34-bb5d-4106-8361-f5b3ad6d8319" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Model: \"sequential_4\"\n", + "_________________________________________________________________\n", + " Layer (type) Output Shape Param # \n", + "=================================================================\n", + " conv2d_12 (Conv2D) (None, 26, 26, 32) 320 \n", + " \n", + " max_pooling2d_8 (MaxPoolin (None, 13, 13, 32) 0 \n", + " g2D) \n", + " \n", + " conv2d_13 (Conv2D) (None, 11, 11, 64) 18496 \n", + " \n", + " flatten_4 (Flatten) (None, 7744) 0 \n", + " \n", + " dense_8 (Dense) (None, 65) 503425 \n", + " \n", + " dense_9 (Dense) (None, 10) 660 \n", + " \n", + "=================================================================\n", + "Total params: 522901 (1.99 MB)\n", + "Trainable params: 522901 (1.99 MB)\n", + "Non-trainable params: 0 (0.00 Byte)\n", + "_________________________________________________________________\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "#### STEP3: Compile the baseline model and obtain training weights" + ], + "metadata": { + "id": "BHQ-6ijb3Urj" + } + }, + { + "cell_type": "code", + "source": [ + "model.compile(optimizer='adam',\n", + " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", + " metrics=['accuracy'])\n", + "\n", + "history = model.fit(train_images, train_labels, epochs=10,\n", + " validation_data=(test_images, test_labels))" + ], + "metadata": { + "id": "k9k_N8wLEmtb", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "42b76204-0f9d-469b-8f49-86c9e7bb4e0c" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch 1/10\n", + "1875/1875 [==============================] - 9s 4ms/step - loss: 0.1138 - accuracy: 0.9648 - val_loss: 0.0474 - val_accuracy: 0.9845\n", + "Epoch 2/10\n", + "1875/1875 [==============================] - 8s 4ms/step - loss: 0.0384 - accuracy: 0.9878 - val_loss: 0.0317 - val_accuracy: 0.9890\n", + "Epoch 3/10\n", + "1875/1875 [==============================] - 8s 5ms/step - loss: 0.0234 - accuracy: 0.9927 - val_loss: 0.0318 - val_accuracy: 0.9909\n", + "Epoch 4/10\n", + "1875/1875 [==============================] - 7s 4ms/step - loss: 0.0182 - accuracy: 0.9941 - val_loss: 0.0393 - val_accuracy: 0.9879\n", + "Epoch 5/10\n", + "1875/1875 [==============================] - 9s 5ms/step - loss: 0.0133 - accuracy: 0.9956 - val_loss: 0.0387 - val_accuracy: 0.9893\n", + "Epoch 6/10\n", + "1875/1875 [==============================] - 7s 4ms/step - loss: 0.0091 - accuracy: 0.9968 - val_loss: 0.0356 - val_accuracy: 0.9896\n", + "Epoch 7/10\n", + "1875/1875 [==============================] - 10s 5ms/step - loss: 0.0073 - accuracy: 0.9976 - val_loss: 0.0438 - val_accuracy: 0.9887\n", + "Epoch 8/10\n", + "1875/1875 [==============================] - 10s 5ms/step - loss: 0.0069 - accuracy: 0.9978 - val_loss: 0.0467 - val_accuracy: 0.9887\n", + "Epoch 9/10\n", + "1875/1875 [==============================] - 7s 4ms/step - loss: 0.0062 - accuracy: 0.9981 - val_loss: 0.0447 - val_accuracy: 0.9898\n", + "Epoch 10/10\n", + "1875/1875 [==============================] - 9s 5ms/step - loss: 0.0047 - accuracy: 0.9984 - val_loss: 0.0599 - val_accuracy: 0.9891\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "plt.plot(history.history['accuracy'], label='accuracy')\n", + "plt.plot(history.history['val_accuracy'], label = 'val_accuracy')\n", + "plt.xlabel('Epoch')\n", + "plt.ylabel('Accuracy')\n", + "plt.ylim([0.5, 1])\n", + "plt.legend(loc='lower right')\n", + "\n", + "test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)" + ], + "metadata": { + "id": "GyyzLIF3EwnB", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 641 + }, + "outputId": "532cdbef-defa-4b03-b040-62aaf7ac9c1c" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "313/313 - 1s - loss: 0.0599 - accuracy: 0.9891 - 1s/epoch - 3ms/step\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAmUAAAJfCAYAAAA+bqHsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABCmElEQVR4nO3deXxTVf7/8XeStuliW5bSlkLZUUHZCxXEDXEYGfkqOgqIWnFwBRQ7joKyuEEVB0QFZUAWF0DcUH6D4mAdRREFwaIOiwsIqLQFka50S/L7o21o6EILaXPavp6Px30k99xz7/2kDc2be0/utbhcLpcAAADgU1ZfFwAAAABCGQAAgBEIZQAAAAYglAEAABiAUAYAAGAAQhkAAIABCGUAAAAGIJQBAAAYgFAGAABgAEIZAACAAXwayjZs2KBhw4YpJiZGFotF77zzzknX+fjjj9W7d2/Z7XZ16tRJy5Ytq/U6AQAAaptPQ1lOTo569Oih+fPnV6v/3r179Ze//EWXXHKJUlJSNHHiRI0dO1YffPBBLVcKAABQuyym3JDcYrFo9erVuuqqqyrt88ADD2jt2rX67rvv3G0jR47U0aNHtW7dujqoEgAAoHb4+bqAmti0aZMGDx7s0TZkyBBNnDix0nXy8/OVn5/vnnc6nTpy5IiaN28ui8VSW6UCAABIklwul7KyshQTEyOrtfKTlPUqlKWmpioqKsqjLSoqSpmZmTp27JiCgoLKrZOUlKRHHnmkrkoEAACo0IEDB9S6detKl9erUHYqJk+erMTERPd8RkaG2rRpowMHDigsLMyHlQHA6XG5XCpyuuQomY4/d7rbHCV9nE6Xihyl/ZxyulxyOFQ87yq7vlNOp1RUso2iMttxlulT5Dy+fulyh6t4H87S5RVMRa7i7TpdxfW7pJLnpS9Kcrpccrkkl0oeyz5X2eUl2yhdrpK+JRtzlW5Lx/flcnk+L91W6c/T6Souovz2T6irwm0V95OrdFvHl8F8744foI4tQmtl25mZmYqNjVVoaNXbr1ehLDo6WmlpaR5taWlpCgsLq/AomSTZ7XbZ7fZy7WFhYYQyAHKVBJJCh0sFDqeKHE4VOlwqdDhV4HCq0OFUYZHL/bzoxGUVLPdY5nCpoOj4vOfykm0Vec6XPi9yOD1CkcNxPCQVBytf//TqK0slz09jU5VsxnIae7BYJKvFIqtFssjini/7aJFktVqK51U8PttqKbtu8d6tVrn7uNct6Vs8X7KsTL/j2/Ks4cQ6VLbvCdv3nC/fVrovWSquz3LCNt2v54TXX1pvcZ+ytZ64v7Lb9ezfLrqFwkICTvG3Vd3fadXvhnoVyvr376/33nvPo239+vXq37+/jyoCUFMOp0v5RQ4VFDmVX+RUfqFT+UWO4udFjpL5sm2l/crOH+9XNvCcGGwKHC4VFlW8rLBMOGqIRzL8bRbZrBbZLMWPfjZr8aPV4n60uuetHu3F/Ys/0D2W244vt1mK+xSvU81tl6nJWvKJejwAqNyHvMcHtDw/1MuGD4sq2JaKl6uCQFPRtsruS9IJ/Y/vS/IME2U/3MsFHY86LR6BoKqwVdoPjY9PQ1l2drZ+/PFH9/zevXuVkpKiZs2aqU2bNpo8ebJ+/fVXvfzyy5KkO+64Q/PmzdP999+vW265RR999JFef/11rV271lcvAahXXK7iI0JlA09BRWHoxGBU6FCBw1lB+4nhqbKwdTxUFdWDwzt+Vov8bVb52ywK8LPKz2qVv19xW4DN6l7mb7OWLC/p71e6vHR9z+UBficssxVv189a2vf4stJ+1pLwU1l4slnLByOrlQ90oD7yaSj76quvdMkll7jnS8d+JSQkaNmyZTp48KD279/vXt6+fXutXbtW9957r5555hm1bt1aL774ooYMGVLntQMnUzrep/T0VGlYKSgzX+BwqrDIqfyybRX0KddWWXsV65UGJJOOClktUqC/TXY/q+x+Ntn9rbL7FYcXu19pu+ey0vbSPqVBp/jxeKApDU9+pc/9yi/zLw1JVs/nhBoAvmDMdcrqSmZmpsLDw5WRkcGYskbM5XLpaG6hDmXnKz0zX4ey8/R7doHyCkuO9JQJN4Xlwo3nfOEJQajsctP/dQX4eQYdd9jxrywQnRCWKugXYLOW9C8ftkpDVOk6fjbu9Aag4atu9qhXY8qAkykocupwdr7Ss/J1KCtf6Vl5JY+l4StfhzLzdCg7X4WOuk1MFosUYLO6Q0npc/fkPr11PBx59LHZyjy3lFlu89iG3e/4abOy2ygOUWWONNk4IgQAJiGUwXgul0uZeUUeIeuQO3R5hq8/cgsluRSgIgWqoHiyFChI+QpUgYIsBTpTBequAgVZ89U0wKEIu0PN7E6F+zvl9AtWkX+oigJC5QgIldM/VE57mGQPk8seJqs9xDME+VllLxOA3GGobLAqmeeoEACgKoQy1B1HoVR4rHgqOqaivBwdzcxSRmamMrMylZ2dpdycLB3LzVZebrYK83JVmJ8jZ8Ex+TvzFKhCBVnyFaziYNXDUhK0SoJXoAoUaC8OYDZLDY6C5ZdM1WGxSYHFIa34MfyE+TAp8MS2E+b9gyW+WdVwuS+s5VTxBaucJ5mvoo/FKln9JJtf8aPVX7L5F7fzHsKpcDolZ6HkKCj+m+woLDNfVPzoLGk/sY+zqPhvoNVPstpKJr/ybR7zfsffxx79y6zvbuM/roQyVM5RJOX+LuWkS9npUs5hKeeQVJDtEa7cz0smR0GuHPm5cpbMW4vyZHMek83l8Ni8n6SIkqlK1pLpVFisxSHIP0jyCyp+LDu52wIlm10qyJHyM6W8TCk/o+SxZN7lKJ6O/VE8nSqrn2QPLQ5vpSHOI9SVeaysj3+QmR/KLlfxH/Ki/JLHvDLPyz7mS0UFJzxW1K+i/if0dRadEGZUvUDkcp28T023qTo8JW71Lwls/iUfbv4VBziPZVX0dc97o29F836ez018/3qDy1X8nqwq9DgKTrNP2fnqBqyS9hP+DpvFUkHQs55k/sQgWFk4rEZYHHivFBrt058AoayxKcwrCVmHigOWR+A64XnuEZ3Kh4ytZKpKrsuuPPkrT3YVWOxy2ALl9AuS/AJlDQiWzR4kP3uIAoJCFBh0hoKCz5B/YHBxwPILPB60KgpXZUOYzd87f/xdrhMCWyXBLS+j6j4uZ/EfWq8EuxOPzlUR7vxDSv5QnyQMVRSAqhWoSh4dBaf/s27sLFYVX1zLUhL2KvkQdZZ8ABcdq9Py0NBYJFtA8d9Km39JmC8zbwsoCdMBxeHFWfKfU2dRyVG3opL/GJ3QVul8URW1uKrRpxbF3UIow2lyuYo/7KsTsrIPSQVZNdyBRa7g5srxb6a9x4K191iQslzBOqYA5SlAx8qEqzxXgI4pQMdkl8UvUIHBZygkJFQhoaEKOyNU4eFhahIWrogmYWoRGqjIMLuiggNkqw+DzS0WyX5G8RQWc2rbKA125YJbBUGusnCXn1Um2B0pnkxm9Zf87MV/0D0e7ZJfQHHArmxZhY8VbMtaejrPWnJV0NJQU3KKr8r5mvS3VGN7Xuh/Iqez+APNUXj8A8v9vLD4Q/J05h1FZZYVnTBfk23VpK6ihn3voYoCTXVCzyn1KW2vxnYq3O7J/gtdC0rf086i4vdF6eOJbaV/69zzjpLnZeerGw4dJ99H8EnP29Q6QpmJnI7io1TukHXohOfpxfOl7Y7qDogqYfWXzoiUQloUT5U8z7M317vf52nhZ/v0U1qOpOKLanaOClVkqF0tQu2KDLWrdahdkaGB7vkWoXaF2HlrlVM22KnVqW3D5So+fVxhcKsi3BXmlvwxLgk2VQWlk4Wg6m7DFsAYEW+wlpy/t/n7uhLAO3hPV4pPzrpSlF8SrA5VEKxOCFm5h0vGp9RAwBknDVnu54HhVZ7S+yOnQK9+sU8vbfpOh7OLT0eF2v10fXwb3Xx+O7UMr/g+o6gDFkvxeDR7qBR+isEOAGAkQpm3ZadLnzxZPnDlZdR8W0HNKghWLaSQyPLtAcGnXfq+33O0+LO9ev2rA8orLA6FMeGBumVge43oG6vQQP5XAwBAbSGUeZuzSNryYsXLLLYKglUlz0Mi6uzQ7tZ9f+jFT/do3f9S3cM8zokJ020XdtDQbi3lz/W1AACodYQybwtpIV1wX8VHtQKbGDPGxuF0af2ONC36dI+27jv+LcBLzmqhWy/ooP4dm8vSUL+yDgCAgQhl3mbzly6d6usqKnWswKE3t/2ixZ/u0c+/50oqvvXPVb1iNPaCDjozKtTHFQIA0DgRyhqJw9n5ennTPr2y6eeSWxFJ4UH+uuG8Nkro306RYYE+rhAAgMaNUNbA/XQoWy9+uldvbftFBUXFg/djmwXpb+e317VxsVy6AgAAQ/CJ3AC5XC5t3ntEiz7dqw93prnbe8Q20W0XdNCQc6K4OTYAAIYhlDUgRQ6n1v0vVYs27NH2X4ovwWGxSIO7ROm2Czsorm1TBu8DAGAoQlkDkJNfpNe/OqDFn+3VL38U3wcvwM+qv/Zprb8NbK+OLc7wcYUAAOBkCGX1WHpmnpZ9/rNe/WKfMvOKb+DaLCRAN57XVjf2b6uIM+w+rhAAAFQXoawe2p2apRc/3aN3Un5VoaP4aq/tI0L0t4HtdU3v1goK8MENZgEAwGkhlNUTLpdLn//0uxZu2KNPvj/kbu/brqnGXtBBg7tEyWZlvBgAAPUVocxwhQ6n1n5zUAs37NGOg5mSJKtF+vO50Rp7QQf1btPUxxUCAABvIJQZKiuvUK9tPqAlG/fqYEaeJCnI36br4lrrloHt1bZ5iI8rBAAA3kQoM8xvR49p2ec/a+WX+5WVXzx4P+IMu24e0Faj49uqaUiAjysEAAC1gVBmiP/9lqEXP92r/7f9NxU5iwfvd4o8Q7de0F5X9mylQH8G7wMA0JARynzI5XLpk+8PadGne7Txx9/d7f07NNdtF3bQRWe2kJXB+wAANAqEMh/IL3JoTcpvevHTvdqdliVJslkt+ku3lrr1gg7q1jrcxxUCAIC6RiirQxm5hVq+eZ+WbfxZ6Vn5kqSQAJtG9mujMee3U+umwT6uEAAA+AqhrA4cOJKrJRv3atWWA8otcEiSosLsGnN+e43q10bhQf4+rhAAAPgaoawWbT9wVAs/3aP3vz2okrH7Ojs6VLde0EHDesQowM/q2wIBAIAxCGVe5nS69NGudC38dI827z3ibr+gc4Ruu7CDBnaKkMXC4H0AAOCJUOZlX+z9XWNf/kqS5Ge16P96xmjswA7qGhPm48oAAIDJCGVe1r9Dc/Vt11S92zbVzQPaqWV4kK9LAgAA9QChzMssFotev70/pygBAECNMNK8FhDIAABATRHKAAAADEAoAwAAMAChDAAAwACEMgAAAAMQygAAAAxAKAMAADAAoQwAAMAAhDIAAAADEMoAAAAMQCgDAAAwAKEMAADAAIQyAAAAAxDKAAAADEAoAwAAMAChDAAAwACEMgAAAAMQygAAAAxAKAMAADAAoQwAAMAAhDIAAAADEMoAAAAMQCgDAAAwAKEMAADAAIQyAAAAAxDKAAAADEAoAwAAMAChDAAAwACEMgAAAAMQygAAAAxAKAMAADAAoQwAAMAAhDIAAAADEMoAAAAMQCgDAAAwAKEMAADAAIQyAAAAAxDKAAAADEAoAwAAMAChDAAAwACEMgAAAAMQygAAAAxAKAMAADAAoQwAAMAAhDIAAAADEMoAAAAMQCgDAAAwAKEMAADAAIQyAAAAAxDKAAAADEAoAwAAMAChDAAAwACEMgAAAAMQygAAAAxAKAMAADAAoQwAAMAAhDIAAAADEMoAAAAMQCgDAAAwAKEMAADAAIQyAAAAAxDKAAAADEAoAwAAMAChDAAAwAA+D2Xz589Xu3btFBgYqPj4eG3evLnSvoWFhXr00UfVsWNHBQYGqkePHlq3bl0dVgsAAFA7fBrKVq1apcTERE2fPl3btm1Tjx49NGTIEKWnp1fYf8qUKfrXv/6l5557Tjt27NAdd9yh4cOH6+uvv67jygEAALzL4nK5XL7aeXx8vPr27at58+ZJkpxOp2JjYzVhwgRNmjSpXP+YmBg99NBDGjdunLvtmmuuUVBQkF599dVq7TMzM1Ph4eHKyMhQWFiYd14IAABAJaqbPXx2pKygoEBbt27V4MGDjxdjtWrw4MHatGlThevk5+crMDDQoy0oKEifffZZpfvJz89XZmamxwQAAGAan4Wyw4cPy+FwKCoqyqM9KipKqampFa4zZMgQzZkzRz/88IOcTqfWr1+vt99+WwcPHqx0P0lJSQoPD3dPsbGxXn0dAAAA3uDzgf418cwzz6hz5846++yzFRAQoPHjx2vMmDGyWit/GZMnT1ZGRoZ7OnDgQB1WDAAAUD0+C2URERGy2WxKS0vzaE9LS1N0dHSF67Ro0ULvvPOOcnJytG/fPu3atUtnnHGGOnToUOl+7Ha7wsLCPCYAAADT+CyUBQQEqE+fPkpOTna3OZ1OJScnq3///lWuGxgYqFatWqmoqEhvvfWWrrzyytouFwAAoFb5+XLniYmJSkhIUFxcnPr166e5c+cqJydHY8aMkSTddNNNatWqlZKSkiRJX375pX799Vf17NlTv/76qx5++GE5nU7df//9vnwZAAAAp82noWzEiBE6dOiQpk2bptTUVPXs2VPr1q1zD/7fv3+/x3ixvLw8TZkyRXv27NEZZ5yhoUOH6pVXXlGTJk189AoAAAC8w6fXKfMFrlMGAADqkvHXKQMAAMBxhDIAAAADEMoAAAAMQCgDAAAwAKEMAADAAIQyAAAAAxDKAAAADEAoAwAAMAChDAAAwACEMgAAAAMQygAAAAxAKAMAADAAoQwAAMAAhDIAAAADEMoAAAAMQCgDAAAwAKEMAADAAIQyAAAAAxDKAAAADEAoAwAAMAChDAAAwACEMgAAAAMQygAAAAxAKAMAADAAoQwAAMAAhDIAAAADEMoAAAAMQCgDAAAwAKEMAADAAIQyAAAAAxDKAAAADEAoAwAAMAChDAAAwACEMgAAAAMQygAAAAxAKAMAADAAoQwAAMAAhDIAAAADEMoAAAAMQCgDAAAwAKEMAADAAIQyAAAAAxDKAAAADEAoAwAAMAChDAAAwACEMgAAAAMQygAAAAxAKAMAADAAoQwAAMAAhDIAAAADEMoAAAAMQCgDAAAwAKEMAADAAIQyAAAAAxDKAAAADEAoAwAAMAChDAAAwACEMgAAAAMQygAAAAxAKAMAADAAoQwAAMAAhDIAAAADEMoAAAAMQCgDAAAwAKEMAADAAIQyAAAAAxDKAAAADEAoAwAAMAChDAAAwACEMgAAAAMQygAAAAxAKAMAADAAoQwAAMAAhDIAAAADEMoAAAAMQCgDAAAwAKEMAADAAIQyAAAAAxDKAAAADEAoAwAAMAChDAAAwACEMgAAAAMQygAAAAxAKAMAADAAoQwAAMAAhDIAAAADEMoAAAAMQCgDAAAwAKEMAADAAIQyAAAAAxDKAAAADEAoAwAAMAChDAAAwACEMgAAAAMQygAAAAxAKAMAADAAoQwAAMAAPg9l8+fPV7t27RQYGKj4+Hht3ry5yv5z587VWWedpaCgIMXGxuree+9VXl5eHVULAABQO3waylatWqXExERNnz5d27ZtU48ePTRkyBClp6dX2H/FihWaNGmSpk+frp07d2rx4sVatWqVHnzwwTquHAAAwLssLpfL5audx8fHq2/fvpo3b54kyel0KjY2VhMmTNCkSZPK9R8/frx27typ5ORkd9vf//53ffnll/rss88q3Ed+fr7y8/Pd85mZmYqNjVVGRobCwsK8/IoAAAA8ZWZmKjw8/KTZw2dHygoKCrR161YNHjz4eDFWqwYPHqxNmzZVuM6AAQO0detW9ynOPXv26L333tPQoUMr3U9SUpLCw8PdU2xsrHdfCAAAgBf4+WrHhw8flsPhUFRUlEd7VFSUdu3aVeE6119/vQ4fPqyBAwfK5XKpqKhId9xxR5WnLydPnqzExET3fOmRMgAAAJP4fKB/TXz88ceaOXOmnn/+eW3btk1vv/221q5dq8cee6zSdex2u8LCwjwmAAAA0/jsSFlERIRsNpvS0tI82tPS0hQdHV3hOlOnTtWNN96osWPHSpK6deumnJwc3XbbbXrooYdktdarjAkAAODmsxQTEBCgPn36eAzadzqdSk5OVv/+/StcJzc3t1zwstlskiQffl8BAADgtPnsSJkkJSYmKiEhQXFxcerXr5/mzp2rnJwcjRkzRpJ00003qVWrVkpKSpIkDRs2THPmzFGvXr0UHx+vH3/8UVOnTtWwYcPc4QwAAKA+8mkoGzFihA4dOqRp06YpNTVVPXv21Lp169yD//fv3+9xZGzKlCmyWCyaMmWKfv31V7Vo0ULDhg3TjBkzfPUSAAAAvMKn1ynzhepeKwQAAMAbjL9OGQAAAI4jlAEAABiAUAYAAGAAQhkAAIABCGUAAAAGIJQBAAAYgFAGAABgAEIZAACAAQhlAAAABiCUAQAAGIBQBgAAYABCGQAAgAEIZQAAAAYglAEAABiAUAYAAGAAQhkAAIABCGUAAAAGIJQBAAAYoMahrF27dnr00Ue1f//+2qgHAACgUapxKJs4caLefvttdejQQZdddplee+015efn10ZtAAAAjcYphbKUlBRt3rxZXbp00YQJE9SyZUuNHz9e27Ztq40aAQAAGjyLy+Vync4GCgsL9fzzz+uBBx5QYWGhunXrprvvvltjxoyRxWLxVp1ek5mZqfDwcGVkZCgsLMzX5QAAgAauutnD71R3UFhYqNWrV2vp0qVav369zjvvPP3tb3/TL7/8ogcffFAffvihVqxYcaqbBwAAaFRqHMq2bdumpUuXauXKlbJarbrpppv09NNP6+yzz3b3GT58uPr27evVQgEAABqyGoeyvn376rLLLtMLL7ygq666Sv7+/uX6tG/fXiNHjvRKgQAAAI1BjUPZnj171LZt2yr7hISEaOnSpadcFAAAQGNT429fpqen68svvyzX/uWXX+qrr77ySlEAAACNTY1D2bhx43TgwIFy7b/++qvGjRvnlaIAAAAamxqHsh07dqh3797l2nv16qUdO3Z4pSgAAIDGpsahzG63Ky0trVz7wYMH5ed3ylfYAAAAaNRqHMr+9Kc/afLkycrIyHC3HT16VA8++KAuu+wyrxYHAADQWNT40NY///lPXXjhhWrbtq169eolSUpJSVFUVJReeeUVrxcIAADQGNQ4lLVq1UrffPONli9fru3btysoKEhjxozRqFGjKrxmGQAAAE7ulAaBhYSE6LbbbvN2LQAAAI3WKY/M37Fjh/bv36+CggKP9v/7v/877aIAAAAam1O6ov/w4cP17bffymKxyOVySZIsFoskyeFweLdCAACARqDG376855571L59e6Wnpys4OFj/+9//tGHDBsXFxenjjz+uhRIBAAAavhofKdu0aZM++ugjRUREyGq1ymq1auDAgUpKStLdd9+tr7/+ujbqBAAAaNBqfKTM4XAoNDRUkhQREaHffvtNktS2bVvt3r3bu9UBAAA0EjU+Unbuuedq+/btat++veLj4zVr1iwFBARo4cKF6tChQ23UCAAA0ODVOJRNmTJFOTk5kqRHH31UV1xxhS644AI1b95cq1at8nqBAAAAjYHFVfr1ydNw5MgRNW3a1P0NTJNlZmYqPDxcGRkZCgsL83U5AACggatu9qjRmLLCwkL5+fnpu+++82hv1qxZvQhkAAAApqpRKPP391ebNm24FhkAAICX1fjblw899JAefPBBHTlypDbqAQAAaJRqPNB/3rx5+vHHHxUTE6O2bdsqJCTEY/m2bdu8VhwAAEBjUeNQdtVVV9VCGQAAAI2bV759WZ/w7UsAAFCXauXblwAAAKgdNT59abVaq7z8Bd/MBAAAqLkah7LVq1d7zBcWFurrr7/WSy+9pEceecRrhQEAADQmXhtTtmLFCq1atUrvvvuuNzZXaxhTBgAA6lKdjyk777zzlJyc7K3NAQAANCpeCWXHjh3Ts88+q1atWnljcwAAAI1OjceUnXjjcZfLpaysLAUHB+vVV1/1anEAAACNRY1D2dNPP+0RyqxWq1q0aKH4+Hg1bdrUq8UBAAA0FjUOZTfffHMtlAEAANC41XhM2dKlS/XGG2+Ua3/jjTf00ksveaUoAACAxqbGoSwpKUkRERHl2iMjIzVz5kyvFAUAANDY1DiU7d+/X+3bty/X3rZtW+3fv98rRQEAADQ2NQ5lkZGR+uabb8q1b9++Xc2bN/dKUQAAAI1NjUPZqFGjdPfdd+u///2vHA6HHA6HPvroI91zzz0aOXJkbdQIAADQ4NX425ePPfaYfv75Z1166aXy8yte3el06qabbmJMGQAAwCk65Xtf/vDDD0pJSVFQUJC6deumtm3beru2WsG9LwEAQF2qbvao8ZGyUp07d1bnzp1PdXUAAACUUeMxZddcc42efPLJcu2zZs3Stdde65WiAAAAGpsah7INGzZo6NCh5dovv/xybdiwwStFAQAANDY1DmXZ2dkKCAgo1+7v76/MzEyvFAUAANDY1DiUdevWTatWrSrX/tprr6lr165eKQoAAKCxqfFA/6lTp+rqq6/WTz/9pEGDBkmSkpOTtWLFCr355pteLxAAAKAxqHEoGzZsmN555x3NnDlTb775poKCgtSjRw999NFHatasWW3UCAAA0OCd8nXKSmVmZmrlypVavHixtm7dKofD4a3aagXXKQMAAHWputmjxmPKSm3YsEEJCQmKiYnR7NmzNWjQIH3xxRenujkAAIBGrUanL1NTU7Vs2TItXrxYmZmZuu6665Sfn6933nmHQf4AAACnodpHyoYNG6azzjpL33zzjebOnavffvtNzz33XG3WBgAA0GhU+0jZ+++/r7vvvlt33nknt1cCAADwsmofKfvss8+UlZWlPn36KD4+XvPmzdPhw4drszYAAIBGo9qh7LzzztOiRYt08OBB3X777XrttdcUExMjp9Op9evXKysrqzbrBAAAaNBO65IYu3fv1uLFi/XKK6/o6NGjuuyyy7RmzRpv1ud1XBIDAADUpVq/JIYknXXWWZo1a5Z++eUXrVy58nQ2BQAA0Kid9sVj6xuOlAEAgLpUJ0fKAAAA4B2EMgAAAAMQygAAAAxAKAMAADAAoQwAAMAAhDIAAAADEMoAAAAMQCgDAAAwAKEMAADAAIQyAAAAAxDKAAAADEAoAwAAMAChDAAAwACEMgAAAAMQygAAAAxgRCibP3++2rVrp8DAQMXHx2vz5s2V9r344otlsVjKTX/5y1/qsGIAAADv8nkoW7VqlRITEzV9+nRt27ZNPXr00JAhQ5Senl5h/7ffflsHDx50T999951sNpuuvfbaOq4cAADAe3weyubMmaNbb71VY8aMUdeuXbVgwQIFBwdryZIlFfZv1qyZoqOj3dP69esVHBxMKAMAAPWaT0NZQUGBtm7dqsGDB7vbrFarBg8erE2bNlVrG4sXL9bIkSMVEhJS4fL8/HxlZmZ6TAAAAKbxaSg7fPiwHA6HoqKiPNqjoqKUmpp60vU3b96s7777TmPHjq20T1JSksLDw91TbGzsadcNAADgbT4/fXk6Fi9erG7duqlfv36V9pk8ebIyMjLc04EDB+qwQgAAgOrx8+XOIyIiZLPZlJaW5tGelpam6OjoKtfNycnRa6+9pkcffbTKfna7XXa7/bRrBQAAqE0+PVIWEBCgPn36KDk52d3mdDqVnJys/v37V7nuG2+8ofz8fN1www21XSYAAECt8+mRMklKTExUQkKC4uLi1K9fP82dO1c5OTkaM2aMJOmmm25Sq1atlJSU5LHe4sWLddVVV6l58+a+KBsAAMCrfB7KRowYoUOHDmnatGlKTU1Vz549tW7dOvfg//3798tq9Tygt3v3bn322Wf6z3/+44uSAQAAvM7icrlcvi6iLmVmZio8PFwZGRkKCwvzdTkAAKCBq272qNffvgQAAGgoCGUAAAAGIJQBAAAYgFAGAABgAEIZAACAAQhlAAAABiCUAQAAGIBQBgAAYABCGQAAgAEIZQAAAAYglAEAABiAUAYAAGAAQhkAAIABCGUAAAAGIJQBAAAYgFAGAABgAEIZAACAAQhlAAAABiCUAQAAGIBQBgAAYABCGQAAgAEIZQAAAAYglAEAABiAUAYAAGAAQhkAAIABCGUAAAAGIJQBAAAYgFAGAABgAEIZAACAAQhlAAAABiCUAQAAGIBQBgAAYABCGQAAgAEIZQAAAAYglAEAABiAUAYAAGAAQhkAAIABCGUAAAAGIJQBAAAYgFAGAABgAEIZAACAAQhlAAAABiCUAQAAGIBQBgAAYABCGQAAgAEIZQAAAAYglAEAABiAUAYAAGAAQhkAAIABCGUAAAAGIJQBAAAYgFAGAABgAEIZAACAAQhlAAAABiCUAQAAGIBQBgAAYABCGQAAgAEIZQAAAAYglAEAABiAUAYAAGAAQhkAAIABCGUAAAAGIJQBAAAYgFAGAABgAEIZAACAAQhlAAAABiCUAQAAGIBQBgAAYABCGQAAgAEIZQAAAAYglAEAABiAUAYAAGAAQhkAAIABCGUAAAAGIJQBAAAYgFAGAABgAEIZAACAAQhlAAAABiCUAQAAGIBQBgAAYABCGQAAgAEIZQAAAAYglAEAABiAUAYAAGAAQhkAAIABCGUAAAAGIJQBAAAYgFAGAABgAEIZAACAAQhlAAAABiCUAQAAGIBQBgAAYABCGQAAgAEIZQAAAAYglAEAABiAUAYAAGAAQhkAAIABfB7K5s+fr3bt2ikwMFDx8fHavHlzlf2PHj2qcePGqWXLlrLb7TrzzDP13nvv1VG1AAAAtcPPlztftWqVEhMTtWDBAsXHx2vu3LkaMmSIdu/ercjIyHL9CwoKdNlllykyMlJvvvmmWrVqpX379qlJkyZ1XzwAAIAXWVwul8tXO4+Pj1ffvn01b948SZLT6VRsbKwmTJigSZMmleu/YMECPfXUU9q1a5f8/f1PaZ+ZmZkKDw9XRkaGwsLCTqt+AACAk6lu9vDZ6cuCggJt3bpVgwcPPl6M1arBgwdr06ZNFa6zZs0a9e/fX+PGjVNUVJTOPfdczZw5Uw6Ho9L95OfnKzMz02MCAAAwjc9C2eHDh+VwOBQVFeXRHhUVpdTU1ArX2bNnj9588005HA699957mjp1qmbPnq3HH3+80v0kJSUpPDzcPcXGxnr1dQAAAHiDzwf614TT6VRkZKQWLlyoPn36aMSIEXrooYe0YMGCSteZPHmyMjIy3NOBAwfqsGIAAIDq8dlA/4iICNlsNqWlpXm0p6WlKTo6usJ1WrZsKX9/f9lsNndbly5dlJqaqoKCAgUEBJRbx263y263e7d4AAAAL/PZkbKAgAD16dNHycnJ7jan06nk5GT179+/wnXOP/98/fjjj3I6ne6277//Xi1btqwwkAEAANQXPj19mZiYqEWLFumll17Szp07deeddyonJ0djxoyRJN10002aPHmyu/+dd96pI0eO6J577tH333+vtWvXaubMmRo3bpyvXgIAAIBX+PQ6ZSNGjNChQ4c0bdo0paamqmfPnlq3bp178P/+/ftltR7PjbGxsfrggw907733qnv37mrVqpXuuecePfDAA756CQAAAF7h0+uU+QLXKQMAAHXJ+OuUAQAA4DhCGQAAgAEIZQAAAAYglAEAABiAUAYAAGAAQhkAAIABCGUAAAAGIJQBAAAYgFAGAABgAEIZAACAAQhlAAAABiCUAQAAGIBQBgAAYABCGQAAgAEIZQAAAAYglAEAABiAUAYAAGAAQhkAAIABCGUAAAAGIJQBAAAYgFAGAABgAEIZAACAAQhlAAAABiCUAQAAGIBQBgAAYABCGQAAgAEIZQAAAAYglAEAABiAUAYAAGAAQhkAAIABCGUAAAAGIJQBAAAYgFAGAABgAEIZAACAAQhlAAAABiCUAQAAGIBQBgAAYABCGQAAgAEIZQAAAAYglAEAABiAUAYAAGAAQhkAAIABCGUAAAAGIJQBAAAYgFAGAABgAEIZAACAAQhlAAAABvDzdQEAAJjK5XKpqKhIDofD16XAYDabTX5+frJYLKe1HUIZAAAVKCgo0MGDB5Wbm+vrUlAPBAcHq2XLlgoICDjlbRDKAAA4gdPp1N69e2Wz2RQTE6OAgIDTPgqChsnlcqmgoECHDh3S3r171blzZ1mtpzY6jFAGAMAJCgoK5HQ6FRsbq+DgYF+XA8MFBQXJ399f+/btU0FBgQIDA09pOwz0BwCgEqd6xAONjzfeK7zbAAAADEAoAwAAMAChDAAAwACEMgAAAAMQygAAQK0qLCz0dQn1AqEMAIBqcLlcyi0o8snkcrlqVOu6des0cOBANWnSRM2bN9cVV1yhn376yb38l19+0ahRo9SsWTOFhIQoLi5OX375pXv5//t//099+/ZVYGCgIiIiNHz4cPcyi8Wid955x2N/TZo00bJlyyRJP//8sywWi1atWqWLLrpIgYGBWr58uX7//XeNGjVKrVq1UnBwsLp166aVK1d6bMfpdGrWrFnq1KmT7Ha72rRpoxkzZkiSBg0apPHjx3v0P3TokAICApScnFyjn4+puE4ZAADVcKzQoa7TPvDJvnc8OkTBAdX/yM7JyVFiYqK6d++u7OxsTZs2TcOHD1dKSopyc3N10UUXqVWrVlqzZo2io6O1bds2OZ1OSdLatWs1fPhwPfTQQ3r55ZdVUFCg9957r8Y1T5o0SbNnz1avXr0UGBiovLw89enTRw888IDCwsK0du1a3XjjjerYsaP69esnSZo8ebIWLVqkp59+WgMHDtTBgwe1a9cuSdLYsWM1fvx4zZ49W3a7XZL06quvqlWrVho0aFCN6zMRoQwAgAbmmmuu8ZhfsmSJWrRooR07dujzzz/XoUOHtGXLFjVr1kyS1KlTJ3ffGTNmaOTIkXrkkUfcbT169KhxDRMnTtTVV1/t0Xbfffe5n0+YMEEffPCBXn/9dfXr109ZWVl65plnNG/ePCUkJEiSOnbsqIEDB0qSrr76ao0fP17vvvuurrvuOknSsmXLdPPNNzeYuy0QygAAqIYgf5t2PDrEZ/uuiR9++EHTpk3Tl19+qcOHD7uPgu3fv18pKSnq1auXO5CdKCUlRbfeeutp1xwXF+cx73A4NHPmTL3++uv69ddfVVBQoPz8fPcdE3bu3Kn8/HxdeumlFW4vMDBQN954o5YsWaLrrrtO27Zt03fffac1a9acdq2mIJQBAFANFoulRqcQfWnYsGFq27atFi1apJiYGDmdTp177rkqKChQUFBQleuebLnFYik3xq2igfwhISEe80899ZSeeeYZzZ07V926dVNISIgmTpyogoKCau1XKj6F2bNnT/3yyy9aunSpBg0apLZt2550vfqCgf4AADQgv//+u3bv3q0pU6bo0ksvVZcuXfTHH3+4l3fv3l0pKSk6cuRIhet37969yoHzLVq00MGDB93zP/zwg3Jzc09a18aNG3XllVfqhhtuUI8ePdShQwd9//337uWdO3dWUFBQlfvu1q2b4uLitGjRIq1YsUK33HLLSfdbnxDKAABoQJo2barmzZtr4cKF+vHHH/XRRx8pMTHRvXzUqFGKjo7WVVddpY0bN2rPnj166623tGnTJknS9OnTtXLlSk2fPl07d+7Ut99+qyeffNK9/qBBgzRv3jx9/fXX+uqrr3THHXfI39//pHV17txZ69ev1+eff66dO3fq9ttvV1pamnt5YGCgHnjgAd1///16+eWX9dNPP+mLL77Q4sWLPbYzduxYPfHEE3K5XB7fCm0ICGUAADQgVqtVr732mrZu3apzzz1X9957r5566in38oCAAP3nP/9RZGSkhg4dqm7duumJJ56QzVY8bu3iiy/WG2+8oTVr1qhnz54aNGiQNm/e7F5/9uzZio2N1QUXXKDrr79e9913n3tcWFWmTJmi3r17a8iQIbr44ovdwbCsqVOn6u9//7umTZumLl26aMSIEUpPT/foM2rUKPn5+WnUqFEKDAw8jZ+UeSyuml78pJ7LzMxUeHi4MjIyFBYW5utyAAAGysvL0969e9W+ffsG98Ff3/3888/q2LGjtmzZot69e/u6HLeq3jPVzR71Y8QiAABo1AoLC/X7779rypQpOu+884wKZN7C6UsAAGC8jRs3qmXLltqyZYsWLFjg63JqBUfKAACA8S6++OIa326qvuFIGQAAgAEIZQAAAAYglAEAABiAUAYAAGAAQhkAAIABCGUAAAAGIJQBAAC3du3aae7cub4uo1EilAEAABiAUAYAABoEh8Mhp9Pp6zJOGaEMAIDqcLmkghzfTNW8kv3ChQsVExNTLphceeWVuuWWW/TTTz/pyiuvVFRUlM444wz17dtXH3744Sn/SObMmaNu3bopJCREsbGxuuuuu5Sdne3RZ+PGjbr44osVHByspk2basiQIfrjjz8kSU6nU7NmzVKnTp1kt9vVpk0bzZgxQ5L08ccfy2Kx6OjRo+5tpaSkyGKx6Oeff5YkLVu2TE2aNNGaNWvUtWtX2e127d+/X1u2bNFll12miIgIhYeH66KLLtK2bds86jp69Khuv/12RUVFKTAwUOeee67+/e9/KycnR2FhYXrzzTc9+r/zzjsKCQlRVlbWKf+8TobbLAEAUB2FudLMGN/s+8HfpICQk3a79tprNWHCBP33v//VpZdeKkk6cuSI1q1bp/fee0/Z2dkaOnSoZsyYIbvdrpdfflnDhg3T7t271aZNmxqXZbVa9eyzz6p9+/bas2eP7rrrLt1///16/vnnJRWHqEsvvVS33HKLnnnmGfn5+em///2vHA6HJGny5MlatGiRnn76aQ0cOFAHDx7Url27alRDbm6unnzySb344otq3ry5IiMjtWfPHiUkJOi5556Ty+XS7NmzNXToUP3www8KDQ2V0+nU5ZdfrqysLL366qvq2LGjduzYIZvNppCQEI0cOVJLly7VX//6V/d+SudDQ0Nr/HOqLkIZAAANRNOmTXX55ZdrxYoV7lD25ptvKiIiQpdccomsVqt69Ojh7v/YY49p9erVWrNmjcaPH1/j/U2cONH9vF27dnr88cd1xx13uEPZrFmzFBcX556XpHPOOUeSlJWVpWeeeUbz5s1TQkKCJKljx44aOHBgjWooLCzU888/7/G6Bg0a5NFn4cKFatKkiT755BNdccUV+vDDD7V582bt3LlTZ555piSpQ4cO7v5jx47VgAEDdPDgQbVs2VLp6el67733TuuoYnUQygAAqA7/4OIjVr7adzWNHj1at956q55//nnZ7XYtX75cI0eOlNVqVXZ2th5++GGtXbtWBw8eVFFRkY4dO6b9+/efUlkffvihkpKStGvXLmVmZqqoqEh5eXnKzc1VcHCwUlJSdO2111a47s6dO5Wfn+8Oj6cqICBA3bt392hLS0vTlClT9PHHHys9PV0Oh0O5ubnu15mSkqLWrVu7A9mJ+vXrp3POOUcvvfSSJk2apFdffVVt27bVhRdeeFq1ngxjygAAqA6LpfgUoi8mi6XaZQ4bNkwul0tr167VgQMH9Omnn2r06NGSpPvuu0+rV6/WzJkz9emnnyolJUXdunVTQUFBjX8cP//8s6644gp1795db731lrZu3ar58+dLknt7QUFBla5f1TKp+NSoJLnKjKcrLCyscDuWE34+CQkJSklJ0TPPPKPPP/9cKSkpat68ebXqKjV27FgtW7ZMUvGpyzFjxpTbj7cRygAAaEACAwN19dVXa/ny5Vq5cqXOOuss9e7dW1LxoPubb75Zw4cPV7du3RQdHe0eNF9TW7duldPp1OzZs3XeeefpzDPP1G+/eR5J7N69u5KTkytcv3PnzgoKCqp0eYsWLSRJBw8edLelpKRUq7aNGzfq7rvv1tChQ3XOOefIbrfr8OHDHnX98ssv+v777yvdxg033KB9+/bp2Wef1Y4dO9ynWGsToQwAgAZm9OjRWrt2rZYsWeI+SiYVB6G3335bKSkp2r59u66//vpTvoREp06dVFhYqOeee0579uzRK6+8ogULFnj0mTx5srZs2aK77rpL33zzjXbt2qUXXnhBhw8fVmBgoB544AHdf//9evnll/XTTz/piy++0OLFi93bj42N1cMPP6wffvhBa9eu1ezZs6tVW+fOnfXKK69o586d+vLLLzV69GiPo2MXXXSRLrzwQl1zzTVav3699u7dq/fff1/r1q1z92natKmuvvpq/eMf/9Cf/vQntW7d+pR+TjVBKAMAoIEZNGiQmjVrpt27d+v66693t8+ZM0dNmzbVgAEDNGzYMA0ZMsR9FK2mevTooTlz5ujJJ5/Uueeeq+XLlyspKcmjz5lnnqn//Oc/2r59u/r166f+/fvr3XfflZ9f8ZD2qVOn6u9//7umTZumLl26aMSIEUpPT5ck+fv7a+XKldq1a5e6d++uJ598Uo8//ni1alu8eLH++OMP9e7dWzfeeKPuvvtuRUZGevR566231LdvX40aNUpdu3bV/fff7/5WaKm//e1vKigo0C233HJKP6Oasrhc1bz4SQORmZmp8PBwZWRkKCwszNflAAAMlJeXp71796p9+/YKDAz0dTnwkVdeeUX33nuvfvvtNwUEBFTZt6r3THWzB9++BAAAKCM3N1cHDx7UE088odtvv/2kgcxbOH0JAADKWb58uc4444wKp9JrjTVUs2bN0tlnn63o6GhNnjy5zvbL6UsAAE7A6cvii7umpaVVuMzf319t27at44rMxulLAABQK0JDQ2v1lkIoj9OXAABUopGdTMJp8MZ7hVAGAMAJ/P39JRUP+Aaqo/S9UvreORWcvgQA4AQ2m01NmjRxXzMrODi41m+xg/rJ5XIpNzdX6enpatKkiWw22ylvi1AGAEAFoqOjJckdzICqNGnSxP2eOVWEMgAAKmCxWNSyZUtFRkZWeCNsoJS/v/9pHSErRSgDAKAKNpvNKx+4wMkYMdB//vz5ateunQIDAxUfH6/NmzdX2nfZsmWyWCweU2O9hgwAAGg4fB7KVq1apcTERE2fPl3btm1Tjx49NGTIkCrP4YeFhengwYPuad++fXVYMQAAgPf5PJTNmTNHt956q8aMGaOuXbtqwYIFCg4O1pIlSypdx2KxKDo62j1FRUXVYcUAAADe59MxZQUFBdq6davHfaWsVqsGDx6sTZs2Vbpedna22rZtK6fTqd69e2vmzJmV3ocrPz9f+fn57vmMjAxJxbc8AAAAqG2lmeNkF5j1aSg7fPiwHA5HuSNdUVFR2rVrV4XrnHXWWVqyZIm6d++ujIwM/fOf/9SAAQP0v//9T61bty7XPykpSY888ki59tjYWO+8CAAAgGrIyspSeHh4pcvr3bcv+/fvr/79+7vnBwwYoC5duuhf//qXHnvssXL9J0+erMTERPe80+nUkSNH1Lx581q7EGBmZqZiY2N14MABbnpej/B7q3/4ndVP/N7qH35np8flcikrK0sxMTFV9vNpKIuIiJDNZit3F/q0tLRqX4DN399fvXr10o8//ljhcrvdLrvd7tHWpEmTU6q3psLCwnjz1kP83uoffmf1E7+3+off2amr6ghZKZ8O9A8ICFCfPn2UnJzsbnM6nUpOTvY4GlYVh8Ohb7/9Vi1btqytMgEAAGqdz09fJiYmKiEhQXFxcerXr5/mzp2rnJwcjRkzRpJ00003qVWrVkpKSpIkPfroozrvvPPUqVMnHT16VE899ZT27dunsWPH+vJlAAAAnBafh7IRI0bo0KFDmjZtmlJTU9WzZ0+tW7fOPfh///79slqPH9D7448/dOuttyo1NVVNmzZVnz599Pnnn6tr166+egnl2O12TZ8+vdxpU5iN31v9w++sfuL3Vv/wO6sbFtfJvp8JAACAWufzi8cCAACAUAYAAGAEQhkAAIABCGUAAAAGIJR52fz589WuXTsFBgYqPj5emzdv9nVJqEJSUpL69u2r0NBQRUZG6qqrrtLu3bt9XRZq4IknnpDFYtHEiRN9XQpO4tdff9UNN9yg5s2bKygoSN26ddNXX33l67JQBYfDoalTp6p9+/YKCgpSx44d9dhjj530Ho44NYQyL1q1apUSExM1ffp0bdu2TT169NCQIUOUnp7u69JQiU8++UTjxo3TF198ofXr16uwsFB/+tOflJOT4+vSUA1btmzRv/71L3Xv3t3XpeAk/vjjD51//vny9/fX+++/rx07dmj27Nlq2rSpr0tDFZ588km98MILmjdvnnbu3Kknn3xSs2bN0nPPPefr0hokLonhRfHx8erbt6/mzZsnqfjuBLGxsZowYYImTZrk4+pQHYcOHVJkZKQ++eQTXXjhhb4uB1XIzs5W79699fzzz+vxxx9Xz549NXfuXF+XhUpMmjRJGzdu1KeffurrUlADV1xxhaKiorR48WJ32zXXXKOgoCC9+uqrPqysYeJImZcUFBRo69atGjx4sLvNarVq8ODB2rRpkw8rQ01kZGRIkpo1a+bjSnAy48aN01/+8hePf3Mw15o1axQXF6drr71WkZGR6tWrlxYtWuTrsnASAwYMUHJysr7//ntJ0vbt2/XZZ5/p8ssv93FlDZPPr+jfUBw+fFgOh8N9J4JSUVFR2rVrl4+qQk04nU5NnDhR559/vs4991xfl4MqvPbaa9q2bZu2bNni61JQTXv27NELL7ygxMREPfjgg9qyZYvuvvtuBQQEKCEhwdfloRKTJk1SZmamzj77bNlsNjkcDs2YMUOjR4/2dWkNEqEMKDFu3Dh99913+uyzz3xdCqpw4MAB3XPPPVq/fr0CAwN9XQ6qyel0Ki4uTjNnzpQk9erVS999950WLFhAKDPY66+/ruXLl2vFihU655xzlJKSookTJyomJobfWy0glHlJRESEbDab0tLSPNrT0tIUHR3to6pQXePHj9e///1vbdiwQa1bt/Z1OajC1q1blZ6ert69e7vbHA6HNmzYoHnz5ik/P182m82HFaIiLVu2LHeP4i5duuitt97yUUWojn/84x+aNGmSRo4cKUnq1q2b9u3bp6SkJEJZLWBMmZcEBASoT58+Sk5Odrc5nU4lJyerf//+PqwMVXG5XBo/frxWr16tjz76SO3bt/d1STiJSy+9VN9++61SUlLcU1xcnEaPHq2UlBQCmaHOP//8cpeb+f7779W2bVsfVYTqyM3NldXqGRVsNpucTqePKmrYOFLmRYmJiUpISFBcXJz69eunuXPnKicnR2PGjPF1aajEuHHjtGLFCr377rsKDQ1VamqqJCk8PFxBQUE+rg4VCQ0NLTfmLyQkRM2bN2csoMHuvfdeDRgwQDNnztR1112nzZs3a+HChVq4cKGvS0MVhg0bphkzZqhNmzY655xz9PXXX2vOnDm65ZZbfF1ag8QlMbxs3rx5euqpp5SamqqePXvq2WefVXx8vK/LQiUsFkuF7UuXLtXNN99ct8XglF188cVcEqMe+Pe//63Jkyfrhx9+UPv27ZWYmKhbb73V12WhCllZWZo6dapWr16t9PR0xcTEaNSoUZo2bZoCAgJ8XV6DQygDAAAwAGPKAAAADEAoAwAAMAChDAAAwACEMgAAAAMQygAAAAxAKAMAADAAoQwAAMAAhDIAAAADEMoAoJZZLBa98847vi4DgOEIZQAatJtvvlkWi6Xc9Oc//9nXpQGAB25IDqDB+/Of/6ylS5d6tNntdh9VAwAV40gZgAbPbrcrOjraY2ratKmk4lOLL7zwgi6//HIFBQWpQ4cOevPNNz3W//bbbzVo0CAFBQWpefPmuu2225Sdne3RZ8mSJTrnnHNkt9vVsmVLjR8/3mP54cOHNXz4cAUHB6tz585as2ZN7b5oAPUOoQxAozd16lRdc8012r59u0aPHq2RI0dq586dkqScnBwNGTJETZs21ZYtW/TGG2/oww8/9AhdL7zwgsaNG6fbbrtN3377rdasWaNOnTp57OORRx7Rddddp2+++UZDhw7V6NGjdeTIkTp9nQAM5wKABiwhIcFls9lcISEhHtOMGTNcLpfLJcl1xx13eKwTHx/vuvPOO10ul8u1cOFCV9OmTV3Z2dnu5WvXrnVZrVZXamqqy+VyuWJiYlwPPfRQpTVIck2ZMsU9n52d7ZLkev/99732OgHUf4wpA9DgXXLJJXrhhRc82po1a+Z+3r9/f49l/fv3V0pKiiRp586d6tGjh0JCQtzLzz//fDmdTu3evVsWi0W//fabLr300ipr6N69u/t5SEiIwsLClJ6efqovCUADRCgD0OCFhISUO53oLUFBQdXq5+/v7zFvsVjkdDproyQA9RRjygA0el988UW5+S5dukiSunTpou3btysnJ8e9fOPGjbJarTrrrLMUGhqqdu3aKTk5uU5rBtDwcKQMQIOXn5+v1NRUjzY/Pz9FRERIkt544w3FxcVp4MCBWr58uTZv3qzFixdLkkaPHq3p06crISFBDz/8sA4dOqQJEyboxhtvVFRUlCTp4Ycf1h133KHIyEhdfvnlysrK0saNGzVhwoS6faEA6jVCGYAGb926dWrZsqVH21lnnaVdu3ZJKv5m5Guvvaa77rpLLVu21MqVK9W1a1dJUnBwsD744APdc8896tu3r4KDg3XNNddozpw57m0lJCQoLy9PTz/9tO677z5FRETor3/9a929QAANgsXlcrl8XQQA+IrFYtHq1at11VVX+boUAI0cY8oAAAAMQCgDAAAwAGPKADRqjOAAYAqOlAEAABiAUAYAAGAAQhkAAIABCGUAAAAGIJQBAAAYgFAGAABgAEIZAACAAQhlAAAABvj/VkzzVz8jTo4AAAAASUVORK5CYII=\n" + }, + "metadata": {} + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "### 2. FA model\n", + "(1) install weights of pretrained CNN layer. This set of weights is from CNN+FCN that used backpropagation\n", + "(2) feed this CNN result into a separate set of FCN layers, which we would use FA to update." + ], + "metadata": { + "id": "dDQDCCwGopEj" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Layer Extraction & Building the hybrid" + ], + "metadata": { + "id": "T27pMjdHbP3j" + } + }, + { + "cell_type": "code", + "source": [ + "# Extract layers\n", + "all_weights = [layer.get_weights() for layer in model.layers]\n", + "\n", + "# should I also train the CNN separately?" + ], + "metadata": { + "id": "4EO0kNHyqHI4" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "class CNN_withFA(nn.Module):\n", + " def __init__(self, in_features, out_features=10): # self defined in_features & number of hidden layers\n", + " # output_feature sis static bc both 10 for cifar10 and mnist\n", + " super(CNN_withFA, self).__init__()\n", + " self.in_features = in_features\n", + " self.out_features = 10\n", + "\n", + " self.layer1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)\n", + " self.\n", + " # self.layer1 = nn.Sequential(\n", + " # nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),\n", + " # nn.ReLU(), nn.MaxPool2d(kernel_size = 2, stride = 2))\n", + " # self.layer2 = nn.Sequential(\n", + " # nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),\n", + " # nn.ReLU())\n", + " self.layer3 = LinearFAModule(self.in_features, 100)\n", + " self.layer4 = LinearFAModule(100, out_features)\n", + " # self.layer2 = nn.Sequential(\n", + " # nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),\n", + " # nn.BatchNorm2d(64),\n", + " # nn.ReLU(),\n", + " # nn.MaxPool2d(kernel_size = 2, stride = 2))\n", + " def forward(self, x):\n", + " out = self.layer1(x)\n", + " out = self.layer2(out)\n", + " out = self.layer3(out)\n", + " out = self.layer4(out)\n", + " # out = self.layer7(out)\n", + " # out = self.layer8(out)\n", + " # out = self.layer9(out)\n", + " # out = self.layer10(out)\n", + " # out = self.layer11(out)\n", + " # out = self.layer12(out)\n", + " # out = self.layer13(out)\n", + " # out = out.reshape(out.size(0), -1)\n", + " # out = self.fc(out)\n", + " # out = self.fc1(out)\n", + " # out = self.fc2(out)\n", + " return out" + ], + "metadata": { + "id": "sH_jXPwcV3cn" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "mnist_input_features = 784\n", + "num_epochs = 10\n", + "batch_size = 16\n", + "learning_rate = 0.005\n", + "\n", + "model = CNN_withFA(mnist_input_features).to(device)\n", + "\n", + "for i, layer in enumerate(model.layers):\n", + " layer_name = layer.name\n", + " layer_weights = layer.get_weights()\n", + " print(f\"Weights of Layer {i} ({layer_name}): {layer_weights}\")\n", + "\n", + "# Set weights of the CNN using pretrained weights\n", + "# for i, layer_weights in enumerate(all_weights[:-2]): # only transfer for CNN\n", + "# if layer_weights:\n", + "# model.layers[i].set_weights(layer_weights)\n", + "\n", + "\n", + "\n", + "# maybe loose this part\n", + "# Now, freeze the CNN\n", + "# for layer in model.layers[:-1]:\n", + "# layer.trainable = False" + ], + "metadata": { + "id": "CJJ5Bo6Lr52R", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 471 + }, + "outputId": "8b15a4a8-702f-4bbd-b72a-546a42186017" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + ":69: UserWarning: nn.init.kaiming_uniform is now deprecated in favor of nn.init.kaiming_uniform_.\n", + " torch.nn.init.kaiming_uniform(self.weight)\n", + ":70: UserWarning: nn.init.kaiming_uniform is now deprecated in favor of nn.init.kaiming_uniform_.\n", + " torch.nn.init.kaiming_uniform(self.weight_fa)\n", + ":71: UserWarning: nn.init.constant is now deprecated in favor of nn.init.constant_.\n", + " torch.nn.init.constant(self.bias, 1)\n" + ] + }, + { + "output_type": "error", + "ename": "AttributeError", + "evalue": "ignored", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCNN_withFA\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmnist_input_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0mlayer_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mlayer_weights\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1693\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1694\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1695\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mAttributeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"'{type(self).__name__}' object has no attribute '{name}'\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1696\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1697\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__setattr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'Module'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: 'CNN_withFA' object has no attribute 'layers'" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# Loss and optimizer\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay = 0.005, momentum = 0.9)\n", + "\n", + "\n", + "# Train the model\n", + "total_step = len(train_loader)\n", + "\n", + "for epoch in range(num_epochs):\n", + " for i, (images, labels) in enumerate(train_loader):\n", + " # Move tensors to the configured device\n", + " images = images.to(device)\n", + " labels = labels.to(device)\n", + "\n", + " # Forward pass\n", + " outputs = model(images)\n", + " loss = criterion(outputs, labels)\n", + "\n", + " # Backward and optimize\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'\n", + " .format(epoch+1, num_epochs, i+1, total_step, loss.item()))\n", + "\n", + " # Validation\n", + " with torch.no_grad():\n", + " correct = 0\n", + " total = 0\n", + " for images, labels in valid_loader:\n", + " images = images.to(device)\n", + " labels = labels.to(device)\n", + " outputs = model(images)\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total += labels.size(0)\n", + " correct += (predicted == labels).sum().item()\n", + " del images, labels, outputs\n", + "\n", + " print('Accuracy of the network on the {} validation images: {} %'.format(5000, 100 * correct / total))" + ], + "metadata": { + "id": "guvyfv-0Yq5w" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### 2. A linear FA implementation" + ], + "metadata": { + "id": "YCgk0-FBA0n0" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Train the FA Part of the model" + ], + "metadata": { + "id": "BZUVAciYbc7K" + } + }, + { + "cell_type": "code", + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from torchvision import datasets, transforms\n", + "from torch.utils.data import DataLoader\n", + "from torch.autograd import Variable\n", + "import torch\n", + "import os\n", + "\n", + "BATCH_SIZE = 32\n", + "\n", + "train_loader = DataLoader(datasets.CIFAR10('./data', train=True, download=True,\n", + " transform=transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))\n", + " ])),\n", + " batch_size=BATCH_SIZE, shuffle=True)\n", + "test_loader = DataLoader(datasets.CIFAR10('./data', train=False, download=True,\n", + " transform=transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))\n", + " ])),\n", + " batch_size=BATCH_SIZE, shuffle=True)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KIijMZDj6NnF", + "outputId": "8a06aa2d-3c69-4417-d95e-55126fcdf339" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 170498071/170498071 [00:02<00:00, 72429883.88it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Extracting ./data/cifar-10-python.tar.gz to ./data\n", + "Files already downloaded and verified\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "torch.utils.data.dataloader.DataLoader" + ] + }, + "metadata": {}, + "execution_count": 3 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# Or, construct and train the FA model, and then import its weights\n", + "model_fa = LinearFANetwork(in_features=64*64, num_layers=2, num_hidden_list=[1000, 10]).to(device)\n", + "# after the CNN, the images were upscaled from 32 to 64\n", + "optimizer_fa = torch.optim.SGD(model_fa.parameters(),\n", + " lr=1e-4, momentum=0.9, weight_decay=0.001, nesterov=True)\n", + "# define loss function\n", + "loss_crossentropy = torch.nn.CrossEntropyLoss()\n", + "# TRAINING LOOP\n", + "epochs = 5\n", + "# Plotting vessel: --> maybe change to validation accuracy later\n", + "fa_loss = []\n", + "for epoch in range(epochs):\n", + " for idx_batch, (inputs, targets) in enumerate(train_loader):\n", + " # flatten the inputs from square image to 1d vector\n", + " inputs = inputs.view(BATCH_SIZE, -1)\n", + " # wrap them into varaibles\n", + " inputs, targets = Variable(inputs), Variable(targets)\n", + "\n", + " # get outputs from the model\n", + " outputs_fa = model_fa(inputs.to(device))\n", + " # calculate loss\n", + " loss_fa = loss_crossentropy(outputs_fa, targets.to(device))\n", + "\n", + " model_fa.zero_grad()\n", + " loss_fa.backward()\n", + " optimizer_fa.step()\n", + "\n", + " if (idx_batch + 1) % 500 == 0: # do this every 10 batch?\n", + " train_log = 'epoch ' + str(epoch) + ' step ' + str(idx_batch + 1) + \\\n", + " ' loss_fa ' + str(loss_fa.item())\n", + " fa_loss.append(loss_fa.item())\n", + " print(train_log)\n", + "\n", + "plt.plot(fa_loss)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 495 + }, + "id": "9Xd4wTDkpKDm", + "outputId": "db6b0274-5117-4e95-e063-ab1b1dcffc11" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + ":53: UserWarning: nn.init.kaiming_uniform is now deprecated in favor of nn.init.kaiming_uniform_.\n", + " torch.nn.init.kaiming_uniform(self.weight)\n", + ":54: UserWarning: nn.init.kaiming_uniform is now deprecated in favor of nn.init.kaiming_uniform_.\n", + " torch.nn.init.kaiming_uniform(self.weight_fa)\n", + ":55: UserWarning: nn.init.constant is now deprecated in favor of nn.init.constant_.\n", + " torch.nn.init.constant(self.bias, 1)\n" + ] + }, + { + "output_type": "error", + "ename": "RuntimeError", + "evalue": "ignored", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;31m# get outputs from the model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m \u001b[0moutputs_fa\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel_fa\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 21\u001b[0m \u001b[0;31m# calculate loss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0mloss_fa\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss_crossentropy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs_fa\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargets\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1517\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1518\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1519\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1520\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1525\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1526\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1528\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1529\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;31m# first layer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 102\u001b[0;31m \u001b[0mlinear1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 103\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;31m# second layer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1517\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1518\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1519\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1520\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1525\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1526\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1528\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1529\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mLinearFAFunction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight_fa\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py\u001b[0m in \u001b[0;36mapply\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m 537\u001b[0m \u001b[0;31m# See NOTE: [functorch vjp and autograd interaction]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 538\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_functorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munwrap_dead_wrappers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 539\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 540\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 541\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetup_context\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_SingleLevelFunction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetup_context\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(context, input, weight, weight_fa, bias)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight_fa\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave_for_backward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight_fa\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexpand_as\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (32x3072 and 4096x1000)" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Compare model performance" + ], + "metadata": { + "id": "tFCtLuY-Wxol" + } + }, + { + "cell_type": "code", + "source": [ + "#validate on CNN+FCN\n", + "#validate on CNN+FA" + ], + "metadata": { + "id": "jGIC_1_LWxOO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "41tHk6e8jvaz" + }, + "source": [ + "### Linear.py" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "IR5MykiuflIM" + }, + "outputs": [], + "source": [ + "from torch.autograd import Function\n", + "from torch import nn\n", + "# already imported: import torch\n", + "# already imported: import torch.nn.functional as F\n", + "\n", + "# Inherit from Function\n", + "class LinearFunction(Function):\n", + " # Note that both forward and backward are @staticmethods\n", + " @staticmethod\n", + " # bias is an optional argument\n", + " def forward(ctx, input, weight, bias=None):\n", + " ctx.save_for_backward(input, weight, bias)\n", + " output = input.mm(weight.t())\n", + " if bias is not None:\n", + " output += bias.unsqueeze(0).expand_as(output)\n", + " return output\n", + "\n", + " # This function has only a single output, so it gets only one gradient\n", + " @staticmethod\n", + " def backward(ctx, grad_output):\n", + " # This is a pattern that is very convenient - at the top of backward\n", + " # unpack saved_tensors and initialize all gradients w.r.t. inputs to\n", + " # None. Thanks to the fact that additional trailing Nones are\n", + " # ignored, the return statement is simple even when the function has\n", + " # optional inputs.\n", + " input, weight, bias = ctx.saved_variables\n", + " grad_input = grad_weight = grad_bias = None\n", + " # These needs_input_grad checks are optional and there only to\n", + " # improve efficiency. If you want to make your code simpler, you can\n", + " # skip them. Returning gradients for inputs that don't require it is\n", + " # not an error.\n", + "\n", + " if ctx.needs_input_grad[0]:\n", + " grad_input = grad_output.mm(weight)\n", + " if ctx.needs_input_grad[1]:\n", + " grad_weight = grad_output.t().mm(input)\n", + " if bias is not None and ctx.needs_input_grad[2]:\n", + " grad_bias = grad_output.sum(0).squeeze(0)\n", + "\n", + " return grad_input, grad_weight, grad_bias\n", + "\n", + "\n", + "class Linear(nn.Module):\n", + " def __init__(self, input_features, output_features, bias=True):\n", + " super(Linear, self).__init__()\n", + " self.input_features = input_features\n", + " self.output_features = output_features\n", + "\n", + " # nn.Parameter is a special kind of Variable, that will get\n", + " # automatically registered as Module's parameter once it's assigned\n", + " # as an attribute. Parameters and buffers need to be registered, or\n", + " # they won't appear in .parameters() (doesn't apply to buffers), and\n", + " # won't be converted when e.g. .cuda() is called. You can use\n", + " # .register_buffer() to register buffers.\n", + " # nn.Parameters can never be volatile and, different than Variables,\n", + " # they require gradients by default.\n", + " self.weight = nn.Parameter(torch.Tensor(output_features, input_features))\n", + " if bias:\n", + " self.bias = nn.Parameter(torch.Tensor(output_features))\n", + " else:\n", + " # You should always register all possible parameters, but the\n", + " # optional ones can be None if you want.\n", + " self.register_parameter('bias', None)\n", + "\n", + " # weight initialization\n", + " torch.nn.init.kaiming_uniform(self.weight)\n", + " torch.nn.init.constant(self.bias, 1)\n", + "\n", + " def forward(self, input):\n", + " # See the autograd section for explanation of what happens here.\n", + " return LinearFunction.apply(input, self.weight, self.bias)\n", + "\n", + "\n", + "class LinearNetwork(nn.Module):\n", + " def __init__(self, in_features, num_layers, num_hidden_list):\n", + " \"\"\"\n", + " :param in_features: dimension of input features (784 for MNIST)\n", + " :param num_layers: number of layers for feed-forward net\n", + " :param num_hidden_list: list of integers indicating hidden nodes of each layer\n", + " \"\"\"\n", + " super(LinearNetwork, self).__init__()\n", + " self.in_features = in_features\n", + " self.num_layers = num_layers\n", + " self.num_hidden_list = num_hidden_list\n", + "\n", + " # create list of linear layers\n", + " # first hidden layer\n", + " self.linear = [Linear(self.in_features, self.num_hidden_list[0])]\n", + " # append additional hidden layers to list\n", + " for idx in range(self.num_layers - 1):\n", + " self.linear.append(Linear(self.num_hidden_list[idx], self.num_hidden_list[idx+1]))\n", + "\n", + " # create ModuleList to make list of layers work\n", + " self.linear = nn.ModuleList(self.linear)\n", + "\n", + "\n", + " def forward(self, inputs):\n", + " \"\"\"\n", + " forward pass, which is same for conventional feed-forward net\n", + " :param inputs: inputs with shape [batch_size, in_features]\n", + " :return: logit outputs from the network\n", + " \"\"\"\n", + " # first layer\n", + " linear1 = F.relu(self.linear[0](inputs))\n", + "\n", + " linear2 = self.linear[1](linear1)\n", + "\n", + " return linear2" + ] + }, + { + "cell_type": "code", + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from torchvision import datasets, transforms\n", + "from torch.utils.data import DataLoader\n", + "from torch.autograd import Variable\n", + "import torch\n", + "import os\n", + "\n", + "BATCH_SIZE = 32\n", + "\n", + "# load mnist dataset\n", + "train_loader = DataLoader(datasets.MNIST('./data', train=True, download=True,\n", + " transform=transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))\n", + " ])),\n", + " batch_size=BATCH_SIZE, shuffle=True)\n", + "test_loader = DataLoader(datasets.MNIST('./data', train=False, download=True,\n", + " transform=transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))\n", + " ])),\n", + " batch_size=BATCH_SIZE, shuffle=True)\n", + "\n", + "# load feedforward dfa model\n", + "model_fa = LinearFANetwork(in_features=784, num_layers=2, num_hidden_list=[1000, 10]).to(device)\n", + "\n", + "# load reference linear model\n", + "model_bp = LinearNetwork(in_features=784, num_layers=2, num_hidden_list=[1000, 10]).to(device)\n", + "\n", + "# optimizers\n", + "optimizer_fa = torch.optim.SGD(model_fa.parameters(),\n", + " lr=1e-4, momentum=0.9, weight_decay=0.001, nesterov=True)\n", + "optimizer_bp = torch.optim.SGD(model_bp.parameters(),\n", + " lr=1e-4, momentum=0.9, weight_decay=0.001, nesterov=True)\n", + "\n", + "loss_crossentropy = torch.nn.CrossEntropyLoss()\n", + "\n", + "# make log file\n", + "results_path = 'bp_vs_fa_'\n", + "logger_train = open(results_path + 'train_log.txt', 'w')\n", + "\n", + "# train loop\n", + "epochs = 5\n", + "# Plotting vessel: --> maybe change to validation accuracy later\n", + "fa_loss = []\n", + "bp_loss = []\n", + "for epoch in range(epochs):\n", + " for idx_batch, (inputs, targets) in enumerate(train_loader):\n", + " # flatten the inputs from square image to 1d vector\n", + " inputs = inputs.view(BATCH_SIZE, -1)\n", + " # wrap them into varaibles\n", + " inputs, targets = Variable(inputs), Variable(targets)\n", + " # get outputs from the model\n", + " outputs_fa = model_fa(inputs.to(device))\n", + " outputs_bp = model_bp(inputs.to(device))\n", + " # calculate loss\n", + " loss_fa = loss_crossentropy(outputs_fa, targets.to(device))\n", + " loss_bp = loss_crossentropy(outputs_bp, targets.to(device))\n", + "\n", + " model_fa.zero_grad()\n", + " loss_fa.backward()\n", + " optimizer_fa.step()\n", + "\n", + " model_bp.zero_grad()\n", + " loss_bp.backward()\n", + " optimizer_bp.step()\n", + "\n", + " if (idx_batch + 1) % 500 == 0: # do this every 10 batch?\n", + " train_log = 'epoch ' + str(epoch) + ' step ' + str(idx_batch + 1) + \\\n", + " ' loss_fa ' + str(loss_fa.item()) + ' loss_bp ' + str(loss_bp.item())\n", + " fa_loss.append(loss_fa.item())\n", + " bp_loss.append(loss_bp.item())\n", + " print(train_log)\n", + " logger_train.write(train_log + '\\n')\n", + "\n", + "plt.plot(fa_loss)\n", + "plt.plot(bp_loss)" + ], + "metadata": { + "id": "GisHz4Vuuluz" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Drafts" + ], + "metadata": { + "id": "_yYEac57QMmy" + } + }, + { + "cell_type": "code", + "source": [ + "# Define your CNN model\n", + "model_fa = keras.Sequential([\n", + " layers.Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)),\n", + " layers.MaxPooling2D((2, 2)),\n", + "\n", + " layers.Flatten(),\n", + " layers.Dense(128, activation='relu'),\n", + " layers.Dense(10, activation='softmax') # Assuming 10 classes for example\n", + "])\n", + "\n", + "# Compile the model\n", + "model_fa.compile(optimizer='adam',\n", + " loss='sparse_categorical_crossentropy',\n", + " metrics=['accuracy'])\n", + "\n", + "# Display the model summary\n", + "model_fa.summary()\n", + "\n", + "# Define a new optimizer class for Feedback Alignment\n", + "class FeedbackAlignmentOptimizer(tf.keras.optimizers.Optimizer):\n", + " def __init__(self, learning_rate=0.001, name=\"FeedbackAlignmentOptimizer\", **kwargs):\n", + " super(FeedbackAlignmentOptimizer, self).__init__(name, **kwargs)\n", + " self.learning_rate = learning_rate\n", + "\n", + " def apply_gradients(self, grads_and_vars, name=None):\n", + " for grad, var in grads_and_vars:\n", + " random_matrix = tf.random.normal(var.shape)\n", + " var.assign_add(-self.learning_rate * grad * random_matrix)\n", + "\n", + "# Create an instance of the custom optimizer\n", + "fa_optimizer = FeedbackAlignmentOptimizer(learning_rate=0.001)\n", + "\n", + "# Compile the model with the custom optimizer\n", + "model_fa.compile(optimizer=fa_optimizer,\n", + " loss='sparse_categorical_crossentropy',\n", + " metrics=['accuracy'])\n", + "\n", + "# Now, freeze the convolutional layers\n", + "for layer in model_fa.layers[:-2]:\n", + " layer.trainable = False\n", + "\n", + "# Train the model with Feedback Alignment on the fully connected layers\n", + "model_fa.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))\n", + "\n", + "\n", + "#you need to build fc layers first, freeze the fc and train CNN, freeze CNN and train fc with new optimizer\n" + ], + "metadata": { + "id": "seuBsXspasM8" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Ja0DuDbMfD_X" + }, + "outputs": [], + "source": [ + "def create_random_dataset(n_in, n_out, len_samples):\n", + " '''Creates randomly a matrix n_out x n_in which will be\n", + " the target function to learn. Then generates\n", + " len_samples examples which will be the training set'''\n", + " M = np.random.randint(low=-10, high=10, size=(n_out, n_in))\n", + " samples = []\n", + " targets = []\n", + " for i in range(len_samples):\n", + " sample = np.random.randn(n_in)\n", + " samples.append(sample)\n", + " targets.append(np.dot(M, sample))\n", + "\n", + " return M, np.asarray(samples), np.asarray(targets)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Using captum" + ], + "metadata": { + "id": "_q1U4wdjnOeO" + } + }, + { + "cell_type": "code", + "source": [ + "# setup\n", + "!pip install torchvision==0.13.0\n", + "!pip install git+https://github.com/pytorch/captum.git\n", + "\n", + "import numpy as np\n", + "import json\n", + "from PIL import Image\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import torchvision.models as models\n", + "import captum\n", + "from captum.attr import visualization as viz\n", + "from torchvision import transforms\n", + "\n", + "# MAKE SURE WE ARE USING GPU\n", + "# This checks if GPU is available, and uses it only if so\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "print(\"Running on\", device)" + ], + "metadata": { + "id": "Wpgk-7opnMqM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!wget https://mylearningsinaiml.files.wordpress.com/2018/09/mnistdata1.jpg -O number.jpg" + ], + "metadata": { + "id": "QnzX--Leor8B" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# -- Your code here -- #\n", + "conv_withfa.to(device)\n", + "# set model to eval mode\n", + "conv_withfa.eval()\n", + "# did not visualize the model\n", + "\n", + "!gdown /content/mnist_classes.json -O mnist_classes.json\n", + "\n", + "\n", + "# Define the preprocessing pipeline for the input image\n", + "preprocess = transforms.Compose([\n", + " transforms.Resize((28, 28)), # resize the image to have a longer side of 256 pixels while maintaining the image ratio\n", + " #transforms.CenterCrop(224), # crop to size 224*224\n", + " transforms.ToTensor(), # convert to tensor not numpy\n", + " transforms.Normalize(mean=[0.5], std=[0.5]), # normalize pixel values, provide mean & std for RGB channels\n", + "])\n", + "\n", + "\n", + "PATH_TO_LABELS = 'mnist_classes.json'\n", + "with open(PATH_TO_LABELS, 'r') as f:\n", + " imagenet_classes = json.load(f)\n", + "\n", + "def decode_preds(outputs, class_names=mnist_classes):\n", + " # Assuming outputs is the tensor of model outputs\n", + " softmax_outputs = F.softmax(outputs, dim=1)\n", + " probability, predicted_class = torch.max(softmax_outputs, dim=1)\n", + "\n", + " predicted_class_labels = [class_names[str(idx)] for idx in predicted_class.cpu().numpy()]\n", + " probability_scores = probability.cpu().numpy()\n", + "\n", + " # Print or return the results\n", + " for label, score in zip(predicted_class_labels, probability_scores):\n", + " print(f'\\nClass: {label}, Probability: {score}')\n", + "\n", + "\n", + "# Function to load and preprocess the image\n", + "def load_and_preprocess_image(image_path):\n", + " img = Image.open(image_path)\n", + " img_tensor = preprocess(img)\n", + " img_tensor.unsqueeze_(0) # Add a batch dimension\n", + " return img_tensor\n", + "\n", + "# -- Your code here -- #\n", + "image = load_and_preprocess_image('number.jpg').to(device)\n", + "\n", + "with torch.no_grad(): # no gradient descent, perform eval\n", + " # Pass images through the model\n", + " outputs = conv_withfa(image)\n", + "\n", + "decode_preds(outputs)\n", + "\n", + "# --------------------- #" + ], + "metadata": { + "id": "BcSldVQIqST8" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def forward_func(i):\n", + " return conv_withfa(i)\n", + "salient_object = captum.attr.Saliency(forward_func)\n", + "attributions = salient_object.attribute(image, target = torch.tensor([0]))\n", + "# type(attributions)\n", + "# attributions is of type tensor/tuple. It means each feature's gradient.\n", + "# image - your preprocessed image\n", + "# attributions - the attributions obtained from captum\n", + "\n", + "# The preprocessing included normalizing the image which we need invert here\n", + "inv_normalize = transforms.Normalize(\n", + " mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],\n", + " std=[1/0.229, 1/0.224, 1/0.255]\n", + ")\n", + "\n", + "# Your input image to the model.\n", + "unnorm_image = inv_normalize(image)\n", + "\n", + "# Display the image and the attribution\n", + "_ = viz.visualize_image_attr_multiple(np.transpose(attributions.squeeze().cpu().detach().numpy(), (1,2,0)),\n", + " np.transpose(unnorm_image.squeeze().cpu().detach().numpy(), (1,2,0)),\n", + " [\"original_image\", \"heat_map\"],\n", + " [\"all\", \"absolute_value\"],\n", + " cmap=None,\n", + " show_colorbar=True, outlier_perc=2)" + ], + "metadata": { + "id": "wNKnFfbzqVUl" + }, + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [], + "collapsed_sections": [ + "gDpwu5vnSwWF", + "W9nOgx7r3M7y", + "BHQ-6ijb3Urj", + "41tHk6e8jvaz", + "_yYEac57QMmy" + ] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file