{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tensorizing Neural Networks\n", "\n", "This is an example of how one can tensorize layers of pre-trained neural network\n", "models, as described in [[NPOV15']](https://arxiv.org/abs/1509.06569)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%mkdir data\n", "%mkdir models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader\n", "\n", "import torchvision.transforms as transforms\n", "import torchvision.datasets as datasets\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "import tensorkrowch as tk" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "device(type='cuda', index=0)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = torch.device('cpu')\n", "\n", "if torch.cuda.is_available():\n", " device = torch.device('cuda:0')\n", "elif torch.backends.mps.is_available():\n", " device = torch.device('mps:0')\n", "else:\n", " device = torch.device('cpu')\n", "\n", "device" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# FashionMNIST Dataset\n", "dataset_name = 'fashion_mnist'\n", "batch_size = 64\n", "image_size = 28\n", "input_size = image_size ** 2\n", "num_classes = 10\n", "\n", "transform = transforms.Compose([transforms.ToTensor(),\n", " transforms.Resize(image_size, antialias=True),\n", " ])\n", "\n", "# Load data\n", "train_dataset = datasets.FashionMNIST(root='data/',\n", " train=True,\n", " transform=transform,\n", " download=True)\n", "test_dataset = datasets.FashionMNIST(root='data/',\n", " train=False,\n", " transform=transform,\n", " download=True)\n", "\n", "train_loader = DataLoader(dataset=train_dataset,\n", " batch_size=batch_size,\n", " shuffle=True)\n", "test_loader = DataLoader(dataset=test_dataset,\n", " batch_size=batch_size,\n", " shuffle=True)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAgf0lEQVR4nO3df2xV9f3H8VcL7aWF9pZS+ksKFkSYAp0idERFHQ3QLUaUGH9lAeMguuKGzGlqVHSa9KsmjmgY/jNhGvEHmUA0jgyqlOBaDAhjzK2jtfLD0qJ1vbcU+4P2fP8gdqsU4XO8977b8nwkN6H33lfPh9PDffVw733fOM/zPAEAEGPx1gsAAFyYKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYGGq9gG/r7u5WfX29UlJSFBcXZ70cAIAjz/PU0tKi3Nxcxcef/Tyn3xVQfX298vLyrJcBAPiejhw5ojFjxpz19n5XQCkpKZJOLzw1NdV4NRcGv9OY+vMZ6oYNG3zlqqqqnDOFhYXOmc7OTufMyZMnnTOBQMA5I0mXXXaZc+aqq67ytS0MPuFwWHl5eT2P52cTtQJavXq1nnvuOTU0NKigoEAvvviiZs6cec7cNw9qqampFFCMDMYCSk5O9pVLTEyMybY6OjqcM35+Tn4LaMSIEc4Z/r3i2871GBGVFyG8+eabWrFihVauXKmPP/5YBQUFmjdvno4fPx6NzQEABqCoFNDzzz+vJUuW6O6779Zll12ml156ScnJyXr55ZejsTkAwAAU8QLq6OjQnj17VFRU9N+NxMerqKhIlZWVZ9y/vb1d4XC41wUAMPhFvIC+/PJLdXV1KSsrq9f1WVlZamhoOOP+ZWVlCgaDPRdeAQcAFwbzN6KWlpYqFAr1XI4cOWK9JABADET8VXAZGRkaMmSIGhsbe13f2Nio7OzsM+4fCAR8v1IHADBwRfwMKDExUdOnT1d5eXnPdd3d3SovL9esWbMivTkAwAAVlfcBrVixQosWLdJVV12lmTNnatWqVWptbdXdd98djc0BAAagqBTQbbfdpi+++EKPP/64Ghoa9MMf/lBbtmw544UJAIALV5zn923wURIOhxUMBhUKhXhndT/X2trqnLn11ludM3/+85+dM36NHDnSOfOf//wnCiuJjMzMTF85P28aP9fYlb4sXrzYOfPCCy84Z/zy8/DYnyeExMr5Po6bvwoOAHBhooAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYCIq07ARGd3d3c6Z+PjY/U4xYsQI54yfiehXXHGFcyYtLc05I0nt7e3Omfr6eudMcnKyc6alpcU5k5OT45yRpIkTJzpn/Oy7tWvXOmf8DCl++umnnTMSg0WjjTMgAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAICJOM/zPOtF/K9wOKxgMKhQKORr6u1gEqtp2IcOHXLOSNJVV13lnMnLy3PO+NkP4XDYOSNJdXV1zhk/k7eHDRvmnGloaHDOJCYmOmck6fLLL3fO+JnwHQqFnDN+frZ+j3E//DykDrap2+f7OM4ZEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABNDrRcAe1VVVb5ynZ2dzpmvv/7aOeNncOfw4cOdM5I0ceJE54yf4Zitra3OmZEjRzpn/AyMlaSWlhbnzFdffeWcGTrU/SHo8OHDzhk/A20lf8N9cf7YuwAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAEwwjLQfi9UgxPLycl+5IUOGOGf8DKwcNWqUc8bP4E7J3+DT+vp650xmZqZzZsyYMc4Zz/OcM5KUmJjonPEzlNXP0Nj8/HznTG1trXNG8jecFuePMyAAgAkKCABgIuIF9MQTTyguLq7XZfLkyZHeDABggIvKc0CXX365tm3b9t+N+PjQKQDA4BaVZhg6dKiys7Oj8a0BAINEVJ4DOnjwoHJzczV+/Hjddddd3/kRuu3t7QqHw70uAIDBL+IFVFhYqHXr1mnLli1as2aN6urqdO211571M+bLysoUDAZ7Lnl5eZFeEgCgH4p4ARUXF+vWW2/VtGnTNG/ePL333ntqbm7WW2+91ef9S0tLFQqFei5HjhyJ9JIAAP1Q1F8dkJaWpksvvVQ1NTV93h4IBBQIBKK9DABAPxP19wGdOHFCtbW1ysnJifamAAADSMQL6MEHH1RFRYU+++wz/fWvf9XNN9+sIUOG6I477oj0pgAAA1jE/wvu6NGjuuOOO9TU1KTRo0frmmuuUVVVlUaPHh3pTQEABrCIF9Abb7wR6W+JKKuqqvKV8zOw8vjx486Zjo4O50xcXJxzRvL3d+rs7IxJ5ujRo86ZtrY254zkbyhrKBRyziQkJDhnurq6nDMVFRXOGYlhpNHGLDgAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmov6BdOj//vGPf/jKZWVlOWcyMzOdMyNHjnTOfP75584Zyd/wyVOnTsUkEx/v/vui36Gshw8fds5MmjTJOfPVV185Z/zsh8rKSueMJP385z93zvjd5xcizoAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACaYhj3InDx50jkzduxYX9tqaWlxzvz0pz91zuTk5DhnnnnmGeeMJLW2tsYkk5SU5JzxM2V56FB//8RDoZBzprS01Dnz1FNPOWf8rO1vf/ubcwbRxxkQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAEwwjHWQ+/fRT54yfAaaS1NTU5Jz55S9/6ZzZtm2bc8avEydOOGcSEhKcM6dOnXLOdHZ2OmcSExOdM34VFRU5Z/70pz85Z3bs2OGcqampcc4g+jgDAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIJhpIPMU0895ZxJTk6Owkr6Nn36dOfMX/7ylyispG9ZWVnOGT9DWZOSkpwz3d3dzpnhw4c7Z/wKBoPOmby8POeMn6Gsfo/xV1991Tnzs5/9zNe2LkScAQEATFBAAAATzgW0Y8cO3XjjjcrNzVVcXJw2bdrU63bP8/T4448rJydHSUlJKioq0sGDByO1XgDAIOFcQK2trSooKNDq1av7vP3ZZ5/VCy+8oJdeekm7du3S8OHDNW/ePLW1tX3vxQIABg/nFyEUFxeruLi4z9s8z9OqVav06KOP6qabbpIkvfLKK8rKytKmTZt0++23f7/VAgAGjYg+B1RXV6eGhoZeH80bDAZVWFioysrKPjPt7e0Kh8O9LgCAwS+iBdTQ0CDpzJeyZmVl9dz2bWVlZQoGgz0XPy/LBAAMPOavgistLVUoFOq5HDlyxHpJAIAYiGgBZWdnS5IaGxt7Xd/Y2Nhz27cFAgGlpqb2ugAABr+IFlB+fr6ys7NVXl7ec104HNauXbs0a9asSG4KADDAOb8K7sSJE6qpqen5uq6uTvv27VN6errGjh2r5cuX6+mnn9bEiROVn5+vxx57TLm5uVqwYEEk1w0AGOCcC2j37t264YYber5esWKFJGnRokVat26dHnroIbW2tmrp0qVqbm7WNddcoy1btmjYsGGRWzUAYMCL8zzPs17E/wqHwwoGgwqFQhf880HLly93zpztDcLf5aKLLnLOSNKhQ4ecM34Ot+uuu845s3PnTueMJF1yySXOmfb2dufM0KHuc4ATExOdM37/De3atcs5s3v3bufM559/7py55557nDMjR450zkj+Bp9+8sknzhk/w2n7s/N9HDd/FRwA4MJEAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADDhPpIXMfP00087ZyZMmOCcWbVqlXNGkuLi4pwz3/603PPhZ7J1Xl6ec0aShgwZ4pzxM9m6u7vbOdPV1eWcOXnypHNGkhISEpwzr7zyinNm9uzZzpkvv/zSOZOSkuKckaS0tDTnTEtLi3NmsE3DPl+cAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADAR53meZ72I/xUOhxUMBhUKhZSammq9HETYp59+6pzxM2C1oKDAOSNJHR0dzplTp045Zzo7O50zsRxY6Wc/NDQ0OGcOHTrknDl8+LBzJjc31zkjSVlZWb5yF7rzfRznDAgAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAICJodYLwNnFak5sXFxcTLYjSS+//HJMtjN0qL9Du7m52TmTmJjoa1uuurq6YrIdScrIyHDO1NbWOmf8DKedMWOGcwb9E2dAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATDCMtB+L5ZDQWNmwYYNzJicnxznT3t7unJH8DRb183OK1QDT+Hh/v2P6GXzqZz888sgjzpmtW7c6Z/zyMxB4MP67jRbOgAAAJiggAIAJ5wLasWOHbrzxRuXm5iouLk6bNm3qdfvixYsVFxfX6zJ//vxIrRcAMEg4F1Bra6sKCgq0evXqs95n/vz5OnbsWM/l9ddf/16LBAAMPs4vQiguLlZxcfF33icQCCg7O9v3ogAAg19UngPavn27MjMzNWnSJN13331qamo6633b29sVDod7XQAAg1/EC2j+/Pl65ZVXVF5ermeeeUYVFRUqLi4+68s6y8rKFAwGey55eXmRXhIAoB+K+PuAbr/99p4/T506VdOmTdOECRO0fft2zZkz54z7l5aWasWKFT1fh8NhSggALgBRfxn2+PHjlZGRoZqamj5vDwQCSk1N7XUBAAx+US+go0ePqqmpyde72QEAg5fzf8GdOHGi19lMXV2d9u3bp/T0dKWnp+vJJ5/UwoULlZ2drdraWj300EO65JJLNG/evIguHAAwsDkX0O7du3XDDTf0fP3N8zeLFi3SmjVrtH//fv3xj39Uc3OzcnNzNXfuXD311FMKBAKRWzUAYMCL8/xM24uicDisYDCoUCjE80E+xHJ4op9t+RmOmZ+f75wZOtTf62tOnTrlnElISHDO+Bn2OWTIEOeMX35+Th0dHc6ZTz/91DkTy4cshpH6c76P48yCAwCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYiPhHcsNWLKf37t271zkzatQo54yfj/Lo7u52zvjdlp/J1n5+Tn62k5iY6JyR/E3e9jNB288HVX722WfOmYsvvtg5IzENO9o4AwIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCYaSDjJ+BkH69+uqrzpmhQ90POT8DQk+cOOGckfwN/ExISPC1LVd+BoT6+ftIUkdHh3MmNTU1Jtv58MMPnTOxHEaK88cZEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMMI4Vv7733nnPGzzDSzs5O54yfwZ2Sv+GYfrflys9gUb9r8zPU1s/gTj/Hw0cffeScueuuu5wzUux+thcqzoAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYYBhpP+ZnuGNcXFwUVtK3f//7386ZSy+91DnT3d3tnPE7RDIxMdFXzlWsBov62Xd+t+Un42cY6SeffOKcQf/EGRAAwAQFBAAw4VRAZWVlmjFjhlJSUpSZmakFCxaourq6133a2tpUUlKiUaNGacSIEVq4cKEaGxsjumgAwMDnVEAVFRUqKSlRVVWVtm7dqs7OTs2dO1etra0993nggQf0zjvvaMOGDaqoqFB9fb1uueWWiC8cADCwOT0DuGXLll5fr1u3TpmZmdqzZ49mz56tUCikP/zhD1q/fr1+/OMfS5LWrl2rH/zgB6qqqtKPfvSjyK0cADCgfa/ngEKhkCQpPT1dkrRnzx51dnaqqKio5z6TJ0/W2LFjVVlZ2ef3aG9vVzgc7nUBAAx+vguou7tby5cv19VXX60pU6ZIkhoaGpSYmKi0tLRe983KylJDQ0Of36esrEzBYLDnkpeX53dJAIABxHcBlZSU6MCBA3rjjTe+1wJKS0sVCoV6LkeOHPle3w8AMDD4eiPqsmXL9O6772rHjh0aM2ZMz/XZ2dnq6OhQc3Nzr7OgxsZGZWdn9/m9AoGAAoGAn2UAAAYwpzMgz/O0bNkybdy4Ue+//77y8/N73T59+nQlJCSovLy857rq6modPnxYs2bNisyKAQCDgtMZUElJidavX6/NmzcrJSWl53mdYDCopKQkBYNB3XPPPVqxYoXS09OVmpqq+++/X7NmzeIVcACAXpwKaM2aNZKk66+/vtf1a9eu1eLFiyVJv/vd7xQfH6+FCxeqvb1d8+bN0+9///uILBYAMHjEeX4mXkZROBxWMBhUKBRSamqq9XJMxWoYaVtbm3NGkpKSkpwzl112mXOmvb3dOeN3GKmfbcVqgKmfv1NHR0cUVtK3lJQU58xXX33lnPnfN76fr6amJucM/Dvfx3FmwQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATPj6RFTERqymYX/xxRfOGUkaOXKkc8bPROdYZaTYTbbu7u52znR1dUVhJX1LTk6OyXYSEhKcM34maMdSrP7dDgacAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADDBMFJo586dvnJ+BmqmpaU5Z/wMS/UzEFLy93eKj3f/Pc5Pxo+hQ/39Ez916pRzxs++GzZsmHPGj46ODl+5WA2nvVBxBgQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEw0j7sVgNrPzoo49ish1JOnnypHPGz37o7Ox0zkj+Bmr6EaufbVdXV8xyfgbABoNB54yfgbZNTU3OGUnKycnxlcP54QwIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACYaRQkePHvWVS0pKcs74GXKZkpLinGlra3POxFJHR4dzZtiwYc4Zv0NZExISnDN+9rmfoaxDh7o/bB07dsw5IzGMNNo4AwIAmKCAAAAmnAqorKxMM2bMUEpKijIzM7VgwQJVV1f3us/111+vuLi4Xpd77703oosGAAx8TgVUUVGhkpISVVVVaevWrers7NTcuXPV2tra635LlizRsWPHei7PPvtsRBcNABj4nJ7N27JlS6+v161bp8zMTO3Zs0ezZ8/uuT45OVnZ2dmRWSEAYFD6Xs8BhUIhSVJ6enqv61977TVlZGRoypQpKi0t/c6PYW5vb1c4HO51AQAMfr5fht3d3a3ly5fr6quv1pQpU3quv/POOzVu3Djl5uZq//79evjhh1VdXa233367z+9TVlamJ5980u8yAAADlO8CKikp0YEDB7Rz585e1y9durTnz1OnTlVOTo7mzJmj2tpaTZgw4YzvU1paqhUrVvR8HQ6HlZeX53dZAIABwlcBLVu2TO+++6527NihMWPGfOd9CwsLJUk1NTV9FlAgEFAgEPCzDADAAOZUQJ7n6f7779fGjRu1fft25efnnzOzb98+SbyjGADQm1MBlZSUaP369dq8ebNSUlLU0NAgSQoGg0pKSlJtba3Wr1+vn/zkJxo1apT279+vBx54QLNnz9a0adOi8hcAAAxMTgW0Zs0aSaffbPq/1q5dq8WLFysxMVHbtm3TqlWr1Nraqry8PC1cuFCPPvpoxBYMABgcnP8L7rvk5eWpoqLiey0IAHBhYBo2VFlZ6SuXnJzsnPn73//unPEzQdvPxGTp9NsLXMXFxTlnYjXZ2s/U7Vi64oornDPNzc3Omf379ztnJOnKK690zpzrF/W++DmGBgOGkQIATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADDBMFLo6NGjvnKxGtzpZ/hka2urc0aSPvvsM+fMoUOHnDP19fXOmdzcXOfMpEmTnDOSNHr0aOdMdna2c8bPpyHX1dU5Z871yc2RFB/P7/Xniz0FADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABP9bhac53mSpHA4bLwSnEusZsH5ORZOnjzpnJH8zZDzs622traYbOfEiRPOGUkaNmyYcyY5Odk542cWXEtLi3PG7+NJQkKCr9yF7pv9/c3j+dn0uwL65uDKy8szXgkA4PtoaWlRMBg86+1x3rkqKsa6u7tVX1+vlJSUM35bDofDysvL05EjR5Sammq0Qnvsh9PYD6exH05jP5zWH/aD53lqaWlRbm7ud04H73dnQPHx8eccnZ6amnpBH2DfYD+cxn44jf1wGvvhNOv98F1nPt/gRQgAABMUEADAxIAqoEAgoJUrV/p65cxgwn44jf1wGvvhNPbDaQNpP/S7FyEAAC4MA+oMCAAweFBAAAATFBAAwAQFBAAwMWAKaPXq1br44os1bNgwFRYW6qOPPrJeUsw98cQTiouL63WZPHmy9bKibseOHbrxxhuVm5uruLg4bdq0qdftnufp8ccfV05OjpKSklRUVKSDBw/aLDaKzrUfFi9efMbxMX/+fJvFRklZWZlmzJihlJQUZWZmasGCBaquru51n7a2NpWUlGjUqFEaMWKEFi5cqMbGRqMVR8f57Ifrr7/+jOPh3nvvNVpx3wZEAb355ptasWKFVq5cqY8//lgFBQWaN2+ejh8/br20mLv88st17NixnsvOnTutlxR1ra2tKigo0OrVq/u8/dlnn9ULL7ygl156Sbt27dLw4cM1b948XwM/+7Nz7QdJmj9/fq/j4/XXX4/hCqOvoqJCJSUlqqqq0tatW9XZ2am5c+f2GiL7wAMP6J133tGGDRtUUVGh+vp63XLLLYarjrzz2Q+StGTJkl7Hw7PPPmu04rPwBoCZM2d6JSUlPV93dXV5ubm5XllZmeGqYm/lypVeQUGB9TJMSfI2btzY83V3d7eXnZ3tPffccz3XNTc3e4FAwHv99dcNVhgb394Pnud5ixYt8m666SaT9Vg5fvy4J8mrqKjwPO/0zz4hIcHbsGFDz33++c9/epK8yspKq2VG3bf3g+d53nXXXef96le/slvUeej3Z0AdHR3as2ePioqKeq6Lj49XUVGRKisrDVdm4+DBg8rNzdX48eN111136fDhw9ZLMlVXV6eGhoZex0cwGFRhYeEFeXxs375dmZmZmjRpku677z41NTVZLymqQqGQJCk9PV2StGfPHnV2dvY6HiZPnqyxY8cO6uPh2/vhG6+99poyMjI0ZcoUlZaW+v6Ykmjpd8NIv+3LL79UV1eXsrKyel2flZWlf/3rX0arslFYWKh169Zp0qRJOnbsmJ588klde+21OnDggFJSUqyXZ6KhoUGS+jw+vrntQjF//nzdcsstys/PV21trR555BEVFxersrJSQ4YMsV5exHV3d2v58uW6+uqrNWXKFEmnj4fExESlpaX1uu9gPh762g+SdOedd2rcuHHKzc3V/v379fDDD6u6ulpvv/224Wp76/cFhP8qLi7u+fO0adNUWFiocePG6a233tI999xjuDL0B7fffnvPn6dOnapp06ZpwoQJ2r59u+bMmWO4sugoKSnRgQMHLojnQb/L2fbD0qVLe/48depU5eTkaM6cOaqtrdWECRNivcw+9fv/gsvIyNCQIUPOeBVLY2OjsrOzjVbVP6SlpenSSy9VTU2N9VLMfHMMcHycafz48crIyBiUx8eyZcv07rvv6oMPPuj18S3Z2dnq6OhQc3Nzr/sP1uPhbPuhL4WFhZLUr46Hfl9AiYmJmj59usrLy3uu6+7uVnl5uWbNmmW4MnsnTpxQbW2tcnJyrJdiJj8/X9nZ2b2Oj3A4rF27dl3wx8fRo0fV1NQ0qI4Pz/O0bNkybdy4Ue+//77y8/N73T59+nQlJCT0Oh6qq6t1+PDhQXU8nGs/9GXfvn2S1L+OB+tXQZyPN954wwsEAt66deu8Tz75xFu6dKmXlpbmNTQ0WC8tpn79619727dv9+rq6rwPP/zQKyoq8jIyMrzjx49bLy2qWlpavL1793p79+71JHnPP/+8t3fvXu/QoUOe53ne//3f/3lpaWne5s2bvf3793s33XSTl5+f73399dfGK4+s79oPLS0t3oMPPuhVVlZ6dXV13rZt27wrr7zSmzhxotfW1ma99Ii57777vGAw6G3fvt07duxYz+XkyZM997n33nu9sWPHeu+//763e/dub9asWd6sWbMMVx1559oPNTU13m9/+1tv9+7dXl1dnbd582Zv/Pjx3uzZs41X3tuAKCDP87wXX3zRGzt2rJeYmOjNnDnTq6qqsl5SzN12221eTk6Ol5iY6F100UXebbfd5tXU1FgvK+o++OADT9IZl0WLFnmed/ql2I899piXlZXlBQIBb86cOV51dbXtoqPgu/bDyZMnvblz53qjR4/2EhISvHHjxnlLliwZdL+k9fX3l+StXbu25z5ff/2194tf/MIbOXKkl5yc7N18883esWPH7BYdBefaD4cPH/Zmz57tpaene4FAwLvkkku83/zmN14oFLJd+LfwcQwAABP9/jkgAMDgRAEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwMT/AxBKjiwvTR3WAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "3\n" ] } ], "source": [ "random_sample = torch.randint(low=0, high=len(train_dataset), size=(1,)).item()\n", "\n", "plt.imshow(train_dataset[random_sample][0].squeeze(0), cmap='Greys')\n", "plt.show()\n", "\n", "print(train_dataset[random_sample][1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train Neural Network" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class FFFC(nn.Module): # Feed forward fully connected\n", " \n", " def __init__(self, input_size, num_classes):\n", " super().__init__() # super(NN, self).__init__()\n", " self.fc1 = nn.Linear(input_size, 50)\n", " self.fc2 = nn.Linear(50, num_classes)\n", " self.relu = nn.ReLU()\n", " \n", " def forward(self, x):\n", " x = self.relu(self.fc1(x))\n", " x = self.fc2(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Initialize network\n", "model_name = 'fffc'\n", "model = FFFC(input_size=input_size, num_classes=num_classes)\n", "model = model.to(device)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Hyperparameters\n", "learning_rate = 1e-3\n", "weight_decay = 1e-5\n", "num_epochs = 10\n", "\n", "# Loss and optimizer\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(model.parameters(),\n", " lr=learning_rate,\n", " weight_decay=weight_decay)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Check accuracy on training & test to see how good our model is\n", "def check_accuracy(loader, model):\n", " num_correct = 0\n", " num_samples = 0\n", " model.eval()\n", " \n", " with torch.no_grad():\n", " for x, y in loader:\n", " x = x.to(device)\n", " y = y.to(device)\n", " x = x.reshape(x.shape[0], -1)\n", " \n", " scores = model(x)\n", " _, predictions = scores.max(1)\n", " num_correct += (predictions == y).sum()\n", " num_samples += predictions.size(0)\n", " \n", " accuracy = float(num_correct) / float(num_samples) * 100\n", " model.train()\n", " return accuracy" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "* Epoch 1 => Train. Acc.: 84.52, Test Acc.: 83.08\n", "* Epoch 2 => Train. Acc.: 86.19, Test Acc.: 84.96\n", "* Epoch 3 => Train. Acc.: 86.57, Test Acc.: 85.15\n", "* Epoch 4 => Train. Acc.: 87.06, Test Acc.: 85.08\n", "* Epoch 5 => Train. Acc.: 87.96, Test Acc.: 86.15\n", "* Epoch 6 => Train. Acc.: 87.83, Test Acc.: 86.07\n", "* Epoch 7 => Train. Acc.: 88.56, Test Acc.: 86.62\n", "* Epoch 8 => Train. Acc.: 89.33, Test Acc.: 86.69\n", "* Epoch 9 => Train. Acc.: 88.94, Test Acc.: 86.45\n", "* Epoch 10 => Train. Acc.: 89.30, Test Acc.: 87.03\n" ] } ], "source": [ "# Train network\n", "for epoch in range(num_epochs):\n", " for batch_idx, (data, targets) in enumerate(train_loader):\n", " # Get data to cuda if possible\n", " data = data.to(device)\n", " targets = targets.to(device)\n", " \n", " # Get to correct shape\n", " data = data.reshape(data.shape[0], -1)\n", " \n", " # Forward\n", " scores = model(data)\n", " loss = criterion(scores, targets)\n", " \n", " # Backward\n", " optimizer.zero_grad()\n", " loss.backward()\n", " \n", " # Gradient descent\n", " optimizer.step()\n", " \n", " train_acc = check_accuracy(train_loader, model)\n", " test_acc = check_accuracy(test_loader, model)\n", " \n", " print(f'* Epoch {epoch + 1:<3} => Train. Acc.: {train_acc:.2f},'\n", " f' Test Acc.: {test_acc:.2f}')\n", "\n", "torch.save(model.state_dict(), f'models/{model_name}_{dataset_name}.pt')" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def n_params(model):\n", " n = 0\n", " for p in model.parameters():\n", " n += p.numel()\n", " return n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(87.03, 39760)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n = n_params(model)\n", "test_acc = check_accuracy(test_loader, model)\n", "test_acc, n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define tensorized layer" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Load network\n", "model = FFFC(input_size=input_size, num_classes=num_classes)\n", "model.load_state_dict(torch.load(f'models/{model_name}_{dataset_name}.pt'))\n", "model = model.to(device)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "class TN_Linear(tk.models.MPO):\n", " \n", " def __init__(self, model, cum_percentage):\n", " \n", " # Get weight matrix from model and reshape it\n", " weight = model.fc1.weight.detach()\n", " weight = weight.reshape(1, 1, 5, 5, 2, 1,\n", " 2, 2, 7, 7, 2, 2).permute(6, 0, 7, 1, 8, 2,\n", " 9, 3, 10, 4, 11, 5)\n", " self.weight = weight\n", " \n", " mpo_tensors = tk.decompositions.mat_to_mpo(weight,\n", " cum_percentage=cum_percentage,\n", " renormalize=True)\n", " super().__init__(tensors=mpo_tensors)\n", " \n", " # Save bias as parameter of tn layer\n", " self.bias = nn.Parameter(model.fc1.bias.detach())\n", " \n", " def set_data_nodes(self):\n", " self.mps_data = tk.models.MPSData(n_features=6,\n", " phys_dim=[2, 2, 7, 7, 2, 2],\n", " bond_dim=10,\n", " boundary='obc')\n", " \n", " def add_data(self, data):\n", " mps_tensors = tk.decompositions.vec_to_mps(data.reshape(-1, 2, 2, 7, 7, 2, 2),\n", " n_batches=1,\n", " cum_percentage=0.95,\n", " renormalize=True)\n", " self.mps_data.add_data(mps_tensors)\n", " \n", " def contract(self):\n", " return super().contract(inline_input=True,\n", " inline_mats=True,\n", " mps=self.mps_data)\n", " \n", " def forward(self, x, *args, **kwargs):\n", " result = super().forward(x, *args, **kwargs)\n", " result = result.reshape(-1, 50)\n", " result += self.bias\n", " return result" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "class TN_NN(nn.Module):\n", " def __init__(self, model, cum_percentage):\n", " super().__init__() # super(NN, self).__init__()\n", " self.tn1 = TN_Linear(model, cum_percentage)\n", " self.fc2 = model.fc2\n", " self.relu = nn.ReLU()\n", " \n", " def forward(self, x):\n", " x = self.relu(self.tn1(x))\n", " x = self.fc2(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "model_name = 'tn_fffc'\n", "tn_model = TN_NN(model, 0.85)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(86.61999999999999, 27944)" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n = n_params(tn_model)\n", "test_acc = check_accuracy(test_loader, tn_model)\n", "test_acc, n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# Trace the model to accelerate training\n", "tn_model.tn1.trace(torch.zeros(1, input_size, device=device))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Hyperparameters\n", "learning_rate = 1e-4\n", "weight_decay = 1e-5\n", "num_epochs = 5\n", "\n", "# Loss and optimizer\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(tn_model.parameters(),\n", " lr=learning_rate,\n", " weight_decay=weight_decay)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "* Epoch 1 => Train. Acc.: 90.04, Test Acc.: 87.95\n", "* Epoch 2 => Train. Acc.: 90.03, Test Acc.: 87.73\n", "* Epoch 3 => Train. Acc.: 90.35, Test Acc.: 87.70\n", "* Epoch 4 => Train. Acc.: 90.64, Test Acc.: 87.90\n", "* Epoch 5 => Train. Acc.: 90.73, Test Acc.: 88.03\n" ] } ], "source": [ "# Train network\n", "for epoch in range(num_epochs):\n", " for batch_idx, (data, targets) in enumerate(train_loader):\n", " # Get data to cuda if possible\n", " data = data.to(device)\n", " targets = targets.to(device)\n", " \n", " # Get to correct shape\n", " data = data.reshape(data.shape[0], -1)\n", " \n", " # Forward\n", " scores = tn_model(data)\n", " loss = criterion(scores, targets)\n", " \n", " # Backward\n", " optimizer.zero_grad()\n", " loss.backward()\n", " \n", " # Gradient descent\n", " optimizer.step()\n", " \n", " train_acc = check_accuracy(train_loader, tn_model)\n", " test_acc = check_accuracy(test_loader, tn_model)\n", " \n", " print(f'* Epoch {epoch + 1:<3} => Train. Acc.: {train_acc:.2f},'\n", " f' Test Acc.: {test_acc:.2f}')\n", "\n", "# Reset before saving the model\n", "tn_model.tn1.reset()\n", "torch.save(tn_model.state_dict(), f'models/{model_name}_{dataset_name}.pt')" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "81.57 16406\n" ] } ], "source": [ "tn_model.tn1.canonicalize(cum_percentage=0.8)\n", "\n", "test_acc = check_accuracy(test_loader, tn_model)\n", "print(test_acc, n_params(tn_model))" ] } ], "metadata": { "kernelspec": { "display_name": "test_tk", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }