{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tensorizing Neural Networks\n", "\n", "This is an example of how one can tensorize layers of pre-trained neural network\n", "models, as described in [[NPOV15']](https://arxiv.org/abs/1509.06569)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%mkdir data\n", "%mkdir models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader\n", "\n", "import torchvision.transforms as transforms\n", "import torchvision.datasets as datasets\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "import tensorkrowch as tk" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "device(type='cuda', index=0)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = torch.device('cpu')\n", "\n", "if torch.cuda.is_available():\n", " device = torch.device('cuda:0')\n", "elif torch.backends.mps.is_available():\n", " device = torch.device('mps:0')\n", "else:\n", " device = torch.device('cpu')\n", "\n", "device" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# FashionMNIST Dataset\n", "dataset_name = 'fashion_mnist'\n", "batch_size = 64\n", "image_size = 28\n", "input_size = image_size ** 2\n", "num_classes = 10\n", "\n", "transform = transforms.Compose([transforms.ToTensor(),\n", " transforms.Resize(image_size, antialias=True),\n", " ])\n", "\n", "# Load data\n", "train_dataset = datasets.FashionMNIST(root='data/',\n", " train=True,\n", " transform=transform,\n", " download=True)\n", "test_dataset = datasets.FashionMNIST(root='data/',\n", " train=False,\n", " transform=transform,\n", " download=True)\n", "\n", "train_loader = DataLoader(dataset=train_dataset,\n", " batch_size=batch_size,\n", " shuffle=True)\n", "test_loader = DataLoader(dataset=test_dataset,\n", " batch_size=batch_size,\n", " shuffle=True)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "3\n" ] } ], "source": [ "random_sample = torch.randint(low=0, high=len(train_dataset), size=(1,)).item()\n", "\n", "plt.imshow(train_dataset[random_sample][0].squeeze(0), cmap='Greys')\n", "plt.show()\n", "\n", "print(train_dataset[random_sample][1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train Neural Network" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class FFFC(nn.Module): # Feed forward fully connected\n", " \n", " def __init__(self, input_size, num_classes):\n", " super().__init__() # super(NN, self).__init__()\n", " self.fc1 = nn.Linear(input_size, 50)\n", " self.fc2 = nn.Linear(50, num_classes)\n", " self.relu = nn.ReLU()\n", " \n", " def forward(self, x):\n", " x = self.relu(self.fc1(x))\n", " x = self.fc2(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Initialize network\n", "model_name = 'fffc'\n", "model = FFFC(input_size=input_size, num_classes=num_classes)\n", "model = model.to(device)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Hyperparameters\n", "learning_rate = 1e-3\n", "weight_decay = 1e-5\n", "num_epochs = 10\n", "\n", "# Loss and optimizer\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(model.parameters(),\n", " lr=learning_rate,\n", " weight_decay=weight_decay)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Check accuracy on training & test to see how good our model is\n", "def check_accuracy(loader, model):\n", " num_correct = 0\n", " num_samples = 0\n", " model.eval()\n", " \n", " with torch.no_grad():\n", " for x, y in loader:\n", " x = x.to(device)\n", " y = y.to(device)\n", " x = x.reshape(x.shape[0], -1)\n", " \n", " scores = model(x)\n", " _, predictions = scores.max(1)\n", " num_correct += (predictions == y).sum()\n", " num_samples += predictions.size(0)\n", " \n", " accuracy = float(num_correct) / float(num_samples) * 100\n", " model.train()\n", " return accuracy" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "* Epoch 1 => Train. Acc.: 84.52, Test Acc.: 83.08\n", "* Epoch 2 => Train. Acc.: 86.19, Test Acc.: 84.96\n", "* Epoch 3 => Train. Acc.: 86.57, Test Acc.: 85.15\n", "* Epoch 4 => Train. Acc.: 87.06, Test Acc.: 85.08\n", "* Epoch 5 => Train. Acc.: 87.96, Test Acc.: 86.15\n", "* Epoch 6 => Train. Acc.: 87.83, Test Acc.: 86.07\n", "* Epoch 7 => Train. Acc.: 88.56, Test Acc.: 86.62\n", "* Epoch 8 => Train. Acc.: 89.33, Test Acc.: 86.69\n", "* Epoch 9 => Train. Acc.: 88.94, Test Acc.: 86.45\n", "* Epoch 10 => Train. Acc.: 89.30, Test Acc.: 87.03\n" ] } ], "source": [ "# Train network\n", "for epoch in range(num_epochs):\n", " for batch_idx, (data, targets) in enumerate(train_loader):\n", " # Get data to cuda if possible\n", " data = data.to(device)\n", " targets = targets.to(device)\n", " \n", " # Get to correct shape\n", " data = data.reshape(data.shape[0], -1)\n", " \n", " # Forward\n", " scores = model(data)\n", " loss = criterion(scores, targets)\n", " \n", " # Backward\n", " optimizer.zero_grad()\n", " loss.backward()\n", " \n", " # Gradient descent\n", " optimizer.step()\n", " \n", " train_acc = check_accuracy(train_loader, model)\n", " test_acc = check_accuracy(test_loader, model)\n", " \n", " print(f'* Epoch {epoch + 1:<3} => Train. Acc.: {train_acc:.2f},'\n", " f' Test Acc.: {test_acc:.2f}')\n", "\n", "torch.save(model.state_dict(), f'models/{model_name}_{dataset_name}.pt')" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def n_params(model):\n", " n = 0\n", " for p in model.parameters():\n", " n += p.numel()\n", " return n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(87.03, 39760)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n = n_params(model)\n", "test_acc = check_accuracy(test_loader, model)\n", "test_acc, n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define tensorized layer" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Load network\n", "model = FFFC(input_size=input_size, num_classes=num_classes)\n", "model.load_state_dict(torch.load(f'models/{model_name}_{dataset_name}.pt'))\n", "model = model.to(device)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "class TN_Linear(tk.models.MPO):\n", " \n", " def __init__(self, model, cum_percentage):\n", " \n", " # Get weight matrix from model and reshape it\n", " weight = model.fc1.weight.detach()\n", " weight = weight.reshape(1, 1, 5, 5, 2, 1,\n", " 2, 2, 7, 7, 2, 2).permute(6, 0, 7, 1, 8, 2,\n", " 9, 3, 10, 4, 11, 5)\n", " self.weight = weight\n", " \n", " mpo_tensors = tk.decompositions.mat_to_mpo(weight,\n", " cum_percentage=cum_percentage,\n", " renormalize=True)\n", " super().__init__(tensors=mpo_tensors)\n", " \n", " # Save bias as parameter of tn layer\n", " self.bias = nn.Parameter(model.fc1.bias.detach())\n", " \n", " def set_data_nodes(self):\n", " self.mps_data = tk.models.MPSData(n_features=6,\n", " phys_dim=[2, 2, 7, 7, 2, 2],\n", " bond_dim=10,\n", " boundary='obc')\n", " \n", " def add_data(self, data):\n", " mps_tensors = tk.decompositions.vec_to_mps(data.reshape(-1, 2, 2, 7, 7, 2, 2),\n", " n_batches=1,\n", " cum_percentage=0.95,\n", " renormalize=True)\n", " self.mps_data.add_data(mps_tensors)\n", " \n", " def contract(self):\n", " return super().contract(inline_input=True,\n", " inline_mats=True,\n", " mps=self.mps_data)\n", " \n", " def forward(self, x, *args, **kwargs):\n", " result = super().forward(x, *args, **kwargs)\n", " result = result.reshape(-1, 50)\n", " result += self.bias\n", " return result" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "class TN_NN(nn.Module):\n", " def __init__(self, model, cum_percentage):\n", " super().__init__() # super(NN, self).__init__()\n", " self.tn1 = TN_Linear(model, cum_percentage)\n", " self.fc2 = model.fc2\n", " self.relu = nn.ReLU()\n", " \n", " def forward(self, x):\n", " x = self.relu(self.tn1(x))\n", " x = self.fc2(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "model_name = 'tn_fffc'\n", "tn_model = TN_NN(model, 0.85)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(86.61999999999999, 27944)" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n = n_params(tn_model)\n", "test_acc = check_accuracy(test_loader, tn_model)\n", "test_acc, n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# Trace the model to accelerate training\n", "tn_model.tn1.trace(torch.zeros(1, input_size, device=device))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Hyperparameters\n", "learning_rate = 1e-4\n", "weight_decay = 1e-5\n", "num_epochs = 5\n", "\n", "# Loss and optimizer\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(tn_model.parameters(),\n", " lr=learning_rate,\n", " weight_decay=weight_decay)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "* Epoch 1 => Train. Acc.: 90.04, Test Acc.: 87.95\n", "* Epoch 2 => Train. Acc.: 90.03, Test Acc.: 87.73\n", "* Epoch 3 => Train. Acc.: 90.35, Test Acc.: 87.70\n", "* Epoch 4 => Train. Acc.: 90.64, Test Acc.: 87.90\n", "* Epoch 5 => Train. Acc.: 90.73, Test Acc.: 88.03\n" ] } ], "source": [ "# Train network\n", "for epoch in range(num_epochs):\n", " for batch_idx, (data, targets) in enumerate(train_loader):\n", " # Get data to cuda if possible\n", " data = data.to(device)\n", " targets = targets.to(device)\n", " \n", " # Get to correct shape\n", " data = data.reshape(data.shape[0], -1)\n", " \n", " # Forward\n", " scores = tn_model(data)\n", " loss = criterion(scores, targets)\n", " \n", " # Backward\n", " optimizer.zero_grad()\n", " loss.backward()\n", " \n", " # Gradient descent\n", " optimizer.step()\n", " \n", " train_acc = check_accuracy(train_loader, tn_model)\n", " test_acc = check_accuracy(test_loader, tn_model)\n", " \n", " print(f'* Epoch {epoch + 1:<3} => Train. Acc.: {train_acc:.2f},'\n", " f' Test Acc.: {test_acc:.2f}')\n", "\n", "# Reset before saving the model\n", "tn_model.tn1.reset()\n", "torch.save(tn_model.state_dict(), f'models/{model_name}_{dataset_name}.pt')" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "81.57 16406\n" ] } ], "source": [ "tn_model.tn1.canonicalize(cum_percentage=0.8)\n", "\n", "test_acc = check_accuracy(test_loader, tn_model)\n", "print(test_acc, n_params(tn_model))" ] } ], "metadata": { "kernelspec": { "display_name": "test_tk", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }